116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, List, Optional
|
|
from django.conf import settings
|
|
from mcp_agent.mcp_client import MCPClient
|
|
from .models import AgentModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Get reference to the base model cache directory
|
|
try:
|
|
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
|
|
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
|
|
except ImportError:
|
|
# Fallback: construct the path manually
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
|
|
|
|
logger.info(f"Base model cache directory reference: {BASE_MODEL_CACHE}")
|
|
|
|
async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Internal async helper to call the MCP HTTP bridge via MCPClient."""
|
|
server_url = getattr(settings, "MCP_AGENT_URL")
|
|
client = MCPClient(server_url)
|
|
logger.info(f"MCP: Calling tool '{tool}' on {server_url}")
|
|
logger.debug(f"MCP: Arguments for '{tool}': {arguments}")
|
|
try:
|
|
resp = await client.send(tool, arguments)
|
|
logger.info(f"MCP: Tool '{tool}' completed successfully")
|
|
logger.debug(f"MCP: Response from '{tool}': {resp}")
|
|
return resp
|
|
except Exception as e:
|
|
logger.error(f"MCP: Tool '{tool}' failed with error: {str(e)}")
|
|
raise
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
def fine_tune_model(
|
|
base_model: str,
|
|
training_files: List[str],
|
|
hyperparams: Dict[str, Any],
|
|
name: str,
|
|
version: str,
|
|
) -> Dict[str, Any]:
|
|
"""Synchronously request a fine-tune run on the MCP server.
|
|
|
|
Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version}
|
|
and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`.
|
|
"""
|
|
logger.info(f"Fine-tuning model: name={name}, version={version}, base_model={base_model}")
|
|
logger.info(f"Training files count: {len(training_files)}")
|
|
logger.debug(f"Training files: {training_files}")
|
|
try:
|
|
logger.info("Calling MCP fine_tune tool...")
|
|
result = asyncio.run(_call_mcp("fine_tune", {
|
|
"base_model": base_model,
|
|
"training_files": training_files,
|
|
"hyperparams": hyperparams,
|
|
"name": name,
|
|
"version": version,
|
|
}))
|
|
logger.info(f"Fine-tune completed: status={result.get('status')}")
|
|
logger.debug(f"Fine-tune result: {result}")
|
|
return result
|
|
except Exception as e:
|
|
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
|
|
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
|
|
# Return a failed response instead of raising
|
|
return {
|
|
"status": "failed",
|
|
"error": error_msg,
|
|
"error_type": type(e).__name__,
|
|
}
|
|
|
|
|
|
def load_model_for_inference(model_path: str) -> Dict[str, Any]:
|
|
"""Tell the MCP server to load a model into memory/serving for inference.
|
|
|
|
Expects the MCP tool `load_model` with {model_path} returning status info.
|
|
"""
|
|
logger.info(f"Loading model for inference: {model_path}")
|
|
try:
|
|
result = asyncio.run(_call_mcp("load_model", {"model_path": model_path}))
|
|
logger.info(f"Model loaded successfully")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""Request inference from the MCP server using a previously fine-tuned model.
|
|
|
|
Calls the MCP tool `infer` with {model_path, prompt, options}.
|
|
"""
|
|
logger.info(f"Running inference with model: {model_path}")
|
|
logger.debug(f"Prompt length: {len(prompt)} characters")
|
|
logger.debug(f"Inference options: {options}")
|
|
try:
|
|
result = asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}}))
|
|
logger.info(f"Inference completed successfully")
|
|
logger.debug(f"Inference result keys: {list(result.keys()) if isinstance(result, dict) else 'not a dict'}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Inference failed: {str(e)}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel:
|
|
"""Convenience DB helper: create and return an AgentModel record.
|
|
|
|
NOTE: migrations are required after the model field change prior to using this in production.
|
|
"""
|
|
return AgentModel.objects.create(name=name, version=version, path=model_path)
|