From 20ac7f471c98b73ab036b524fe4599a64c6d314d Mon Sep 17 00:00:00 2001 From: Viswamedha Nalabotu Date: Wed, 18 Mar 2026 01:04:16 +0000 Subject: [PATCH] Added mcp tweaks and fixed failing tests with profanities --- apps/onboarding/mcp.py | 74 ++++++++------------- apps/onboarding/tests/test_consumers.py | 86 ++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 49 deletions(-) diff --git a/apps/onboarding/mcp.py b/apps/onboarding/mcp.py index 37a3f94..2a7ef00 100644 --- a/apps/onboarding/mcp.py +++ b/apps/onboarding/mcp.py @@ -1,7 +1,7 @@ -import httpx import logging import random +import httpx from channels.db import database_sync_to_async from django.conf import settings from django.db.models import Q @@ -12,40 +12,35 @@ from apps.knowledge.models import RoleRagDocument from apps.onboarding.models import OnboardingSession logger = logging.getLogger(__name__) -mcp_meta_value = 'mcp_tool_meta' + +_MCP_TOOL_META = 'mcp_tool_meta' def mcp_tool(name, description, input_schema): def decorator(func): - setattr(func, mcp_meta_value, { + setattr(func, _MCP_TOOL_META, { 'name': name, 'description': description, 'inputSchema': input_schema, }) return func - return decorator - def _collect_tools(class_namespace): tools = [] for method_name, value in class_namespace.items(): - metadata = getattr(value, mcp_meta_value, None) + metadata = getattr(value, _MCP_TOOL_META, None) if not metadata: continue - - tools.append( - { - 'name': metadata['name'], - 'method': method_name, - 'description': metadata['description'], - 'inputSchema': metadata['inputSchema'], - } - ) - + tools.append({ + 'name': metadata['name'], + 'method': method_name, + 'description': metadata['description'], + 'inputSchema': metadata['inputSchema'], + }) return tools - class MCPRouter: + def get_tool_definitions(self): return self.tools @@ -58,11 +53,7 @@ class MCPRouter: method = getattr(self, method_name, None) if method: result = await method(arguments) - logger.info( - 'MCP tool call completed: tool=%s result=%s', - name, - result, - ) + logger.info('MCP tool call completed: tool=%s result=%s', name, result) return result logger.warning('MCP tool call rejected: unknown tool=%s', name) @@ -73,9 +64,7 @@ class MCPRouter: async with httpx.AsyncClient() as client: response = await client.post( settings.INFERENCE_EMBEDDINGS_ENDPOINT, - json={ - 'input': text, - }, + json={'input': text}, ) response.raise_for_status() embedding = response.json()['data'][0]['embedding'] @@ -130,11 +119,7 @@ class MCPRouter: } for d in docs ] - logger.info( - 'MCP search_knowledge_documents completed: role_uuid=%s results=%s', - role_uuid, - len(results), - ) + logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results)) return results @mcp_tool( @@ -152,7 +137,11 @@ class MCPRouter: ) @database_sync_to_async def _update_progress(self, args): - session = OnboardingSession.objects.get(uuid=args.get('session_uuid')) + session_uuid = args.get('session_uuid') + session = OnboardingSession.objects.filter(uuid=session_uuid).first() + if session is None: + logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid) + return {'error': 'Session not found'} state = session.state or {} if 'score' in args: @@ -162,10 +151,7 @@ class MCPRouter: session.state = state session.save() - logger.info( - 'MCP update_progress completed: session_uuid=%s', - args.get('session_uuid'), - ) + logger.info('MCP update_progress completed: session_uuid=%s', session_uuid) return {'status': 'success', 'new_state': state} @mcp_tool( @@ -181,12 +167,10 @@ class MCPRouter: }, ) async def _random_int(self, args): - min_value = args.get('min') - max_value = args.get('max') try: - min_value = int(min_value) - max_value = int(max_value) - except Exception: + min_value = int(args.get('min')) + max_value = int(args.get('max')) + except (TypeError, ValueError): logger.warning('MCP random_int invalid args: %s', args) return {'error': 'min and max must be integers'} @@ -194,15 +178,11 @@ class MCPRouter: min_value, max_value = max_value, min_value value = random.randint(min_value, max_value) - logger.info( - 'MCP random_int generated value=%s range=[%s,%s]', - value, - min_value, - max_value, - ) + logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value) return {'value': value, 'min': min_value, 'max': max_value} tools = _collect_tools(locals()) _tool_name_to_method = {tool['name']: tool['method'] for tool in tools} -mcp_router = MCPRouter() \ No newline at end of file + +mcp_router = MCPRouter() diff --git a/apps/onboarding/tests/test_consumers.py b/apps/onboarding/tests/test_consumers.py index 60c3ef0..19d7792 100644 --- a/apps/onboarding/tests/test_consumers.py +++ b/apps/onboarding/tests/test_consumers.py @@ -1,13 +1,95 @@ from asgiref.sync import async_to_sync from django.contrib.auth import get_user_model from django.test import TestCase +from unittest.mock import AsyncMock, patch from apps.accounts.models import Organization, Role -from apps.onboarding.consumers import OnboardingConsumer +from apps.onboarding.consumers import OnboardingGenerateConsumer from apps.onboarding.models import AgentConfig +from apps.onboarding.utils.content_moderator import ContentModerator User = get_user_model() +_PROFANE = "fuck" # Common profanity... + +class ContentModeratorTests(TestCase): + def setUp(self): + self.moderator = ContentModerator() + + def test_clean_text_passes(self): + self.assertTrue(self.moderator.is_clean("What is the onboarding process?")) + + def test_profane_text_blocked(self): + self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content")) + + def test_non_string_input_passes(self): + self.assertTrue(self.moderator.is_clean(None)) + self.assertTrue(self.moderator.is_clean(123)) + + def test_censor_replaces_profanity(self): + result = self.moderator.censor(f"this is {_PROFANE} content") + self.assertNotIn(_PROFANE, result) + + def test_censor_passes_clean_text_unchanged(self): + text = "Please review the onboarding materials." + self.assertEqual(self.moderator.censor(text), text) + + def test_censor_non_string_returned_as_is(self): + self.assertIsNone(self.moderator.censor(None)) + self.assertEqual(self.moderator.censor(42), 42) + + +class ConsumerModerationTests(TestCase): + def setUp(self): + self.user = User.objects.create_user( + email_address='moderation-test@example.com', + password='pass1234', + first_name='Mod', + last_name='Tester', + date_of_birth='1995-05-05', + is_manager=True, + ) + self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user) + self.org.members.add(self.user) + self.role = Role.objects.create(name='Mod Role', organization=self.org) + self.consumer = OnboardingGenerateConsumer() + self.consumer.user = self.user + + def _run_receive(self, payload: str): + return async_to_sync(self.consumer.receive)(payload) + + def test_clean_query_is_dispatched(self): + import json + self.consumer.send_error = AsyncMock() + with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action: + self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"})) + mock_action.assert_called_once() + self.consumer.send_error.assert_not_called() + + def test_profane_query_is_blocked(self): + import json + self.consumer.send_error = AsyncMock() + self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"})) + self.consumer.send_error.assert_called_once() + args = self.consumer.send_error.call_args[0] + self.assertIn("moderation", args[0].lower()) + + def test_profane_message_field_is_blocked(self): + import json + self.consumer.send_error = AsyncMock() + self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"})) + self.consumer.send_error.assert_called_once() + args = self.consumer.send_error.call_args[0] + self.assertIn("moderation", args[0].lower()) + + def test_clean_message_field_is_dispatched(self): + import json + self.consumer.send_error = AsyncMock() + with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action: + self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"})) + mock_action.assert_called_once() + self.consumer.send_error.assert_not_called() + class OnboardingConsumerConfigSelectionTests(TestCase): def setUp(self): self.user = User.objects.create_user( @@ -24,7 +106,7 @@ class OnboardingConsumerConfigSelectionTests(TestCase): self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org) self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org) - self.consumer = OnboardingConsumer() + self.consumer = OnboardingGenerateConsumer() def test_get_config_by_type_prefers_exact_role(self): quant_cfg = AgentConfig.objects.create(