import logging import os import re import time import traceback from hashlib import sha256 from asgiref.sync import async_to_sync from celery import shared_task from channels.layers import get_channel_layer from django.utils import timezone from django.db import transaction from apps.orgs.models import TrainingFile, Role from . import services from .models import Agent, AgentEvent, AgentModel, AgentRun, RoleRagDocument logger = logging.getLogger(__name__) def _get_mem_info() -> str: try: with open('/proc/self/status', 'r', encoding='utf-8') as f: lines = f.read().splitlines() mem = {line.split(':', 1)[0]: line.split(':', 1)[1].strip() for line in lines if ':' in line} return f"VmRSS={mem.get('VmRSS','?')}, VmHWM={mem.get('VmHWM','?')}, VmSize={mem.get('VmSize','?')}" except Exception: return "mem_info_unavailable" def _estimate_tokens(text: str) -> int: if not text: return 0 return len(re.findall(r"\w+|[^\s\w]", text)) def _split_semantic_units(text: str) -> list[str]: paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()] units: list[str] = [] for para in paragraphs: sentences = re.split(r"(?<=[.!?])\s+", para) for sent in sentences: sent = sent.strip() if sent: units.append(sent) return units or paragraphs def _chunk_text(text: str, max_tokens: int = 400, overlap_tokens: int = 50) -> list[str]: if not text: return [] units = _split_semantic_units(text) logger.info( "Semantic chunking units=%s max_tokens=%s overlap_tokens=%s mem=%s", len(units), max_tokens, overlap_tokens, _get_mem_info(), ) chunks: list[str] = [] current: list[str] = [] current_tokens = 0 for unit in units: unit_tokens = _estimate_tokens(unit) if unit_tokens == 0: continue if current_tokens + unit_tokens > max_tokens and current: chunk = " ".join(current).strip() if chunk: chunks.append(chunk) if overlap_tokens > 0: overlap: list[str] = [] overlap_count = 0 for prev in reversed(current): prev_tokens = _estimate_tokens(prev) if overlap_count + prev_tokens > overlap_tokens: break overlap.insert(0, prev) overlap_count += prev_tokens current = overlap current_tokens = overlap_count else: current = [] current_tokens = 0 current.append(unit) current_tokens += unit_tokens if current: chunk = " ".join(current).strip() if chunk: chunks.append(chunk) return chunks def _extract_text_from_file(file_path: str, file_type: str | None) -> str: file_type = (file_type or '').lower() if file_type in {'txt', 'md', 'csv', 'json'}: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: return f.read() if file_type == 'pdf': try: from PyPDF2 import PdfReader except Exception as e: raise RuntimeError('PyPDF2 is required to parse PDF files') from e reader = PdfReader(file_path) return "\n".join(page.extract_text() or "" for page in reader.pages) if file_type in {'docx', 'doc'}: try: import docx except Exception as e: raise RuntimeError('python-docx is required to parse DOCX files') from e doc = docx.Document(file_path) return "\n".join(p.text for p in doc.paragraphs) with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: return f.read() def _send_group_event(room_group_name: str, event_type: str, content: dict): channel_layer = get_channel_layer() async_to_sync(channel_layer.group_send)( room_group_name, { "type": "mlstore_event", "event_type": event_type, "content": content, "timestamp": timezone.now().isoformat(), } ) def _persist_event(execution: AgentRun, event_type: str, content: dict): AgentEvent.objects.create( execution=execution, event_type=event_type, content=content, ) def _update_agent_status(agent: Agent, status: str): agent.status = status if status == "running": agent.started_at = timezone.now() elif status in ("completed", "failed"): agent.completed_at = timezone.now() agent.save() @shared_task def start_fine_tune_run_task(execution_id: str): logger.info(f"Fine-tune run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") from apps.mlstore.services import BASE_MODEL_CACHE logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}") input_data = execution.input_data or {} base_model = input_data.get("base_model") or agent.model.name training_files = input_data.get("training_files") or [] org_training_files = [] role_uuid = input_data.get("role_uuid") if not training_files and agent.organization: training_files_qs = TrainingFile.objects.filter( role__organization=agent.organization, is_processed=False ).select_related('uploaded_by', 'role') if role_uuid: try: role = Role.objects.get(uuid=role_uuid, organization=agent.organization) training_files_qs = training_files_qs.filter(role=role) except Role.DoesNotExist: logger.warning(f"Role {role_uuid} not found for organization {agent.organization.name}") org_training_files = list(training_files_qs) training_files = [tf.file.path for tf in org_training_files if tf.file] logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}") hyperparams = input_data.get("hyperparams") or {} name = input_data.get("name") or agent.model.name if not input_data.get("version"): existing_models = AgentModel.objects.filter(name=name).order_by('-version') if existing_models.exists(): last_version = existing_models.first().version try: if last_version.startswith('v'): num = int(last_version[1:]) version = f"v{num + 1}" else: version = f"v1" except: version = "v1" else: version = "v1" else: version = input_data.get("version") logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}") _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) try: result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) logger.info(f"Fine-tune result received: {result.get('status')}") logger.debug(f"Full fine-tune result: {result}") if isinstance(result, dict) and result.get("status") == "completed": model_path = result.get("model_path") or result.get("path") or "" model_version = result.get("version") or version new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path) agent.model = new_model agent.save() logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}") if org_training_files: file_ids = [tf.id for tf in org_training_files] TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True) logger.info(f"Marked {len(org_training_files)} training files as processed") execution.status = "completed" execution.output_data = { "result": result, "model_id": new_model.id, "model_uuid": str(new_model.uuid), } execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id} logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}") execution.status = "failed" execution.error_message = str(result) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(result), }, ) return {"status": "failed", "execution_id": execution_id, "result": result} except Exception as e: logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "error", "execution_id": execution_id, "error": str(e)} @shared_task def ingest_training_file_task(training_file_uuid: str): logger.info(f"Ingest task started for training_file_uuid={training_file_uuid}") started_at = time.time() try: training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid) except TrainingFile.DoesNotExist: logger.error(f"Training file not found: {training_file_uuid}") return {"status": "error", "error": "training_file_not_found"} if training_file.is_processed: logger.info(f"Training file already processed: {training_file_uuid}") return {"status": "skipped", "reason": "already_processed"} if not training_file.file: logger.error(f"Training file has no file attached: {training_file_uuid}") return {"status": "error", "error": "file_missing"} try: file_path = training_file.file.path file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0 logger.info( "Ingesting file: name=%s type=%s size_bytes=%s role=%s path=%s", training_file.file_name, training_file.file_type, file_size, training_file.role_id, file_path, ) except Exception as e: logger.warning(f"Failed to stat training file for {training_file_uuid}: {str(e)}") try: training_file.status = 'ingesting' training_file.save(update_fields=['status']) extract_started = time.time() text = _extract_text_from_file(training_file.file.path, training_file.file_type) logger.info( "Extracted text length=%s for training_file_uuid=%s in %.2fs mem=%s", len(text), training_file_uuid, time.time() - extract_started, _get_mem_info(), ) chunk_started = time.time() chunks = _chunk_text(text) logger.info( "Chunked text into %s chunks in %.2fs (sample lengths: %s) mem=%s", len(chunks), time.time() - chunk_started, [len(c) for c in chunks[:5]], _get_mem_info(), ) if not chunks: raise RuntimeError("No text extracted from file") with transaction.atomic(): logger.info("Clearing existing RAG docs for training_file_uuid=%s mem=%s", training_file_uuid, _get_mem_info()) RoleRagDocument.objects.filter(training_file=training_file).delete() logger.info("Preparing %s RAG docs for bulk_create mem=%s", len(chunks), _get_mem_info()) existing_hashes = set( RoleRagDocument.objects.filter(role=training_file.role) .values_list('content_hash', flat=True) ) documents = [] skipped = 0 for index, chunk in enumerate(chunks): content_hash = sha256(chunk.encode('utf-8')).hexdigest() if content_hash in existing_hashes: skipped += 1 continue documents.append( RoleRagDocument( role=training_file.role, training_file=training_file, content=chunk, embedding=None, content_hash=content_hash, metadata={ "file_name": training_file.file_name, "file_type": training_file.file_type, "chunk_size": len(chunk), "source": "training_file", }, chunk_index=index, ) ) logger.info("Skipped %s duplicate chunks based on content_hash", skipped) logger.info("Bulk creating RAG docs count=%s mem=%s", len(documents), _get_mem_info()) RoleRagDocument.objects.bulk_create(documents, batch_size=500) training_file.status = 'chunked' training_file.is_processed = True training_file.save(update_fields=['status', 'is_processed']) elapsed = time.time() - started_at logger.info( "Ingested training file %s into %s RAG chunks in %.2fs", training_file_uuid, len(chunks), elapsed, ) logger.info(f"Enqueueing embedding task for training_file_uuid={training_file_uuid}") embed_training_file_task.delay(training_file_uuid) return {"status": "completed", "chunks": len(chunks)} except Exception as e: elapsed = time.time() - started_at logger.error(f"Failed to ingest training file {training_file_uuid}: {str(e)}", exc_info=True) logger.error(f"Ingest task failed after {elapsed:.2f}s for training_file_uuid={training_file_uuid}") try: TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed') except Exception: pass return {"status": "error", "error": str(e)} @shared_task def embed_training_file_task(training_file_uuid: str): """Generate embeddings for all documents of a training file. This task is called after chunking to embed the document chunks using the configured embedding provider (OpenAI, Google, or local). """ logger.info(f"Embedding task started for training_file_uuid={training_file_uuid}") started_at = time.time() try: training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid) except TrainingFile.DoesNotExist: logger.error(f"Training file not found: {training_file_uuid}") return {"status": "error", "error": "training_file_not_found"} try: documents = list(RoleRagDocument.objects.filter(training_file=training_file)) if not documents: logger.warning(f"No RAG documents found for training_file_uuid={training_file_uuid}") return {"status": "skipped", "reason": "no_documents"} logger.info( f"Starting to embed {len(documents)} documents for training_file_uuid={training_file_uuid} " f"mem={_get_mem_info()}" ) num_embedded, num_failed = services.batch_embed_documents(documents, batch_size=32) if num_failed == 0: training_file.status = 'embedded' training_file.save(update_fields=['status']) logger.info(f"Successfully embedded all documents for training_file_uuid={training_file_uuid}") elif num_embedded > 0: training_file.status = 'embedded' training_file.save(update_fields=['status']) logger.warning( f"Partially embedded {num_embedded} documents, {num_failed} failed " f"for training_file_uuid={training_file_uuid}" ) else: training_file.status = 'failed' training_file.save(update_fields=['status']) logger.error(f"Failed to embed any documents for training_file_uuid={training_file_uuid}") return {"status": "error", "error": "embedding_failed", "num_failed": num_failed} elapsed = time.time() - started_at logger.info( f"Embedding task completed for {training_file_uuid}: " f"embedded={num_embedded}, failed={num_failed}, time={elapsed:.2f}s" ) return { "status": "completed", "num_embedded": num_embedded, "num_failed": num_failed, "elapsed": elapsed, } except Exception as e: elapsed = time.time() - started_at logger.error( f"Failed to embed training file {training_file_uuid}: {str(e)}", exc_info=True ) try: TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed') except Exception: pass return {"status": "error", "error": str(e), "elapsed": elapsed} @shared_task def infer_run_task(execution_id: str): logger.info(f"Inference run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") input_data = execution.input_data or {} prompt = input_data.get("prompt") or input_data.get("query") or "" options = dict(input_data.get("options") or {}) role_uuid = input_data.get("role_uuid") or options.get("role_uuid") rag_top_k = int(input_data.get("rag_top_k", 5)) rag_similarity_threshold = float(input_data.get("rag_similarity_threshold", 0.5)) options.setdefault("temperature", 0.2) options.setdefault("top_p", 0.9) options.setdefault("max_tokens", 200) options.setdefault("stop", ["\n\n", "References:", "Sources:"]) logger.info(f"Prompt length: {len(prompt)} characters") if not role_uuid: logger.warning(f"No role_uuid provided for inference run {execution_id}") execution.status = "failed" execution.error_message = "role_uuid_required" execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": "role_uuid_required", }, ) return {"status": "failed", "execution_id": execution_id, "error": "role_uuid_required"} if role_uuid and prompt: try: context = services.get_context_for_query( query=prompt, role_uuid=str(role_uuid), top_k=rag_top_k, similarity_threshold=rag_similarity_threshold, ) if context: logger.info(f"RAG context retrieved for role={role_uuid} (top_k={rag_top_k})") prompt = ( "You are a technical assistant.\n\n" "Answer the question using ONLY the information in the context.\n" "Do NOT:\n" "- ask follow-up questions\n" "- include hashtags\n" "- include references or sources\n" "- repeat the question\n" "- add headings or sections\n" "- add information not present in the context\n\n" "Answer in 3-6 concise sentences.\n" "If the context is insufficient, say: \"The context does not provide enough information.\"\n\n" "Context:\n" f"{context}\n\n" "Question:\n" f"{prompt}\n\n" "Answer:" ) else: logger.info(f"No RAG context found for role={role_uuid}") except Exception as e: logger.warning(f"RAG context retrieval failed for role={role_uuid}: {e}") if not prompt: logger.warning(f"No prompt provided for inference run {execution_id}") execution.status = "failed" execution.error_message = "prompt_required" execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": "prompt_required", }, ) return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"} _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"}) try: try: logger.info(f"Loading model: {agent.model.path}") services.load_model_for_inference(agent.model.path) except Exception as e: logger.warning(f"Failed to preload model: {str(e)}") pass logger.info(f"Starting inference with model: {agent.model.path}") result = services.infer_with_model(agent.model.path, prompt, options) execution.status = "completed" execution.output_data = {"result": result} execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Inference execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id} except Exception as e: logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "failed", "execution_id": execution_id, "error": str(e)}