2026-01-20 17:21:28 +00:00
|
|
|
import json
|
|
|
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
|
|
|
from channels.db import database_sync_to_async
|
|
|
|
|
from django.utils import timezone
|
|
|
|
|
from .models import Agent, AgentRun, AgentEvent
|
|
|
|
|
from .tasks import start_fine_tune_run_task, infer_run_task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLStoreConsumer(AsyncWebsocketConsumer):
|
|
|
|
|
async def connect(self):
|
|
|
|
|
self.user = self.scope["user"]
|
|
|
|
|
self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id")
|
|
|
|
|
self.room_group_name = f"mlstore_agent_{self.agent_id}"
|
|
|
|
|
|
|
|
|
|
if not self.user.is_authenticated:
|
|
|
|
|
await self.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
agent = await self.get_agent(self.agent_id)
|
|
|
|
|
if not agent:
|
|
|
|
|
await self.close()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
|
|
|
|
|
await self.accept()
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "connection",
|
|
|
|
|
"message": "Connected to mlstore agent stream",
|
|
|
|
|
"agent_id": str(self.agent_id)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
async def disconnect(self, close_code):
|
|
|
|
|
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
|
|
|
|
|
|
|
|
|
|
async def receive(self, text_data):
|
|
|
|
|
try:
|
|
|
|
|
data = json.loads(text_data)
|
|
|
|
|
action = data.get("action")
|
|
|
|
|
|
|
|
|
|
if action == "fine_tune":
|
|
|
|
|
await self.handle_fine_tune(data)
|
|
|
|
|
elif action == "infer":
|
|
|
|
|
await self.handle_infer(data)
|
2026-02-08 15:34:26 +00:00
|
|
|
elif action == "onboarding_progress":
|
|
|
|
|
await self.handle_onboarding_progress(data)
|
2026-01-20 17:21:28 +00:00
|
|
|
elif action in ("stop_agent", "stop"):
|
|
|
|
|
await self.handle_stop(data)
|
|
|
|
|
else:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": f"Unknown action: {action}"
|
|
|
|
|
}))
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "Invalid JSON"
|
|
|
|
|
}))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": str(e)
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
async def handle_fine_tune(self, data):
|
|
|
|
|
agent = await self.get_agent(self.agent_id)
|
|
|
|
|
if not agent:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "Agent not found"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
input_data = data.get("input_data") or {}
|
|
|
|
|
execution = await self.create_run(agent, self.user, input_data)
|
|
|
|
|
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "execution_started",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"agent_id": str(agent.uuid),
|
|
|
|
|
"message": f"Fine-tune run {execution.uuid} queued"
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
start_fine_tune_run_task.delay(str(execution.uuid))
|
|
|
|
|
|
|
|
|
|
async def handle_infer(self, data):
|
|
|
|
|
agent = await self.get_agent(self.agent_id)
|
|
|
|
|
if not agent:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "Agent not found"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
input_data = data.get("input_data") or {}
|
2026-02-08 15:34:26 +00:00
|
|
|
role_uuid = input_data.get("role_uuid")
|
|
|
|
|
if not role_uuid:
|
|
|
|
|
options = input_data.get("options") or {}
|
|
|
|
|
role_uuid = options.get("role_uuid")
|
|
|
|
|
if not role_uuid:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "role_uuid is required for inference to enable RAG"
|
|
|
|
|
}))
|
|
|
|
|
return
|
2026-01-20 17:21:28 +00:00
|
|
|
execution = await self.create_run(agent, self.user, input_data)
|
|
|
|
|
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "execution_started",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"agent_id": str(agent.uuid),
|
|
|
|
|
"message": f"Inference run {execution.uuid} queued"
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
infer_run_task.delay(str(execution.uuid))
|
|
|
|
|
|
|
|
|
|
async def handle_stop(self, data):
|
|
|
|
|
execution_id = data.get("execution_id")
|
|
|
|
|
if not execution_id:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "execution_id required to stop"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
execution = await self.get_execution(execution_id)
|
|
|
|
|
if not execution:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "Execution not found"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
await self.update_execution_status(execution, "failed")
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "execution_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": "Execution stopped by user"
|
|
|
|
|
}))
|
|
|
|
|
|
2026-02-08 15:34:26 +00:00
|
|
|
async def handle_onboarding_progress(self, data):
|
|
|
|
|
execution_id = data.get("execution_id")
|
|
|
|
|
if not execution_id:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "execution_id required for onboarding_progress"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
execution = await self.get_execution(execution_id)
|
|
|
|
|
if not execution:
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "error",
|
|
|
|
|
"message": "Execution not found"
|
|
|
|
|
}))
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
content = data.get("content") or data.get("progress") or {}
|
|
|
|
|
await self.create_event(execution, "progress", content)
|
|
|
|
|
await self.channel_layer.group_send(
|
|
|
|
|
self.room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_event",
|
|
|
|
|
"event_type": "progress",
|
|
|
|
|
"content": content,
|
|
|
|
|
"timestamp": timezone.now().isoformat(),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
async def mlstore_event(self, event):
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "mlstore_event",
|
|
|
|
|
"event_type": event["event_type"],
|
|
|
|
|
"content": event["content"],
|
|
|
|
|
"timestamp": event["timestamp"]
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
async def mlstore_completed(self, event):
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "execution_completed",
|
|
|
|
|
"execution_id": event["execution_id"],
|
|
|
|
|
"output_data": event["output_data"],
|
|
|
|
|
"message": "Execution completed"
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
async def mlstore_error(self, event):
|
|
|
|
|
await self.send(json.dumps({
|
|
|
|
|
"type": "execution_error",
|
|
|
|
|
"execution_id": event["execution_id"],
|
|
|
|
|
"error_message": event["error_message"]
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
|
def get_agent(self, agent_id):
|
|
|
|
|
try:
|
|
|
|
|
return Agent.objects.get(uuid=agent_id)
|
|
|
|
|
except Agent.DoesNotExist:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
|
def create_run(self, agent, user, input_data):
|
|
|
|
|
return AgentRun.objects.create(
|
|
|
|
|
agent=agent,
|
|
|
|
|
user=user,
|
|
|
|
|
input_data=input_data,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
|
def get_execution(self, execution_id):
|
|
|
|
|
try:
|
|
|
|
|
return AgentRun.objects.get(uuid=execution_id)
|
|
|
|
|
except AgentRun.DoesNotExist:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
|
def update_execution_status(self, execution, status):
|
|
|
|
|
execution.status = status
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
try:
|
|
|
|
|
agent = execution.agent
|
|
|
|
|
agent.status = status
|
|
|
|
|
agent.completed_at = timezone.now()
|
|
|
|
|
agent.save()
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
return execution
|
|
|
|
|
|
|
|
|
|
@database_sync_to_async
|
|
|
|
|
def create_event(self, execution, event_type, content):
|
|
|
|
|
return AgentEvent.objects.create(
|
|
|
|
|
execution=execution,
|
|
|
|
|
event_type=event_type,
|
|
|
|
|
content=content,
|
|
|
|
|
)
|