Cleaned tasks file, updated field for model
This commit is contained in:
parent
688558b3c9
commit
2de25c4a0e
3 changed files with 53 additions and 61 deletions
15
apps/mlstore/migrations/0002_alter_agentrun_input_data.py
Normal file
15
apps/mlstore/migrations/0002_alter_agentrun_input_data.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('mlstore', '0001_initial'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='agentrun',
|
||||||
|
name='input_data',
|
||||||
|
field=models.JSONField(blank=True, default=dict),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -63,7 +63,7 @@ class AgentRun(TimeStampMixin, Model):
|
||||||
user = ForeignKey(User, on_delete = CASCADE, related_name = 'agent_runs')
|
user = ForeignKey(User, on_delete = CASCADE, related_name = 'agent_runs')
|
||||||
status = CharField(max_length = 20, choices = RUN_CHOICES, default = 'queued')
|
status = CharField(max_length = 20, choices = RUN_CHOICES, default = 'queued')
|
||||||
|
|
||||||
input_data = JSONField(default = dict)
|
input_data = JSONField(default = dict, blank = True)
|
||||||
output_data = JSONField(default = dict, blank = True)
|
output_data = JSONField(default = dict, blank = True)
|
||||||
error_message = TextField(blank = True, default = "")
|
error_message = TextField(blank = True, default = "")
|
||||||
started_at = DateTimeField(null = True, blank = True)
|
started_at = DateTimeField(null = True, blank = True)
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,16 @@
|
||||||
from celery import shared_task
|
|
||||||
from django.utils import timezone
|
|
||||||
from channels.layers import get_channel_layer
|
|
||||||
from asgiref.sync import async_to_sync
|
|
||||||
from . import services
|
|
||||||
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from celery import shared_task
|
||||||
|
from channels.layers import get_channel_layer
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from apps.orgs.models import TrainingFile
|
||||||
|
from . import services
|
||||||
|
from .models import Agent, AgentEvent, AgentModel, AgentRun
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
|
||||||
def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str):
|
|
||||||
"""Start a fine-tune via MCP, and register the resulting model on success.
|
|
||||||
|
|
||||||
This task calls `services.fine_tune_model`, expects a dict result with `status` and on success
|
|
||||||
`model_path` and optionally `version`.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
|
||||||
|
|
||||||
if isinstance(result, dict) and result.get("status") == "completed":
|
|
||||||
model_path = result.get("model_path") or result.get("path") or ""
|
|
||||||
model_version = result.get("version") or version
|
|
||||||
m = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
|
||||||
return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result}
|
|
||||||
|
|
||||||
return {"status": "failed", "result": result}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"status": "error", "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
|
||||||
def infer_with_model_task(model_id: int, prompt: str, options: dict = None):
|
|
||||||
"""Run inference by requesting the MCP server to use the stored model.
|
|
||||||
|
|
||||||
Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
model = AgentModel.objects.get(id=model_id)
|
|
||||||
except AgentModel.DoesNotExist:
|
|
||||||
return {"status": "error", "error": "model_not_found", "model_id": model_id}
|
|
||||||
|
|
||||||
try:
|
|
||||||
services.load_model_for_inference(model.path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = services.infer_with_model(model.path, prompt, options or {})
|
|
||||||
return {"status": "completed", "model_id": model_id, "response": out}
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
return {"status": "failed", "error": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
async_to_sync(channel_layer.group_send)(
|
async_to_sync(channel_layer.group_send)(
|
||||||
|
|
@ -113,18 +67,35 @@ def start_fine_tune_run_task(execution_id: str):
|
||||||
base_model = input_data.get("base_model") or agent.model.name
|
base_model = input_data.get("base_model") or agent.model.name
|
||||||
|
|
||||||
training_files = input_data.get("training_files") or []
|
training_files = input_data.get("training_files") or []
|
||||||
|
org_training_files = []
|
||||||
if not training_files and agent.organization:
|
if not training_files and agent.organization:
|
||||||
from apps.orgs.models import TrainingFile
|
org_training_files = list(TrainingFile.objects.filter(
|
||||||
org_training_files = TrainingFile.objects.filter(
|
|
||||||
organization=agent.organization,
|
organization=agent.organization,
|
||||||
is_processed=False
|
is_processed=False
|
||||||
).select_related('uploaded_by')
|
).select_related('uploaded_by'))
|
||||||
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
||||||
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
||||||
|
|
||||||
hyperparams = input_data.get("hyperparams") or {}
|
hyperparams = input_data.get("hyperparams") or {}
|
||||||
name = input_data.get("name") or f"{agent.model.name}-ft"
|
name = input_data.get("name") or agent.model.name
|
||||||
version = input_data.get("version") or "v1"
|
|
||||||
|
if not input_data.get("version"):
|
||||||
|
existing_models = AgentModel.objects.filter(name=name).order_by('-version')
|
||||||
|
if existing_models.exists():
|
||||||
|
last_version = existing_models.first().version
|
||||||
|
try:
|
||||||
|
if last_version.startswith('v'):
|
||||||
|
num = int(last_version[1:])
|
||||||
|
version = f"v{num + 1}"
|
||||||
|
else:
|
||||||
|
version = f"v1"
|
||||||
|
except:
|
||||||
|
version = "v1"
|
||||||
|
else:
|
||||||
|
version = "v1"
|
||||||
|
else:
|
||||||
|
version = input_data.get("version")
|
||||||
|
|
||||||
logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}")
|
logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}")
|
||||||
|
|
||||||
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
||||||
|
|
@ -143,6 +114,12 @@ def start_fine_tune_run_task(execution_id: str):
|
||||||
agent.save()
|
agent.save()
|
||||||
logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}")
|
logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}")
|
||||||
|
|
||||||
|
if org_training_files:
|
||||||
|
file_ids = [tf.id for tf in org_training_files]
|
||||||
|
TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True)
|
||||||
|
logger.info(f"Marked {len(org_training_files)} training files as processed")
|
||||||
|
|
||||||
|
|
||||||
execution.status = "completed"
|
execution.status = "completed"
|
||||||
execution.output_data = {
|
execution.output_data = {
|
||||||
"result": result,
|
"result": result,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue