Dynavera/apps/mlstore/tests/test_api.py

91 lines
4.3 KiB
Python

from unittest.mock import patch
from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
from apps.orgs.models import Organization, Role
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
User = get_user_model()
class MLStoreAPITests(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
self.other = User.objects.create_user(email_address='other@example.com', password='pass')
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
self.org = Organization.objects.create(name='Org', owner=self.manager)
self.role = Role.objects.create(name='Engineer', organization=self.org)
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
self.agent = Agent.objects.create(model=self.model, organization=self.org)
def test_agents_list_requires_auth(self):
view = AgentViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
response = view(request)
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
def test_agents_list_authenticated(self):
view = AgentViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
def test_agent_runs_scoped_to_user(self):
AgentRun.objects.create(agent=self.agent, user=self.user)
AgentRun.objects.create(agent=self.agent, user=self.other)
view = AgentRunViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(len(response.data), 1)
def test_agent_run_events(self):
run = AgentRun.objects.create(agent=self.agent, user=self.user)
AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
view = AgentRunViewSet.as_view({'get': 'events'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request, uuid=str(run.uuid))
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(len(response.data), 1)
def test_retrieve_context_missing_params(self):
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
request = self.factory.post('/', {})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
def test_retrieve_context_role_not_found(self):
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
request = self.factory.post('/', {'query': 'q', 'role_uuid': '00000000-0000-0000-0000-000000000000'})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_404_NOT_FOUND)
@patch('apps.mlstore.viewsets.services.search_similar_documents')
@patch('apps.mlstore.viewsets.services.get_context_for_query')
def test_retrieve_context_success(self, mock_context, mock_search):
doc = RoleRagDocument.objects.create(
role=self.role,
content='chunk',
content_hash='hash',
chunk_index=0,
)
mock_search.return_value = [(doc, 0.9)]
mock_context.return_value = 'context text'
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
payload = {'query': 'hello', 'role_uuid': str(self.role.uuid)}
request = self.factory.post('/', payload, format='json')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(response.data.get('num_results'), 1)
self.assertEqual(response.data.get('context'), 'context text')