Added agent models, viewsets, consumers, gpu and mcp
This commit is contained in:
parent
6039d6b2ac
commit
fcd4862e18
22 changed files with 1463 additions and 26 deletions
193
apps/mlstore/consumers.py
Normal file
193
apps/mlstore/consumers.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
import json
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from channels.db import database_sync_to_async
|
||||
from django.utils import timezone
|
||||
from .models import Agent, AgentRun, AgentEvent
|
||||
from .tasks import start_fine_tune_run_task, infer_run_task
|
||||
|
||||
|
||||
class MLStoreConsumer(AsyncWebsocketConsumer):
|
||||
async def connect(self):
|
||||
self.user = self.scope["user"]
|
||||
self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id")
|
||||
self.room_group_name = f"mlstore_agent_{self.agent_id}"
|
||||
|
||||
if not self.user.is_authenticated:
|
||||
await self.close()
|
||||
return
|
||||
|
||||
agent = await self.get_agent(self.agent_id)
|
||||
if not agent:
|
||||
await self.close()
|
||||
return
|
||||
|
||||
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
|
||||
await self.accept()
|
||||
await self.send(json.dumps({
|
||||
"type": "connection",
|
||||
"message": "Connected to mlstore agent stream",
|
||||
"agent_id": str(self.agent_id)
|
||||
}))
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
|
||||
|
||||
async def receive(self, text_data):
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
action = data.get("action")
|
||||
|
||||
if action == "fine_tune":
|
||||
await self.handle_fine_tune(data)
|
||||
elif action == "infer":
|
||||
await self.handle_infer(data)
|
||||
elif action in ("stop_agent", "stop"):
|
||||
await self.handle_stop(data)
|
||||
else:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Unknown action: {action}"
|
||||
}))
|
||||
except json.JSONDecodeError:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Invalid JSON"
|
||||
}))
|
||||
except Exception as e:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}))
|
||||
|
||||
async def handle_fine_tune(self, data):
|
||||
agent = await self.get_agent(self.agent_id)
|
||||
if not agent:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Agent not found"
|
||||
}))
|
||||
return
|
||||
|
||||
input_data = data.get("input_data") or {}
|
||||
execution = await self.create_run(agent, self.user, input_data)
|
||||
|
||||
await self.send(json.dumps({
|
||||
"type": "execution_started",
|
||||
"execution_id": str(execution.uuid),
|
||||
"agent_id": str(agent.uuid),
|
||||
"message": f"Fine-tune run {execution.uuid} queued"
|
||||
}))
|
||||
|
||||
start_fine_tune_run_task.delay(str(execution.uuid))
|
||||
|
||||
async def handle_infer(self, data):
|
||||
agent = await self.get_agent(self.agent_id)
|
||||
if not agent:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Agent not found"
|
||||
}))
|
||||
return
|
||||
|
||||
input_data = data.get("input_data") or {}
|
||||
execution = await self.create_run(agent, self.user, input_data)
|
||||
|
||||
await self.send(json.dumps({
|
||||
"type": "execution_started",
|
||||
"execution_id": str(execution.uuid),
|
||||
"agent_id": str(agent.uuid),
|
||||
"message": f"Inference run {execution.uuid} queued"
|
||||
}))
|
||||
|
||||
infer_run_task.delay(str(execution.uuid))
|
||||
|
||||
async def handle_stop(self, data):
|
||||
execution_id = data.get("execution_id")
|
||||
if not execution_id:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": "execution_id required to stop"
|
||||
}))
|
||||
return
|
||||
|
||||
execution = await self.get_execution(execution_id)
|
||||
if not execution:
|
||||
await self.send(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Execution not found"
|
||||
}))
|
||||
return
|
||||
|
||||
await self.update_execution_status(execution, "failed")
|
||||
await self.send(json.dumps({
|
||||
"type": "execution_error",
|
||||
"execution_id": str(execution.uuid),
|
||||
"error_message": "Execution stopped by user"
|
||||
}))
|
||||
|
||||
async def mlstore_event(self, event):
|
||||
await self.send(json.dumps({
|
||||
"type": "mlstore_event",
|
||||
"event_type": event["event_type"],
|
||||
"content": event["content"],
|
||||
"timestamp": event["timestamp"]
|
||||
}))
|
||||
|
||||
async def mlstore_completed(self, event):
|
||||
await self.send(json.dumps({
|
||||
"type": "execution_completed",
|
||||
"execution_id": event["execution_id"],
|
||||
"output_data": event["output_data"],
|
||||
"message": "Execution completed"
|
||||
}))
|
||||
|
||||
async def mlstore_error(self, event):
|
||||
await self.send(json.dumps({
|
||||
"type": "execution_error",
|
||||
"execution_id": event["execution_id"],
|
||||
"error_message": event["error_message"]
|
||||
}))
|
||||
|
||||
@database_sync_to_async
|
||||
def get_agent(self, agent_id):
|
||||
try:
|
||||
return Agent.objects.get(uuid=agent_id)
|
||||
except Agent.DoesNotExist:
|
||||
return None
|
||||
|
||||
@database_sync_to_async
|
||||
def create_run(self, agent, user, input_data):
|
||||
return AgentRun.objects.create(
|
||||
agent=agent,
|
||||
user=user,
|
||||
input_data=input_data,
|
||||
)
|
||||
|
||||
@database_sync_to_async
|
||||
def get_execution(self, execution_id):
|
||||
try:
|
||||
return AgentRun.objects.get(uuid=execution_id)
|
||||
except AgentRun.DoesNotExist:
|
||||
return None
|
||||
|
||||
@database_sync_to_async
|
||||
def update_execution_status(self, execution, status):
|
||||
execution.status = status
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
try:
|
||||
agent = execution.agent
|
||||
agent.status = status
|
||||
agent.completed_at = timezone.now()
|
||||
agent.save()
|
||||
except Exception:
|
||||
pass
|
||||
return execution
|
||||
|
||||
@database_sync_to_async
|
||||
def create_event(self, execution, event_type, content):
|
||||
return AgentEvent.objects.create(
|
||||
execution=execution,
|
||||
event_type=event_type,
|
||||
content=content,
|
||||
)
|
||||
|
|
@ -3,7 +3,6 @@ import uuid
|
|||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
|
|
|||
15
apps/mlstore/migrations/0002_agentmodel_path.py
Normal file
15
apps/mlstore/migrations/0002_agentmodel_path.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from django.db import migrations, models
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('mlstore', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='agentmodel',
|
||||
name='path',
|
||||
field=models.CharField(blank=True, default='', max_length=1024),
|
||||
),
|
||||
]
|
||||
|
|
@ -9,6 +9,7 @@ class AgentModel(Model):
|
|||
uuid = UUIDField(default = uuid4, unique = True, editable = False)
|
||||
name = CharField(max_length = 255)
|
||||
version = CharField(max_length = 50)
|
||||
path = CharField(max_length=1024, blank=True, default='')
|
||||
|
||||
class Meta:
|
||||
verbose_name = 'Model'
|
||||
|
|
|
|||
6
apps/mlstore/routing.py
Normal file
6
apps/mlstore/routing.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from django.urls import path
|
||||
from . import consumers
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path("ws/mlstore/agents/<str:agent_id>/", consumers.MLStoreConsumer.as_asgi()),
|
||||
]
|
||||
51
apps/mlstore/serializers.py
Normal file
51
apps/mlstore/serializers.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from rest_framework.serializers import ModelSerializer
|
||||
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
||||
|
||||
|
||||
class AgentModelSerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = AgentModel
|
||||
fields = ['id', 'uuid', 'name', 'version', 'path']
|
||||
read_only_fields = ['id', 'uuid']
|
||||
|
||||
|
||||
class AgentSerializer(ModelSerializer):
|
||||
model = AgentModelSerializer(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Agent
|
||||
fields = [
|
||||
'id',
|
||||
'uuid',
|
||||
'model',
|
||||
'status',
|
||||
'description',
|
||||
'started_at',
|
||||
'completed_at',
|
||||
]
|
||||
read_only_fields = ['id', 'uuid', 'started_at', 'completed_at']
|
||||
|
||||
|
||||
class AgentRunSerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = AgentRun
|
||||
fields = [
|
||||
'id',
|
||||
'uuid',
|
||||
'agent',
|
||||
'user',
|
||||
'status',
|
||||
'input_data',
|
||||
'output_data',
|
||||
'error_message',
|
||||
'started_at',
|
||||
'completed_at',
|
||||
]
|
||||
read_only_fields = ['id', 'uuid', 'started_at', 'completed_at']
|
||||
|
||||
|
||||
class AgentEventSerializer(ModelSerializer):
|
||||
class Meta:
|
||||
model = AgentEvent
|
||||
fields = ['id', 'uuid', 'execution', 'event_type', 'content', 'timestamp']
|
||||
read_only_fields = ['id', 'uuid', 'timestamp']
|
||||
61
apps/mlstore/services.py
Normal file
61
apps/mlstore/services.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from django.conf import settings
|
||||
from mcp_agent.mcp_client import MCPClient
|
||||
from .models import AgentModel
|
||||
|
||||
|
||||
async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Internal async helper to call the MCP HTTP bridge via MCPClient."""
|
||||
server_url = getattr(settings, "MCP_AGENT_URL")
|
||||
client = MCPClient(server_url)
|
||||
try:
|
||||
resp = await client.send(tool, arguments)
|
||||
return resp
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
def fine_tune_model(
|
||||
base_model: str,
|
||||
training_files: List[str],
|
||||
hyperparams: Dict[str, Any],
|
||||
name: str,
|
||||
version: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Synchronously request a fine-tune run on the MCP server.
|
||||
|
||||
Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version}
|
||||
and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`.
|
||||
"""
|
||||
return asyncio.run(_call_mcp("fine_tune", {
|
||||
"base_model": base_model,
|
||||
"training_files": training_files,
|
||||
"hyperparams": hyperparams,
|
||||
"name": name,
|
||||
"version": version,
|
||||
}))
|
||||
|
||||
|
||||
def load_model_for_inference(model_path: str) -> Dict[str, Any]:
|
||||
"""Tell the MCP server to load a model into memory/serving for inference.
|
||||
|
||||
Expects the MCP tool `load_model` with {model_path} returning status info.
|
||||
"""
|
||||
return asyncio.run(_call_mcp("load_model", {"model_path": model_path}))
|
||||
|
||||
|
||||
def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Request inference from the MCP server using a previously fine-tuned model.
|
||||
|
||||
Calls the MCP tool `infer` with {model_path, prompt, options}.
|
||||
"""
|
||||
return asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}}))
|
||||
|
||||
|
||||
def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel:
|
||||
"""Convenience DB helper: create and return an AgentModel record.
|
||||
|
||||
NOTE: migrations are required after the model field change prior to using this in production.
|
||||
"""
|
||||
return AgentModel.objects.create(name=name, version=version, path=model_path)
|
||||
267
apps/mlstore/tasks.py
Normal file
267
apps/mlstore/tasks.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
from celery import shared_task
|
||||
from django.utils import timezone
|
||||
from channels.layers import get_channel_layer
|
||||
from asgiref.sync import async_to_sync
|
||||
from . import services
|
||||
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
||||
import traceback
|
||||
|
||||
|
||||
@shared_task
|
||||
def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str):
|
||||
"""Start a fine-tune via MCP, and register the resulting model on success.
|
||||
|
||||
This task calls `services.fine_tune_model`, expects a dict result with `status` and on success
|
||||
`model_path` and optionally `version`.
|
||||
"""
|
||||
try:
|
||||
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
||||
|
||||
if isinstance(result, dict) and result.get("status") == "completed":
|
||||
model_path = result.get("model_path") or result.get("path") or ""
|
||||
model_version = result.get("version") or version
|
||||
m = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
||||
return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result}
|
||||
|
||||
return {"status": "failed", "result": result}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
@shared_task
|
||||
def infer_with_model_task(model_id: int, prompt: str, options: dict = None):
|
||||
"""Run inference by requesting the MCP server to use the stored model.
|
||||
|
||||
Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response.
|
||||
"""
|
||||
try:
|
||||
model = AgentModel.objects.get(id=model_id)
|
||||
except AgentModel.DoesNotExist:
|
||||
return {"status": "error", "error": "model_not_found", "model_id": model_id}
|
||||
|
||||
try:
|
||||
services.load_model_for_inference(model.path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
out = services.infer_with_model(model.path, prompt, options or {})
|
||||
return {"status": "completed", "model_id": model_id, "response": out}
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {"status": "failed", "error": str(e)}
|
||||
|
||||
|
||||
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
||||
channel_layer = get_channel_layer()
|
||||
async_to_sync(channel_layer.group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_event",
|
||||
"event_type": event_type,
|
||||
"content": content,
|
||||
"timestamp": timezone.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _persist_event(execution: AgentRun, event_type: str, content: dict):
|
||||
AgentEvent.objects.create(
|
||||
execution=execution,
|
||||
event_type=event_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _update_agent_status(agent: Agent, status: str):
|
||||
agent.status = status
|
||||
if status == "running":
|
||||
agent.started_at = timezone.now()
|
||||
elif status in ("completed", "failed"):
|
||||
agent.completed_at = timezone.now()
|
||||
agent.save()
|
||||
|
||||
|
||||
@shared_task
|
||||
def start_fine_tune_run_task(execution_id: str):
|
||||
try:
|
||||
execution = AgentRun.objects.get(uuid=execution_id)
|
||||
except AgentRun.DoesNotExist:
|
||||
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
||||
|
||||
agent = execution.agent
|
||||
room_group_name = f"mlstore_agent_{agent.uuid}"
|
||||
|
||||
execution.status = "running"
|
||||
execution.started_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "running")
|
||||
|
||||
input_data = execution.input_data or {}
|
||||
base_model = input_data.get("base_model") or agent.model.name
|
||||
training_files = input_data.get("training_files") or []
|
||||
hyperparams = input_data.get("hyperparams") or {}
|
||||
name = input_data.get("name") or f"{agent.model.name}-ft"
|
||||
version = input_data.get("version") or "v1"
|
||||
|
||||
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
||||
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
||||
|
||||
try:
|
||||
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
||||
if isinstance(result, dict) and result.get("status") == "completed":
|
||||
model_path = result.get("model_path") or result.get("path") or ""
|
||||
model_version = result.get("version") or version
|
||||
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
||||
agent.model = new_model
|
||||
agent.save()
|
||||
|
||||
execution.status = "completed"
|
||||
execution.output_data = {
|
||||
"result": result,
|
||||
"model_id": new_model.id,
|
||||
"model_uuid": str(new_model.uuid),
|
||||
}
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "completed")
|
||||
|
||||
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
||||
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
||||
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_completed",
|
||||
"execution_id": str(execution.uuid),
|
||||
"output_data": execution.output_data,
|
||||
},
|
||||
)
|
||||
|
||||
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
|
||||
|
||||
execution.status = "failed"
|
||||
execution.error_message = str(result)
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "failed")
|
||||
|
||||
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
|
||||
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
|
||||
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_error",
|
||||
"execution_id": str(execution.uuid),
|
||||
"error_message": str(result),
|
||||
},
|
||||
)
|
||||
return {"status": "failed", "execution_id": execution_id, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
execution.status = "failed"
|
||||
execution.error_message = str(e)
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "failed")
|
||||
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||||
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_error",
|
||||
"execution_id": str(execution.uuid),
|
||||
"error_message": str(e),
|
||||
},
|
||||
)
|
||||
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
||||
|
||||
|
||||
@shared_task
|
||||
def infer_run_task(execution_id: str):
|
||||
try:
|
||||
execution = AgentRun.objects.get(uuid=execution_id)
|
||||
except AgentRun.DoesNotExist:
|
||||
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
||||
|
||||
agent = execution.agent
|
||||
room_group_name = f"mlstore_agent_{agent.uuid}"
|
||||
|
||||
execution.status = "running"
|
||||
execution.started_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "running")
|
||||
|
||||
input_data = execution.input_data or {}
|
||||
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
||||
options = input_data.get("options") or {}
|
||||
|
||||
if not prompt:
|
||||
execution.status = "failed"
|
||||
execution.error_message = "prompt_required"
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "failed")
|
||||
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
||||
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_error",
|
||||
"execution_id": str(execution.uuid),
|
||||
"error_message": "prompt_required",
|
||||
},
|
||||
)
|
||||
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
|
||||
|
||||
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
||||
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
||||
|
||||
try:
|
||||
try:
|
||||
services.load_model_for_inference(agent.model.path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result = services.infer_with_model(agent.model.path, prompt, options)
|
||||
|
||||
execution.status = "completed"
|
||||
execution.output_data = {"result": result}
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "completed")
|
||||
|
||||
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
|
||||
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_completed",
|
||||
"execution_id": str(execution.uuid),
|
||||
"output_data": execution.output_data,
|
||||
},
|
||||
)
|
||||
return {"status": "completed", "execution_id": execution_id}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
execution.status = "failed"
|
||||
execution.error_message = str(e)
|
||||
execution.completed_at = timezone.now()
|
||||
execution.save()
|
||||
_update_agent_status(agent, "failed")
|
||||
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||||
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||||
async_to_sync(get_channel_layer().group_send)(
|
||||
room_group_name,
|
||||
{
|
||||
"type": "mlstore_error",
|
||||
"execution_id": str(execution.uuid),
|
||||
"error_message": str(e),
|
||||
},
|
||||
)
|
||||
return {"status": "failed", "execution_id": execution_id, "error": str(e)}
|
||||
29
apps/mlstore/viewsets.py
Normal file
29
apps/mlstore/viewsets.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from rest_framework.viewsets import ModelViewSet
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from .models import Agent, AgentRun, AgentEvent
|
||||
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
||||
class AgentViewSet(ModelViewSet):
|
||||
queryset = Agent.objects.all()
|
||||
serializer_class = AgentSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
lookup_field = 'uuid'
|
||||
|
||||
|
||||
class AgentRunViewSet(ModelViewSet):
|
||||
queryset = AgentRun.objects.all()
|
||||
serializer_class = AgentRunSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
lookup_field = 'uuid'
|
||||
|
||||
def get_queryset(self):
|
||||
return AgentRun.objects.filter(user=self.request.user)
|
||||
|
||||
@action(detail=True, methods=['get'], url_path='events')
|
||||
def events(self, request, uuid=None):
|
||||
run = self.get_object()
|
||||
events = AgentEvent.objects.filter(execution=run).order_by('timestamp')
|
||||
serializer = AgentEventSerializer(events, many=True)
|
||||
return Response(serializer.data)
|
||||
|
|
@ -91,6 +91,18 @@ services:
|
|||
- ../../.env
|
||||
volumes:
|
||||
- ../../:/app
|
||||
- ../../notebooks/build:/app/notebooks/build
|
||||
deploy:
|
||||
mode: replicated
|
||||
replicas: 1
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
environment:
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
ports:
|
||||
- "0.0.0.0:8001:8001"
|
||||
depends_on:
|
||||
|
|
@ -100,6 +112,8 @@ services:
|
|||
condition: service_healthy
|
||||
|
||||
|
||||
|
||||
|
||||
volumes:
|
||||
fyp_postgres_data:
|
||||
fyp_redis_data:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ from rest_framework.routers import DefaultRouter
|
|||
|
||||
from apps.orgs.viewsets import OrganizationViewSet
|
||||
from apps.users.viewsets import UserViewSet
|
||||
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r'user', UserViewSet, basename='user')
|
||||
router.register(r'organization', OrganizationViewSet, basename='organization')
|
||||
router.register(r'agent', AgentViewSet, basename='agent')
|
||||
router.register(r'agent-run', AgentRunViewSet, basename='agent-run')
|
||||
|
||||
urlpatterns = router.urls
|
||||
|
|
|
|||
|
|
@ -5,16 +5,17 @@ from django.core.asgi import get_asgi_application
|
|||
from channels.auth import AuthMiddlewareStack
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from channels.security.websocket import AllowedHostsOriginValidator
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
|
||||
django_asgi_app = get_asgi_application()
|
||||
|
||||
from apps.mlstore.routing import websocket_urlpatterns
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": django_asgi_app,
|
||||
"websocket": AllowedHostsOriginValidator(
|
||||
AuthMiddlewareStack(
|
||||
URLRouter([])
|
||||
URLRouter(websocket_urlpatterns)
|
||||
)
|
||||
)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ PARENT_NAME = Path(__file__).resolve().parent.name
|
|||
|
||||
DJANGO_CELERY_BROKER_URL = os.getenv('DJANGO_CELERY_BROKER_URL', 'redis://localhost:6379/0')
|
||||
|
||||
MCP_SERVER_HOST = os.getenv('MCP_SERVER_HOST', 'localhost')
|
||||
MCP_SERVER_PORT = os.getenv('MCP_SERVER_PORT', '8001')
|
||||
MCP_AGENT_URL = f"http://{MCP_SERVER_HOST}:{MCP_SERVER_PORT}"
|
||||
|
||||
STATIC_URL = os.getenv('DJANGO_STATIC_URL', '/static/')
|
||||
MEDIA_URL = os.getenv('DJANGO_MEDIA_URL', '/media/')
|
||||
STATIC_ROOT = os.getenv('DJANGO_STATIC_ROOT', BASE_DIR / 'static')
|
||||
|
|
|
|||
|
|
@ -2,11 +2,16 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import PureWindowsPath
|
||||
from typing import Any, Dict, List
|
||||
from aiohttp import web
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
|
||||
app = Server("minimal-mcp-server")
|
||||
app = Server("mlstore-mcp-server")
|
||||
|
||||
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
@app.list_tools()
|
||||
|
|
@ -23,26 +28,263 @@ async def list_tools():
|
|||
"required": ["message"]
|
||||
},
|
||||
)
|
||||
,
|
||||
Tool(
|
||||
name="fine_tune",
|
||||
description="Start fine-tuning a base model using training files",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base_model": {"type": "string"},
|
||||
"training_files": {"type": "array", "items": {"type": "string"}},
|
||||
"hyperparams": {"type": "object"},
|
||||
"name": {"type": "string"},
|
||||
"version": {"type": "string"}
|
||||
},
|
||||
"required": ["base_model", "training_files", "name", "version"]
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="load_model",
|
||||
description="Load a fine-tuned model into memory for inference",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_path": {"type": "string"}
|
||||
},
|
||||
"required": ["model_path"]
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="infer",
|
||||
description="Run inference with a fine-tuned model",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_path": {"type": "string"},
|
||||
"prompt": {"type": "string"},
|
||||
"options": {"type": "object"}
|
||||
},
|
||||
"required": ["model_path", "prompt"]
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
return datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
|
||||
def _model_root() -> str:
|
||||
return os.getenv("MCP_MODEL_DIR") or os.getenv("DJANGO_MODEL_DIR") or os.path.join(os.getcwd(), "model")
|
||||
|
||||
|
||||
def _safe_dir_name(name: str) -> str:
|
||||
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
|
||||
|
||||
|
||||
def _resolve_model_path(model_path: str) -> str:
|
||||
if not model_path:
|
||||
return model_path
|
||||
|
||||
norm = os.path.normpath(model_path)
|
||||
if os.path.isabs(norm) and os.path.exists(norm):
|
||||
return norm
|
||||
|
||||
candidates = []
|
||||
|
||||
# Try relative to current working directory
|
||||
candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm)))
|
||||
|
||||
# Try relative to model root
|
||||
candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm))))
|
||||
|
||||
# If it's a Windows-style absolute path, map to container /app by trimming common root
|
||||
if ":" in model_path or "\\" in model_path:
|
||||
p = PureWindowsPath(model_path)
|
||||
parts = [str(x) for x in p.parts]
|
||||
for anchor in ("notebooks", "model"):
|
||||
if anchor in parts:
|
||||
idx = parts.index(anchor)
|
||||
rel = os.path.join(*parts[idx:])
|
||||
candidates.append(os.path.normpath(os.path.join(os.getcwd(), rel)))
|
||||
|
||||
for cand in candidates:
|
||||
if os.path.exists(cand):
|
||||
return cand
|
||||
|
||||
return norm
|
||||
|
||||
|
||||
def _resolve_model_file(model_path: str) -> tuple[str, str]:
|
||||
"""Return (model_dir, model_filename) for GPT4All."""
|
||||
resolved = _resolve_model_path(model_path)
|
||||
if os.path.isdir(resolved):
|
||||
for name in os.listdir(resolved):
|
||||
if name.lower().endswith(".gguf"):
|
||||
return resolved, name
|
||||
return resolved, ""
|
||||
return os.path.dirname(resolved), os.path.basename(resolved)
|
||||
|
||||
|
||||
async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
||||
if name == "echo":
|
||||
return {"status": "ok", "received": arguments, "timestamp": _now()}
|
||||
|
||||
if name == "fine_tune":
|
||||
base_model = arguments.get("base_model")
|
||||
training_files = arguments.get("training_files") or []
|
||||
hyperparams = arguments.get("hyperparams") or {}
|
||||
model_name = arguments.get("name") or "model"
|
||||
version = arguments.get("version") or "v1"
|
||||
|
||||
model_root = _model_root()
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
|
||||
safe_name = _safe_dir_name(model_name)
|
||||
safe_version = _safe_dir_name(version)
|
||||
output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
metadata = {
|
||||
"status": "completed",
|
||||
"base_model": base_model,
|
||||
"training_files": training_files,
|
||||
"hyperparams": hyperparams,
|
||||
"name": model_name,
|
||||
"version": version,
|
||||
"model_path": output_dir,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
try:
|
||||
with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return metadata
|
||||
|
||||
if name == "load_model":
|
||||
model_path = arguments.get("model_path")
|
||||
if not model_path:
|
||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||
|
||||
model_path = _resolve_model_path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
|
||||
model_dir, model_file = _resolve_model_file(model_path)
|
||||
if not model_file:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "model_file_not_found",
|
||||
"model_path": model_path,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu')
|
||||
LOADED_MODELS[model_path] = {
|
||||
"loaded_at": _now(),
|
||||
"model": model,
|
||||
"model_dir": model_dir,
|
||||
"model_file": model_file,
|
||||
}
|
||||
return {
|
||||
"status": "completed",
|
||||
"model_path": model_path,
|
||||
"loaded": True,
|
||||
"model_dir": model_dir,
|
||||
"model_file": model_file,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"model_path": model_path,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
if name == "infer":
|
||||
model_path = arguments.get("model_path")
|
||||
prompt = arguments.get("prompt") or ""
|
||||
options = arguments.get("options") or {}
|
||||
|
||||
if not model_path:
|
||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||
|
||||
model_path = _resolve_model_path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
||||
|
||||
try:
|
||||
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
|
||||
from gpt4all import GPT4All
|
||||
|
||||
model_dir, model_file = _resolve_model_file(model_path)
|
||||
if not model_file:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "model_file_not_found",
|
||||
"model_path": model_path,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False)
|
||||
LOADED_MODELS[model_path] = {
|
||||
"loaded_at": _now(),
|
||||
"model": model,
|
||||
"model_dir": model_dir,
|
||||
"model_file": model_file,
|
||||
}
|
||||
|
||||
model = LOADED_MODELS[model_path]["model"]
|
||||
max_tokens = int(options.get("max_tokens", 256))
|
||||
temp = float(options.get("temperature", options.get("temp", 0.7)))
|
||||
top_p = float(options.get("top_p", 0.95))
|
||||
top_k = int(options.get("top_k", 40))
|
||||
|
||||
response_text = model.generate(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
temp=temp,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"model_path": model_path,
|
||||
"response": response_text,
|
||||
"options": {
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temp,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
},
|
||||
"timestamp": _now(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"model_path": model_path,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
|
||||
@app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict):
|
||||
if name != "echo":
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=json.dumps(
|
||||
{
|
||||
"received": arguments,
|
||||
"status": "ok",
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
]
|
||||
result = await _run_tool_http(name, arguments)
|
||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
|
||||
|
||||
async def handle_execute(request: web.Request) -> web.Response:
|
||||
|
|
@ -56,13 +298,8 @@ async def handle_execute(request: web.Request) -> web.Response:
|
|||
{"error": "Missing 'tool' field"}, status=400
|
||||
)
|
||||
|
||||
result = await call_tool(tool, arguments)
|
||||
return web.json_response(
|
||||
{
|
||||
"tool": tool,
|
||||
"result": [c.text for c in result],
|
||||
}
|
||||
)
|
||||
result = await _run_tool_http(tool, arguments)
|
||||
return web.json_response(result)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON"}, status=400)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ django-jazzmin==3.0.1
|
|||
django-timezone-field==7.2.1
|
||||
django_celery_results==2.6.0
|
||||
djangorestframework==3.16.1
|
||||
httpx==0.28.1
|
||||
hyperlink==21.0.0
|
||||
idna==3.11
|
||||
Incremental==24.11.0
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ pyjwt==2.10.1
|
|||
python-multipart==0.0.21
|
||||
sse-starlette==3.2.0
|
||||
starlette==0.52.1
|
||||
gpt4all==2.8.2
|
||||
|
|
@ -93,6 +93,8 @@ export const API = {
|
|||
`/api/organization/${orgUuid}/create-invite/?max_uses=${max_uses}`,
|
||||
organizationJoin: (token: string) => `/api/organization/join/${token}/`,
|
||||
organizationLeave: (orgUuid: string) => `/api/organization/${orgUuid}/leave/`,
|
||||
agents: () => '/api/agent/',
|
||||
agent: (id: string) => `/api/agent/${id}/`,
|
||||
}
|
||||
|
||||
export const apiClient = new ApiClient()
|
||||
|
|
|
|||
|
|
@ -55,6 +55,18 @@ const router = createRouter({
|
|||
name: 'invite-accept',
|
||||
component: () => import('../views/InviteAccept.vue'),
|
||||
},
|
||||
{
|
||||
path: '/agents',
|
||||
name: 'agents',
|
||||
component: () => import('../views/AgentsView.vue'),
|
||||
meta: { requiresAuth: true, requiresManager: true },
|
||||
},
|
||||
{
|
||||
path: '/agents/:id',
|
||||
name: 'agent-detail',
|
||||
component: () => import('../views/AgentDetailView.vue'),
|
||||
meta: { requiresAuth: true, requiresManager: true },
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
|
|
|
|||
139
src/stores/agentStore.ts
Normal file
139
src/stores/agentStore.ts
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import { ref } from 'vue'
|
||||
import type { AgentEvent } from '../types/agent'
|
||||
|
||||
export const useAgentStore = defineStore('agent', () => {
|
||||
const isConnected = ref(false)
|
||||
const executionStatus = ref<'idle' | 'running' | 'completed' | 'failed'>('idle')
|
||||
const eventLog = ref<AgentEvent[]>([])
|
||||
const lastExecutionId = ref<string | null>(null)
|
||||
|
||||
let socket: WebSocket | null = null
|
||||
|
||||
const pushEvent = (evt: { type: string; message?: string; content?: unknown; timestamp?: string }) => {
|
||||
eventLog.value.unshift({
|
||||
type: evt.type,
|
||||
message: evt.message,
|
||||
content: evt.content,
|
||||
timestamp: evt.timestamp ? new Date(evt.timestamp) : new Date(),
|
||||
})
|
||||
}
|
||||
|
||||
const connect = (agentId: string) => {
|
||||
if (socket) {
|
||||
socket.close()
|
||||
socket = null
|
||||
}
|
||||
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws'
|
||||
const wsUrl = `${wsProtocol}://${window.location.host}/ws/mlstore/agents/${agentId}/`
|
||||
socket = new WebSocket(wsUrl)
|
||||
|
||||
socket.onopen = () => {
|
||||
isConnected.value = true
|
||||
pushEvent({ type: 'connected', message: 'WebSocket connected' })
|
||||
}
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
try {
|
||||
const payload = JSON.parse(event.data)
|
||||
const type = payload.type
|
||||
|
||||
if (type === 'connection') {
|
||||
pushEvent({ type: 'connection', message: payload.message, content: payload })
|
||||
return
|
||||
}
|
||||
|
||||
if (type === 'execution_started') {
|
||||
executionStatus.value = 'running'
|
||||
lastExecutionId.value = payload.execution_id
|
||||
pushEvent({ type: 'started', message: payload.message, content: payload })
|
||||
return
|
||||
}
|
||||
|
||||
if (type === 'mlstore_event') {
|
||||
const eventType = payload.event_type || 'message'
|
||||
pushEvent({
|
||||
type: eventType,
|
||||
content: payload.content,
|
||||
timestamp: payload.timestamp,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (type === 'execution_completed') {
|
||||
executionStatus.value = 'completed'
|
||||
pushEvent({ type: 'completed', content: payload.output_data, timestamp: payload.timestamp })
|
||||
return
|
||||
}
|
||||
|
||||
if (type === 'execution_error') {
|
||||
executionStatus.value = 'failed'
|
||||
pushEvent({ type: 'error', message: payload.error_message, content: payload })
|
||||
return
|
||||
}
|
||||
|
||||
pushEvent({ type: 'message', content: payload })
|
||||
} catch {
|
||||
pushEvent({ type: 'error', message: 'Invalid message received', content: event.data })
|
||||
}
|
||||
}
|
||||
|
||||
socket.onerror = () => {
|
||||
pushEvent({ type: 'error', message: 'WebSocket error' })
|
||||
}
|
||||
|
||||
socket.onclose = () => {
|
||||
isConnected.value = false
|
||||
pushEvent({ type: 'disconnected', message: 'WebSocket disconnected' })
|
||||
}
|
||||
}
|
||||
|
||||
const disconnect = () => {
|
||||
if (socket) {
|
||||
socket.close()
|
||||
socket = null
|
||||
}
|
||||
isConnected.value = false
|
||||
}
|
||||
|
||||
const startAgent = (data: { query?: string; prompt?: string; options?: Record<string, unknown> }) => {
|
||||
if (!socket || socket.readyState !== WebSocket.OPEN) return
|
||||
const prompt = data.query ?? data.prompt ?? ''
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
action: 'infer',
|
||||
input_data: {
|
||||
prompt,
|
||||
options: data.options ?? {},
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
const stopAgent = (executionId?: string) => {
|
||||
if (!socket || socket.readyState !== WebSocket.OPEN) return
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
action: 'stop_agent',
|
||||
execution_id: executionId ?? lastExecutionId.value,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
const resetLog = () => {
|
||||
eventLog.value = []
|
||||
}
|
||||
|
||||
return {
|
||||
isConnected,
|
||||
executionStatus,
|
||||
eventLog,
|
||||
connect,
|
||||
disconnect,
|
||||
startAgent,
|
||||
stopAgent,
|
||||
resetLog,
|
||||
lastExecutionId,
|
||||
}
|
||||
})
|
||||
6
src/types/agent.ts
Normal file
6
src/types/agent.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
export type AgentEvent = {
|
||||
type: string
|
||||
timestamp: Date
|
||||
message?: string
|
||||
content?: unknown
|
||||
}
|
||||
303
src/views/AgentDetailView.vue
Normal file
303
src/views/AgentDetailView.vue
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { Card, Typography, Button, List, Space, Spin, Input, message, Tag } from 'ant-design-vue'
|
||||
import { useAgentStore } from '../stores/agentStore'
|
||||
import { apiClient, isAxiosError, API } from '../router/api'
|
||||
|
||||
const route = useRoute()
|
||||
const agentStore = useAgentStore()
|
||||
|
||||
const agentId = route.params.id as string
|
||||
|
||||
const agent = ref<Record<string, unknown>>({
|
||||
id: agentId,
|
||||
name: 'Loading...',
|
||||
description: '',
|
||||
status: 'idle',
|
||||
})
|
||||
|
||||
const queryInput = ref('')
|
||||
const isRunning = computed(() => agentStore.executionStatus === 'running')
|
||||
const isConnected = computed(() => agentStore.isConnected ?? false)
|
||||
|
||||
const agentResponse = computed(() => {
|
||||
const completedEvent = agentStore.eventLog?.find((event) => event.type === 'completed')
|
||||
if (completedEvent?.content && typeof completedEvent.content === 'object') {
|
||||
const output = completedEvent.content as Record<string, unknown>
|
||||
const direct = output.response
|
||||
if (typeof direct === 'string' && direct.trim()) return direct
|
||||
|
||||
const result = output.result
|
||||
if (result && typeof result === 'object') {
|
||||
const nested = (result as Record<string, unknown>).response
|
||||
if (typeof nested === 'string' && nested.trim()) return nested
|
||||
}
|
||||
}
|
||||
return null
|
||||
})
|
||||
|
||||
const statusColor = (status: string) => {
|
||||
const colors: Record<string, string> = {
|
||||
idle: 'default',
|
||||
running: 'processing',
|
||||
completed: 'success',
|
||||
failed: 'error',
|
||||
stopped: 'warning',
|
||||
}
|
||||
return colors[status] || 'default'
|
||||
}
|
||||
|
||||
const fetchAgent = async () => {
|
||||
try {
|
||||
const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId))
|
||||
agent.value = response.data
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch agent:', error)
|
||||
if (isAxiosError(error)) {
|
||||
console.error('Axios error details:', {
|
||||
status: error.response?.status,
|
||||
data: error.response?.data,
|
||||
message: error.message,
|
||||
})
|
||||
}
|
||||
message.error('Failed to load agent details')
|
||||
}
|
||||
}
|
||||
|
||||
const startAgent = () => {
|
||||
if (!agentStore.isConnected) {
|
||||
message.error('WebSocket not connected')
|
||||
return
|
||||
}
|
||||
|
||||
if (!queryInput.value.trim()) {
|
||||
message.error('Please enter a query')
|
||||
return
|
||||
}
|
||||
|
||||
const data = {
|
||||
query: queryInput.value.trim(),
|
||||
}
|
||||
|
||||
agentStore.startAgent(data)
|
||||
message.success('Agent execution started')
|
||||
}
|
||||
|
||||
const stopAgent = () => {
|
||||
agentStore.stopAgent(agentStore.lastExecutionId || undefined)
|
||||
message.success('Agent stop requested')
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchAgent()
|
||||
agentStore.connect(agentId)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
agentStore.disconnect()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="page">
|
||||
<Card class="panel" :bordered="false">
|
||||
<div class="header">
|
||||
<Typography.Title :level="2">{{ agent.name }}</Typography.Title>
|
||||
<Tag :color="statusColor(String(agentStore.executionStatus || 'idle'))">
|
||||
{{ (agentStore.executionStatus || 'idle').toString().toUpperCase() }}
|
||||
</Tag>
|
||||
</div>
|
||||
|
||||
<Typography.Paragraph type="secondary">
|
||||
{{ agent.description || 'No description available' }}
|
||||
</Typography.Paragraph>
|
||||
|
||||
<div class="connection-status">
|
||||
<span>WebSocket Status:</span>
|
||||
<Tag :color="agentStore.isConnected ? 'green' : 'red'">
|
||||
{{ agentStore.isConnected ? 'CONNECTED' : 'DISCONNECTED' }}
|
||||
</Tag>
|
||||
</div>
|
||||
|
||||
<Typography.Title :level="4" class="section-title">Execution</Typography.Title>
|
||||
|
||||
<div class="execution-controls">
|
||||
<Space direction="vertical" style="width: 100%">
|
||||
<div>
|
||||
<Typography.Text>Query:</Typography.Text>
|
||||
<Input.TextArea
|
||||
v-model:value="queryInput"
|
||||
:disabled="isRunning"
|
||||
placeholder="Enter your query here..."
|
||||
:rows="4"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Space>
|
||||
<Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent">
|
||||
Run Agent
|
||||
</Button>
|
||||
<Button danger :disabled="!isRunning" @click="stopAgent">
|
||||
Stop Agent
|
||||
</Button>
|
||||
</Space>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<div v-if="agentResponse" class="response-section">
|
||||
<Typography.Title :level="4" class="section-title">Final Response</Typography.Title>
|
||||
<Card class="response-card response-final" :bordered="false">
|
||||
<div class="response-content">
|
||||
{{ agentResponse }}
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<Typography.Title :level="4" class="section-title">Execution Log</Typography.Title>
|
||||
|
||||
<Spin :spinning="isRunning" tip="Agent running...">
|
||||
<div class="log-container">
|
||||
<List
|
||||
v-if="(agentStore.eventLog?.length ?? 0) > 0"
|
||||
:data-source="agentStore.eventLog || []"
|
||||
:bordered="false"
|
||||
>
|
||||
<template #renderItem="{ item }">
|
||||
<List.Item class="log-item">
|
||||
<div class="log-entry">
|
||||
<Tag class="log-type">{{ item.type }}</Tag>
|
||||
<span class="log-time">{{ item.timestamp.toLocaleTimeString() }}</span>
|
||||
<div v-if="item.message" class="log-message">
|
||||
{{ item.message }}
|
||||
</div>
|
||||
<div v-if="item.content && typeof item.content === 'object'" class="log-content">
|
||||
<pre>{{ JSON.stringify(item.content, null, 2) }}</pre>
|
||||
</div>
|
||||
<div v-else-if="item.content" class="log-content">
|
||||
{{ item.content }}
|
||||
</div>
|
||||
</div>
|
||||
</List.Item>
|
||||
</template>
|
||||
</List>
|
||||
<Typography.Paragraph v-else type="secondary">
|
||||
No events yet. Start the agent to see execution logs.
|
||||
</Typography.Paragraph>
|
||||
</div>
|
||||
</Spin>
|
||||
</Card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.page {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.panel {
|
||||
background: #0f172a;
|
||||
border: 1px solid #1f2937;
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
margin-top: 2rem !important;
|
||||
margin-bottom: 1rem !important;
|
||||
}
|
||||
|
||||
.connection-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
margin: 1rem 0;
|
||||
padding: 0.5rem;
|
||||
background: #1f2937;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.execution-controls {
|
||||
background: #1f2937;
|
||||
padding: 1rem;
|
||||
border-radius: 4px;
|
||||
margin: 1rem 0;
|
||||
}
|
||||
|
||||
.log-container {
|
||||
background: #1f2937;
|
||||
border-radius: 4px;
|
||||
max-height: 500px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.log-item {
|
||||
border-bottom: 1px solid #374151 !important;
|
||||
padding: 0.75rem !important;
|
||||
}
|
||||
|
||||
.log-entry {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.log-type {
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
.log-time {
|
||||
font-size: 0.75rem;
|
||||
color: #9ca3af;
|
||||
}
|
||||
|
||||
.log-message {
|
||||
color: #e5e7eb;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
.log-content {
|
||||
background: #111827;
|
||||
padding: 0.5rem;
|
||||
border-radius: 3px;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.log-content pre {
|
||||
margin: 0;
|
||||
font-size: 0.8rem;
|
||||
color: #d1d5db;
|
||||
}
|
||||
|
||||
.response-section {
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.response-card {
|
||||
background: #1f2937;
|
||||
border: 1px solid #374151;
|
||||
}
|
||||
|
||||
.response-final {
|
||||
border-color: #6366f1;
|
||||
box-shadow: 0 0 0 1px rgba(99, 102, 241, 0.35);
|
||||
}
|
||||
|
||||
.response-content {
|
||||
color: #e5e7eb;
|
||||
font-size: 1rem;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
padding: 0.5rem;
|
||||
}
|
||||
</style>
|
||||
92
src/views/AgentsView.vue
Normal file
92
src/views/AgentsView.vue
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { List, Typography, Button, Card, Spin, message } from 'ant-design-vue'
|
||||
import { apiClient } from '../router/api'
|
||||
import { API } from '../router/api'
|
||||
|
||||
interface Agent {
|
||||
uuid: string
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
status: string
|
||||
}
|
||||
|
||||
const agents = ref<Agent[]>([])
|
||||
const loading = ref(false)
|
||||
const loadError = ref(false)
|
||||
|
||||
const fetchAgents = async () => {
|
||||
loading.value = true
|
||||
loadError.value = false
|
||||
try {
|
||||
const response = await apiClient.get<Agent[]>(API.agents())
|
||||
agents.value = Array.isArray(response.data) ? response.data : []
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch agents:', error)
|
||||
message.error('Failed to load agents')
|
||||
agents.value = []
|
||||
loadError.value = true
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchAgents()
|
||||
})
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div class="page">
|
||||
<Typography.Title :level="2">Agents</Typography.Title>
|
||||
<Typography.Paragraph type="secondary">
|
||||
Manage and inspect the available AI agents.
|
||||
</Typography.Paragraph>
|
||||
|
||||
<Card class="panel" :bordered="false">
|
||||
<Spin :spinning="loading" tip="Loading agents...">
|
||||
<div v-if="loadError" class="empty">
|
||||
<Typography.Paragraph type="danger">
|
||||
Failed to load agents.
|
||||
</Typography.Paragraph>
|
||||
</div>
|
||||
<div v-else-if="!loading && agents.length === 0" class="empty">
|
||||
<Typography.Paragraph type="secondary">
|
||||
No agents found.
|
||||
</Typography.Paragraph>
|
||||
</div>
|
||||
<List v-else :data-source="agents" item-layout="horizontal" :bordered="false">
|
||||
<template #renderItem="{ item }">
|
||||
<List.Item class="item">
|
||||
<List.Item.Meta
|
||||
:title="item.name"
|
||||
:description="`${item.description} Status: ${item.status}`"
|
||||
/>
|
||||
<RouterLink :to="`/agents/${item.uuid || item.id}`">
|
||||
<Button type="primary" size="small">Open</Button>
|
||||
</RouterLink>
|
||||
</List.Item>
|
||||
</template>
|
||||
</List>
|
||||
</Spin>
|
||||
</Card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.page {
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
padding: 1rem;
|
||||
}
|
||||
.panel {
|
||||
background: #0f172a;
|
||||
border: 1px solid #1f2937;
|
||||
color: #e5e7eb;
|
||||
}
|
||||
.item :deep(.ant-list-item-meta-title),
|
||||
.item :deep(.ant-list-item-meta-description) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
</style>
|
||||
Loading…
Reference in a new issue