Added agent models, viewsets, consumers, gpu and mcp

This commit is contained in:
Viswamedha Nalabotu 2026-01-20 17:21:28 +00:00
parent 6039d6b2ac
commit fcd4862e18
22 changed files with 1463 additions and 26 deletions

193
apps/mlstore/consumers.py Normal file
View file

@ -0,0 +1,193 @@
import json
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from django.utils import timezone
from .models import Agent, AgentRun, AgentEvent
from .tasks import start_fine_tune_run_task, infer_run_task
class MLStoreConsumer(AsyncWebsocketConsumer):
async def connect(self):
self.user = self.scope["user"]
self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id")
self.room_group_name = f"mlstore_agent_{self.agent_id}"
if not self.user.is_authenticated:
await self.close()
return
agent = await self.get_agent(self.agent_id)
if not agent:
await self.close()
return
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
await self.accept()
await self.send(json.dumps({
"type": "connection",
"message": "Connected to mlstore agent stream",
"agent_id": str(self.agent_id)
}))
async def disconnect(self, close_code):
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
async def receive(self, text_data):
try:
data = json.loads(text_data)
action = data.get("action")
if action == "fine_tune":
await self.handle_fine_tune(data)
elif action == "infer":
await self.handle_infer(data)
elif action in ("stop_agent", "stop"):
await self.handle_stop(data)
else:
await self.send(json.dumps({
"type": "error",
"message": f"Unknown action: {action}"
}))
except json.JSONDecodeError:
await self.send(json.dumps({
"type": "error",
"message": "Invalid JSON"
}))
except Exception as e:
await self.send(json.dumps({
"type": "error",
"message": str(e)
}))
async def handle_fine_tune(self, data):
agent = await self.get_agent(self.agent_id)
if not agent:
await self.send(json.dumps({
"type": "error",
"message": "Agent not found"
}))
return
input_data = data.get("input_data") or {}
execution = await self.create_run(agent, self.user, input_data)
await self.send(json.dumps({
"type": "execution_started",
"execution_id": str(execution.uuid),
"agent_id": str(agent.uuid),
"message": f"Fine-tune run {execution.uuid} queued"
}))
start_fine_tune_run_task.delay(str(execution.uuid))
async def handle_infer(self, data):
agent = await self.get_agent(self.agent_id)
if not agent:
await self.send(json.dumps({
"type": "error",
"message": "Agent not found"
}))
return
input_data = data.get("input_data") or {}
execution = await self.create_run(agent, self.user, input_data)
await self.send(json.dumps({
"type": "execution_started",
"execution_id": str(execution.uuid),
"agent_id": str(agent.uuid),
"message": f"Inference run {execution.uuid} queued"
}))
infer_run_task.delay(str(execution.uuid))
async def handle_stop(self, data):
execution_id = data.get("execution_id")
if not execution_id:
await self.send(json.dumps({
"type": "error",
"message": "execution_id required to stop"
}))
return
execution = await self.get_execution(execution_id)
if not execution:
await self.send(json.dumps({
"type": "error",
"message": "Execution not found"
}))
return
await self.update_execution_status(execution, "failed")
await self.send(json.dumps({
"type": "execution_error",
"execution_id": str(execution.uuid),
"error_message": "Execution stopped by user"
}))
async def mlstore_event(self, event):
await self.send(json.dumps({
"type": "mlstore_event",
"event_type": event["event_type"],
"content": event["content"],
"timestamp": event["timestamp"]
}))
async def mlstore_completed(self, event):
await self.send(json.dumps({
"type": "execution_completed",
"execution_id": event["execution_id"],
"output_data": event["output_data"],
"message": "Execution completed"
}))
async def mlstore_error(self, event):
await self.send(json.dumps({
"type": "execution_error",
"execution_id": event["execution_id"],
"error_message": event["error_message"]
}))
@database_sync_to_async
def get_agent(self, agent_id):
try:
return Agent.objects.get(uuid=agent_id)
except Agent.DoesNotExist:
return None
@database_sync_to_async
def create_run(self, agent, user, input_data):
return AgentRun.objects.create(
agent=agent,
user=user,
input_data=input_data,
)
@database_sync_to_async
def get_execution(self, execution_id):
try:
return AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
return None
@database_sync_to_async
def update_execution_status(self, execution, status):
execution.status = status
execution.completed_at = timezone.now()
execution.save()
try:
agent = execution.agent
agent.status = status
agent.completed_at = timezone.now()
agent.save()
except Exception:
pass
return execution
@database_sync_to_async
def create_event(self, execution, event_type, content):
return AgentEvent.objects.create(
execution=execution,
event_type=event_type,
content=content,
)

View file

@ -3,7 +3,6 @@ import uuid
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
initial = True initial = True

View file

@ -0,0 +1,15 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('mlstore', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='agentmodel',
name='path',
field=models.CharField(blank=True, default='', max_length=1024),
),
]

View file

@ -9,6 +9,7 @@ class AgentModel(Model):
uuid = UUIDField(default = uuid4, unique = True, editable = False) uuid = UUIDField(default = uuid4, unique = True, editable = False)
name = CharField(max_length = 255) name = CharField(max_length = 255)
version = CharField(max_length = 50) version = CharField(max_length = 50)
path = CharField(max_length=1024, blank=True, default='')
class Meta: class Meta:
verbose_name = 'Model' verbose_name = 'Model'

6
apps/mlstore/routing.py Normal file
View file

@ -0,0 +1,6 @@
from django.urls import path
from . import consumers
websocket_urlpatterns = [
path("ws/mlstore/agents/<str:agent_id>/", consumers.MLStoreConsumer.as_asgi()),
]

View file

@ -0,0 +1,51 @@
from rest_framework.serializers import ModelSerializer
from .models import AgentModel, Agent, AgentRun, AgentEvent
class AgentModelSerializer(ModelSerializer):
class Meta:
model = AgentModel
fields = ['id', 'uuid', 'name', 'version', 'path']
read_only_fields = ['id', 'uuid']
class AgentSerializer(ModelSerializer):
model = AgentModelSerializer(read_only=True)
class Meta:
model = Agent
fields = [
'id',
'uuid',
'model',
'status',
'description',
'started_at',
'completed_at',
]
read_only_fields = ['id', 'uuid', 'started_at', 'completed_at']
class AgentRunSerializer(ModelSerializer):
class Meta:
model = AgentRun
fields = [
'id',
'uuid',
'agent',
'user',
'status',
'input_data',
'output_data',
'error_message',
'started_at',
'completed_at',
]
read_only_fields = ['id', 'uuid', 'started_at', 'completed_at']
class AgentEventSerializer(ModelSerializer):
class Meta:
model = AgentEvent
fields = ['id', 'uuid', 'execution', 'event_type', 'content', 'timestamp']
read_only_fields = ['id', 'uuid', 'timestamp']

61
apps/mlstore/services.py Normal file
View file

@ -0,0 +1,61 @@
import asyncio
from typing import Any, Dict, List, Optional
from django.conf import settings
from mcp_agent.mcp_client import MCPClient
from .models import AgentModel
async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Internal async helper to call the MCP HTTP bridge via MCPClient."""
server_url = getattr(settings, "MCP_AGENT_URL")
client = MCPClient(server_url)
try:
resp = await client.send(tool, arguments)
return resp
finally:
await client.close()
def fine_tune_model(
base_model: str,
training_files: List[str],
hyperparams: Dict[str, Any],
name: str,
version: str,
) -> Dict[str, Any]:
"""Synchronously request a fine-tune run on the MCP server.
Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version}
and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`.
"""
return asyncio.run(_call_mcp("fine_tune", {
"base_model": base_model,
"training_files": training_files,
"hyperparams": hyperparams,
"name": name,
"version": version,
}))
def load_model_for_inference(model_path: str) -> Dict[str, Any]:
"""Tell the MCP server to load a model into memory/serving for inference.
Expects the MCP tool `load_model` with {model_path} returning status info.
"""
return asyncio.run(_call_mcp("load_model", {"model_path": model_path}))
def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Request inference from the MCP server using a previously fine-tuned model.
Calls the MCP tool `infer` with {model_path, prompt, options}.
"""
return asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}}))
def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel:
"""Convenience DB helper: create and return an AgentModel record.
NOTE: migrations are required after the model field change prior to using this in production.
"""
return AgentModel.objects.create(name=name, version=version, path=model_path)

267
apps/mlstore/tasks.py Normal file
View file

@ -0,0 +1,267 @@
from celery import shared_task
from django.utils import timezone
from channels.layers import get_channel_layer
from asgiref.sync import async_to_sync
from . import services
from .models import AgentModel, Agent, AgentRun, AgentEvent
import traceback
@shared_task
def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str):
"""Start a fine-tune via MCP, and register the resulting model on success.
This task calls `services.fine_tune_model`, expects a dict result with `status` and on success
`model_path` and optionally `version`.
"""
try:
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
if isinstance(result, dict) and result.get("status") == "completed":
model_path = result.get("model_path") or result.get("path") or ""
model_version = result.get("version") or version
m = AgentModel.objects.create(name=name, version=model_version, path=model_path)
return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result}
return {"status": "failed", "result": result}
except Exception as e:
traceback.print_exc()
return {"status": "error", "error": str(e)}
@shared_task
def infer_with_model_task(model_id: int, prompt: str, options: dict = None):
"""Run inference by requesting the MCP server to use the stored model.
Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response.
"""
try:
model = AgentModel.objects.get(id=model_id)
except AgentModel.DoesNotExist:
return {"status": "error", "error": "model_not_found", "model_id": model_id}
try:
services.load_model_for_inference(model.path)
except Exception:
pass
try:
out = services.infer_with_model(model.path, prompt, options or {})
return {"status": "completed", "model_id": model_id, "response": out}
except Exception as e:
traceback.print_exc()
return {"status": "failed", "error": str(e)}
def _send_group_event(room_group_name: str, event_type: str, content: dict):
channel_layer = get_channel_layer()
async_to_sync(channel_layer.group_send)(
room_group_name,
{
"type": "mlstore_event",
"event_type": event_type,
"content": content,
"timestamp": timezone.now().isoformat(),
}
)
def _persist_event(execution: AgentRun, event_type: str, content: dict):
AgentEvent.objects.create(
execution=execution,
event_type=event_type,
content=content,
)
def _update_agent_status(agent: Agent, status: str):
agent.status = status
if status == "running":
agent.started_at = timezone.now()
elif status in ("completed", "failed"):
agent.completed_at = timezone.now()
agent.save()
@shared_task
def start_fine_tune_run_task(execution_id: str):
try:
execution = AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
agent = execution.agent
room_group_name = f"mlstore_agent_{agent.uuid}"
execution.status = "running"
execution.started_at = timezone.now()
execution.save()
_update_agent_status(agent, "running")
input_data = execution.input_data or {}
base_model = input_data.get("base_model") or agent.model.name
training_files = input_data.get("training_files") or []
hyperparams = input_data.get("hyperparams") or {}
name = input_data.get("name") or f"{agent.model.name}-ft"
version = input_data.get("version") or "v1"
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
try:
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
if isinstance(result, dict) and result.get("status") == "completed":
model_path = result.get("model_path") or result.get("path") or ""
model_version = result.get("version") or version
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
agent.model = new_model
agent.save()
execution.status = "completed"
execution.output_data = {
"result": result,
"model_id": new_model.id,
"model_uuid": str(new_model.uuid),
}
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "completed")
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_completed",
"execution_id": str(execution.uuid),
"output_data": execution.output_data,
},
)
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
execution.status = "failed"
execution.error_message = str(result)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(result),
},
)
return {"status": "failed", "execution_id": execution_id, "result": result}
except Exception as e:
traceback.print_exc()
execution.status = "failed"
execution.error_message = str(e)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(e),
},
)
return {"status": "error", "execution_id": execution_id, "error": str(e)}
@shared_task
def infer_run_task(execution_id: str):
try:
execution = AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
agent = execution.agent
room_group_name = f"mlstore_agent_{agent.uuid}"
execution.status = "running"
execution.started_at = timezone.now()
execution.save()
_update_agent_status(agent, "running")
input_data = execution.input_data or {}
prompt = input_data.get("prompt") or input_data.get("query") or ""
options = input_data.get("options") or {}
if not prompt:
execution.status = "failed"
execution.error_message = "prompt_required"
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": "prompt_required",
},
)
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
try:
try:
services.load_model_for_inference(agent.model.path)
except Exception:
pass
result = services.infer_with_model(agent.model.path, prompt, options)
execution.status = "completed"
execution.output_data = {"result": result}
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "completed")
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_completed",
"execution_id": str(execution.uuid),
"output_data": execution.output_data,
},
)
return {"status": "completed", "execution_id": execution_id}
except Exception as e:
traceback.print_exc()
execution.status = "failed"
execution.error_message = str(e)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(e),
},
)
return {"status": "failed", "execution_id": execution_id, "error": str(e)}

29
apps/mlstore/viewsets.py Normal file
View file

@ -0,0 +1,29 @@
from rest_framework.viewsets import ModelViewSet
from rest_framework.permissions import IsAuthenticated
from .models import Agent, AgentRun, AgentEvent
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
from rest_framework.decorators import action
from rest_framework.response import Response
class AgentViewSet(ModelViewSet):
queryset = Agent.objects.all()
serializer_class = AgentSerializer
permission_classes = [IsAuthenticated]
lookup_field = 'uuid'
class AgentRunViewSet(ModelViewSet):
queryset = AgentRun.objects.all()
serializer_class = AgentRunSerializer
permission_classes = [IsAuthenticated]
lookup_field = 'uuid'
def get_queryset(self):
return AgentRun.objects.filter(user=self.request.user)
@action(detail=True, methods=['get'], url_path='events')
def events(self, request, uuid=None):
run = self.get_object()
events = AgentEvent.objects.filter(execution=run).order_by('timestamp')
serializer = AgentEventSerializer(events, many=True)
return Response(serializer.data)

View file

@ -91,6 +91,18 @@ services:
- ../../.env - ../../.env
volumes: volumes:
- ../../:/app - ../../:/app
- ../../notebooks/build:/app/notebooks/build
deploy:
mode: replicated
replicas: 1
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
environment:
- NVIDIA_VISIBLE_DEVICES=all
ports: ports:
- "0.0.0.0:8001:8001" - "0.0.0.0:8001:8001"
depends_on: depends_on:
@ -100,6 +112,8 @@ services:
condition: service_healthy condition: service_healthy
volumes: volumes:
fyp_postgres_data: fyp_postgres_data:
fyp_redis_data: fyp_redis_data:

View file

@ -2,9 +2,12 @@ from rest_framework.routers import DefaultRouter
from apps.orgs.viewsets import OrganizationViewSet from apps.orgs.viewsets import OrganizationViewSet
from apps.users.viewsets import UserViewSet from apps.users.viewsets import UserViewSet
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
router = DefaultRouter() router = DefaultRouter()
router.register(r'user', UserViewSet, basename='user') router.register(r'user', UserViewSet, basename='user')
router.register(r'organization', OrganizationViewSet, basename='organization') router.register(r'organization', OrganizationViewSet, basename='organization')
router.register(r'agent', AgentViewSet, basename='agent')
router.register(r'agent-run', AgentRunViewSet, basename='agent-run')
urlpatterns = router.urls urlpatterns = router.urls

View file

@ -5,16 +5,17 @@ from django.core.asgi import get_asgi_application
from channels.auth import AuthMiddlewareStack from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from channels.security.websocket import AllowedHostsOriginValidator from channels.security.websocket import AllowedHostsOriginValidator
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
django_asgi_app = get_asgi_application() django_asgi_app = get_asgi_application()
from apps.mlstore.routing import websocket_urlpatterns
application = ProtocolTypeRouter({ application = ProtocolTypeRouter({
"http": django_asgi_app, "http": django_asgi_app,
"websocket": AllowedHostsOriginValidator( "websocket": AllowedHostsOriginValidator(
AuthMiddlewareStack( AuthMiddlewareStack(
URLRouter([]) URLRouter(websocket_urlpatterns)
) )
) )
}) })

View file

@ -24,6 +24,10 @@ PARENT_NAME = Path(__file__).resolve().parent.name
DJANGO_CELERY_BROKER_URL = os.getenv('DJANGO_CELERY_BROKER_URL', 'redis://localhost:6379/0') DJANGO_CELERY_BROKER_URL = os.getenv('DJANGO_CELERY_BROKER_URL', 'redis://localhost:6379/0')
MCP_SERVER_HOST = os.getenv('MCP_SERVER_HOST', 'localhost')
MCP_SERVER_PORT = os.getenv('MCP_SERVER_PORT', '8001')
MCP_AGENT_URL = f"http://{MCP_SERVER_HOST}:{MCP_SERVER_PORT}"
STATIC_URL = os.getenv('DJANGO_STATIC_URL', '/static/') STATIC_URL = os.getenv('DJANGO_STATIC_URL', '/static/')
MEDIA_URL = os.getenv('DJANGO_MEDIA_URL', '/media/') MEDIA_URL = os.getenv('DJANGO_MEDIA_URL', '/media/')
STATIC_ROOT = os.getenv('DJANGO_STATIC_ROOT', BASE_DIR / 'static') STATIC_ROOT = os.getenv('DJANGO_STATIC_ROOT', BASE_DIR / 'static')

View file

@ -2,11 +2,16 @@ import asyncio
import json import json
import os import os
import sys import sys
from datetime import datetime
from pathlib import PureWindowsPath
from typing import Any, Dict, List
from aiohttp import web from aiohttp import web
from mcp.server import Server from mcp.server import Server
from mcp.types import Tool, TextContent from mcp.types import Tool, TextContent
app = Server("minimal-mcp-server") app = Server("mlstore-mcp-server")
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
@app.list_tools() @app.list_tools()
@ -23,26 +28,263 @@ async def list_tools():
"required": ["message"] "required": ["message"]
}, },
) )
,
Tool(
name="fine_tune",
description="Start fine-tuning a base model using training files",
inputSchema={
"type": "object",
"properties": {
"base_model": {"type": "string"},
"training_files": {"type": "array", "items": {"type": "string"}},
"hyperparams": {"type": "object"},
"name": {"type": "string"},
"version": {"type": "string"}
},
"required": ["base_model", "training_files", "name", "version"]
},
),
Tool(
name="load_model",
description="Load a fine-tuned model into memory for inference",
inputSchema={
"type": "object",
"properties": {
"model_path": {"type": "string"}
},
"required": ["model_path"]
},
),
Tool(
name="infer",
description="Run inference with a fine-tuned model",
inputSchema={
"type": "object",
"properties": {
"model_path": {"type": "string"},
"prompt": {"type": "string"},
"options": {"type": "object"}
},
"required": ["model_path", "prompt"]
},
),
] ]
def _now() -> str:
return datetime.utcnow().isoformat() + "Z"
def _model_root() -> str:
return os.getenv("MCP_MODEL_DIR") or os.getenv("DJANGO_MODEL_DIR") or os.path.join(os.getcwd(), "model")
def _safe_dir_name(name: str) -> str:
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
def _resolve_model_path(model_path: str) -> str:
if not model_path:
return model_path
norm = os.path.normpath(model_path)
if os.path.isabs(norm) and os.path.exists(norm):
return norm
candidates = []
# Try relative to current working directory
candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm)))
# Try relative to model root
candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm))))
# If it's a Windows-style absolute path, map to container /app by trimming common root
if ":" in model_path or "\\" in model_path:
p = PureWindowsPath(model_path)
parts = [str(x) for x in p.parts]
for anchor in ("notebooks", "model"):
if anchor in parts:
idx = parts.index(anchor)
rel = os.path.join(*parts[idx:])
candidates.append(os.path.normpath(os.path.join(os.getcwd(), rel)))
for cand in candidates:
if os.path.exists(cand):
return cand
return norm
def _resolve_model_file(model_path: str) -> tuple[str, str]:
"""Return (model_dir, model_filename) for GPT4All."""
resolved = _resolve_model_path(model_path)
if os.path.isdir(resolved):
for name in os.listdir(resolved):
if name.lower().endswith(".gguf"):
return resolved, name
return resolved, ""
return os.path.dirname(resolved), os.path.basename(resolved)
async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
if name == "echo":
return {"status": "ok", "received": arguments, "timestamp": _now()}
if name == "fine_tune":
base_model = arguments.get("base_model")
training_files = arguments.get("training_files") or []
hyperparams = arguments.get("hyperparams") or {}
model_name = arguments.get("name") or "model"
version = arguments.get("version") or "v1"
model_root = _model_root()
os.makedirs(model_root, exist_ok=True)
safe_name = _safe_dir_name(model_name)
safe_version = _safe_dir_name(version)
output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}")
os.makedirs(output_dir, exist_ok=True)
metadata = {
"status": "completed",
"base_model": base_model,
"training_files": training_files,
"hyperparams": hyperparams,
"name": model_name,
"version": version,
"model_path": output_dir,
"timestamp": _now(),
}
try:
with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
except Exception:
pass
return metadata
if name == "load_model":
model_path = arguments.get("model_path")
if not model_path:
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
model_path = _resolve_model_path(model_path)
if not os.path.exists(model_path):
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
try:
from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path)
if not model_file:
return {
"status": "failed",
"error": "model_file_not_found",
"model_path": model_path,
"timestamp": _now(),
}
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu')
LOADED_MODELS[model_path] = {
"loaded_at": _now(),
"model": model,
"model_dir": model_dir,
"model_file": model_file,
}
return {
"status": "completed",
"model_path": model_path,
"loaded": True,
"model_dir": model_dir,
"model_file": model_file,
"timestamp": _now(),
}
except Exception as e:
return {
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"model_path": model_path,
"timestamp": _now(),
}
if name == "infer":
model_path = arguments.get("model_path")
prompt = arguments.get("prompt") or ""
options = arguments.get("options") or {}
if not model_path:
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
model_path = _resolve_model_path(model_path)
if not os.path.exists(model_path):
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
try:
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path)
if not model_file:
return {
"status": "failed",
"error": "model_file_not_found",
"model_path": model_path,
"timestamp": _now(),
}
model = GPT4All(model_file, model_path=model_dir, allow_download=False)
LOADED_MODELS[model_path] = {
"loaded_at": _now(),
"model": model,
"model_dir": model_dir,
"model_file": model_file,
}
model = LOADED_MODELS[model_path]["model"]
max_tokens = int(options.get("max_tokens", 256))
temp = float(options.get("temperature", options.get("temp", 0.7)))
top_p = float(options.get("top_p", 0.95))
top_k = int(options.get("top_k", 40))
response_text = model.generate(
prompt,
max_tokens=max_tokens,
temp=temp,
top_p=top_p,
top_k=top_k,
)
return {
"status": "completed",
"model_path": model_path,
"response": response_text,
"options": {
"max_tokens": max_tokens,
"temperature": temp,
"top_p": top_p,
"top_k": top_k,
},
"timestamp": _now(),
}
except Exception as e:
return {
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"model_path": model_path,
"timestamp": _now(),
}
raise ValueError(f"Unknown tool: {name}")
@app.call_tool() @app.call_tool()
async def call_tool(name: str, arguments: dict): async def call_tool(name: str, arguments: dict):
if name != "echo": result = await _run_tool_http(name, arguments)
raise ValueError(f"Unknown tool: {name}") return [TextContent(type="text", text=json.dumps(result, indent=2))]
return [
TextContent(
type="text",
text=json.dumps(
{
"received": arguments,
"status": "ok",
},
indent=2,
),
)
]
async def handle_execute(request: web.Request) -> web.Response: async def handle_execute(request: web.Request) -> web.Response:
@ -56,13 +298,8 @@ async def handle_execute(request: web.Request) -> web.Response:
{"error": "Missing 'tool' field"}, status=400 {"error": "Missing 'tool' field"}, status=400
) )
result = await call_tool(tool, arguments) result = await _run_tool_http(tool, arguments)
return web.json_response( return web.json_response(result)
{
"tool": tool,
"result": [c.text for c in result],
}
)
except json.JSONDecodeError: except json.JSONDecodeError:
return web.json_response({"error": "Invalid JSON"}, status=400) return web.json_response({"error": "Invalid JSON"}, status=400)

View file

@ -25,6 +25,7 @@ django-jazzmin==3.0.1
django-timezone-field==7.2.1 django-timezone-field==7.2.1
django_celery_results==2.6.0 django_celery_results==2.6.0
djangorestframework==3.16.1 djangorestframework==3.16.1
httpx==0.28.1
hyperlink==21.0.0 hyperlink==21.0.0
idna==3.11 idna==3.11
Incremental==24.11.0 Incremental==24.11.0

View file

@ -4,3 +4,4 @@ pyjwt==2.10.1
python-multipart==0.0.21 python-multipart==0.0.21
sse-starlette==3.2.0 sse-starlette==3.2.0
starlette==0.52.1 starlette==0.52.1
gpt4all==2.8.2

View file

@ -93,6 +93,8 @@ export const API = {
`/api/organization/${orgUuid}/create-invite/?max_uses=${max_uses}`, `/api/organization/${orgUuid}/create-invite/?max_uses=${max_uses}`,
organizationJoin: (token: string) => `/api/organization/join/${token}/`, organizationJoin: (token: string) => `/api/organization/join/${token}/`,
organizationLeave: (orgUuid: string) => `/api/organization/${orgUuid}/leave/`, organizationLeave: (orgUuid: string) => `/api/organization/${orgUuid}/leave/`,
agents: () => '/api/agent/',
agent: (id: string) => `/api/agent/${id}/`,
} }
export const apiClient = new ApiClient() export const apiClient = new ApiClient()

View file

@ -55,6 +55,18 @@ const router = createRouter({
name: 'invite-accept', name: 'invite-accept',
component: () => import('../views/InviteAccept.vue'), component: () => import('../views/InviteAccept.vue'),
}, },
{
path: '/agents',
name: 'agents',
component: () => import('../views/AgentsView.vue'),
meta: { requiresAuth: true, requiresManager: true },
},
{
path: '/agents/:id',
name: 'agent-detail',
component: () => import('../views/AgentDetailView.vue'),
meta: { requiresAuth: true, requiresManager: true },
},
], ],
}) })

139
src/stores/agentStore.ts Normal file
View file

@ -0,0 +1,139 @@
import { defineStore } from 'pinia'
import { ref } from 'vue'
import type { AgentEvent } from '../types/agent'
export const useAgentStore = defineStore('agent', () => {
const isConnected = ref(false)
const executionStatus = ref<'idle' | 'running' | 'completed' | 'failed'>('idle')
const eventLog = ref<AgentEvent[]>([])
const lastExecutionId = ref<string | null>(null)
let socket: WebSocket | null = null
const pushEvent = (evt: { type: string; message?: string; content?: unknown; timestamp?: string }) => {
eventLog.value.unshift({
type: evt.type,
message: evt.message,
content: evt.content,
timestamp: evt.timestamp ? new Date(evt.timestamp) : new Date(),
})
}
const connect = (agentId: string) => {
if (socket) {
socket.close()
socket = null
}
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws'
const wsUrl = `${wsProtocol}://${window.location.host}/ws/mlstore/agents/${agentId}/`
socket = new WebSocket(wsUrl)
socket.onopen = () => {
isConnected.value = true
pushEvent({ type: 'connected', message: 'WebSocket connected' })
}
socket.onmessage = (event) => {
try {
const payload = JSON.parse(event.data)
const type = payload.type
if (type === 'connection') {
pushEvent({ type: 'connection', message: payload.message, content: payload })
return
}
if (type === 'execution_started') {
executionStatus.value = 'running'
lastExecutionId.value = payload.execution_id
pushEvent({ type: 'started', message: payload.message, content: payload })
return
}
if (type === 'mlstore_event') {
const eventType = payload.event_type || 'message'
pushEvent({
type: eventType,
content: payload.content,
timestamp: payload.timestamp,
})
return
}
if (type === 'execution_completed') {
executionStatus.value = 'completed'
pushEvent({ type: 'completed', content: payload.output_data, timestamp: payload.timestamp })
return
}
if (type === 'execution_error') {
executionStatus.value = 'failed'
pushEvent({ type: 'error', message: payload.error_message, content: payload })
return
}
pushEvent({ type: 'message', content: payload })
} catch {
pushEvent({ type: 'error', message: 'Invalid message received', content: event.data })
}
}
socket.onerror = () => {
pushEvent({ type: 'error', message: 'WebSocket error' })
}
socket.onclose = () => {
isConnected.value = false
pushEvent({ type: 'disconnected', message: 'WebSocket disconnected' })
}
}
const disconnect = () => {
if (socket) {
socket.close()
socket = null
}
isConnected.value = false
}
const startAgent = (data: { query?: string; prompt?: string; options?: Record<string, unknown> }) => {
if (!socket || socket.readyState !== WebSocket.OPEN) return
const prompt = data.query ?? data.prompt ?? ''
socket.send(
JSON.stringify({
action: 'infer',
input_data: {
prompt,
options: data.options ?? {},
},
}),
)
}
const stopAgent = (executionId?: string) => {
if (!socket || socket.readyState !== WebSocket.OPEN) return
socket.send(
JSON.stringify({
action: 'stop_agent',
execution_id: executionId ?? lastExecutionId.value,
}),
)
}
const resetLog = () => {
eventLog.value = []
}
return {
isConnected,
executionStatus,
eventLog,
connect,
disconnect,
startAgent,
stopAgent,
resetLog,
lastExecutionId,
}
})

6
src/types/agent.ts Normal file
View file

@ -0,0 +1,6 @@
export type AgentEvent = {
type: string
timestamp: Date
message?: string
content?: unknown
}

View file

@ -0,0 +1,303 @@
<script setup lang="ts">
import { ref, onMounted, onUnmounted, computed } from 'vue'
import { useRoute } from 'vue-router'
import { Card, Typography, Button, List, Space, Spin, Input, message, Tag } from 'ant-design-vue'
import { useAgentStore } from '../stores/agentStore'
import { apiClient, isAxiosError, API } from '../router/api'
const route = useRoute()
const agentStore = useAgentStore()
const agentId = route.params.id as string
const agent = ref<Record<string, unknown>>({
id: agentId,
name: 'Loading...',
description: '',
status: 'idle',
})
const queryInput = ref('')
const isRunning = computed(() => agentStore.executionStatus === 'running')
const isConnected = computed(() => agentStore.isConnected ?? false)
const agentResponse = computed(() => {
const completedEvent = agentStore.eventLog?.find((event) => event.type === 'completed')
if (completedEvent?.content && typeof completedEvent.content === 'object') {
const output = completedEvent.content as Record<string, unknown>
const direct = output.response
if (typeof direct === 'string' && direct.trim()) return direct
const result = output.result
if (result && typeof result === 'object') {
const nested = (result as Record<string, unknown>).response
if (typeof nested === 'string' && nested.trim()) return nested
}
}
return null
})
const statusColor = (status: string) => {
const colors: Record<string, string> = {
idle: 'default',
running: 'processing',
completed: 'success',
failed: 'error',
stopped: 'warning',
}
return colors[status] || 'default'
}
const fetchAgent = async () => {
try {
const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId))
agent.value = response.data
} catch (error) {
console.error('Failed to fetch agent:', error)
if (isAxiosError(error)) {
console.error('Axios error details:', {
status: error.response?.status,
data: error.response?.data,
message: error.message,
})
}
message.error('Failed to load agent details')
}
}
const startAgent = () => {
if (!agentStore.isConnected) {
message.error('WebSocket not connected')
return
}
if (!queryInput.value.trim()) {
message.error('Please enter a query')
return
}
const data = {
query: queryInput.value.trim(),
}
agentStore.startAgent(data)
message.success('Agent execution started')
}
const stopAgent = () => {
agentStore.stopAgent(agentStore.lastExecutionId || undefined)
message.success('Agent stop requested')
}
onMounted(() => {
fetchAgent()
agentStore.connect(agentId)
})
onUnmounted(() => {
agentStore.disconnect()
})
</script>
<template>
<div class="page">
<Card class="panel" :bordered="false">
<div class="header">
<Typography.Title :level="2">{{ agent.name }}</Typography.Title>
<Tag :color="statusColor(String(agentStore.executionStatus || 'idle'))">
{{ (agentStore.executionStatus || 'idle').toString().toUpperCase() }}
</Tag>
</div>
<Typography.Paragraph type="secondary">
{{ agent.description || 'No description available' }}
</Typography.Paragraph>
<div class="connection-status">
<span>WebSocket Status:</span>
<Tag :color="agentStore.isConnected ? 'green' : 'red'">
{{ agentStore.isConnected ? 'CONNECTED' : 'DISCONNECTED' }}
</Tag>
</div>
<Typography.Title :level="4" class="section-title">Execution</Typography.Title>
<div class="execution-controls">
<Space direction="vertical" style="width: 100%">
<div>
<Typography.Text>Query:</Typography.Text>
<Input.TextArea
v-model:value="queryInput"
:disabled="isRunning"
placeholder="Enter your query here..."
:rows="4"
/>
</div>
<Space>
<Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent">
Run Agent
</Button>
<Button danger :disabled="!isRunning" @click="stopAgent">
Stop Agent
</Button>
</Space>
</Space>
</div>
<div v-if="agentResponse" class="response-section">
<Typography.Title :level="4" class="section-title">Final Response</Typography.Title>
<Card class="response-card response-final" :bordered="false">
<div class="response-content">
{{ agentResponse }}
</div>
</Card>
</div>
<Typography.Title :level="4" class="section-title">Execution Log</Typography.Title>
<Spin :spinning="isRunning" tip="Agent running...">
<div class="log-container">
<List
v-if="(agentStore.eventLog?.length ?? 0) > 0"
:data-source="agentStore.eventLog || []"
:bordered="false"
>
<template #renderItem="{ item }">
<List.Item class="log-item">
<div class="log-entry">
<Tag class="log-type">{{ item.type }}</Tag>
<span class="log-time">{{ item.timestamp.toLocaleTimeString() }}</span>
<div v-if="item.message" class="log-message">
{{ item.message }}
</div>
<div v-if="item.content && typeof item.content === 'object'" class="log-content">
<pre>{{ JSON.stringify(item.content, null, 2) }}</pre>
</div>
<div v-else-if="item.content" class="log-content">
{{ item.content }}
</div>
</div>
</List.Item>
</template>
</List>
<Typography.Paragraph v-else type="secondary">
No events yet. Start the agent to see execution logs.
</Typography.Paragraph>
</div>
</Spin>
</Card>
</div>
</template>
<style scoped>
.page {
max-width: 1200px;
margin: 0 auto;
padding: 1rem;
}
.panel {
background: #0f172a;
border: 1px solid #1f2937;
color: #e5e7eb;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1rem;
}
.section-title {
margin-top: 2rem !important;
margin-bottom: 1rem !important;
}
.connection-status {
display: flex;
align-items: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 0.5rem;
background: #1f2937;
border-radius: 4px;
}
.execution-controls {
background: #1f2937;
padding: 1rem;
border-radius: 4px;
margin: 1rem 0;
}
.log-container {
background: #1f2937;
border-radius: 4px;
max-height: 500px;
overflow-y: auto;
}
.log-item {
border-bottom: 1px solid #374151 !important;
padding: 0.75rem !important;
}
.log-entry {
display: flex;
flex-direction: column;
gap: 0.5rem;
width: 100%;
}
.log-type {
width: fit-content;
}
.log-time {
font-size: 0.75rem;
color: #9ca3af;
}
.log-message {
color: #e5e7eb;
font-size: 0.9rem;
}
.log-content {
background: #111827;
padding: 0.5rem;
border-radius: 3px;
overflow-x: auto;
}
.log-content pre {
margin: 0;
font-size: 0.8rem;
color: #d1d5db;
}
.response-section {
margin-top: 2rem;
}
.response-card {
background: #1f2937;
border: 1px solid #374151;
}
.response-final {
border-color: #6366f1;
box-shadow: 0 0 0 1px rgba(99, 102, 241, 0.35);
}
.response-content {
color: #e5e7eb;
font-size: 1rem;
line-height: 1.6;
white-space: pre-wrap;
word-wrap: break-word;
padding: 0.5rem;
}
</style>

92
src/views/AgentsView.vue Normal file
View file

@ -0,0 +1,92 @@
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { List, Typography, Button, Card, Spin, message } from 'ant-design-vue'
import { apiClient } from '../router/api'
import { API } from '../router/api'
interface Agent {
uuid: string
id: string
name: string
description: string
status: string
}
const agents = ref<Agent[]>([])
const loading = ref(false)
const loadError = ref(false)
const fetchAgents = async () => {
loading.value = true
loadError.value = false
try {
const response = await apiClient.get<Agent[]>(API.agents())
agents.value = Array.isArray(response.data) ? response.data : []
} catch (error) {
console.error('Failed to fetch agents:', error)
message.error('Failed to load agents')
agents.value = []
loadError.value = true
} finally {
loading.value = false
}
}
onMounted(() => {
fetchAgents()
})
</script>
<template>
<div class="page">
<Typography.Title :level="2">Agents</Typography.Title>
<Typography.Paragraph type="secondary">
Manage and inspect the available AI agents.
</Typography.Paragraph>
<Card class="panel" :bordered="false">
<Spin :spinning="loading" tip="Loading agents...">
<div v-if="loadError" class="empty">
<Typography.Paragraph type="danger">
Failed to load agents.
</Typography.Paragraph>
</div>
<div v-else-if="!loading && agents.length === 0" class="empty">
<Typography.Paragraph type="secondary">
No agents found.
</Typography.Paragraph>
</div>
<List v-else :data-source="agents" item-layout="horizontal" :bordered="false">
<template #renderItem="{ item }">
<List.Item class="item">
<List.Item.Meta
:title="item.name"
:description="`${item.description} Status: ${item.status}`"
/>
<RouterLink :to="`/agents/${item.uuid || item.id}`">
<Button type="primary" size="small">Open</Button>
</RouterLink>
</List.Item>
</template>
</List>
</Spin>
</Card>
</div>
</template>
<style scoped>
.page {
max-width: 900px;
margin: 0 auto;
padding: 1rem;
}
.panel {
background: #0f172a;
border: 1px solid #1f2937;
color: #e5e7eb;
}
.item :deep(.ant-list-item-meta-title),
.item :deep(.ant-list-item-meta-description) {
color: #e5e7eb;
}
</style>