import json import logging import re from uuid import uuid4 import httpx from channels.db import database_sync_to_async from channels.generic.websocket import AsyncWebsocketConsumer from django.conf import settings from django.db.models import Q from django.utils import timezone from apps.onboarding.mcp import MCPRouter from apps.onboarding.models import AgentConfig, OnboardingFlow, OnboardingSession logger = logging.getLogger(__name__) class OnboardingConsumer(AsyncWebsocketConsumer): async def connect(self): self.user = self.scope["user"] self.context_uuid = self.scope["url_route"]["kwargs"].get("session_uuid") if not self.user.is_authenticated: await self.close() return self.router = MCPRouter() await self.accept() async def disconnect(self, close_code): pass def _build_system_prompt(self, config): base_prompt = config.system_prompt or "You are a helpful onboarding assistant." permissions = config.tool_permissions or [] if permissions: return f"{base_prompt}\n\nAllowed tools: {', '.join(str(p) for p in permissions)}" return base_prompt async def receive(self, text_data): try: data = json.loads(text_data) action = data.get("action") if action == "start_full_onboarding": role_uuid = data.get("role_uuid") if not role_uuid: await self.send_log("error", "Missing role_uuid for full onboarding generation") return if not await self.can_manage_role(role_uuid, self.user.id): await self.send_log("error", "Forbidden") return await self.run_full_onboarding_generation(role_uuid) elif action == "progress_monitor": role_uuid = data.get("role_uuid") or self.context_uuid if not role_uuid: await self.send_log("error", "Missing role_uuid for progress monitoring") return if not await self.can_access_role(role_uuid, self.user.id): await self.send_log("error", "Forbidden") return await self.run_progress_monitor(role_uuid) else: user_message = data.get("query") or data.get("message") requested_max_tokens = data.get("max_tokens") if not user_message: await self.send_log("error", "Missing query/message payload") return config = await self.get_config_for_user(self.context_uuid, self.user.id) if config is None: await self.send_log("error", "Forbidden") return ai_response = await self.orchestrate_ai( user_message, config, max_tokens=requested_max_tokens, ) await self.send(json.dumps({ "type": "completed", "timestamp": timezone.now().isoformat(), "message": "Inference complete.", "content": { "response": ai_response, } })) except Exception as e: logger.error(f"WebSocket Receive Error: {str(e)}") await self.send_log("error", f"Consumer encountered an error: {str(e)}") async def run_full_onboarding_generation(self, role_uuid): """ The Master Script that builds the JSON structure sequentially. Pipeline: Curriculum Agent -> Knowledge Agent -> Assessment Agent """ await self.send_log("status", "Phase 1: Generating Curriculum...", "curriculum") ca_config = await self.get_config_by_type(role_uuid, 'curriculum') if not ca_config: await self.send_log("error", "Missing curriculum AgentConfig for this role") return ca_prompt = ( "Based on available documentation, create an onboarding curriculum for this role. " "Output ONLY a valid JSON array of 3-5 strings representing module titles. " "Example: [\"Introduction\", \"Safety\", \"Operations\"]" ) ca_response = await self.orchestrate_ai( ca_prompt, ca_config, min_internal_turns=1, max_tokens=384, ) topics = self._extract_json_list(ca_response) if not topics: await self.send_log("error", "Curriculum generation returned no topics") return full_structure = [] module_briefs = [] for index, topic in enumerate(topics): await self.send_log("status", f"Phase 2: Researching {topic}...", "knowledge") ka_config = await self.get_config_by_type(role_uuid, 'knowledge') if not ka_config: await self.send_log("error", "Missing knowledge AgentConfig for this role") return knowledge_hits = await self.fetch_knowledge_context(role_uuid, topic) context_markdown = self.format_knowledge_context(knowledge_hits) page_content = await self.orchestrate_ai( ( f"Write a practical onboarding training guide for the topic '{topic}'. " "Think step-by-step internally before writing the final answer. " "Use the MCP search context below as your primary source, and call additional tools if needed. " "If no indexed documents are available, provide a concise best-practice overview and clearly say no indexed documents were found. " "Use Markdown formatting and do NOT include a table of contents in this section. " "Generate substantial depth: target 900-1400 words. " "Include these sections in order: Overview, Core Concepts, Role-Specific Workflow, Practical Examples, Common Pitfalls, and Action Checklist. " "In Practical Examples, provide at least 2 concrete examples relevant to this role/topic. " "In Action Checklist, provide at least 8 actionable checklist items.\n\n" f"Role UUID: {role_uuid}\n" f"Topic: {topic}\n" f"MCP search context:\n{context_markdown}" ), ka_config, min_internal_turns=2, max_tokens=2400, ) full_structure.append({ "title": topic, "body": page_content, "order": index, "fields": [], "meta": { "topic_index": index, "table_of_contents": [str(item) for item in topics], }, }) module_briefs.append({ "topic": str(topic), "summary_excerpt": str(page_content)[:1200], }) await self.send_log("status", "Phase 3: Creating final assessment quiz...", "assessment") aa_config = await self.get_config_by_type(role_uuid, 'assessment') if not aa_config: await self.send_log("error", "Missing assessment AgentConfig for this role") return quiz_prompt = ( "Create a final onboarding quiz that assesses all generated modules. " "Output ONLY a valid JSON array of 8 multiple-choice question objects. " "Each object MUST include: 'key' (snake_case), 'label', 'field_type' ('select'), " "'options' (array of 4 unique strings), 'required' (true), and 'validation' with " "'correct_option' (exactly matching one option) and 'explanation' (short rationale). " "Cover all topics with balanced difficulty and avoid ambiguous choices.\n\n" f"Modules JSON:\n{json.dumps(module_briefs, ensure_ascii=False)}" ) quiz_response = await self.orchestrate_ai( quiz_prompt, aa_config, min_internal_turns=1, max_tokens=1600, ) quiz_fields = self._sanitize_quiz_fields(self._extract_json_list(quiz_response)) if not quiz_fields: await self.send_log("status", "Assessment output invalid, retrying quiz generation...", "assessment") retry_response = await self.orchestrate_ai( f"{quiz_prompt}\n\nReturn ONLY raw JSON. Do not use markdown fences. Do not include explanations outside JSON.", aa_config, min_internal_turns=1, max_tokens=1600, ) quiz_fields = self._sanitize_quiz_fields(self._extract_json_list(retry_response)) if not quiz_fields: await self.send_log("status", "Assessment output still invalid. Using fallback final quiz.", "assessment") quiz_fields = self._build_fallback_quiz_fields([str(topic) for topic in topics]) full_structure.append({ "title": "Final Assessment Quiz", "body": ( "### Final Quiz\n" "Answer all questions below. You need **80%** to complete onboarding. " "You can update answers and submit when ready." ), "order": len(full_structure), "fields": quiz_fields, "meta": { "page_type": "final_quiz", "pass_mark": 80, }, }) await self.save_full_flow(role_uuid, full_structure) await self.send(json.dumps({ "type": "completed", "timestamp": timezone.now().isoformat(), "message": "Onboarding pipeline complete and structure saved." })) async def run_progress_monitor(self, role_uuid): await self.send_log("status", "Progress Monitor is analyzing your onboarding progress...", "monitor") monitor_config = await self.get_config_by_type(role_uuid, 'monitor') if not monitor_config: await self.send_log("error", "Missing Progress Monitor AgentConfig for this role") return progress_context = await self.get_role_progress_context(role_uuid, self.user.id) monitor_prompt = ( "You are a progress monitoring agent for onboarding. " "Analyze the role onboarding data below and provide concise feedback with:\n" "1) current status\n2) strengths\n3) gaps\n4) next actions\n" "Keep it short and practical.\n\n" f"Progress context JSON:\n{json.dumps(progress_context)}" ) feedback = await self.orchestrate_ai( monitor_prompt, monitor_config, min_internal_turns=1, max_tokens=640, ) await self.send(json.dumps({ "type": "completed", "timestamp": timezone.now().isoformat(), "message": "Progress analysis complete.", "content": { "role_uuid": role_uuid, "feedback": feedback, "status": progress_context.get("latest_status", "unknown"), } })) async def orchestrate_ai( self, user_message, config, min_internal_turns=2, max_turns=6, max_tokens=None, ): """ Handles the multi-turn ReAct loop (Reasoning + Tool Use). """ messages = [ {"role": "system", "content": self._build_system_prompt(config)}, {"role": "user", "content": user_message} ] llm_config = config.llm_config if isinstance(config.llm_config, dict) else {} resolved_max_tokens = max_tokens if resolved_max_tokens is None: resolved_max_tokens = llm_config.get("max_tokens", 1024) try: resolved_max_tokens = max(64, int(resolved_max_tokens)) except Exception: resolved_max_tokens = 1024 last_content = "" min_internal_turns = max(1, int(min_internal_turns or 1)) max_turns = max(min_internal_turns, int(max_turns or 1)) async with httpx.AsyncClient(timeout=60.0) as client: for turn in range(max_turns): await self.send_log("thought", f"Agent is thinking (Turn {turn+1})...") try: response = await client.post( settings.INFERENCE_CHAT_COMPLETIONS_ENDPOINT, json={ "model": llm_config.get("model_id", "meta-llama-3.1-8b"), "messages": messages, "tools": self.router.get_tool_definitions(), "tool_choice": "auto", "max_tokens": resolved_max_tokens, } ) response.raise_for_status() res_json = response.json() ai_message = res_json["choices"][0]["message"] messages.append(ai_message) if ai_message.get("tool_calls"): for tool_call in ai_message["tool_calls"]: fn_name = tool_call["function"]["name"] fn_args = json.loads(tool_call["function"]["arguments"]) await self.send(json.dumps({ "type": "tool_start", "message": f"Accessing knowledge base: {fn_name}...", "content": fn_args })) result = await self.router.handle_tool_call(fn_name, fn_args) messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "name": fn_name, "content": json.dumps(result) }) continue else: last_content = str(ai_message.get("content") or "").strip() if (turn + 1) < min_internal_turns: messages.append({ "role": "user", "content": ( "Run one more internal reasoning pass before finalizing. " "If additional evidence is needed, call tools. " "Then return only the improved final answer." ), }) continue return last_content except Exception as e: await self.send_log("error", f"Inference failed: {str(e)}") return f"Error: {str(e)}" return last_content async def fetch_knowledge_context(self, role_uuid, topic): query = f"onboarding training content for {topic}" await self.send(json.dumps({ "type": "tool_start", "message": "Accessing knowledge base: search_knowledge...", "content": {"query": query, "role_uuid": role_uuid} })) try: result = await self.router.handle_tool_call( "search_knowledge", { "query": query, "role_uuid": role_uuid, }, ) await self.send(json.dumps({ "type": "tool_result", "message": f"Retrieved {len(result) if isinstance(result, list) else 0} knowledge chunk(s)", "content": result, "timestamp": timezone.now().isoformat(), })) return result if isinstance(result, list) else [] except Exception as exc: await self.send_log("error", f"Knowledge retrieval failed for topic '{topic}': {str(exc)}") return [] def format_knowledge_context(self, knowledge_hits): if not knowledge_hits: return "No indexed MCP documents found for this role/topic." lines = [] for idx, item in enumerate(knowledge_hits[:5]): source = item.get("source", "Unknown Source") if isinstance(item, dict) else "Unknown Source" relevance = item.get("relevance") if isinstance(item, dict) else None content = item.get("content", "") if isinstance(item, dict) else "" safe_content = str(content).strip()[:1600] lines.append( f"[{idx + 1}] Source: {source} | Relevance: {relevance}\n{safe_content}" ) return "\n\n".join(lines) def _coerce_list_payload(self, payload): if isinstance(payload, list): return payload if isinstance(payload, dict): for key in ('questions', 'items', 'fields', 'quiz'): value = payload.get(key) if isinstance(value, list): return value return [] def _extract_json_list(self, text): """Extracts a JSON list from model output, tolerating wrappers and markdown fences.""" if not text: return [] candidate_texts = [str(text).strip()] for block in re.findall(r'```(?:json)?\s*([\s\S]*?)```', str(text), re.IGNORECASE): candidate_texts.append(block.strip()) decoder = json.JSONDecoder() for candidate in candidate_texts: if not candidate: continue try: parsed = json.loads(candidate) coerced = self._coerce_list_payload(parsed) if coerced: return coerced except Exception: pass for idx, char in enumerate(candidate): if char not in '[{': continue try: parsed, _ = decoder.raw_decode(candidate[idx:]) except Exception: continue coerced = self._coerce_list_payload(parsed) if coerced: return coerced return [] def _sanitize_quiz_fields(self, raw_fields): sanitized = [] seen_keys = set() for index, field in enumerate(raw_fields or []): if not isinstance(field, dict): continue key = str(field.get('key') or f'final_quiz_q_{index + 1}').strip().lower().replace(' ', '_') if not key: key = f'final_quiz_q_{index + 1}' if key in seen_keys: key = f'{key}_{index + 1}' seen_keys.add(key) label = str(field.get('label') or '').strip() if not label: continue raw_options = field.get('options') if isinstance(field.get('options'), list) else [] options = [] for option in raw_options: option_text = str(option).strip() if option_text and option_text not in options: options.append(option_text) if len(options) < 2: continue validation = field.get('validation') if isinstance(field.get('validation'), dict) else {} correct_option = str(validation.get('correct_option') or '').strip() if correct_option not in options: correct_option = options[0] sanitized.append({ 'key': key, 'label': label, 'field_type': 'select', 'options': options[:5], 'required': True, 'validation': { 'correct_option': correct_option, 'explanation': str(validation.get('explanation') or ''), }, }) return sanitized def _build_fallback_quiz_fields(self, topics): safe_topics = [str(topic).strip() for topic in (topics or []) if str(topic).strip()] if not safe_topics: safe_topics = ['onboarding fundamentals'] fallback_fields = [] for index in range(8): topic = safe_topics[index % len(safe_topics)] key = f'final_quiz_q_{index + 1}' correct = f"Use documented best practices for {topic}." options = [ correct, f"Skip review steps for {topic} to move faster.", f"Rely only on assumptions instead of evidence for {topic}.", f"Ignore quality and compliance checks in {topic} tasks.", ] fallback_fields.append({ 'key': key, 'label': f"Which approach is most appropriate when working on {topic}?", 'field_type': 'select', 'options': options, 'required': True, 'validation': { 'correct_option': correct, 'explanation': f"{correct} balances reliability, quality, and role expectations.", }, }) return fallback_fields def _normalize_structure(self, structure): normalized_pages = [] for index, page in enumerate(structure or []): fields = [] for field_index, field in enumerate(page.get('fields', []) if isinstance(page, dict) else []): if not isinstance(field, dict): continue key = str(field.get('key') or f'field_{field_index + 1}') raw_options = field.get('options') if isinstance(field.get('options'), list) else [] options = [str(option) for option in raw_options if str(option).strip()] validation = field.get('validation') if isinstance(field.get('validation'), dict) else {} correct_option = validation.get('correct_option') if correct_option is not None: correct_option = str(correct_option) normalized_validation = { 'correct_option': correct_option if correct_option in options else None, 'explanation': str(validation.get('explanation') or ''), } fields.append({ 'uuid': str(uuid4()), 'key': key, 'label': str(field.get('label') or key.replace('_', ' ').title()), 'field_type': str(field.get('field_type') or 'text'), 'required': bool(field.get('required', False)), 'options': options, 'default_value': field.get('default_value', ''), 'validation': normalized_validation, }) page_title = page.get('title') if isinstance(page, dict) else None page_body = page.get('body') if isinstance(page, dict) else '' page_order = page.get('order') if isinstance(page, dict) else index normalized_pages.append({ 'uuid': str(uuid4()), 'title': str(page_title or f'Module {index + 1}'), 'body': str(page_body or ''), 'order': int(page_order if isinstance(page_order, int) else index), 'fields': fields, 'meta': page.get('meta') if isinstance(page.get('meta'), dict) else {}, }) return normalized_pages @database_sync_to_async def save_full_flow(self, role_uuid, structure): """Saves the final nested structure to the OnboardingFlow model.""" from apps.accounts.models import Role role = Role.objects.get(uuid=role_uuid) normalized_structure = self._normalize_structure(structure) flow, _ = OnboardingFlow.objects.update_or_create( role=role, defaults={ 'title': f"AI Onboarding: {role.name}", 'structure': normalized_structure, 'is_active': True } ) return flow async def send_log(self, log_type, message, content=None): await self.send(json.dumps({ "type": log_type, "message": message, "content": content, "timestamp": timezone.now().isoformat() })) @database_sync_to_async def get_config(self, config_uuid): return AgentConfig.objects.get(uuid=config_uuid) @database_sync_to_async def get_config_for_user(self, config_uuid, user_id): return AgentConfig.objects.filter( uuid=config_uuid, ).filter( Q(organization__owner__id=user_id) | Q(organization__members__id=user_id) ).first() @database_sync_to_async def can_access_role(self, role_uuid, user_id): from apps.accounts.models import Role role = Role.objects.filter(uuid=role_uuid).first() if role is None: return False if role.organization.owner.id == user_id: return True return role.organization.members.filter(id=user_id).exists() @database_sync_to_async def can_manage_role(self, role_uuid, user_id): from apps.accounts.models import Role, User role = Role.objects.filter(uuid=role_uuid).first() user = User.objects.filter(id=user_id).first() if role is None or user is None: return False if role.organization.owner.id == user_id: return True return bool(user.is_manager) and role.organization.members.filter(id=user_id).exists() @database_sync_to_async def get_config_by_type(self, role_uuid, agent_type): role_specific = AgentConfig.objects.filter( role__uuid=role_uuid, agent_type=agent_type, ).order_by('-updated_at').first() if role_specific: return role_specific return AgentConfig.objects.filter( organization__roles__uuid=role_uuid, role__isnull=True, agent_type=agent_type, ).order_by('-updated_at').first() @database_sync_to_async def get_role_progress_context(self, role_uuid, user_id): from apps.accounts.models import Role role = Role.objects.get(uuid=role_uuid) sessions = OnboardingSession.objects.filter(user_id=user_id, role=role).order_by('-updated_at') latest_session = sessions.first() active_flow = OnboardingFlow.objects.filter(role=role, is_active=True).order_by('-updated_at').first() if not latest_session: return { "role_uuid": str(role.uuid), "role_name": role.name, "latest_status": "not_started", "session_count": 0, "flow_exists": bool(active_flow), "progress": 0, "responses_count": 0, "completed_modules": [], } state = latest_session.state or {} responses = state.get("responses", {}) completed_modules = state.get("completed_modules", []) progress = state.get("progress_percentage", state.get("progress", 0)) return { "role_uuid": str(role.uuid), "role_name": role.name, "latest_status": latest_session.status, "session_count": sessions.count(), "flow_exists": bool(active_flow), "progress": progress, "responses_count": len(responses) if isinstance(responses, dict) else 0, "completed_modules": completed_modules if isinstance(completed_modules, list) else [], "updated_at": latest_session.updated_at.isoformat() if latest_session.updated_at else None, }