92 lines
4.3 KiB
Python
92 lines
4.3 KiB
Python
|
|
from unittest.mock import patch
|
||
|
|
from django.contrib.auth import get_user_model
|
||
|
|
from django.test import TestCase
|
||
|
|
from rest_framework.test import APIRequestFactory, force_authenticate
|
||
|
|
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||
|
|
|
||
|
|
from apps.orgs.models import Organization, Role
|
||
|
|
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
|
||
|
|
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
|
||
|
|
|
||
|
|
User = get_user_model()
|
||
|
|
|
||
|
|
|
||
|
|
class MLStoreAPITests(TestCase):
|
||
|
|
def setUp(self):
|
||
|
|
self.factory = APIRequestFactory()
|
||
|
|
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
|
||
|
|
self.other = User.objects.create_user(email_address='other@example.com', password='pass')
|
||
|
|
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
|
||
|
|
self.org = Organization.objects.create(name='Org', owner=self.manager)
|
||
|
|
self.role = Role.objects.create(name='Engineer', organization=self.org)
|
||
|
|
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
|
||
|
|
self.agent = Agent.objects.create(model=self.model, organization=self.org)
|
||
|
|
|
||
|
|
def test_agents_list_requires_auth(self):
|
||
|
|
view = AgentViewSet.as_view({'get': 'list'})
|
||
|
|
request = self.factory.get('/')
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
|
||
|
|
|
||
|
|
def test_agents_list_authenticated(self):
|
||
|
|
view = AgentViewSet.as_view({'get': 'list'})
|
||
|
|
request = self.factory.get('/')
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||
|
|
|
||
|
|
def test_agent_runs_scoped_to_user(self):
|
||
|
|
AgentRun.objects.create(agent=self.agent, user=self.user)
|
||
|
|
AgentRun.objects.create(agent=self.agent, user=self.other)
|
||
|
|
view = AgentRunViewSet.as_view({'get': 'list'})
|
||
|
|
request = self.factory.get('/')
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||
|
|
self.assertEqual(len(response.data), 1)
|
||
|
|
|
||
|
|
def test_agent_run_events(self):
|
||
|
|
run = AgentRun.objects.create(agent=self.agent, user=self.user)
|
||
|
|
AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
|
||
|
|
view = AgentRunViewSet.as_view({'get': 'events'})
|
||
|
|
request = self.factory.get('/')
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request, uuid=str(run.uuid))
|
||
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||
|
|
self.assertEqual(len(response.data), 1)
|
||
|
|
|
||
|
|
def test_retrieve_context_missing_params(self):
|
||
|
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||
|
|
request = self.factory.post('/', {})
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
||
|
|
|
||
|
|
def test_retrieve_context_role_not_found(self):
|
||
|
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||
|
|
request = self.factory.post('/', {'query': 'q', 'role_uuid': '00000000-0000-0000-0000-000000000000'})
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_404_NOT_FOUND)
|
||
|
|
|
||
|
|
@patch('apps.mlstore.viewsets.services.search_similar_documents')
|
||
|
|
@patch('apps.mlstore.viewsets.services.get_context_for_query')
|
||
|
|
def test_retrieve_context_success(self, mock_context, mock_search):
|
||
|
|
doc = RoleRagDocument.objects.create(
|
||
|
|
role=self.role,
|
||
|
|
content='chunk',
|
||
|
|
content_hash='hash',
|
||
|
|
chunk_index=0,
|
||
|
|
)
|
||
|
|
mock_search.return_value = [(doc, 0.9)]
|
||
|
|
mock_context.return_value = 'context text'
|
||
|
|
|
||
|
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||
|
|
payload = {'query': 'hello', 'role_uuid': str(self.role.uuid)}
|
||
|
|
request = self.factory.post('/', payload, format='json')
|
||
|
|
force_authenticate(request, user=self.user)
|
||
|
|
response = view(request)
|
||
|
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||
|
|
self.assertEqual(response.data.get('num_results'), 1)
|
||
|
|
self.assertEqual(response.data.get('context'), 'context text')
|