Added mcp tweaks and fixed failing tests with profanities
This commit is contained in:
parent
6aa98b2839
commit
20ac7f471c
2 changed files with 111 additions and 49 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import httpx
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import httpx
|
||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
|
|
@ -12,40 +12,35 @@ from apps.knowledge.models import RoleRagDocument
|
||||||
from apps.onboarding.models import OnboardingSession
|
from apps.onboarding.models import OnboardingSession
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
mcp_meta_value = 'mcp_tool_meta'
|
|
||||||
|
_MCP_TOOL_META = 'mcp_tool_meta'
|
||||||
|
|
||||||
def mcp_tool(name, description, input_schema):
|
def mcp_tool(name, description, input_schema):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
setattr(func, mcp_meta_value, {
|
setattr(func, _MCP_TOOL_META, {
|
||||||
'name': name,
|
'name': name,
|
||||||
'description': description,
|
'description': description,
|
||||||
'inputSchema': input_schema,
|
'inputSchema': input_schema,
|
||||||
})
|
})
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _collect_tools(class_namespace):
|
def _collect_tools(class_namespace):
|
||||||
tools = []
|
tools = []
|
||||||
for method_name, value in class_namespace.items():
|
for method_name, value in class_namespace.items():
|
||||||
metadata = getattr(value, mcp_meta_value, None)
|
metadata = getattr(value, _MCP_TOOL_META, None)
|
||||||
if not metadata:
|
if not metadata:
|
||||||
continue
|
continue
|
||||||
|
tools.append({
|
||||||
tools.append(
|
'name': metadata['name'],
|
||||||
{
|
'method': method_name,
|
||||||
'name': metadata['name'],
|
'description': metadata['description'],
|
||||||
'method': method_name,
|
'inputSchema': metadata['inputSchema'],
|
||||||
'description': metadata['description'],
|
})
|
||||||
'inputSchema': metadata['inputSchema'],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
class MCPRouter:
|
class MCPRouter:
|
||||||
|
|
||||||
def get_tool_definitions(self):
|
def get_tool_definitions(self):
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
|
|
@ -58,11 +53,7 @@ class MCPRouter:
|
||||||
method = getattr(self, method_name, None)
|
method = getattr(self, method_name, None)
|
||||||
if method:
|
if method:
|
||||||
result = await method(arguments)
|
result = await method(arguments)
|
||||||
logger.info(
|
logger.info('MCP tool call completed: tool=%s result=%s', name, result)
|
||||||
'MCP tool call completed: tool=%s result=%s',
|
|
||||||
name,
|
|
||||||
result,
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
logger.warning('MCP tool call rejected: unknown tool=%s', name)
|
logger.warning('MCP tool call rejected: unknown tool=%s', name)
|
||||||
|
|
@ -73,9 +64,7 @@ class MCPRouter:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
|
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
|
||||||
json={
|
json={'input': text},
|
||||||
'input': text,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
embedding = response.json()['data'][0]['embedding']
|
embedding = response.json()['data'][0]['embedding']
|
||||||
|
|
@ -130,11 +119,7 @@ class MCPRouter:
|
||||||
}
|
}
|
||||||
for d in docs
|
for d in docs
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
|
||||||
'MCP search_knowledge_documents completed: role_uuid=%s results=%s',
|
|
||||||
role_uuid,
|
|
||||||
len(results),
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@mcp_tool(
|
@mcp_tool(
|
||||||
|
|
@ -152,7 +137,11 @@ class MCPRouter:
|
||||||
)
|
)
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def _update_progress(self, args):
|
def _update_progress(self, args):
|
||||||
session = OnboardingSession.objects.get(uuid=args.get('session_uuid'))
|
session_uuid = args.get('session_uuid')
|
||||||
|
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
|
||||||
|
if session is None:
|
||||||
|
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
|
||||||
|
return {'error': 'Session not found'}
|
||||||
|
|
||||||
state = session.state or {}
|
state = session.state or {}
|
||||||
if 'score' in args:
|
if 'score' in args:
|
||||||
|
|
@ -162,10 +151,7 @@ class MCPRouter:
|
||||||
|
|
||||||
session.state = state
|
session.state = state
|
||||||
session.save()
|
session.save()
|
||||||
logger.info(
|
logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
|
||||||
'MCP update_progress completed: session_uuid=%s',
|
|
||||||
args.get('session_uuid'),
|
|
||||||
)
|
|
||||||
return {'status': 'success', 'new_state': state}
|
return {'status': 'success', 'new_state': state}
|
||||||
|
|
||||||
@mcp_tool(
|
@mcp_tool(
|
||||||
|
|
@ -181,12 +167,10 @@ class MCPRouter:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def _random_int(self, args):
|
async def _random_int(self, args):
|
||||||
min_value = args.get('min')
|
|
||||||
max_value = args.get('max')
|
|
||||||
try:
|
try:
|
||||||
min_value = int(min_value)
|
min_value = int(args.get('min'))
|
||||||
max_value = int(max_value)
|
max_value = int(args.get('max'))
|
||||||
except Exception:
|
except (TypeError, ValueError):
|
||||||
logger.warning('MCP random_int invalid args: %s', args)
|
logger.warning('MCP random_int invalid args: %s', args)
|
||||||
return {'error': 'min and max must be integers'}
|
return {'error': 'min and max must be integers'}
|
||||||
|
|
||||||
|
|
@ -194,15 +178,11 @@ class MCPRouter:
|
||||||
min_value, max_value = max_value, min_value
|
min_value, max_value = max_value, min_value
|
||||||
|
|
||||||
value = random.randint(min_value, max_value)
|
value = random.randint(min_value, max_value)
|
||||||
logger.info(
|
logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
|
||||||
'MCP random_int generated value=%s range=[%s,%s]',
|
|
||||||
value,
|
|
||||||
min_value,
|
|
||||||
max_value,
|
|
||||||
)
|
|
||||||
return {'value': value, 'min': min_value, 'max': max_value}
|
return {'value': value, 'min': min_value, 'max': max_value}
|
||||||
|
|
||||||
tools = _collect_tools(locals())
|
tools = _collect_tools(locals())
|
||||||
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
|
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
|
||||||
|
|
||||||
mcp_router = MCPRouter()
|
|
||||||
|
mcp_router = MCPRouter()
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,95 @@
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from apps.accounts.models import Organization, Role
|
from apps.accounts.models import Organization, Role
|
||||||
from apps.onboarding.consumers import OnboardingConsumer
|
from apps.onboarding.consumers import OnboardingGenerateConsumer
|
||||||
from apps.onboarding.models import AgentConfig
|
from apps.onboarding.models import AgentConfig
|
||||||
|
from apps.onboarding.utils.content_moderator import ContentModerator
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
|
|
||||||
|
_PROFANE = "fuck" # Common profanity...
|
||||||
|
|
||||||
|
class ContentModeratorTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.moderator = ContentModerator()
|
||||||
|
|
||||||
|
def test_clean_text_passes(self):
|
||||||
|
self.assertTrue(self.moderator.is_clean("What is the onboarding process?"))
|
||||||
|
|
||||||
|
def test_profane_text_blocked(self):
|
||||||
|
self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content"))
|
||||||
|
|
||||||
|
def test_non_string_input_passes(self):
|
||||||
|
self.assertTrue(self.moderator.is_clean(None))
|
||||||
|
self.assertTrue(self.moderator.is_clean(123))
|
||||||
|
|
||||||
|
def test_censor_replaces_profanity(self):
|
||||||
|
result = self.moderator.censor(f"this is {_PROFANE} content")
|
||||||
|
self.assertNotIn(_PROFANE, result)
|
||||||
|
|
||||||
|
def test_censor_passes_clean_text_unchanged(self):
|
||||||
|
text = "Please review the onboarding materials."
|
||||||
|
self.assertEqual(self.moderator.censor(text), text)
|
||||||
|
|
||||||
|
def test_censor_non_string_returned_as_is(self):
|
||||||
|
self.assertIsNone(self.moderator.censor(None))
|
||||||
|
self.assertEqual(self.moderator.censor(42), 42)
|
||||||
|
|
||||||
|
|
||||||
|
class ConsumerModerationTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
email_address='moderation-test@example.com',
|
||||||
|
password='pass1234',
|
||||||
|
first_name='Mod',
|
||||||
|
last_name='Tester',
|
||||||
|
date_of_birth='1995-05-05',
|
||||||
|
is_manager=True,
|
||||||
|
)
|
||||||
|
self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user)
|
||||||
|
self.org.members.add(self.user)
|
||||||
|
self.role = Role.objects.create(name='Mod Role', organization=self.org)
|
||||||
|
self.consumer = OnboardingGenerateConsumer()
|
||||||
|
self.consumer.user = self.user
|
||||||
|
|
||||||
|
def _run_receive(self, payload: str):
|
||||||
|
return async_to_sync(self.consumer.receive)(payload)
|
||||||
|
|
||||||
|
def test_clean_query_is_dispatched(self):
|
||||||
|
import json
|
||||||
|
self.consumer.send_error = AsyncMock()
|
||||||
|
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
|
||||||
|
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"}))
|
||||||
|
mock_action.assert_called_once()
|
||||||
|
self.consumer.send_error.assert_not_called()
|
||||||
|
|
||||||
|
def test_profane_query_is_blocked(self):
|
||||||
|
import json
|
||||||
|
self.consumer.send_error = AsyncMock()
|
||||||
|
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"}))
|
||||||
|
self.consumer.send_error.assert_called_once()
|
||||||
|
args = self.consumer.send_error.call_args[0]
|
||||||
|
self.assertIn("moderation", args[0].lower())
|
||||||
|
|
||||||
|
def test_profane_message_field_is_blocked(self):
|
||||||
|
import json
|
||||||
|
self.consumer.send_error = AsyncMock()
|
||||||
|
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"}))
|
||||||
|
self.consumer.send_error.assert_called_once()
|
||||||
|
args = self.consumer.send_error.call_args[0]
|
||||||
|
self.assertIn("moderation", args[0].lower())
|
||||||
|
|
||||||
|
def test_clean_message_field_is_dispatched(self):
|
||||||
|
import json
|
||||||
|
self.consumer.send_error = AsyncMock()
|
||||||
|
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
|
||||||
|
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"}))
|
||||||
|
mock_action.assert_called_once()
|
||||||
|
self.consumer.send_error.assert_not_called()
|
||||||
|
|
||||||
class OnboardingConsumerConfigSelectionTests(TestCase):
|
class OnboardingConsumerConfigSelectionTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.user = User.objects.create_user(
|
self.user = User.objects.create_user(
|
||||||
|
|
@ -24,7 +106,7 @@ class OnboardingConsumerConfigSelectionTests(TestCase):
|
||||||
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
|
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
|
||||||
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
|
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
|
||||||
|
|
||||||
self.consumer = OnboardingConsumer()
|
self.consumer = OnboardingGenerateConsumer()
|
||||||
|
|
||||||
def test_get_config_by_type_prefers_exact_role(self):
|
def test_get_config_by_type_prefers_exact_role(self):
|
||||||
quant_cfg = AgentConfig.objects.create(
|
quant_cfg = AgentConfig.objects.create(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue