Dynavera/mcp_agent/mcp_server.py

333 lines
11 KiB
Python
Raw Normal View History

import asyncio
import json
import os
import sys
from datetime import datetime
from pathlib import PureWindowsPath
from typing import Any, Dict, List
from aiohttp import web
from mcp.server import Server
from mcp.types import Tool, TextContent
app = Server("mlstore-mcp-server")
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
@app.list_tools()
async def list_tools():
return [
Tool(
name="echo",
description="Echo back the provided input",
inputSchema={
"type": "object",
"properties": {
"message": {"type": "string"}
},
"required": ["message"]
},
)
,
Tool(
name="fine_tune",
description="Start fine-tuning a base model using training files",
inputSchema={
"type": "object",
"properties": {
"base_model": {"type": "string"},
"training_files": {"type": "array", "items": {"type": "string"}},
"hyperparams": {"type": "object"},
"name": {"type": "string"},
"version": {"type": "string"}
},
"required": ["base_model", "training_files", "name", "version"]
},
),
Tool(
name="load_model",
description="Load a fine-tuned model into memory for inference",
inputSchema={
"type": "object",
"properties": {
"model_path": {"type": "string"}
},
"required": ["model_path"]
},
),
Tool(
name="infer",
description="Run inference with a fine-tuned model",
inputSchema={
"type": "object",
"properties": {
"model_path": {"type": "string"},
"prompt": {"type": "string"},
"options": {"type": "object"}
},
"required": ["model_path", "prompt"]
},
),
]
def _now() -> str:
return datetime.utcnow().isoformat() + "Z"
def _model_root() -> str:
return os.getenv("MCP_MODEL_DIR") or os.getenv("DJANGO_MODEL_DIR") or os.path.join(os.getcwd(), "model")
def _safe_dir_name(name: str) -> str:
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
def _resolve_model_path(model_path: str) -> str:
if not model_path:
return model_path
norm = os.path.normpath(model_path)
if os.path.isabs(norm) and os.path.exists(norm):
return norm
candidates = []
# Try relative to current working directory
candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm)))
# Try relative to model root
candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm))))
# If it's a Windows-style absolute path, map to container /app by trimming common root
if ":" in model_path or "\\" in model_path:
p = PureWindowsPath(model_path)
parts = [str(x) for x in p.parts]
for anchor in ("notebooks", "model"):
if anchor in parts:
idx = parts.index(anchor)
rel = os.path.join(*parts[idx:])
candidates.append(os.path.normpath(os.path.join(os.getcwd(), rel)))
for cand in candidates:
if os.path.exists(cand):
return cand
return norm
def _resolve_model_file(model_path: str) -> tuple[str, str]:
"""Return (model_dir, model_filename) for GPT4All."""
resolved = _resolve_model_path(model_path)
if os.path.isdir(resolved):
for name in os.listdir(resolved):
if name.lower().endswith(".gguf"):
return resolved, name
return resolved, ""
return os.path.dirname(resolved), os.path.basename(resolved)
async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
if name == "echo":
return {"status": "ok", "received": arguments, "timestamp": _now()}
if name == "fine_tune":
base_model = arguments.get("base_model")
training_files = arguments.get("training_files") or []
hyperparams = arguments.get("hyperparams") or {}
model_name = arguments.get("name") or "model"
version = arguments.get("version") or "v1"
model_root = _model_root()
os.makedirs(model_root, exist_ok=True)
safe_name = _safe_dir_name(model_name)
safe_version = _safe_dir_name(version)
output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}")
os.makedirs(output_dir, exist_ok=True)
metadata = {
"status": "completed",
"base_model": base_model,
"training_files": training_files,
"hyperparams": hyperparams,
"name": model_name,
"version": version,
"model_path": output_dir,
"timestamp": _now(),
}
try:
with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
except Exception:
pass
return metadata
if name == "load_model":
model_path = arguments.get("model_path")
if not model_path:
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
model_path = _resolve_model_path(model_path)
if not os.path.exists(model_path):
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
try:
from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path)
if not model_file:
return {
"status": "failed",
"error": "model_file_not_found",
"model_path": model_path,
"timestamp": _now(),
}
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu')
LOADED_MODELS[model_path] = {
"loaded_at": _now(),
"model": model,
"model_dir": model_dir,
"model_file": model_file,
}
return {
"status": "completed",
"model_path": model_path,
"loaded": True,
"model_dir": model_dir,
"model_file": model_file,
"timestamp": _now(),
}
except Exception as e:
return {
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"model_path": model_path,
"timestamp": _now(),
}
if name == "infer":
model_path = arguments.get("model_path")
prompt = arguments.get("prompt") or ""
options = arguments.get("options") or {}
if not model_path:
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
model_path = _resolve_model_path(model_path)
if not os.path.exists(model_path):
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
try:
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path)
if not model_file:
return {
"status": "failed",
"error": "model_file_not_found",
"model_path": model_path,
"timestamp": _now(),
}
model = GPT4All(model_file, model_path=model_dir, allow_download=False)
LOADED_MODELS[model_path] = {
"loaded_at": _now(),
"model": model,
"model_dir": model_dir,
"model_file": model_file,
}
model = LOADED_MODELS[model_path]["model"]
max_tokens = int(options.get("max_tokens", 256))
temp = float(options.get("temperature", options.get("temp", 0.7)))
top_p = float(options.get("top_p", 0.95))
top_k = int(options.get("top_k", 40))
response_text = model.generate(
prompt,
max_tokens=max_tokens,
temp=temp,
top_p=top_p,
top_k=top_k,
)
return {
"status": "completed",
"model_path": model_path,
"response": response_text,
"options": {
"max_tokens": max_tokens,
"temperature": temp,
"top_p": top_p,
"top_k": top_k,
},
"timestamp": _now(),
}
except Exception as e:
return {
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"model_path": model_path,
"timestamp": _now(),
}
raise ValueError(f"Unknown tool: {name}")
@app.call_tool()
async def call_tool(name: str, arguments: dict):
result = await _run_tool_http(name, arguments)
return [TextContent(type="text", text=json.dumps(result, indent=2))]
async def handle_execute(request: web.Request) -> web.Response:
try:
payload = await request.json()
tool = payload.get("tool")
arguments = payload.get("arguments", {})
if not tool:
return web.json_response(
{"error": "Missing 'tool' field"}, status=400
)
result = await _run_tool_http(tool, arguments)
return web.json_response(result)
except json.JSONDecodeError:
return web.json_response({"error": "Invalid JSON"}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def handle_health(request: web.Request) -> web.Response:
return web.json_response({"status": "healthy"})
async def run_http_server():
host = os.getenv("MCP_HTTP_HOST", "0.0.0.0")
port = int(os.getenv("MCP_HTTP_PORT", "8001"))
app_http = web.Application()
app_http.router.add_post("/execute", handle_execute)
app_http.router.add_get("/health", handle_health)
runner = web.AppRunner(app_http)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
print(f"HTTP server running on {host}:{port}", file=sys.stderr)
await asyncio.Event().wait()
if __name__ == "__main__":
asyncio.run(run_http_server())