247 lines
9.1 KiB
Python
247 lines
9.1 KiB
Python
import logging
|
|
import random
|
|
|
|
import httpx
|
|
from channels.db import database_sync_to_async
|
|
from django.conf import settings
|
|
from django.db.models import Q
|
|
from pgvector.django import CosineDistance
|
|
|
|
from apps.accounts.models import Role
|
|
from apps.knowledge.models import KnowledgeChunk, TrainingFile
|
|
from apps.onboarding.models import OnboardingSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MCP_TOOL_META = 'mcp_tool_meta'
|
|
|
|
def mcp_tool(name, description, input_schema):
|
|
def decorator(func):
|
|
setattr(func, _MCP_TOOL_META, {
|
|
'name': name,
|
|
'description': description,
|
|
'inputSchema': input_schema,
|
|
})
|
|
return func
|
|
return decorator
|
|
|
|
def _collect_tools(class_namespace):
|
|
tools = []
|
|
for method_name, value in class_namespace.items():
|
|
metadata = getattr(value, _MCP_TOOL_META, None)
|
|
if not metadata:
|
|
continue
|
|
tools.append({
|
|
'name': metadata['name'],
|
|
'method': method_name,
|
|
'description': metadata['description'],
|
|
'inputSchema': metadata['inputSchema'],
|
|
})
|
|
return tools
|
|
|
|
class MCPRouter:
|
|
|
|
def get_tool_definitions(self):
|
|
return self.tools
|
|
|
|
async def handle_tool_call(self, name, arguments):
|
|
logger.info('MCP tool call received: tool=%s args=%s', name, arguments)
|
|
arguments = arguments or {}
|
|
|
|
method_name = self._tool_name_to_method.get(name)
|
|
if method_name:
|
|
method = getattr(self, method_name, None)
|
|
if method:
|
|
try:
|
|
result = await method(arguments)
|
|
logger.info('MCP tool call completed: tool=%s result=%s', name, result)
|
|
return result
|
|
except Exception as exc:
|
|
logger.exception('MCP tool call failed: tool=%s error=%s', name, exc)
|
|
return {'error': f'Tool {name} failed: {exc}'}
|
|
|
|
logger.warning('MCP tool call rejected: unknown tool=%s', name)
|
|
return {'error': f'Tool {name} not found'}
|
|
|
|
async def _get_embedding(self, text):
|
|
logger.info('MCP embedding request started')
|
|
async with httpx.AsyncClient(timeout=settings.INFERENCE_REQUEST_TIMEOUT, auth=settings.INFERENCE_AUTH) as client:
|
|
response = await client.post(
|
|
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
|
|
json={'input': text},
|
|
)
|
|
response.raise_for_status()
|
|
embedding = response.json()['data'][0]['embedding']
|
|
logger.info('MCP embedding request completed')
|
|
return embedding
|
|
|
|
@mcp_tool(
|
|
name='search_knowledge',
|
|
description='Search the RAG database for role-specific training content.',
|
|
input_schema={
|
|
'type': 'object',
|
|
'properties': {
|
|
'query': {'type': 'string'},
|
|
'role_uuid': {'type': 'string'},
|
|
},
|
|
'required': ['query', 'role_uuid'],
|
|
},
|
|
)
|
|
async def _search_knowledge(self, args):
|
|
query = args.get('query')
|
|
role_uuid = args.get('role_uuid')
|
|
|
|
if not query or not role_uuid:
|
|
logger.warning('MCP search_knowledge missing query or role_uuid')
|
|
return []
|
|
|
|
query_vector = await self._get_embedding(query)
|
|
return await self._search_knowledge_documents(role_uuid, query_vector)
|
|
|
|
@database_sync_to_async
|
|
def _search_knowledge_documents(self, role_uuid, query_vector):
|
|
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
|
|
if role is None:
|
|
logger.warning('MCP search_knowledge_documents role not found: role_uuid=%s', role_uuid)
|
|
return []
|
|
|
|
docs = KnowledgeChunk.objects.filter(
|
|
organization=role.organization,
|
|
embedding__isnull=False,
|
|
is_active=True,
|
|
).filter(
|
|
Q(role__uuid=role_uuid) | Q(role__isnull=True),
|
|
).annotate(
|
|
distance=CosineDistance('embedding', query_vector)
|
|
).order_by('distance')[:5]
|
|
|
|
results = [
|
|
{
|
|
'content': d.content,
|
|
'source': d.metadata.get('file_name') or d.metadata.get('source', 'Unknown Source'),
|
|
'relevance': round(1 - d.distance, 4),
|
|
}
|
|
for d in docs
|
|
]
|
|
logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
|
|
return results
|
|
|
|
@mcp_tool(
|
|
name='update_progress',
|
|
description="Update the user's score or current module in their session.",
|
|
input_schema={
|
|
'type': 'object',
|
|
'properties': {
|
|
'session_uuid': {'type': 'string'},
|
|
'score': {'type': 'integer'},
|
|
'completed_module': {'type': 'string'},
|
|
},
|
|
'required': ['session_uuid'],
|
|
},
|
|
)
|
|
@database_sync_to_async
|
|
def _update_progress(self, args):
|
|
session_uuid = args.get('session_uuid')
|
|
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
|
|
if session is None:
|
|
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
|
|
return {'error': 'Session not found'}
|
|
|
|
state = session.state or {}
|
|
if 'score' in args:
|
|
state['last_score'] = args['score']
|
|
if 'completed_module' in args:
|
|
state.setdefault('completed_modules', []).append(args['completed_module'])
|
|
|
|
session.state = state
|
|
session.save()
|
|
logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
|
|
return {'status': 'success', 'new_state': state}
|
|
|
|
@mcp_tool(
|
|
name='get_role_context',
|
|
description='Get the name, description, and organization for a role. Call this first to understand what the role involves before generating content.',
|
|
input_schema={
|
|
'type': 'object',
|
|
'properties': {
|
|
'role_uuid': {'type': 'string', 'description': 'The UUID of the role'},
|
|
},
|
|
'required': ['role_uuid'],
|
|
},
|
|
)
|
|
@database_sync_to_async
|
|
def _get_role_context(self, args):
|
|
role_uuid = args.get('role_uuid')
|
|
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
|
|
if role is None:
|
|
logger.warning('MCP get_role_context role not found: role_uuid=%s', role_uuid)
|
|
return {'error': 'Role not found'}
|
|
logger.info('MCP get_role_context completed: role_uuid=%s', role_uuid)
|
|
return {
|
|
'name': role.name,
|
|
'description': role.description or '',
|
|
'organization': role.organization.name,
|
|
'member_count': role.members.count(),
|
|
}
|
|
|
|
@mcp_tool(
|
|
name='list_training_files',
|
|
description='List processed training files available for a role. Use this to understand what source materials exist before generating curriculum or content.',
|
|
input_schema={
|
|
'type': 'object',
|
|
'properties': {
|
|
'role_uuid': {'type': 'string', 'description': 'The UUID of the role'},
|
|
},
|
|
'required': ['role_uuid'],
|
|
},
|
|
)
|
|
@database_sync_to_async
|
|
def _list_training_files(self, args):
|
|
role_uuid = args.get('role_uuid')
|
|
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
|
|
if role is None:
|
|
logger.warning('MCP list_training_files role not found: role_uuid=%s', role_uuid)
|
|
return {'error': 'Role not found'}
|
|
files = list(
|
|
TrainingFile.objects.filter(
|
|
organization=role.organization,
|
|
status='embedded',
|
|
).filter(
|
|
Q(role__uuid=role_uuid) | Q(role__isnull=True)
|
|
).values('file_name', 'description', 'file_type')[:20]
|
|
)
|
|
logger.info('MCP list_training_files completed: role_uuid=%s count=%s', role_uuid, len(files))
|
|
return {'files': files, 'count': len(files)}
|
|
|
|
@mcp_tool(
|
|
name='random_int',
|
|
description='Generate a random integer in an inclusive range.',
|
|
input_schema={
|
|
'type': 'object',
|
|
'properties': {
|
|
'min': {'type': 'integer'},
|
|
'max': {'type': 'integer'},
|
|
},
|
|
'required': ['min', 'max'],
|
|
},
|
|
)
|
|
async def _random_int(self, args):
|
|
try:
|
|
min_value = int(args.get('min'))
|
|
max_value = int(args.get('max'))
|
|
except (TypeError, ValueError):
|
|
logger.warning('MCP random_int invalid args: %s', args)
|
|
return {'error': 'min and max must be integers'}
|
|
|
|
if min_value > max_value:
|
|
min_value, max_value = max_value, min_value
|
|
|
|
value = random.randint(min_value, max_value)
|
|
logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
|
|
return {'value': value, 'min': min_value, 'max': max_value}
|
|
|
|
tools = _collect_tools(locals())
|
|
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
|
|
|
|
|
|
mcp_router = MCPRouter()
|