2026-01-20 17:21:28 +00:00
|
|
|
import asyncio
|
2026-01-25 17:29:37 +00:00
|
|
|
import logging
|
|
|
|
|
import os
|
2026-02-08 15:34:26 +00:00
|
|
|
import re
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
2026-01-20 17:21:28 +00:00
|
|
|
from django.conf import settings
|
|
|
|
|
from mcp_agent.mcp_client import MCPClient
|
2026-02-08 15:34:26 +00:00
|
|
|
from .models import AgentModel, RoleRagDocument
|
2026-01-20 17:21:28 +00:00
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
|
|
|
|
|
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
|
|
|
|
|
except ImportError:
|
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Base model cache directory reference: {BASE_MODEL_CACHE}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
"""Internal async helper to call the MCP HTTP bridge via MCPClient."""
|
|
|
|
|
server_url = getattr(settings, "MCP_AGENT_URL")
|
|
|
|
|
client = MCPClient(server_url)
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"MCP: Calling tool '{tool}' on {server_url}")
|
|
|
|
|
logger.debug(f"MCP: Arguments for '{tool}': {arguments}")
|
2026-01-20 17:21:28 +00:00
|
|
|
try:
|
|
|
|
|
resp = await client.send(tool, arguments)
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"MCP: Tool '{tool}' completed successfully")
|
|
|
|
|
logger.debug(f"MCP: Response from '{tool}': {resp}")
|
2026-01-20 17:21:28 +00:00
|
|
|
return resp
|
2026-01-25 17:29:37 +00:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"MCP: Tool '{tool}' failed with error: {str(e)}")
|
|
|
|
|
raise
|
2026-01-20 17:21:28 +00:00
|
|
|
finally:
|
|
|
|
|
await client.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fine_tune_model(
|
|
|
|
|
base_model: str,
|
|
|
|
|
training_files: List[str],
|
|
|
|
|
hyperparams: Dict[str, Any],
|
|
|
|
|
name: str,
|
|
|
|
|
version: str,
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""Synchronously request a fine-tune run on the MCP server.
|
|
|
|
|
|
|
|
|
|
Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version}
|
|
|
|
|
and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`.
|
|
|
|
|
"""
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tuning model: name={name}, version={version}, base_model={base_model}")
|
|
|
|
|
logger.info(f"Training files count: {len(training_files)}")
|
|
|
|
|
logger.debug(f"Training files: {training_files}")
|
|
|
|
|
try:
|
|
|
|
|
logger.info("Calling MCP fine_tune tool...")
|
|
|
|
|
result = asyncio.run(_call_mcp("fine_tune", {
|
|
|
|
|
"base_model": base_model,
|
|
|
|
|
"training_files": training_files,
|
|
|
|
|
"hyperparams": hyperparams,
|
|
|
|
|
"name": name,
|
|
|
|
|
"version": version,
|
|
|
|
|
}))
|
|
|
|
|
logger.info(f"Fine-tune completed: status={result.get('status')}")
|
|
|
|
|
logger.debug(f"Fine-tune result: {result}")
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
|
|
|
|
|
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
|
|
|
|
|
return {
|
|
|
|
|
"status": "failed",
|
|
|
|
|
"error": error_msg,
|
|
|
|
|
"error_type": type(e).__name__,
|
|
|
|
|
}
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_for_inference(model_path: str) -> Dict[str, Any]:
|
|
|
|
|
"""Tell the MCP server to load a model into memory/serving for inference.
|
|
|
|
|
|
|
|
|
|
Expects the MCP tool `load_model` with {model_path} returning status info.
|
|
|
|
|
"""
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Loading model for inference: {model_path}")
|
|
|
|
|
try:
|
|
|
|
|
result = asyncio.run(_call_mcp("load_model", {"model_path": model_path}))
|
|
|
|
|
logger.info(f"Model loaded successfully")
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
|
|
|
|
|
raise
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
|
|
|
"""Request inference from the MCP server using a previously fine-tuned model.
|
|
|
|
|
|
|
|
|
|
Calls the MCP tool `infer` with {model_path, prompt, options}.
|
|
|
|
|
"""
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Running inference with model: {model_path}")
|
|
|
|
|
logger.debug(f"Prompt length: {len(prompt)} characters")
|
|
|
|
|
logger.debug(f"Inference options: {options}")
|
|
|
|
|
try:
|
|
|
|
|
result = asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}}))
|
|
|
|
|
logger.info(f"Inference completed successfully")
|
|
|
|
|
logger.debug(f"Inference result keys: {list(result.keys()) if isinstance(result, dict) else 'not a dict'}")
|
|
|
|
|
return result
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Inference failed: {str(e)}", exc_info=True)
|
|
|
|
|
raise
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel:
|
|
|
|
|
"""Convenience DB helper: create and return an AgentModel record.
|
|
|
|
|
|
|
|
|
|
NOTE: migrations are required after the model field change prior to using this in production.
|
|
|
|
|
"""
|
|
|
|
|
return AgentModel.objects.create(name=name, version=version, path=model_path)
|
2026-02-08 15:34:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_texts(texts: List[str]) -> List[List[float]]:
|
|
|
|
|
"""Generate embeddings for texts using the MCP embedding service.
|
|
|
|
|
|
|
|
|
|
Falls back to local sentence-transformers if MCP unavailable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
texts: List of text strings to embed.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of embedding vectors (list of floats).
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: If both MCP and local embedding fail.
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Embedding {len(texts)} texts")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = asyncio.run(_call_mcp("embed", {"texts": texts}))
|
|
|
|
|
embeddings = result.get("embeddings", [])
|
|
|
|
|
if embeddings and len(embeddings) == len(texts):
|
|
|
|
|
logger.info(f"Successfully embedded {len(texts)} texts via MCP")
|
|
|
|
|
return embeddings
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"MCP embedding failed, trying local fallback: {e}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
embeddings = model.encode(texts).tolist()
|
|
|
|
|
logger.info(f"Successfully embedded {len(texts)} texts via local model")
|
|
|
|
|
return embeddings
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Local embedding also failed: {e}")
|
|
|
|
|
raise RuntimeError(f"Failed to embed texts: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_text(text: str) -> List[float]:
|
|
|
|
|
"""Generate embedding for a single text.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: Text string to embed.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Embedding vector (list of floats).
|
|
|
|
|
"""
|
|
|
|
|
return embed_texts([text])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_similar_documents(
|
|
|
|
|
query: str,
|
|
|
|
|
role_uuid: str,
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
similarity_threshold: float = 0.0,
|
|
|
|
|
) -> List[Tuple[RoleRagDocument, float]]:
|
|
|
|
|
"""Search for documents similar to the query using vector similarity.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query: Query text to embed and search for.
|
|
|
|
|
role_uuid: UUID of role to scope search.
|
|
|
|
|
top_k: Number of top results to return.
|
|
|
|
|
similarity_threshold: Minimum similarity score (0-1) to include results.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of (RoleRagDocument, similarity_score) tuples, ordered by similarity DESC.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If role not found or embedding fails.
|
|
|
|
|
"""
|
|
|
|
|
from apps.orgs.models import Role
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
query_embedding = embed_text(query)
|
|
|
|
|
logger.info(f"Embedded query: '{query[:50]}...' to {len(query_embedding)}D vector")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to embed query: {e}")
|
|
|
|
|
raise ValueError(f"Failed to embed query: {e}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
role = Role.objects.get(uuid=role_uuid)
|
|
|
|
|
except Role.DoesNotExist:
|
|
|
|
|
raise ValueError(f"Role with UUID {role_uuid} not found")
|
|
|
|
|
|
|
|
|
|
queryset = RoleRagDocument.objects.filter(
|
|
|
|
|
role=role,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not queryset.exists():
|
|
|
|
|
logger.warning(f"No documents with embeddings found for role {role.uuid}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
from django.db import connection
|
|
|
|
|
|
|
|
|
|
with connection.cursor() as cursor:
|
|
|
|
|
query_sql = """
|
|
|
|
|
SELECT id, 1 - (embedding <=> %s::vector) as similarity
|
|
|
|
|
FROM mlstore_roleragdocument
|
|
|
|
|
WHERE role_id = %s AND embedding IS NOT NULL
|
|
|
|
|
ORDER BY similarity DESC
|
|
|
|
|
LIMIT %s
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
query_sql,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
doc_ids_with_scores = cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
if not doc_ids_with_scores:
|
|
|
|
|
logger.info(f"No similar documents found for query in role {role.uuid}")
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
filtered_docs = [
|
|
|
|
|
(doc_id, score)
|
|
|
|
|
for doc_id, score in doc_ids_with_scores
|
|
|
|
|
if score >= similarity_threshold
|
|
|
|
|
][:top_k]
|
|
|
|
|
|
|
|
|
|
if not filtered_docs:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"No documents met similarity threshold {similarity_threshold}"
|
|
|
|
|
)
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
doc_ids = [doc_id for doc_id, _ in filtered_docs]
|
|
|
|
|
doc_scores = {doc_id: score for doc_id, score in filtered_docs}
|
|
|
|
|
|
|
|
|
|
documents = RoleRagDocument.objects.filter(id__in=doc_ids)
|
|
|
|
|
results = [
|
|
|
|
|
(doc, doc_scores[doc.id])
|
|
|
|
|
for doc in documents
|
|
|
|
|
if doc.id in doc_scores
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Found {len(results)} similar documents for query "
|
|
|
|
|
f"(threshold={similarity_threshold}, top_k={top_k})"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_embed_documents(
|
|
|
|
|
documents: List[RoleRagDocument],
|
|
|
|
|
batch_size: int = 32,
|
|
|
|
|
force_reembed: bool = False,
|
|
|
|
|
) -> Tuple[int, int]:
|
|
|
|
|
"""Batch embed documents that don't have embeddings yet.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
documents: List of RoleRagDocument instances to embed.
|
|
|
|
|
batch_size: Number of documents to embed per API call.
|
|
|
|
|
force_reembed: If True, re-embed documents that already have embeddings.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (num_embedded, num_failed).
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Updates documents in-place with embedding values.
|
|
|
|
|
"""
|
|
|
|
|
to_embed = [
|
|
|
|
|
doc for doc in documents
|
|
|
|
|
if force_reembed or not doc.embedding
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if not to_embed:
|
|
|
|
|
logger.info("No documents to embed")
|
|
|
|
|
return 0, 0
|
|
|
|
|
|
|
|
|
|
num_embedded = 0
|
|
|
|
|
num_failed = 0
|
|
|
|
|
|
|
|
|
|
for i in range(0, len(to_embed), batch_size):
|
|
|
|
|
batch = to_embed[i : i + batch_size]
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Embedding batch {i // batch_size + 1} "
|
|
|
|
|
f"({len(batch)} documents)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
texts = [doc.content for doc in batch]
|
|
|
|
|
embeddings = embed_texts(texts)
|
|
|
|
|
|
|
|
|
|
for doc, embedding in zip(batch, embeddings):
|
|
|
|
|
doc.embedding = embedding
|
|
|
|
|
num_embedded += 1
|
|
|
|
|
|
|
|
|
|
RoleRagDocument.objects.bulk_update(batch, ["embedding"], batch_size=500)
|
|
|
|
|
logger.info(f"Successfully embedded {len(batch)} documents")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to embed batch: {e}")
|
|
|
|
|
num_failed += len(batch)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Embedding complete: {num_embedded} embedded, {num_failed} failed"
|
|
|
|
|
)
|
|
|
|
|
return num_embedded, num_failed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_context_for_query(
|
|
|
|
|
query: str,
|
|
|
|
|
role_uuid: str,
|
|
|
|
|
top_k: int = 5,
|
|
|
|
|
similarity_threshold: float = 0.5,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Get context string from similar documents for a query.
|
|
|
|
|
|
|
|
|
|
Useful for augmenting prompts with retrieved context.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query: Query text.
|
|
|
|
|
role_uuid: UUID of role to search within.
|
|
|
|
|
top_k: Number of top results to include.
|
|
|
|
|
similarity_threshold: Minimum similarity score.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Formatted context string with source attribution.
|
|
|
|
|
"""
|
|
|
|
|
def _clean_chunk_text(text: str) -> str:
|
|
|
|
|
"""Strip junk and deduplicate paragraphs to keep context lean."""
|
|
|
|
|
if not text:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
text = re.sub(r"\[\s*Answer\s*:.*?\]", "", text, flags=re.IGNORECASE | re.DOTALL)
|
|
|
|
|
|
|
|
|
|
lines = []
|
|
|
|
|
for raw_line in text.splitlines():
|
|
|
|
|
line = raw_line.strip()
|
|
|
|
|
if not line:
|
|
|
|
|
lines.append("")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
lower = line.lower()
|
|
|
|
|
|
|
|
|
|
if line.startswith("#"):
|
|
|
|
|
continue
|
|
|
|
|
if "do you have any questions" in lower:
|
|
|
|
|
continue
|
|
|
|
|
if "feel free to ask" in lower:
|
|
|
|
|
continue
|
|
|
|
|
if "references" in lower or "sources" in lower or "wikipedia" in lower:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
lines.append(line)
|
|
|
|
|
|
|
|
|
|
cleaned = "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", cleaned) if p.strip()]
|
|
|
|
|
seen = set()
|
|
|
|
|
unique_paragraphs: List[str] = []
|
|
|
|
|
for para in paragraphs:
|
|
|
|
|
if para in seen:
|
|
|
|
|
continue
|
|
|
|
|
seen.add(para)
|
|
|
|
|
unique_paragraphs.append(para)
|
|
|
|
|
|
|
|
|
|
return "\n\n".join(unique_paragraphs)
|
|
|
|
|
try:
|
|
|
|
|
results = search_similar_documents(
|
|
|
|
|
query=query,
|
|
|
|
|
role_uuid=role_uuid,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
similarity_threshold=similarity_threshold,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to retrieve context: {e}")
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
if not results:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
context_parts = []
|
|
|
|
|
for doc, similarity in results:
|
|
|
|
|
cleaned = _clean_chunk_text(doc.content)
|
|
|
|
|
if not cleaned:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
source = "unknown"
|
|
|
|
|
if doc.training_file:
|
|
|
|
|
source = doc.training_file.file_name
|
|
|
|
|
|
|
|
|
|
context_parts.append(
|
|
|
|
|
f"[Source: {source}, Similarity: {similarity:.2%}]\n{cleaned}\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
context = "\n---\n".join(context_parts)
|
|
|
|
|
return context
|