Dynavera/apps/mlstore/consumers.py
2026-01-20 17:21:28 +00:00

193 lines
6.1 KiB
Python

import json
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
from django.utils import timezone
from .models import Agent, AgentRun, AgentEvent
from .tasks import start_fine_tune_run_task, infer_run_task
class MLStoreConsumer(AsyncWebsocketConsumer):
async def connect(self):
self.user = self.scope["user"]
self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id")
self.room_group_name = f"mlstore_agent_{self.agent_id}"
if not self.user.is_authenticated:
await self.close()
return
agent = await self.get_agent(self.agent_id)
if not agent:
await self.close()
return
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
await self.accept()
await self.send(json.dumps({
"type": "connection",
"message": "Connected to mlstore agent stream",
"agent_id": str(self.agent_id)
}))
async def disconnect(self, close_code):
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
async def receive(self, text_data):
try:
data = json.loads(text_data)
action = data.get("action")
if action == "fine_tune":
await self.handle_fine_tune(data)
elif action == "infer":
await self.handle_infer(data)
elif action in ("stop_agent", "stop"):
await self.handle_stop(data)
else:
await self.send(json.dumps({
"type": "error",
"message": f"Unknown action: {action}"
}))
except json.JSONDecodeError:
await self.send(json.dumps({
"type": "error",
"message": "Invalid JSON"
}))
except Exception as e:
await self.send(json.dumps({
"type": "error",
"message": str(e)
}))
async def handle_fine_tune(self, data):
agent = await self.get_agent(self.agent_id)
if not agent:
await self.send(json.dumps({
"type": "error",
"message": "Agent not found"
}))
return
input_data = data.get("input_data") or {}
execution = await self.create_run(agent, self.user, input_data)
await self.send(json.dumps({
"type": "execution_started",
"execution_id": str(execution.uuid),
"agent_id": str(agent.uuid),
"message": f"Fine-tune run {execution.uuid} queued"
}))
start_fine_tune_run_task.delay(str(execution.uuid))
async def handle_infer(self, data):
agent = await self.get_agent(self.agent_id)
if not agent:
await self.send(json.dumps({
"type": "error",
"message": "Agent not found"
}))
return
input_data = data.get("input_data") or {}
execution = await self.create_run(agent, self.user, input_data)
await self.send(json.dumps({
"type": "execution_started",
"execution_id": str(execution.uuid),
"agent_id": str(agent.uuid),
"message": f"Inference run {execution.uuid} queued"
}))
infer_run_task.delay(str(execution.uuid))
async def handle_stop(self, data):
execution_id = data.get("execution_id")
if not execution_id:
await self.send(json.dumps({
"type": "error",
"message": "execution_id required to stop"
}))
return
execution = await self.get_execution(execution_id)
if not execution:
await self.send(json.dumps({
"type": "error",
"message": "Execution not found"
}))
return
await self.update_execution_status(execution, "failed")
await self.send(json.dumps({
"type": "execution_error",
"execution_id": str(execution.uuid),
"error_message": "Execution stopped by user"
}))
async def mlstore_event(self, event):
await self.send(json.dumps({
"type": "mlstore_event",
"event_type": event["event_type"],
"content": event["content"],
"timestamp": event["timestamp"]
}))
async def mlstore_completed(self, event):
await self.send(json.dumps({
"type": "execution_completed",
"execution_id": event["execution_id"],
"output_data": event["output_data"],
"message": "Execution completed"
}))
async def mlstore_error(self, event):
await self.send(json.dumps({
"type": "execution_error",
"execution_id": event["execution_id"],
"error_message": event["error_message"]
}))
@database_sync_to_async
def get_agent(self, agent_id):
try:
return Agent.objects.get(uuid=agent_id)
except Agent.DoesNotExist:
return None
@database_sync_to_async
def create_run(self, agent, user, input_data):
return AgentRun.objects.create(
agent=agent,
user=user,
input_data=input_data,
)
@database_sync_to_async
def get_execution(self, execution_id):
try:
return AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
return None
@database_sync_to_async
def update_execution_status(self, execution, status):
execution.status = status
execution.completed_at = timezone.now()
execution.save()
try:
agent = execution.agent
agent.status = status
agent.completed_at = timezone.now()
agent.save()
except Exception:
pass
return execution
@database_sync_to_async
def create_event(self, execution, event_type, content):
return AgentEvent.objects.create(
execution=execution,
event_type=event_type,
content=content,
)