102 lines
No EOL
3.6 KiB
Python
102 lines
No EOL
3.6 KiB
Python
import httpx
|
|
from channels.db import database_sync_to_async
|
|
from django.conf import settings
|
|
from pgvector.django import CosineDistance
|
|
|
|
from apps.knowledge.models import RoleRagDocument
|
|
from apps.onboarding.models import OnboardingSession
|
|
|
|
|
|
class MCPRouter:
|
|
|
|
def get_tool_definitions(self):
|
|
return [
|
|
{
|
|
"name": "search_knowledge",
|
|
"description": "Search the RAG database for role-specific training content.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"role_uuid": {"type": "string"}
|
|
},
|
|
"required": ["query", "role_uuid"]
|
|
}
|
|
},
|
|
{
|
|
"name": "update_progress",
|
|
"description": "Update the user's score or current module in their session.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"session_uuid": {"type": "string"},
|
|
"score": {"type": "integer"},
|
|
"completed_module": {"type": "string"}
|
|
},
|
|
"required": ["session_uuid"]
|
|
}
|
|
}
|
|
]
|
|
|
|
async def handle_tool_call(self, name, arguments):
|
|
if name == "search_knowledge":
|
|
return await self._search_knowledge(arguments)
|
|
elif name == "update_progress":
|
|
return await self._update_progress(arguments)
|
|
return {"error": f"Tool {name} not found"}
|
|
|
|
async def _get_embedding(self, text):
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
|
|
json={"input": text}
|
|
)
|
|
|
|
return response.json()["data"][0]["embedding"]
|
|
|
|
async def _search_knowledge(self, args):
|
|
query = args.get("query")
|
|
role_uuid = args.get("role_uuid")
|
|
|
|
if not query or not role_uuid:
|
|
return []
|
|
|
|
|
|
query_vector = await self._get_embedding(query)
|
|
|
|
return await self._search_knowledge_documents(role_uuid, query_vector)
|
|
|
|
@database_sync_to_async
|
|
def _search_knowledge_documents(self, role_uuid, query_vector):
|
|
|
|
|
|
docs = RoleRagDocument.objects.filter(
|
|
role__uuid=role_uuid,
|
|
is_active=True
|
|
).annotate(
|
|
distance=CosineDistance('embedding', query_vector)
|
|
).order_by('distance')[:5]
|
|
|
|
|
|
return [
|
|
{
|
|
"content": d.content,
|
|
"source": d.metadata.get("file_name", "Unknown Source"),
|
|
"relevance": round(1 - d.distance, 4)
|
|
}
|
|
for d in docs
|
|
]
|
|
|
|
@database_sync_to_async
|
|
def _update_progress(self, args):
|
|
session = OnboardingSession.objects.get(uuid=args.get("session_uuid"))
|
|
|
|
state = session.state or {}
|
|
if "score" in args:
|
|
state["last_score"] = args["score"]
|
|
if "completed_module" in args:
|
|
state.setdefault("completed_modules", []).append(args["completed_module"])
|
|
|
|
session.state = state
|
|
session.save()
|
|
return {"status": "success", "new_state": state} |