from celery import shared_task from django.utils import timezone from channels.layers import get_channel_layer from asgiref.sync import async_to_sync from . import services from .models import AgentModel, Agent, AgentRun, AgentEvent import traceback import logging logger = logging.getLogger(__name__) @shared_task def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str): """Start a fine-tune via MCP, and register the resulting model on success. This task calls `services.fine_tune_model`, expects a dict result with `status` and on success `model_path` and optionally `version`. """ try: result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) if isinstance(result, dict) and result.get("status") == "completed": model_path = result.get("model_path") or result.get("path") or "" model_version = result.get("version") or version m = AgentModel.objects.create(name=name, version=model_version, path=model_path) return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result} return {"status": "failed", "result": result} except Exception as e: traceback.print_exc() return {"status": "error", "error": str(e)} @shared_task def infer_with_model_task(model_id: int, prompt: str, options: dict = None): """Run inference by requesting the MCP server to use the stored model. Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response. """ try: model = AgentModel.objects.get(id=model_id) except AgentModel.DoesNotExist: return {"status": "error", "error": "model_not_found", "model_id": model_id} try: services.load_model_for_inference(model.path) except Exception: pass try: out = services.infer_with_model(model.path, prompt, options or {}) return {"status": "completed", "model_id": model_id, "response": out} except Exception as e: traceback.print_exc() return {"status": "failed", "error": str(e)} def _send_group_event(room_group_name: str, event_type: str, content: dict): channel_layer = get_channel_layer() async_to_sync(channel_layer.group_send)( room_group_name, { "type": "mlstore_event", "event_type": event_type, "content": content, "timestamp": timezone.now().isoformat(), } ) def _persist_event(execution: AgentRun, event_type: str, content: dict): AgentEvent.objects.create( execution=execution, event_type=event_type, content=content, ) def _update_agent_status(agent: Agent, status: str): agent.status = status if status == "running": agent.started_at = timezone.now() elif status in ("completed", "failed"): agent.completed_at = timezone.now() agent.save() @shared_task def start_fine_tune_run_task(execution_id: str): logger.info(f"Fine-tune run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") from apps.mlstore.services import BASE_MODEL_CACHE logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}") input_data = execution.input_data or {} base_model = input_data.get("base_model") or agent.model.name training_files = input_data.get("training_files") or [] if not training_files and agent.organization: from apps.orgs.models import TrainingFile org_training_files = TrainingFile.objects.filter( organization=agent.organization, is_processed=False ).select_related('uploaded_by') training_files = [tf.file.path for tf in org_training_files if tf.file] logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}") hyperparams = input_data.get("hyperparams") or {} name = input_data.get("name") or f"{agent.model.name}-ft" version = input_data.get("version") or "v1" logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}") _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) try: result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) logger.info(f"Fine-tune result received: {result.get('status')}") logger.debug(f"Full fine-tune result: {result}") if isinstance(result, dict) and result.get("status") == "completed": model_path = result.get("model_path") or result.get("path") or "" model_version = result.get("version") or version new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path) agent.model = new_model agent.save() logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}") execution.status = "completed" execution.output_data = { "result": result, "model_id": new_model.id, "model_uuid": str(new_model.uuid), } execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id} logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}") execution.status = "failed" execution.error_message = str(result) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(result), }, ) return {"status": "failed", "execution_id": execution_id, "result": result} except Exception as e: logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "error", "execution_id": execution_id, "error": str(e)} @shared_task def infer_run_task(execution_id: str): logger.info(f"Inference run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") input_data = execution.input_data or {} prompt = input_data.get("prompt") or input_data.get("query") or "" options = input_data.get("options") or {} logger.info(f"Prompt length: {len(prompt)} characters") if not prompt: logger.warning(f"No prompt provided for inference run {execution_id}") execution.status = "failed" execution.error_message = "prompt_required" execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": "prompt_required", }, ) return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"} _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"}) try: try: logger.info(f"Loading model: {agent.model.path}") services.load_model_for_inference(agent.model.path) except Exception as e: logger.warning(f"Failed to preload model: {str(e)}") pass logger.info(f"Starting inference with model: {agent.model.path}") result = services.infer_with_model(agent.model.path, prompt, options) execution.status = "completed" execution.output_data = {"result": result} execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Inference execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id} except Exception as e: logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "failed", "execution_id": execution_id, "error": str(e)}