diff --git a/apps/onboarding/consumers.py b/apps/onboarding/consumers.py index e3cab24..ec5f49b 100644 --- a/apps/onboarding/consumers.py +++ b/apps/onboarding/consumers.py @@ -1,12 +1,14 @@ import json -import httpx -import re import logging +import re from uuid import uuid4 -from channels.generic.websocket import AsyncWebsocketConsumer + +import httpx from channels.db import database_sync_to_async -from django.utils import timezone +from channels.generic.websocket import AsyncWebsocketConsumer from django.conf import settings +from django.db.models import Q +from django.utils import timezone from apps.onboarding.mcp import MCPRouter from apps.onboarding.models import AgentConfig, OnboardingFlow, OnboardingSession @@ -46,21 +48,35 @@ class OnboardingConsumer(AsyncWebsocketConsumer): if not role_uuid: await self.send_log("error", "Missing role_uuid for full onboarding generation") return + if not await self.can_manage_role(role_uuid, self.user.id): + await self.send_log("error", "Forbidden") + return await self.run_full_onboarding_generation(role_uuid) elif action == "progress_monitor": role_uuid = data.get("role_uuid") or self.context_uuid if not role_uuid: await self.send_log("error", "Missing role_uuid for progress monitoring") return + if not await self.can_access_role(role_uuid, self.user.id): + await self.send_log("error", "Forbidden") + return await self.run_progress_monitor(role_uuid) else: user_message = data.get("query") or data.get("message") + requested_max_tokens = data.get("max_tokens") if not user_message: await self.send_log("error", "Missing query/message payload") return - config = await self.get_config(self.context_uuid) - ai_response = await self.orchestrate_ai(user_message, config) + config = await self.get_config_for_user(self.context_uuid, self.user.id) + if config is None: + await self.send_log("error", "Forbidden") + return + ai_response = await self.orchestrate_ai( + user_message, + config, + max_tokens=requested_max_tokens, + ) await self.send(json.dumps({ "type": "completed", @@ -91,16 +107,19 @@ class OnboardingConsumer(AsyncWebsocketConsumer): "Output ONLY a valid JSON array of 3-5 strings representing module titles. " "Example: [\"Introduction\", \"Safety\", \"Operations\"]" ) - ca_response = await self.orchestrate_ai(ca_prompt, ca_config) + ca_response = await self.orchestrate_ai( + ca_prompt, + ca_config, + min_internal_turns=1, + max_tokens=384, + ) topics = self._extract_json_list(ca_response) if not topics: await self.send_log("error", "Curriculum generation returned no topics") return - toc_lines = [f"{idx + 1}. {title}" for idx, title in enumerate(topics)] - toc_markdown = "## Table of Contents\n\n" + "\n".join(toc_lines) - full_structure = [] + module_briefs = [] for index, topic in enumerate(topics): @@ -117,40 +136,92 @@ class OnboardingConsumer(AsyncWebsocketConsumer): page_content = await self.orchestrate_ai( ( f"Write a practical onboarding training guide for the topic '{topic}'. " - "Use the MCP search context provided below as the primary source. " - "If the context is empty, provide a concise best-practice overview and clearly say no indexed documents were found. " - "Use Markdown formatting and do NOT include a table of contents in this section.\n\n" + "Think step-by-step internally before writing the final answer. " + "Use the MCP search context below as your primary source, and call additional tools if needed. " + "If no indexed documents are available, provide a concise best-practice overview and clearly say no indexed documents were found. " + "Use Markdown formatting and do NOT include a table of contents in this section. " + "Generate substantial depth: target 900-1400 words. " + "Include these sections in order: Overview, Core Concepts, Role-Specific Workflow, Practical Examples, Common Pitfalls, and Action Checklist. " + "In Practical Examples, provide at least 2 concrete examples relevant to this role/topic. " + "In Action Checklist, provide at least 8 actionable checklist items.\n\n" f"Role UUID: {role_uuid}\n" + f"Topic: {topic}\n" f"MCP search context:\n{context_markdown}" ), - ka_config + ka_config, + min_internal_turns=2, + max_tokens=2400, ) - if index == 0: - page_content = f"{toc_markdown}\n\n---\n\n{page_content}" - - await self.send_log("status", f"Phase 3: Creating quiz for {topic}...", "assessment") - aa_config = await self.get_config_by_type(role_uuid, 'assessment') - if not aa_config: - await self.send_log("error", "Missing assessment AgentConfig for this role") - return - aa_prompt = ( - f"Based on this content: '{page_content[:1000]}', create 2 multiple choice questions. " - "Output ONLY a JSON array of objects with keys: 'key', 'label', 'field_type' (use 'select'), " - "'options' (array of strings), and 'required' (true)." - ) - quiz_response = await self.orchestrate_ai(aa_prompt, aa_config) - quiz_fields = self._extract_json_list(quiz_response) - - full_structure.append({ "title": topic, "body": page_content, "order": index, - "fields": quiz_fields + "fields": [], + "meta": { + "topic_index": index, + "table_of_contents": [str(item) for item in topics], + }, }) + module_briefs.append({ + "topic": str(topic), + "summary_excerpt": str(page_content)[:1200], + }) + + await self.send_log("status", "Phase 3: Creating final assessment quiz...", "assessment") + aa_config = await self.get_config_by_type(role_uuid, 'assessment') + if not aa_config: + await self.send_log("error", "Missing assessment AgentConfig for this role") + return + + quiz_prompt = ( + "Create a final onboarding quiz that assesses all generated modules. " + "Output ONLY a valid JSON array of 8 multiple-choice question objects. " + "Each object MUST include: 'key' (snake_case), 'label', 'field_type' ('select'), " + "'options' (array of 4 unique strings), 'required' (true), and 'validation' with " + "'correct_option' (exactly matching one option) and 'explanation' (short rationale). " + "Cover all topics with balanced difficulty and avoid ambiguous choices.\n\n" + f"Modules JSON:\n{json.dumps(module_briefs, ensure_ascii=False)}" + ) + quiz_response = await self.orchestrate_ai( + quiz_prompt, + aa_config, + min_internal_turns=1, + max_tokens=1600, + ) + quiz_fields = self._sanitize_quiz_fields(self._extract_json_list(quiz_response)) + + if not quiz_fields: + await self.send_log("status", "Assessment output invalid, retrying quiz generation...", "assessment") + retry_response = await self.orchestrate_ai( + f"{quiz_prompt}\n\nReturn ONLY raw JSON. Do not use markdown fences. Do not include explanations outside JSON.", + aa_config, + min_internal_turns=1, + max_tokens=1600, + ) + quiz_fields = self._sanitize_quiz_fields(self._extract_json_list(retry_response)) + + if not quiz_fields: + await self.send_log("status", "Assessment output still invalid. Using fallback final quiz.", "assessment") + quiz_fields = self._build_fallback_quiz_fields([str(topic) for topic in topics]) + + full_structure.append({ + "title": "Final Assessment Quiz", + "body": ( + "### Final Quiz\n" + "Answer all questions below. You need **80%** to complete onboarding. " + "You can update answers and submit when ready." + ), + "order": len(full_structure), + "fields": quiz_fields, + "meta": { + "page_type": "final_quiz", + "pass_mark": 80, + }, + }) + await self.save_full_flow(role_uuid, full_structure) @@ -179,7 +250,12 @@ class OnboardingConsumer(AsyncWebsocketConsumer): f"Progress context JSON:\n{json.dumps(progress_context)}" ) - feedback = await self.orchestrate_ai(monitor_prompt, monitor_config) + feedback = await self.orchestrate_ai( + monitor_prompt, + monitor_config, + min_internal_turns=1, + max_tokens=640, + ) await self.send(json.dumps({ "type": "completed", @@ -192,7 +268,14 @@ class OnboardingConsumer(AsyncWebsocketConsumer): } })) - async def orchestrate_ai(self, user_message, config): + async def orchestrate_ai( + self, + user_message, + config, + min_internal_turns=2, + max_turns=6, + max_tokens=None, + ): """ Handles the multi-turn ReAct loop (Reasoning + Tool Use). """ @@ -201,18 +284,34 @@ class OnboardingConsumer(AsyncWebsocketConsumer): {"role": "user", "content": user_message} ] + llm_config = config.llm_config if isinstance(config.llm_config, dict) else {} + + resolved_max_tokens = max_tokens + if resolved_max_tokens is None: + resolved_max_tokens = llm_config.get("max_tokens", 1024) + + try: + resolved_max_tokens = max(64, int(resolved_max_tokens)) + except Exception: + resolved_max_tokens = 1024 + + last_content = "" + min_internal_turns = max(1, int(min_internal_turns or 1)) + max_turns = max(min_internal_turns, int(max_turns or 1)) + async with httpx.AsyncClient(timeout=60.0) as client: - for turn in range(5): + for turn in range(max_turns): await self.send_log("thought", f"Agent is thinking (Turn {turn+1})...") try: response = await client.post( f"{settings.INFERENCE_URL}/v1/chat/completions", json={ - "model": config.llm_config.get("model_id", "meta-llama-3.1-8b"), + "model": llm_config.get("model_id", "meta-llama-3.1-8b"), "messages": messages, "tools": self.router.get_tool_definitions(), - "tool_choice": "auto" + "tool_choice": "auto", + "max_tokens": resolved_max_tokens, } ) response.raise_for_status() @@ -244,12 +343,27 @@ class OnboardingConsumer(AsyncWebsocketConsumer): continue else: - return ai_message["content"] + last_content = str(ai_message.get("content") or "").strip() + + if (turn + 1) < min_internal_turns: + messages.append({ + "role": "user", + "content": ( + "Run one more internal reasoning pass before finalizing. " + "If additional evidence is needed, call tools. " + "Then return only the improved final answer." + ), + }) + continue + + return last_content except Exception as e: await self.send_log("error", f"Inference failed: {str(e)}") return f"Error: {str(e)}" + return last_content + async def fetch_knowledge_context(self, role_uuid, topic): @@ -297,18 +411,133 @@ class OnboardingConsumer(AsyncWebsocketConsumer): return "\n\n".join(lines) + def _coerce_list_payload(self, payload): + if isinstance(payload, list): + return payload + if isinstance(payload, dict): + for key in ('questions', 'items', 'fields', 'quiz'): + value = payload.get(key) + if isinstance(value, list): + return value + return [] + def _extract_json_list(self, text): - """Regex helper to pull JSON out of LLM conversational filler.""" - try: - if not text: - return [] - match = re.search(r'\[.*\]', text, re.DOTALL) - if match: - return json.loads(match.group()) - return [] - except Exception: + """Extracts a JSON list from model output, tolerating wrappers and markdown fences.""" + if not text: return [] + candidate_texts = [str(text).strip()] + + for block in re.findall(r'```(?:json)?\s*([\s\S]*?)```', str(text), re.IGNORECASE): + candidate_texts.append(block.strip()) + + decoder = json.JSONDecoder() + + for candidate in candidate_texts: + if not candidate: + continue + + try: + parsed = json.loads(candidate) + coerced = self._coerce_list_payload(parsed) + if coerced: + return coerced + except Exception: + pass + + for idx, char in enumerate(candidate): + if char not in '[{': + continue + try: + parsed, _ = decoder.raw_decode(candidate[idx:]) + except Exception: + continue + + coerced = self._coerce_list_payload(parsed) + if coerced: + return coerced + + return [] + + def _sanitize_quiz_fields(self, raw_fields): + sanitized = [] + seen_keys = set() + + for index, field in enumerate(raw_fields or []): + if not isinstance(field, dict): + continue + + key = str(field.get('key') or f'final_quiz_q_{index + 1}').strip().lower().replace(' ', '_') + if not key: + key = f'final_quiz_q_{index + 1}' + + if key in seen_keys: + key = f'{key}_{index + 1}' + seen_keys.add(key) + + label = str(field.get('label') or '').strip() + if not label: + continue + + raw_options = field.get('options') if isinstance(field.get('options'), list) else [] + options = [] + for option in raw_options: + option_text = str(option).strip() + if option_text and option_text not in options: + options.append(option_text) + + if len(options) < 2: + continue + + validation = field.get('validation') if isinstance(field.get('validation'), dict) else {} + correct_option = str(validation.get('correct_option') or '').strip() + if correct_option not in options: + correct_option = options[0] + + sanitized.append({ + 'key': key, + 'label': label, + 'field_type': 'select', + 'options': options[:5], + 'required': True, + 'validation': { + 'correct_option': correct_option, + 'explanation': str(validation.get('explanation') or ''), + }, + }) + + return sanitized + + def _build_fallback_quiz_fields(self, topics): + safe_topics = [str(topic).strip() for topic in (topics or []) if str(topic).strip()] + if not safe_topics: + safe_topics = ['onboarding fundamentals'] + + fallback_fields = [] + for index in range(8): + topic = safe_topics[index % len(safe_topics)] + key = f'final_quiz_q_{index + 1}' + correct = f"Use documented best practices for {topic}." + options = [ + correct, + f"Skip review steps for {topic} to move faster.", + f"Rely only on assumptions instead of evidence for {topic}.", + f"Ignore quality and compliance checks in {topic} tasks.", + ] + fallback_fields.append({ + 'key': key, + 'label': f"Which approach is most appropriate when working on {topic}?", + 'field_type': 'select', + 'options': options, + 'required': True, + 'validation': { + 'correct_option': correct, + 'explanation': f"{correct} balances reliability, quality, and role expectations.", + }, + }) + + return fallback_fields + def _normalize_structure(self, structure): normalized_pages = [] for index, page in enumerate(structure or []): @@ -317,14 +546,28 @@ class OnboardingConsumer(AsyncWebsocketConsumer): if not isinstance(field, dict): continue key = str(field.get('key') or f'field_{field_index + 1}') + raw_options = field.get('options') if isinstance(field.get('options'), list) else [] + options = [str(option) for option in raw_options if str(option).strip()] + + validation = field.get('validation') if isinstance(field.get('validation'), dict) else {} + correct_option = validation.get('correct_option') + if correct_option is not None: + correct_option = str(correct_option) + + normalized_validation = { + 'correct_option': correct_option if correct_option in options else None, + 'explanation': str(validation.get('explanation') or ''), + } + fields.append({ 'uuid': str(uuid4()), 'key': key, 'label': str(field.get('label') or key.replace('_', ' ').title()), 'field_type': str(field.get('field_type') or 'text'), 'required': bool(field.get('required', False)), - 'options': field.get('options') if isinstance(field.get('options'), list) else [], + 'options': options, 'default_value': field.get('default_value', ''), + 'validation': normalized_validation, }) page_title = page.get('title') if isinstance(page, dict) else None @@ -336,6 +579,7 @@ class OnboardingConsumer(AsyncWebsocketConsumer): 'body': str(page_body or ''), 'order': int(page_order if isinstance(page_order, int) else index), 'fields': fields, + 'meta': page.get('meta') if isinstance(page.get('meta'), dict) else {}, }) return normalized_pages @@ -367,10 +611,54 @@ class OnboardingConsumer(AsyncWebsocketConsumer): def get_config(self, config_uuid): return AgentConfig.objects.get(uuid=config_uuid) + @database_sync_to_async + def get_config_for_user(self, config_uuid, user_id): + return AgentConfig.objects.filter( + uuid=config_uuid, + ).filter( + Q(organization__owner__id=user_id) | Q(organization__members__id=user_id) + ).first() + + @database_sync_to_async + def can_access_role(self, role_uuid, user_id): + from apps.accounts.models import Role + + role = Role.objects.filter(uuid=role_uuid).first() + if role is None: + return False + + if role.organization.owner.id == user_id: + return True + + return role.organization.members.filter(id=user_id).exists() + + @database_sync_to_async + def can_manage_role(self, role_uuid, user_id): + from apps.accounts.models import Role, User + + role = Role.objects.filter(uuid=role_uuid).first() + user = User.objects.filter(id=user_id).first() + if role is None or user is None: + return False + + if role.organization.owner.id == user_id: + return True + + return bool(user.is_manager) and role.organization.members.filter(id=user_id).exists() + @database_sync_to_async def get_config_by_type(self, role_uuid, agent_type): + role_specific = AgentConfig.objects.filter( + role__uuid=role_uuid, + agent_type=agent_type, + ).order_by('-updated_at').first() + + if role_specific: + return role_specific + return AgentConfig.objects.filter( organization__roles__uuid=role_uuid, + role__isnull=True, agent_type=agent_type, ).order_by('-updated_at').first()