268 lines
10 KiB
Python
268 lines
10 KiB
Python
|
|
from celery import shared_task
|
||
|
|
from django.utils import timezone
|
||
|
|
from channels.layers import get_channel_layer
|
||
|
|
from asgiref.sync import async_to_sync
|
||
|
|
from . import services
|
||
|
|
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
||
|
|
import traceback
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task
|
||
|
|
def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str):
|
||
|
|
"""Start a fine-tune via MCP, and register the resulting model on success.
|
||
|
|
|
||
|
|
This task calls `services.fine_tune_model`, expects a dict result with `status` and on success
|
||
|
|
`model_path` and optionally `version`.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
||
|
|
|
||
|
|
if isinstance(result, dict) and result.get("status") == "completed":
|
||
|
|
model_path = result.get("model_path") or result.get("path") or ""
|
||
|
|
model_version = result.get("version") or version
|
||
|
|
m = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
||
|
|
return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result}
|
||
|
|
|
||
|
|
return {"status": "failed", "result": result}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
traceback.print_exc()
|
||
|
|
return {"status": "error", "error": str(e)}
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task
|
||
|
|
def infer_with_model_task(model_id: int, prompt: str, options: dict = None):
|
||
|
|
"""Run inference by requesting the MCP server to use the stored model.
|
||
|
|
|
||
|
|
Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
model = AgentModel.objects.get(id=model_id)
|
||
|
|
except AgentModel.DoesNotExist:
|
||
|
|
return {"status": "error", "error": "model_not_found", "model_id": model_id}
|
||
|
|
|
||
|
|
try:
|
||
|
|
services.load_model_for_inference(model.path)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
try:
|
||
|
|
out = services.infer_with_model(model.path, prompt, options or {})
|
||
|
|
return {"status": "completed", "model_id": model_id, "response": out}
|
||
|
|
except Exception as e:
|
||
|
|
traceback.print_exc()
|
||
|
|
return {"status": "failed", "error": str(e)}
|
||
|
|
|
||
|
|
|
||
|
|
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
||
|
|
channel_layer = get_channel_layer()
|
||
|
|
async_to_sync(channel_layer.group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_event",
|
||
|
|
"event_type": event_type,
|
||
|
|
"content": content,
|
||
|
|
"timestamp": timezone.now().isoformat(),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _persist_event(execution: AgentRun, event_type: str, content: dict):
|
||
|
|
AgentEvent.objects.create(
|
||
|
|
execution=execution,
|
||
|
|
event_type=event_type,
|
||
|
|
content=content,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _update_agent_status(agent: Agent, status: str):
|
||
|
|
agent.status = status
|
||
|
|
if status == "running":
|
||
|
|
agent.started_at = timezone.now()
|
||
|
|
elif status in ("completed", "failed"):
|
||
|
|
agent.completed_at = timezone.now()
|
||
|
|
agent.save()
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task
|
||
|
|
def start_fine_tune_run_task(execution_id: str):
|
||
|
|
try:
|
||
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
||
|
|
except AgentRun.DoesNotExist:
|
||
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
||
|
|
|
||
|
|
agent = execution.agent
|
||
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
||
|
|
|
||
|
|
execution.status = "running"
|
||
|
|
execution.started_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "running")
|
||
|
|
|
||
|
|
input_data = execution.input_data or {}
|
||
|
|
base_model = input_data.get("base_model") or agent.model.name
|
||
|
|
training_files = input_data.get("training_files") or []
|
||
|
|
hyperparams = input_data.get("hyperparams") or {}
|
||
|
|
name = input_data.get("name") or f"{agent.model.name}-ft"
|
||
|
|
version = input_data.get("version") or "v1"
|
||
|
|
|
||
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
||
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
||
|
|
if isinstance(result, dict) and result.get("status") == "completed":
|
||
|
|
model_path = result.get("model_path") or result.get("path") or ""
|
||
|
|
model_version = result.get("version") or version
|
||
|
|
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
||
|
|
agent.model = new_model
|
||
|
|
agent.save()
|
||
|
|
|
||
|
|
execution.status = "completed"
|
||
|
|
execution.output_data = {
|
||
|
|
"result": result,
|
||
|
|
"model_id": new_model.id,
|
||
|
|
"model_uuid": str(new_model.uuid),
|
||
|
|
}
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "completed")
|
||
|
|
|
||
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
||
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
||
|
|
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_completed",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"output_data": execution.output_data,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
|
||
|
|
|
||
|
|
execution.status = "failed"
|
||
|
|
execution.error_message = str(result)
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "failed")
|
||
|
|
|
||
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
|
||
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
|
||
|
|
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_error",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"error_message": str(result),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return {"status": "failed", "execution_id": execution_id, "result": result}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
traceback.print_exc()
|
||
|
|
execution.status = "failed"
|
||
|
|
execution.error_message = str(e)
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "failed")
|
||
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_error",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"error_message": str(e),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
||
|
|
|
||
|
|
|
||
|
|
@shared_task
|
||
|
|
def infer_run_task(execution_id: str):
|
||
|
|
try:
|
||
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
||
|
|
except AgentRun.DoesNotExist:
|
||
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
||
|
|
|
||
|
|
agent = execution.agent
|
||
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
||
|
|
|
||
|
|
execution.status = "running"
|
||
|
|
execution.started_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "running")
|
||
|
|
|
||
|
|
input_data = execution.input_data or {}
|
||
|
|
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
||
|
|
options = input_data.get("options") or {}
|
||
|
|
|
||
|
|
if not prompt:
|
||
|
|
execution.status = "failed"
|
||
|
|
execution.error_message = "prompt_required"
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "failed")
|
||
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
||
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_error",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"error_message": "prompt_required",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
|
||
|
|
|
||
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
||
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
||
|
|
|
||
|
|
try:
|
||
|
|
try:
|
||
|
|
services.load_model_for_inference(agent.model.path)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
result = services.infer_with_model(agent.model.path, prompt, options)
|
||
|
|
|
||
|
|
execution.status = "completed"
|
||
|
|
execution.output_data = {"result": result}
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "completed")
|
||
|
|
|
||
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
|
||
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_completed",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"output_data": execution.output_data,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return {"status": "completed", "execution_id": execution_id}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
traceback.print_exc()
|
||
|
|
execution.status = "failed"
|
||
|
|
execution.error_message = str(e)
|
||
|
|
execution.completed_at = timezone.now()
|
||
|
|
execution.save()
|
||
|
|
_update_agent_status(agent, "failed")
|
||
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
||
|
|
async_to_sync(get_channel_layer().group_send)(
|
||
|
|
room_group_name,
|
||
|
|
{
|
||
|
|
"type": "mlstore_error",
|
||
|
|
"execution_id": str(execution.uuid),
|
||
|
|
"error_message": str(e),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return {"status": "failed", "execution_id": execution_id, "error": str(e)}
|