Dynavera/apps/onboarding/mcp.py
2026-03-21 00:03:55 +00:00

247 lines
9 KiB
Python

import logging
import random
import httpx
from channels.db import database_sync_to_async
from django.conf import settings
from django.db.models import Q
from pgvector.django import CosineDistance
from apps.accounts.models import Role
from apps.knowledge.models import RoleRagDocument, TrainingFile
from apps.onboarding.models import OnboardingSession
logger = logging.getLogger(__name__)
_MCP_TOOL_META = 'mcp_tool_meta'
def mcp_tool(name, description, input_schema):
def decorator(func):
setattr(func, _MCP_TOOL_META, {
'name': name,
'description': description,
'inputSchema': input_schema,
})
return func
return decorator
def _collect_tools(class_namespace):
tools = []
for method_name, value in class_namespace.items():
metadata = getattr(value, _MCP_TOOL_META, None)
if not metadata:
continue
tools.append({
'name': metadata['name'],
'method': method_name,
'description': metadata['description'],
'inputSchema': metadata['inputSchema'],
})
return tools
class MCPRouter:
def get_tool_definitions(self):
return self.tools
async def handle_tool_call(self, name, arguments):
logger.info('MCP tool call received: tool=%s args=%s', name, arguments)
arguments = arguments or {}
method_name = self._tool_name_to_method.get(name)
if method_name:
method = getattr(self, method_name, None)
if method:
try:
result = await method(arguments)
logger.info('MCP tool call completed: tool=%s result=%s', name, result)
return result
except Exception as exc:
logger.exception('MCP tool call failed: tool=%s error=%s', name, exc)
return {'error': f'Tool {name} failed: {exc}'}
logger.warning('MCP tool call rejected: unknown tool=%s', name)
return {'error': f'Tool {name} not found'}
async def _get_embedding(self, text):
logger.info('MCP embedding request started')
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
json={'input': text},
)
response.raise_for_status()
embedding = response.json()['data'][0]['embedding']
logger.info('MCP embedding request completed')
return embedding
@mcp_tool(
name='search_knowledge',
description='Search the RAG database for role-specific training content.',
input_schema={
'type': 'object',
'properties': {
'query': {'type': 'string'},
'role_uuid': {'type': 'string'},
},
'required': ['query', 'role_uuid'],
},
)
async def _search_knowledge(self, args):
query = args.get('query')
role_uuid = args.get('role_uuid')
if not query or not role_uuid:
logger.warning('MCP search_knowledge missing query or role_uuid')
return []
query_vector = await self._get_embedding(query)
return await self._search_knowledge_documents(role_uuid, query_vector)
@database_sync_to_async
def _search_knowledge_documents(self, role_uuid, query_vector):
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
if role is None:
logger.warning('MCP search_knowledge_documents role not found: role_uuid=%s', role_uuid)
return []
docs = RoleRagDocument.objects.filter(
organization=role.organization,
embedding__isnull=False,
is_active=True,
).filter(
Q(role__uuid=role_uuid) | Q(role__isnull=True),
).annotate(
distance=CosineDistance('embedding', query_vector)
).order_by('distance')[:5]
results = [
{
'content': d.content,
'source': d.metadata.get('file_name') or d.metadata.get('source', 'Unknown Source'),
'relevance': round(1 - d.distance, 4),
}
for d in docs
]
logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
return results
@mcp_tool(
name='update_progress',
description="Update the user's score or current module in their session.",
input_schema={
'type': 'object',
'properties': {
'session_uuid': {'type': 'string'},
'score': {'type': 'integer'},
'completed_module': {'type': 'string'},
},
'required': ['session_uuid'],
},
)
@database_sync_to_async
def _update_progress(self, args):
session_uuid = args.get('session_uuid')
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
if session is None:
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
return {'error': 'Session not found'}
state = session.state or {}
if 'score' in args:
state['last_score'] = args['score']
if 'completed_module' in args:
state.setdefault('completed_modules', []).append(args['completed_module'])
session.state = state
session.save()
logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
return {'status': 'success', 'new_state': state}
@mcp_tool(
name='get_role_context',
description='Get the name, description, and organization for a role. Call this first to understand what the role involves before generating content.',
input_schema={
'type': 'object',
'properties': {
'role_uuid': {'type': 'string', 'description': 'The UUID of the role'},
},
'required': ['role_uuid'],
},
)
@database_sync_to_async
def _get_role_context(self, args):
role_uuid = args.get('role_uuid')
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
if role is None:
logger.warning('MCP get_role_context role not found: role_uuid=%s', role_uuid)
return {'error': 'Role not found'}
logger.info('MCP get_role_context completed: role_uuid=%s', role_uuid)
return {
'name': role.name,
'description': role.description or '',
'organization': role.organization.name,
'member_count': role.members.count(),
}
@mcp_tool(
name='list_training_files',
description='List processed training files available for a role. Use this to understand what source materials exist before generating curriculum or content.',
input_schema={
'type': 'object',
'properties': {
'role_uuid': {'type': 'string', 'description': 'The UUID of the role'},
},
'required': ['role_uuid'],
},
)
@database_sync_to_async
def _list_training_files(self, args):
role_uuid = args.get('role_uuid')
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
if role is None:
logger.warning('MCP list_training_files role not found: role_uuid=%s', role_uuid)
return {'error': 'Role not found'}
files = list(
TrainingFile.objects.filter(
organization=role.organization,
is_processed=True,
).filter(
Q(role__uuid=role_uuid) | Q(role__isnull=True)
).values('file_name', 'description', 'file_type')[:20]
)
logger.info('MCP list_training_files completed: role_uuid=%s count=%s', role_uuid, len(files))
return {'files': files, 'count': len(files)}
@mcp_tool(
name='random_int',
description='Generate a random integer in an inclusive range.',
input_schema={
'type': 'object',
'properties': {
'min': {'type': 'integer'},
'max': {'type': 'integer'},
},
'required': ['min', 'max'],
},
)
async def _random_int(self, args):
try:
min_value = int(args.get('min'))
max_value = int(args.get('max'))
except (TypeError, ValueError):
logger.warning('MCP random_int invalid args: %s', args)
return {'error': 'min and max must be integers'}
if min_value > max_value:
min_value, max_value = max_value, min_value
value = random.randint(min_value, max_value)
logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
return {'value': value, 'min': min_value, 'max': max_value}
tools = _collect_tools(locals())
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
mcp_router = MCPRouter()