Added mcp rag, onboarding app, frontend for learning and agent inference

This commit is contained in:
Viswamedha Nalabotu 2026-02-08 15:34:26 +00:00
parent a12d5f906c
commit 2f7b2001d4
41 changed files with 3305 additions and 134 deletions

View file

@ -1,6 +1,6 @@
from django.contrib import admin from django.contrib import admin
from django.contrib.admin import ModelAdmin, TabularInline from django.contrib.admin import ModelAdmin, TabularInline
from apps.mlstore.models import AgentModel, AgentRun, Agent, AgentEvent from apps.mlstore.models import AgentModel, AgentRun, Agent, AgentEvent, RoleRagDocument
class AgentInline(TabularInline): class AgentInline(TabularInline):
@ -55,3 +55,12 @@ class AgentEventAdmin(ModelAdmin):
search_fields = ('event_type', 'execution__uuid', 'execution__agent__model__name') search_fields = ('event_type', 'execution__uuid', 'execution__agent__model__name')
list_filter = ('event_type',) list_filter = ('event_type',)
raw_id_fields = ('execution',) raw_id_fields = ('execution',)
@admin.register(RoleRagDocument)
class RoleRagDocumentAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'role', 'training_file', 'chunk_index', 'is_active', 'created_at')
search_fields = ('role__name', 'training_file__file_name')
list_filter = ('is_active', 'created_at')
raw_id_fields = ('role', 'training_file')
readonly_fields = ('uuid', 'created_at', 'updated_at')

View file

@ -41,6 +41,8 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
await self.handle_fine_tune(data) await self.handle_fine_tune(data)
elif action == "infer": elif action == "infer":
await self.handle_infer(data) await self.handle_infer(data)
elif action == "onboarding_progress":
await self.handle_onboarding_progress(data)
elif action in ("stop_agent", "stop"): elif action in ("stop_agent", "stop"):
await self.handle_stop(data) await self.handle_stop(data)
else: else:
@ -90,6 +92,16 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
return return
input_data = data.get("input_data") or {} input_data = data.get("input_data") or {}
role_uuid = input_data.get("role_uuid")
if not role_uuid:
options = input_data.get("options") or {}
role_uuid = options.get("role_uuid")
if not role_uuid:
await self.send(json.dumps({
"type": "error",
"message": "role_uuid is required for inference to enable RAG"
}))
return
execution = await self.create_run(agent, self.user, input_data) execution = await self.create_run(agent, self.user, input_data)
await self.send(json.dumps({ await self.send(json.dumps({
@ -125,6 +137,35 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
"error_message": "Execution stopped by user" "error_message": "Execution stopped by user"
})) }))
async def handle_onboarding_progress(self, data):
execution_id = data.get("execution_id")
if not execution_id:
await self.send(json.dumps({
"type": "error",
"message": "execution_id required for onboarding_progress"
}))
return
execution = await self.get_execution(execution_id)
if not execution:
await self.send(json.dumps({
"type": "error",
"message": "Execution not found"
}))
return
content = data.get("content") or data.get("progress") or {}
await self.create_event(execution, "progress", content)
await self.channel_layer.group_send(
self.room_group_name,
{
"type": "mlstore_event",
"event_type": "progress",
"content": content,
"timestamp": timezone.now().isoformat(),
}
)
async def mlstore_event(self, event): async def mlstore_event(self, event):
await self.send(json.dumps({ await self.send(json.dumps({
"type": "mlstore_event", "type": "mlstore_event",

View file

@ -2,6 +2,22 @@ import django.db.models.deletion
import uuid import uuid
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
from pgvector.django import VectorField
def _create_vector_extension(apps, schema_editor):
if schema_editor.connection.vendor != 'postgresql':
return
with schema_editor.connection.cursor() as cursor:
cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
def _drop_vector_extension(apps, schema_editor):
if schema_editor.connection.vendor != 'postgresql':
return
with schema_editor.connection.cursor() as cursor:
cursor.execute('DROP EXTENSION IF EXISTS vector')
class Migration(migrations.Migration): class Migration(migrations.Migration):
@ -13,6 +29,10 @@ class Migration(migrations.Migration):
] ]
operations = [ operations = [
migrations.RunPython(
code=_create_vector_extension,
reverse_code=_drop_vector_extension,
),
migrations.CreateModel( migrations.CreateModel(
name='AgentModel', name='AgentModel',
fields=[ fields=[
@ -54,7 +74,7 @@ class Migration(migrations.Migration):
('id', models.BigAutoField(primary_key=True, serialize=False)), ('id', models.BigAutoField(primary_key=True, serialize=False)),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)), ('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('status', models.CharField(choices=[('queued', 'Queued'), ('running', 'Running'), ('completed', 'Completed'), ('failed', 'Failed')], default='queued', max_length=20)), ('status', models.CharField(choices=[('queued', 'Queued'), ('running', 'Running'), ('completed', 'Completed'), ('failed', 'Failed')], default='queued', max_length=20)),
('input_data', models.JSONField(default=dict)), ('input_data', models.JSONField(blank=True, default=dict)),
('output_data', models.JSONField(blank=True, default=dict)), ('output_data', models.JSONField(blank=True, default=dict)),
('error_message', models.TextField(blank=True, default='')), ('error_message', models.TextField(blank=True, default='')),
('started_at', models.DateTimeField(blank=True, null=True)), ('started_at', models.DateTimeField(blank=True, null=True)),
@ -83,4 +103,25 @@ class Migration(migrations.Migration):
'ordering': ['timestamp'], 'ordering': ['timestamp'],
}, },
), ),
migrations.CreateModel(
name='RoleRagDocument',
fields=[
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
('id', models.BigAutoField(primary_key=True, serialize=False)),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('content', models.TextField()),
('content_hash', models.CharField(db_index=True, max_length=64)),
('embedding', VectorField(blank=True, dimensions=1536, null=True)),
('metadata', models.JSONField(blank=True, default=dict)),
('chunk_index', models.IntegerField(default=0)),
('is_active', models.BooleanField(default=True)),
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='orgs.role')),
('training_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='orgs.trainingfile')),
],
options={
'verbose_name': 'Role RAG Document',
'verbose_name_plural': 'Role RAG Documents',
},
),
] ]

View file

@ -1,15 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('mlstore', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='agentrun',
name='input_data',
field=models.JSONField(blank=True, default=dict),
),
]

View file

@ -1,7 +1,8 @@
from django.db.models import BigAutoField, CASCADE, CharField, DateTimeField, ForeignKey, JSONField, Model, TextField, UUIDField from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, JSONField, Model, TextField, UUIDField, IntegerField
from pgvector.django import VectorField
from apps.users.mixins import TimeStampMixin from apps.users.mixins import TimeStampMixin
from apps.users.models import User from apps.users.models import User
from apps.orgs.models import Organization from apps.orgs.models import Organization, Role, TrainingFile
from uuid import uuid4 from uuid import uuid4
class AgentModel(Model): class AgentModel(Model):
@ -100,3 +101,25 @@ class AgentEvent(Model):
ordering = ['timestamp'] ordering = ['timestamp']
verbose_name = "Agent Event" verbose_name = "Agent Event"
verbose_name_plural = "Agent Events" verbose_name_plural = "Agent Events"
class RoleRagDocument(TimeStampMixin, Model):
id = BigAutoField(primary_key = True)
uuid = UUIDField(default = uuid4, editable = False, unique = True)
role = ForeignKey(Role, on_delete = CASCADE, related_name = 'rag_documents')
training_file = ForeignKey(TrainingFile, on_delete = CASCADE, related_name = 'rag_documents', null = True, blank = True)
content = TextField()
content_hash = CharField(max_length = 64, db_index = True)
embedding = VectorField(dimensions = 1536, null = True, blank = True)
metadata = JSONField(default = dict, blank = True)
chunk_index = IntegerField(default = 0)
is_active = BooleanField(default = True)
class Meta:
verbose_name = "Role RAG Document"
verbose_name_plural = "Role RAG Documents"
def __str__(self) -> str:
return f"{self.role.name} - chunk {self.chunk_index}"

View file

@ -1,5 +1,6 @@
from rest_framework.serializers import ModelSerializer from rest_framework.serializers import ModelSerializer
from .models import AgentModel, Agent, AgentRun, AgentEvent from .models import AgentModel, Agent, AgentRun, AgentEvent
from apps.orgs.serializers import OrganizationSerializer
class AgentModelSerializer(ModelSerializer): class AgentModelSerializer(ModelSerializer):
@ -11,6 +12,7 @@ class AgentModelSerializer(ModelSerializer):
class AgentSerializer(ModelSerializer): class AgentSerializer(ModelSerializer):
model = AgentModelSerializer(read_only=True) model = AgentModelSerializer(read_only=True)
organization = OrganizationSerializer(read_only=True)
class Meta: class Meta:
model = Agent model = Agent

View file

@ -1,19 +1,18 @@
import asyncio import asyncio
import logging import logging
import os import os
from typing import Any, Dict, List, Optional import re
from typing import Any, Dict, List, Optional, Tuple
from django.conf import settings from django.conf import settings
from mcp_agent.mcp_client import MCPClient from mcp_agent.mcp_client import MCPClient
from .models import AgentModel from .models import AgentModel, RoleRagDocument
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Get reference to the base model cache directory
try: try:
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
except ImportError: except ImportError:
# Fallback: construct the path manually
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model") BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
@ -67,7 +66,6 @@ def fine_tune_model(
except Exception as e: except Exception as e:
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}" error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True) logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
# Return a failed response instead of raising
return { return {
"status": "failed", "status": "failed",
"error": error_msg, "error": error_msg,
@ -114,3 +112,294 @@ def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel
NOTE: migrations are required after the model field change prior to using this in production. NOTE: migrations are required after the model field change prior to using this in production.
""" """
return AgentModel.objects.create(name=name, version=version, path=model_path) return AgentModel.objects.create(name=name, version=version, path=model_path)
def embed_texts(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for texts using the MCP embedding service.
Falls back to local sentence-transformers if MCP unavailable.
Args:
texts: List of text strings to embed.
Returns:
List of embedding vectors (list of floats).
Raises:
RuntimeError: If both MCP and local embedding fail.
"""
logger.info(f"Embedding {len(texts)} texts")
try:
result = asyncio.run(_call_mcp("embed", {"texts": texts}))
embeddings = result.get("embeddings", [])
if embeddings and len(embeddings) == len(texts):
logger.info(f"Successfully embedded {len(texts)} texts via MCP")
return embeddings
except Exception as e:
logger.warning(f"MCP embedding failed, trying local fallback: {e}")
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts).tolist()
logger.info(f"Successfully embedded {len(texts)} texts via local model")
return embeddings
except Exception as e:
logger.error(f"Local embedding also failed: {e}")
raise RuntimeError(f"Failed to embed texts: {e}")
def embed_text(text: str) -> List[float]:
"""Generate embedding for a single text.
Args:
text: Text string to embed.
Returns:
Embedding vector (list of floats).
"""
return embed_texts([text])[0]
def search_similar_documents(
query: str,
role_uuid: str,
top_k: int = 5,
similarity_threshold: float = 0.0,
) -> List[Tuple[RoleRagDocument, float]]:
"""Search for documents similar to the query using vector similarity.
Args:
query: Query text to embed and search for.
role_uuid: UUID of role to scope search.
top_k: Number of top results to return.
similarity_threshold: Minimum similarity score (0-1) to include results.
Returns:
List of (RoleRagDocument, similarity_score) tuples, ordered by similarity DESC.
Raises:
ValueError: If role not found or embedding fails.
"""
from apps.orgs.models import Role
try:
query_embedding = embed_text(query)
logger.info(f"Embedded query: '{query[:50]}...' to {len(query_embedding)}D vector")
except Exception as e:
logger.error(f"Failed to embed query: {e}")
raise ValueError(f"Failed to embed query: {e}")
try:
role = Role.objects.get(uuid=role_uuid)
except Role.DoesNotExist:
raise ValueError(f"Role with UUID {role_uuid} not found")
queryset = RoleRagDocument.objects.filter(
role=role,
)
if not queryset.exists():
logger.warning(f"No documents with embeddings found for role {role.uuid}")
return []
from django.db import connection
with connection.cursor() as cursor:
query_sql = """
SELECT id, 1 - (embedding <=> %s::vector) as similarity
FROM mlstore_roleragdocument
WHERE role_id = %s AND embedding IS NOT NULL
ORDER BY similarity DESC
LIMIT %s
"""
cursor.execute(
query_sql,
)
doc_ids_with_scores = cursor.fetchall()
if not doc_ids_with_scores:
logger.info(f"No similar documents found for query in role {role.uuid}")
return []
filtered_docs = [
(doc_id, score)
for doc_id, score in doc_ids_with_scores
if score >= similarity_threshold
][:top_k]
if not filtered_docs:
logger.info(
f"No documents met similarity threshold {similarity_threshold}"
)
return []
doc_ids = [doc_id for doc_id, _ in filtered_docs]
doc_scores = {doc_id: score for doc_id, score in filtered_docs}
documents = RoleRagDocument.objects.filter(id__in=doc_ids)
results = [
(doc, doc_scores[doc.id])
for doc in documents
if doc.id in doc_scores
]
results.sort(key=lambda x: x[1], reverse=True)
logger.info(
f"Found {len(results)} similar documents for query "
f"(threshold={similarity_threshold}, top_k={top_k})"
)
return results
def batch_embed_documents(
documents: List[RoleRagDocument],
batch_size: int = 32,
force_reembed: bool = False,
) -> Tuple[int, int]:
"""Batch embed documents that don't have embeddings yet.
Args:
documents: List of RoleRagDocument instances to embed.
batch_size: Number of documents to embed per API call.
force_reembed: If True, re-embed documents that already have embeddings.
Returns:
Tuple of (num_embedded, num_failed).
Note:
Updates documents in-place with embedding values.
"""
to_embed = [
doc for doc in documents
if force_reembed or not doc.embedding
]
if not to_embed:
logger.info("No documents to embed")
return 0, 0
num_embedded = 0
num_failed = 0
for i in range(0, len(to_embed), batch_size):
batch = to_embed[i : i + batch_size]
logger.info(
f"Embedding batch {i // batch_size + 1} "
f"({len(batch)} documents)"
)
try:
texts = [doc.content for doc in batch]
embeddings = embed_texts(texts)
for doc, embedding in zip(batch, embeddings):
doc.embedding = embedding
num_embedded += 1
RoleRagDocument.objects.bulk_update(batch, ["embedding"], batch_size=500)
logger.info(f"Successfully embedded {len(batch)} documents")
except Exception as e:
logger.error(f"Failed to embed batch: {e}")
num_failed += len(batch)
logger.info(
f"Embedding complete: {num_embedded} embedded, {num_failed} failed"
)
return num_embedded, num_failed
def get_context_for_query(
query: str,
role_uuid: str,
top_k: int = 5,
similarity_threshold: float = 0.5,
) -> str:
"""Get context string from similar documents for a query.
Useful for augmenting prompts with retrieved context.
Args:
query: Query text.
role_uuid: UUID of role to search within.
top_k: Number of top results to include.
similarity_threshold: Minimum similarity score.
Returns:
Formatted context string with source attribution.
"""
def _clean_chunk_text(text: str) -> str:
"""Strip junk and deduplicate paragraphs to keep context lean."""
if not text:
return ""
text = re.sub(r"\[\s*Answer\s*:.*?\]", "", text, flags=re.IGNORECASE | re.DOTALL)
lines = []
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
lines.append("")
continue
lower = line.lower()
if line.startswith("#"):
continue
if "do you have any questions" in lower:
continue
if "feel free to ask" in lower:
continue
if "references" in lower or "sources" in lower or "wikipedia" in lower:
continue
lines.append(line)
cleaned = "\n".join(lines)
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", cleaned) if p.strip()]
seen = set()
unique_paragraphs: List[str] = []
for para in paragraphs:
if para in seen:
continue
seen.add(para)
unique_paragraphs.append(para)
return "\n\n".join(unique_paragraphs)
try:
results = search_similar_documents(
query=query,
role_uuid=role_uuid,
top_k=top_k,
similarity_threshold=similarity_threshold,
)
except Exception as e:
logger.warning(f"Failed to retrieve context: {e}")
return ""
if not results:
return ""
context_parts = []
for doc, similarity in results:
cleaned = _clean_chunk_text(doc.content)
if not cleaned:
continue
source = "unknown"
if doc.training_file:
source = doc.training_file.file_name
context_parts.append(
f"[Source: {source}, Similarity: {similarity:.2%}]\n{cleaned}\n"
)
context = "\n---\n".join(context_parts)
return context

View file

@ -1,16 +1,128 @@
import logging import logging
import os
import re
import time
import traceback import traceback
from hashlib import sha256
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from celery import shared_task from celery import shared_task
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.utils import timezone from django.utils import timezone
from django.db import transaction
from apps.orgs.models import TrainingFile from apps.orgs.models import TrainingFile, Role
from . import services from . import services
from .models import Agent, AgentEvent, AgentModel, AgentRun from .models import Agent, AgentEvent, AgentModel, AgentRun, RoleRagDocument
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_mem_info() -> str:
try:
with open('/proc/self/status', 'r', encoding='utf-8') as f:
lines = f.read().splitlines()
mem = {line.split(':', 1)[0]: line.split(':', 1)[1].strip() for line in lines if ':' in line}
return f"VmRSS={mem.get('VmRSS','?')}, VmHWM={mem.get('VmHWM','?')}, VmSize={mem.get('VmSize','?')}"
except Exception:
return "mem_info_unavailable"
def _estimate_tokens(text: str) -> int:
if not text:
return 0
return len(re.findall(r"\w+|[^\s\w]", text))
def _split_semantic_units(text: str) -> list[str]:
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
units: list[str] = []
for para in paragraphs:
sentences = re.split(r"(?<=[.!?])\s+", para)
for sent in sentences:
sent = sent.strip()
if sent:
units.append(sent)
return units or paragraphs
def _chunk_text(text: str, max_tokens: int = 400, overlap_tokens: int = 50) -> list[str]:
if not text:
return []
units = _split_semantic_units(text)
logger.info(
"Semantic chunking units=%s max_tokens=%s overlap_tokens=%s mem=%s",
len(units),
max_tokens,
overlap_tokens,
_get_mem_info(),
)
chunks: list[str] = []
current: list[str] = []
current_tokens = 0
for unit in units:
unit_tokens = _estimate_tokens(unit)
if unit_tokens == 0:
continue
if current_tokens + unit_tokens > max_tokens and current:
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
if overlap_tokens > 0:
overlap: list[str] = []
overlap_count = 0
for prev in reversed(current):
prev_tokens = _estimate_tokens(prev)
if overlap_count + prev_tokens > overlap_tokens:
break
overlap.insert(0, prev)
overlap_count += prev_tokens
current = overlap
current_tokens = overlap_count
else:
current = []
current_tokens = 0
current.append(unit)
current_tokens += unit_tokens
if current:
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
return chunks
def _extract_text_from_file(file_path: str, file_type: str | None) -> str:
file_type = (file_type or '').lower()
if file_type in {'txt', 'md', 'csv', 'json'}:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
if file_type == 'pdf':
try:
from PyPDF2 import PdfReader
except Exception as e:
raise RuntimeError('PyPDF2 is required to parse PDF files') from e
reader = PdfReader(file_path)
return "\n".join(page.extract_text() or "" for page in reader.pages)
if file_type in {'docx', 'doc'}:
try:
import docx
except Exception as e:
raise RuntimeError('python-docx is required to parse DOCX files') from e
doc = docx.Document(file_path)
return "\n".join(p.text for p in doc.paragraphs)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
def _send_group_event(room_group_name: str, event_type: str, content: dict): def _send_group_event(room_group_name: str, event_type: str, content: dict):
channel_layer = get_channel_layer() channel_layer = get_channel_layer()
async_to_sync(channel_layer.group_send)( async_to_sync(channel_layer.group_send)(
@ -68,11 +180,21 @@ def start_fine_tune_run_task(execution_id: str):
training_files = input_data.get("training_files") or [] training_files = input_data.get("training_files") or []
org_training_files = [] org_training_files = []
role_uuid = input_data.get("role_uuid")
if not training_files and agent.organization: if not training_files and agent.organization:
org_training_files = list(TrainingFile.objects.filter( training_files_qs = TrainingFile.objects.filter(
organization=agent.organization, role__organization=agent.organization,
is_processed=False is_processed=False
).select_related('uploaded_by')) ).select_related('uploaded_by', 'role')
if role_uuid:
try:
role = Role.objects.get(uuid=role_uuid, organization=agent.organization)
training_files_qs = training_files_qs.filter(role=role)
except Role.DoesNotExist:
logger.warning(f"Role {role_uuid} not found for organization {agent.organization.name}")
org_training_files = list(training_files_qs)
training_files = [tf.file.path for tf in org_training_files if tf.file] training_files = [tf.file.path for tf in org_training_files if tf.file]
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}") logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
@ -186,6 +308,197 @@ def start_fine_tune_run_task(execution_id: str):
return {"status": "error", "execution_id": execution_id, "error": str(e)} return {"status": "error", "execution_id": execution_id, "error": str(e)}
@shared_task
def ingest_training_file_task(training_file_uuid: str):
logger.info(f"Ingest task started for training_file_uuid={training_file_uuid}")
started_at = time.time()
try:
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
except TrainingFile.DoesNotExist:
logger.error(f"Training file not found: {training_file_uuid}")
return {"status": "error", "error": "training_file_not_found"}
if training_file.is_processed:
logger.info(f"Training file already processed: {training_file_uuid}")
return {"status": "skipped", "reason": "already_processed"}
if not training_file.file:
logger.error(f"Training file has no file attached: {training_file_uuid}")
return {"status": "error", "error": "file_missing"}
try:
file_path = training_file.file.path
file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0
logger.info(
"Ingesting file: name=%s type=%s size_bytes=%s role=%s path=%s",
training_file.file_name,
training_file.file_type,
file_size,
training_file.role_id,
file_path,
)
except Exception as e:
logger.warning(f"Failed to stat training file for {training_file_uuid}: {str(e)}")
try:
training_file.status = 'ingesting'
training_file.save(update_fields=['status'])
extract_started = time.time()
text = _extract_text_from_file(training_file.file.path, training_file.file_type)
logger.info(
"Extracted text length=%s for training_file_uuid=%s in %.2fs mem=%s",
len(text),
training_file_uuid,
time.time() - extract_started,
_get_mem_info(),
)
chunk_started = time.time()
chunks = _chunk_text(text)
logger.info(
"Chunked text into %s chunks in %.2fs (sample lengths: %s) mem=%s",
len(chunks),
time.time() - chunk_started,
[len(c) for c in chunks[:5]],
_get_mem_info(),
)
if not chunks:
raise RuntimeError("No text extracted from file")
with transaction.atomic():
logger.info("Clearing existing RAG docs for training_file_uuid=%s mem=%s", training_file_uuid, _get_mem_info())
RoleRagDocument.objects.filter(training_file=training_file).delete()
logger.info("Preparing %s RAG docs for bulk_create mem=%s", len(chunks), _get_mem_info())
existing_hashes = set(
RoleRagDocument.objects.filter(role=training_file.role)
.values_list('content_hash', flat=True)
)
documents = []
skipped = 0
for index, chunk in enumerate(chunks):
content_hash = sha256(chunk.encode('utf-8')).hexdigest()
if content_hash in existing_hashes:
skipped += 1
continue
documents.append(
RoleRagDocument(
role=training_file.role,
training_file=training_file,
content=chunk,
embedding=None,
content_hash=content_hash,
metadata={
"file_name": training_file.file_name,
"file_type": training_file.file_type,
"chunk_size": len(chunk),
"source": "training_file",
},
chunk_index=index,
)
)
logger.info("Skipped %s duplicate chunks based on content_hash", skipped)
logger.info("Bulk creating RAG docs count=%s mem=%s", len(documents), _get_mem_info())
RoleRagDocument.objects.bulk_create(documents, batch_size=500)
training_file.status = 'chunked'
training_file.is_processed = True
training_file.save(update_fields=['status', 'is_processed'])
elapsed = time.time() - started_at
logger.info(
"Ingested training file %s into %s RAG chunks in %.2fs",
training_file_uuid,
len(chunks),
elapsed,
)
logger.info(f"Enqueueing embedding task for training_file_uuid={training_file_uuid}")
embed_training_file_task.delay(training_file_uuid)
return {"status": "completed", "chunks": len(chunks)}
except Exception as e:
elapsed = time.time() - started_at
logger.error(f"Failed to ingest training file {training_file_uuid}: {str(e)}", exc_info=True)
logger.error(f"Ingest task failed after {elapsed:.2f}s for training_file_uuid={training_file_uuid}")
try:
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
except Exception:
pass
return {"status": "error", "error": str(e)}
@shared_task
def embed_training_file_task(training_file_uuid: str):
"""Generate embeddings for all documents of a training file.
This task is called after chunking to embed the document chunks
using the configured embedding provider (OpenAI, Google, or local).
"""
logger.info(f"Embedding task started for training_file_uuid={training_file_uuid}")
started_at = time.time()
try:
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
except TrainingFile.DoesNotExist:
logger.error(f"Training file not found: {training_file_uuid}")
return {"status": "error", "error": "training_file_not_found"}
try:
documents = list(RoleRagDocument.objects.filter(training_file=training_file))
if not documents:
logger.warning(f"No RAG documents found for training_file_uuid={training_file_uuid}")
return {"status": "skipped", "reason": "no_documents"}
logger.info(
f"Starting to embed {len(documents)} documents for training_file_uuid={training_file_uuid} "
f"mem={_get_mem_info()}"
)
num_embedded, num_failed = services.batch_embed_documents(documents, batch_size=32)
if num_failed == 0:
training_file.status = 'embedded'
training_file.save(update_fields=['status'])
logger.info(f"Successfully embedded all documents for training_file_uuid={training_file_uuid}")
elif num_embedded > 0:
training_file.status = 'embedded'
training_file.save(update_fields=['status'])
logger.warning(
f"Partially embedded {num_embedded} documents, {num_failed} failed "
f"for training_file_uuid={training_file_uuid}"
)
else:
training_file.status = 'failed'
training_file.save(update_fields=['status'])
logger.error(f"Failed to embed any documents for training_file_uuid={training_file_uuid}")
return {"status": "error", "error": "embedding_failed", "num_failed": num_failed}
elapsed = time.time() - started_at
logger.info(
f"Embedding task completed for {training_file_uuid}: "
f"embedded={num_embedded}, failed={num_failed}, time={elapsed:.2f}s"
)
return {
"status": "completed",
"num_embedded": num_embedded,
"num_failed": num_failed,
"elapsed": elapsed,
}
except Exception as e:
elapsed = time.time() - started_at
logger.error(
f"Failed to embed training file {training_file_uuid}: {str(e)}",
exc_info=True
)
try:
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
except Exception:
pass
return {"status": "error", "error": str(e), "elapsed": elapsed}
@shared_task @shared_task
def infer_run_task(execution_id: str): def infer_run_task(execution_id: str):
logger.info(f"Inference run task started for execution: {execution_id}") logger.info(f"Inference run task started for execution: {execution_id}")
@ -207,9 +520,70 @@ def infer_run_task(execution_id: str):
input_data = execution.input_data or {} input_data = execution.input_data or {}
prompt = input_data.get("prompt") or input_data.get("query") or "" prompt = input_data.get("prompt") or input_data.get("query") or ""
options = input_data.get("options") or {} options = dict(input_data.get("options") or {})
role_uuid = input_data.get("role_uuid") or options.get("role_uuid")
rag_top_k = int(input_data.get("rag_top_k", 5))
rag_similarity_threshold = float(input_data.get("rag_similarity_threshold", 0.5))
options.setdefault("temperature", 0.2)
options.setdefault("top_p", 0.9)
options.setdefault("max_tokens", 200)
options.setdefault("stop", ["\n\n", "References:", "Sources:"])
logger.info(f"Prompt length: {len(prompt)} characters") logger.info(f"Prompt length: {len(prompt)} characters")
if not role_uuid:
logger.warning(f"No role_uuid provided for inference run {execution_id}")
execution.status = "failed"
execution.error_message = "role_uuid_required"
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": "role_uuid_required",
},
)
return {"status": "failed", "execution_id": execution_id, "error": "role_uuid_required"}
if role_uuid and prompt:
try:
context = services.get_context_for_query(
query=prompt,
role_uuid=str(role_uuid),
top_k=rag_top_k,
similarity_threshold=rag_similarity_threshold,
)
if context:
logger.info(f"RAG context retrieved for role={role_uuid} (top_k={rag_top_k})")
prompt = (
"You are a technical assistant.\n\n"
"Answer the question using ONLY the information in the context.\n"
"Do NOT:\n"
"- ask follow-up questions\n"
"- include hashtags\n"
"- include references or sources\n"
"- repeat the question\n"
"- add headings or sections\n"
"- add information not present in the context\n\n"
"Answer in 3-6 concise sentences.\n"
"If the context is insufficient, say: \"The context does not provide enough information.\"\n\n"
"Context:\n"
f"{context}\n\n"
"Question:\n"
f"{prompt}\n\n"
"Answer:"
)
else:
logger.info(f"No RAG context found for role={role_uuid}")
except Exception as e:
logger.warning(f"RAG context retrieval failed for role={role_uuid}: {e}")
if not prompt: if not prompt:
logger.warning(f"No prompt provided for inference run {execution_id}") logger.warning(f"No prompt provided for inference run {execution_id}")
execution.status = "failed" execution.status = "failed"

View file

View file

@ -0,0 +1,91 @@
from unittest.mock import patch
from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
from apps.orgs.models import Organization, Role
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
User = get_user_model()
class MLStoreAPITests(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
self.other = User.objects.create_user(email_address='other@example.com', password='pass')
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
self.org = Organization.objects.create(name='Org', owner=self.manager)
self.role = Role.objects.create(name='Engineer', organization=self.org)
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
self.agent = Agent.objects.create(model=self.model, organization=self.org)
def test_agents_list_requires_auth(self):
view = AgentViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
response = view(request)
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
def test_agents_list_authenticated(self):
view = AgentViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
def test_agent_runs_scoped_to_user(self):
AgentRun.objects.create(agent=self.agent, user=self.user)
AgentRun.objects.create(agent=self.agent, user=self.other)
view = AgentRunViewSet.as_view({'get': 'list'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(len(response.data), 1)
def test_agent_run_events(self):
run = AgentRun.objects.create(agent=self.agent, user=self.user)
AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
view = AgentRunViewSet.as_view({'get': 'events'})
request = self.factory.get('/')
force_authenticate(request, user=self.user)
response = view(request, uuid=str(run.uuid))
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(len(response.data), 1)
def test_retrieve_context_missing_params(self):
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
request = self.factory.post('/', {})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
def test_retrieve_context_role_not_found(self):
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
request = self.factory.post('/', {'query': 'q', 'role_uuid': '00000000-0000-0000-0000-000000000000'})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_404_NOT_FOUND)
@patch('apps.mlstore.viewsets.services.search_similar_documents')
@patch('apps.mlstore.viewsets.services.get_context_for_query')
def test_retrieve_context_success(self, mock_context, mock_search):
doc = RoleRagDocument.objects.create(
role=self.role,
content='chunk',
content_hash='hash',
chunk_index=0,
)
mock_search.return_value = [(doc, 0.9)]
mock_context.return_value = 'context text'
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
payload = {'query': 'hello', 'role_uuid': str(self.role.uuid)}
request = self.factory.post('/', payload, format='json')
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(response.data.get('num_results'), 1)
self.assertEqual(response.data.get('context'), 'context text')

View file

@ -0,0 +1,41 @@
from django.test import TestCase
from django.contrib.auth import get_user_model
from apps.orgs.models import Organization, Role
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
User = get_user_model()
class MLStoreModelTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
self.org = Organization.objects.create(name='Org', owner=self.manager)
self.role = Role.objects.create(name='Engineer', organization=self.org)
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
self.agent = Agent.objects.create(model=self.model, organization=self.org)
def test_agent_model_str(self):
self.assertEqual(str(self.model), 'test-model')
def test_agent_str(self):
self.assertIn(self.model.name, str(self.agent))
def test_agent_run_str(self):
run = AgentRun.objects.create(agent=self.agent, user=self.user)
self.assertIn(str(run.uuid), str(run))
self.assertIn(str(self.agent), str(run))
def test_agent_event_str(self):
run = AgentRun.objects.create(agent=self.agent, user=self.user)
evt = AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
self.assertIn('message', str(evt))
def test_role_rag_document_str(self):
doc = RoleRagDocument.objects.create(
role=self.role,
content='chunk',
content_hash='hash',
chunk_index=0,
)
self.assertIn(self.role.name, str(doc))

View file

@ -1,9 +1,13 @@
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from .models import Agent, AgentRun, AgentEvent
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework import status
from .models import Agent, AgentRun, AgentEvent
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
from . import services
from apps.orgs.models import Role
class AgentViewSet(ModelViewSet): class AgentViewSet(ModelViewSet):
queryset = Agent.objects.all() queryset = Agent.objects.all()
@ -27,3 +31,109 @@ class AgentRunViewSet(ModelViewSet):
events = AgentEvent.objects.filter(execution=run).order_by('timestamp') events = AgentEvent.objects.filter(execution=run).order_by('timestamp')
serializer = AgentEventSerializer(events, many=True) serializer = AgentEventSerializer(events, many=True)
return Response(serializer.data) return Response(serializer.data)
@action(detail=False, methods=['post'], url_path='retrieve-context')
def retrieve_context(self, request):
"""Retrieve context documents from RAG using semantic similarity.
Request body:
{
"query": "search query text",
"role_uuid": "role-uuid",
"top_k": 5, # optional, default 5
"similarity_threshold": 0.5 # optional, default 0.5
}
Returns:
{
"query": "search query text",
"context": "formatted context string with sources",
"documents": [
{
"id": 123,
"content": "chunk text",
"similarity": 0.87,
"source": "filename.pdf",
"chunk_index": 0
},
...
]
}
"""
query = request.data.get('query', '').strip()
role_uuid = request.data.get('role_uuid', '').strip()
top_k = request.data.get('top_k', 5)
similarity_threshold = request.data.get('similarity_threshold', 0.5)
if not query:
return Response(
{"error": "query is required"},
status=status.HTTP_400_BAD_REQUEST
)
if not role_uuid:
return Response(
{"error": "role_uuid is required"},
status=status.HTTP_400_BAD_REQUEST
)
try:
# Validate role exists and user has access
role = Role.objects.get(uuid=role_uuid)
# You can add additional permission checks here if needed
# Search for similar documents
results = services.search_similar_documents(
query=query,
role_uuid=role_uuid,
top_k=top_k,
similarity_threshold=similarity_threshold
)
# Format response
documents = []
for doc, similarity in results:
documents.append({
"id": doc.id,
"uuid": str(doc.uuid),
"content": doc.content,
"similarity": float(similarity),
"source": doc.training_file.file_name if doc.training_file else "unknown",
"chunk_index": doc.chunk_index,
"metadata": doc.metadata,
})
# Get formatted context string
context = services.get_context_for_query(
query=query,
role_uuid=role_uuid,
top_k=top_k,
similarity_threshold=similarity_threshold
)
return Response({
"query": query,
"role_uuid": role_uuid,
"num_results": len(documents),
"context": context,
"documents": documents,
})
except Role.DoesNotExist:
return Response(
{"error": f"Role with UUID {role_uuid} not found"},
status=status.HTTP_404_NOT_FOUND
)
except ValueError as e:
return Response(
{"error": str(e)},
status=status.HTTP_400_BAD_REQUEST
)
except Exception as e:
import logging
logging.exception("Error retrieving context")
return Response(
{"error": "Failed to retrieve context", "detail": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View file

47
apps/onboarding/admin.py Normal file
View file

@ -0,0 +1,47 @@
from django.contrib import admin
from django.contrib.admin import ModelAdmin, TabularInline
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
class OnboardingPageInline(TabularInline):
model = OnboardingPage
extra = 0
class OnboardingFieldInline(TabularInline):
model = OnboardingField
extra = 0
@admin.register(OnboardingFlow)
class OnboardingFlowAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'title', 'role', 'status')
search_fields = ('title', 'role__name')
list_filter = ('status',)
inlines = (OnboardingPageInline,)
readonly_fields = ('uuid',)
@admin.register(OnboardingPage)
class OnboardingPageAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'title', 'flow', 'order')
search_fields = ('title', 'flow__title')
list_filter = ('flow',)
inlines = (OnboardingFieldInline,)
readonly_fields = ('uuid',)
@admin.register(OnboardingField)
class OnboardingFieldAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'label', 'page', 'field_type', 'required')
search_fields = ('label', 'page__title')
list_filter = ('field_type',)
readonly_fields = ('uuid',)
@admin.register(OnboardingSession)
class OnboardingSessionAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'flow', 'user', 'status', 'current_page_order')
search_fields = ('flow__title', 'user__email_address')
list_filter = ('status',)
readonly_fields = ('uuid',)

6
apps/onboarding/apps.py Normal file
View file

@ -0,0 +1,6 @@
from django.apps import AppConfig
class OnboardingConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.onboarding'

View file

@ -0,0 +1,100 @@
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
initial = True
dependencies = [
('orgs', '0001_initial'),
('mlstore', '0001_initial'),
('users', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='OnboardingFlow',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('title', models.CharField(max_length=255)),
('description', models.TextField(blank=True, default='')),
('status', models.CharField(choices=[('draft', 'Draft'), ('published', 'Published'), ('archived', 'Archived')], default='draft', max_length=20)),
('agent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_flows', to='mlstore.agent')),
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_flows', to='orgs.role')),
],
options={
'verbose_name': 'Onboarding Flow',
'verbose_name_plural': 'Onboarding Flows',
'ordering': ['-created_at'],
},
),
migrations.CreateModel(
name='OnboardingPage',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('order', models.IntegerField(default=0)),
('title', models.CharField(max_length=255)),
('body', models.TextField(blank=True, default='')),
('meta', models.JSONField(blank=True, default=dict)),
('flow', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='pages', to='onboarding.onboardingflow')),
],
options={
'verbose_name': 'Onboarding Page',
'verbose_name_plural': 'Onboarding Pages',
'ordering': ['order', 'created_at'],
},
),
migrations.CreateModel(
name='OnboardingField',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('order', models.IntegerField(default=0)),
('key', models.CharField(max_length=120)),
('label', models.CharField(max_length=255)),
('field_type', models.CharField(choices=[('text', 'Text'), ('textarea', 'Textarea'), ('select', 'Select'), ('multiselect', 'Multi Select'), ('number', 'Number'), ('boolean', 'Boolean'), ('date', 'Date')], default='text', max_length=30)),
('required', models.BooleanField(default=False)),
('help_text', models.TextField(blank=True, default='')),
('placeholder', models.CharField(blank=True, default='', max_length=255)),
('options', models.JSONField(blank=True, default=list)),
('default_value', models.JSONField(blank=True, null=True, default=None)),
('validation', models.JSONField(blank=True, default=dict)),
('page', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='fields', to='onboarding.onboardingpage')),
],
options={
'verbose_name': 'Onboarding Field',
'verbose_name_plural': 'Onboarding Fields',
'ordering': ['order', 'created_at'],
'unique_together': {('page', 'key')},
},
),
migrations.CreateModel(
name='OnboardingSession',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
('status', models.CharField(choices=[('in_progress', 'In Progress'), ('completed', 'Completed'), ('abandoned', 'Abandoned')], default='in_progress', max_length=20)),
('current_page_order', models.IntegerField(default=0)),
('responses', models.JSONField(blank=True, default=dict)),
('completed_at', models.DateTimeField(blank=True, null=True)),
('agent_run', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_sessions', to='mlstore.agentrun')),
('flow', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sessions', to='onboarding.onboardingflow')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_sessions', to='users.user')),
],
options={
'verbose_name': 'Onboarding Session',
'verbose_name_plural': 'Onboarding Sessions',
'ordering': ['-created_at'],
},
),
]

View file

121
apps/onboarding/models.py Normal file
View file

@ -0,0 +1,121 @@
from uuid import uuid4
from django.db.models import (
BigAutoField,
BooleanField,
CASCADE,
CharField,
DateTimeField,
ForeignKey,
IntegerField,
JSONField,
Model,
TextField,
UUIDField,
)
from apps.users.mixins import TimeStampMixin
from apps.users.models import User
from apps.orgs.models import Role
from apps.mlstore.models import Agent, AgentRun
class OnboardingFlow(TimeStampMixin, Model):
STATUS_CHOICES = [
('draft', 'Draft'),
('published', 'Published'),
('archived', 'Archived'),
]
id = BigAutoField(primary_key=True)
uuid = UUIDField(default=uuid4, editable=False, unique=True)
role = ForeignKey(Role, on_delete=CASCADE, related_name='onboarding_flows')
agent = ForeignKey(Agent, on_delete=CASCADE, related_name='onboarding_flows', null=True, blank=True)
title = CharField(max_length=255)
description = TextField(blank=True, default='')
status = CharField(max_length=20, choices=STATUS_CHOICES, default='draft')
class Meta:
verbose_name = 'Onboarding Flow'
verbose_name_plural = 'Onboarding Flows'
ordering = ['-created_at']
def __str__(self) -> str:
return f'{self.title} ({self.role.name})'
class OnboardingPage(TimeStampMixin, Model):
id = BigAutoField(primary_key=True)
uuid = UUIDField(default=uuid4, editable=False, unique=True)
flow = ForeignKey(OnboardingFlow, on_delete=CASCADE, related_name='pages')
order = IntegerField(default=0)
title = CharField(max_length=255)
body = TextField(blank=True, default='')
meta = JSONField(default=dict, blank=True)
class Meta:
verbose_name = 'Onboarding Page'
verbose_name_plural = 'Onboarding Pages'
ordering = ['order', 'created_at']
def __str__(self) -> str:
return f'{self.flow.title} - {self.title}'
class OnboardingField(TimeStampMixin, Model):
FIELD_TYPES = [
('text', 'Text'),
('textarea', 'Textarea'),
('select', 'Select'),
('multiselect', 'Multi Select'),
('number', 'Number'),
('boolean', 'Boolean'),
('date', 'Date'),
]
id = BigAutoField(primary_key=True)
uuid = UUIDField(default=uuid4, editable=False, unique=True)
page = ForeignKey(OnboardingPage, on_delete=CASCADE, related_name='fields')
order = IntegerField(default=0)
key = CharField(max_length=120)
label = CharField(max_length=255)
field_type = CharField(max_length=30, choices=FIELD_TYPES, default='text')
required = BooleanField(default=False)
help_text = TextField(blank=True, default='')
placeholder = CharField(max_length=255, blank=True, default='')
options = JSONField(default=list, blank=True)
default_value = JSONField(null=True, blank=True, default=None)
validation = JSONField(default=dict, blank=True)
class Meta:
verbose_name = 'Onboarding Field'
verbose_name_plural = 'Onboarding Fields'
ordering = ['order', 'created_at']
unique_together = [['page', 'key']]
def __str__(self) -> str:
return f'{self.page.title} - {self.label}'
class OnboardingSession(TimeStampMixin, Model):
STATUS_CHOICES = [
('in_progress', 'In Progress'),
('completed', 'Completed'),
('abandoned', 'Abandoned'),
]
id = BigAutoField(primary_key=True)
uuid = UUIDField(default=uuid4, editable=False, unique=True)
flow = ForeignKey(OnboardingFlow, on_delete=CASCADE, related_name='sessions')
user = ForeignKey(User, on_delete=CASCADE, related_name='onboarding_sessions')
agent_run = ForeignKey(AgentRun, on_delete=CASCADE, related_name='onboarding_sessions', null=True, blank=True)
status = CharField(max_length=20, choices=STATUS_CHOICES, default='in_progress')
current_page_order = IntegerField(default=0)
responses = JSONField(default=dict, blank=True)
completed_at = DateTimeField(null=True, blank=True)
class Meta:
verbose_name = 'Onboarding Session'
verbose_name_plural = 'Onboarding Sessions'
ordering = ['-created_at']
def __str__(self) -> str:
return f'{self.user.email_address} - {self.flow.title}'

View file

@ -0,0 +1,105 @@
from rest_framework import serializers
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
from apps.orgs.models import Role
from apps.mlstore.models import Agent
class OnboardingFieldSerializer(serializers.ModelSerializer):
page = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingPage.objects.all())
class Meta:
model = OnboardingField
fields = [
'id',
'uuid',
'page',
'order',
'key',
'label',
'field_type',
'required',
'help_text',
'placeholder',
'options',
'default_value',
'validation',
]
read_only_fields = ['id', 'uuid']
class OnboardingPageSerializer(serializers.ModelSerializer):
fields = OnboardingFieldSerializer(many=True, read_only=True)
flow = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingFlow.objects.all())
class Meta:
model = OnboardingPage
fields = [
'id',
'uuid',
'flow',
'order',
'title',
'body',
'meta',
'fields',
]
read_only_fields = ['id', 'uuid']
class OnboardingFlowSerializer(serializers.ModelSerializer):
role = serializers.SlugRelatedField(slug_field='uuid', queryset=Role.objects.all())
agent = serializers.SlugRelatedField(slug_field='uuid', queryset=Agent.objects.all(), allow_null=True, required=False)
class Meta:
model = OnboardingFlow
fields = [
'id',
'uuid',
'role',
'agent',
'title',
'description',
'status',
'created_at',
'updated_at',
]
read_only_fields = ['id', 'uuid', 'created_at', 'updated_at']
class OnboardingFlowDetailSerializer(OnboardingFlowSerializer):
pages = OnboardingPageSerializer(many=True, read_only=True)
class Meta(OnboardingFlowSerializer.Meta):
fields = OnboardingFlowSerializer.Meta.fields + ['pages']
class OnboardingSessionSerializer(serializers.ModelSerializer):
flow = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingFlow.objects.all())
user = serializers.SlugRelatedField(slug_field='uuid', read_only=True)
agent_run = serializers.SlugRelatedField(slug_field='uuid', read_only=True)
class Meta:
model = OnboardingSession
fields = [
'id',
'uuid',
'flow',
'user',
'agent_run',
'status',
'current_page_order',
'responses',
'created_at',
'updated_at',
'completed_at',
]
read_only_fields = ['id', 'uuid', 'user', 'agent_run', 'created_at', 'updated_at', 'completed_at']
class OnboardingSubmissionSerializer(serializers.Serializer):
page_uuid = serializers.CharField()
responses = serializers.DictField()
mark_complete = serializers.BooleanField(required=False, default=False)
class OnboardingFeedbackSerializer(serializers.Serializer):
page_uuid = serializers.CharField()
responses = serializers.DictField()
question = serializers.CharField(required=False, allow_blank=True, default='')

View file

View file

@ -0,0 +1,124 @@
from django.contrib.auth import get_user_model
from django.test import TestCase
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_403_FORBIDDEN
from apps.orgs.models import Organization, Role
from apps.mlstore.models import AgentModel, Agent
from apps.onboarding.models import OnboardingFlow, OnboardingPage, OnboardingSession
from apps.onboarding.viewsets import OnboardingFlowViewSet, OnboardingSessionViewSet
User = get_user_model()
class OnboardingAPITests(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
self.org = Organization.objects.create(name='Org', owner=self.manager)
self.role = Role.objects.create(name='Engineer', organization=self.org)
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
self.agent = Agent.objects.create(model=self.model, organization=self.org)
def test_create_flow(self):
view = OnboardingFlowViewSet.as_view({'post': 'create'})
data = {
'role': str(self.role.uuid),
'agent': str(self.agent.uuid),
'title': 'Flow',
'description': 'Desc',
'status': 'draft',
}
request = self.factory.post('/', data)
force_authenticate(request, user=self.manager)
response = view(request)
self.assertIn(response.status_code, (HTTP_200_OK, HTTP_201_CREATED))
self.assertTrue(OnboardingFlow.objects.filter(title='Flow').exists())
def test_pages_action(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
view = OnboardingFlowViewSet.as_view({'get': 'pages'})
request = self.factory.get('/')
force_authenticate(request, user=self.manager)
response = view(request, uuid=str(flow.uuid))
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(len(response.data.get('pages', [])), 1)
def test_create_session(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
view = OnboardingSessionViewSet.as_view({'post': 'create'})
request = self.factory.post('/', {'flow': str(flow.uuid)})
force_authenticate(request, user=self.user)
response = view(request)
self.assertIn(response.status_code, (HTTP_200_OK, HTTP_201_CREATED))
self.assertTrue(OnboardingSession.objects.filter(flow=flow, user=self.user).exists())
def test_submit_updates_session(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
session = OnboardingSession.objects.create(flow=flow, user=self.user)
view = OnboardingSessionViewSet.as_view({'post': 'submit'})
payload = {'page_uuid': str(page.uuid), 'responses': {'q1': 'a'}, 'mark_complete': True}
request = self.factory.post('/', payload, format='json')
force_authenticate(request, user=self.user)
response = view(request, uuid=str(session.uuid))
self.assertEqual(response.status_code, HTTP_200_OK)
session.refresh_from_db()
self.assertEqual(session.status, 'completed')
self.assertIn(str(page.uuid), session.responses)
def test_publish_flow_as_manager(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
self.assertEqual(flow.status, 'draft')
view = OnboardingFlowViewSet.as_view({'post': 'publish'})
request = self.factory.post('/')
force_authenticate(request, user=self.manager)
response = view(request, uuid=str(flow.uuid))
self.assertEqual(response.status_code, HTTP_200_OK)
flow.refresh_from_db()
self.assertEqual(flow.status, 'published')
def test_publish_flow_requires_manager(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
view = OnboardingFlowViewSet.as_view({'post': 'publish'})
request = self.factory.post('/')
force_authenticate(request, user=self.user)
response = view(request, uuid=str(flow.uuid))
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
def test_get_or_create_session_creates_when_missing(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
request = self.factory.post('/', {'flow': str(flow.uuid)})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertTrue(OnboardingSession.objects.filter(flow=flow, user=self.user).exists())
def test_get_or_create_session_reuses_active(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
existing = OnboardingSession.objects.create(flow=flow, user=self.user, current_page_order=1)
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
request = self.factory.post('/', {'flow': str(flow.uuid)})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertEqual(response.data.get('uuid'), str(existing.uuid))
self.assertEqual(response.data.get('current_page_order'), 1)
def test_get_or_create_session_creates_after_completion(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
completed = OnboardingSession.objects.create(flow=flow, user=self.user, status='completed')
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
request = self.factory.post('/', {'flow': str(flow.uuid)})
force_authenticate(request, user=self.user)
response = view(request)
self.assertEqual(response.status_code, HTTP_200_OK)
self.assertNotEqual(response.data.get('uuid'), str(completed.uuid))

View file

@ -0,0 +1,41 @@
from django.test import TestCase
from django.contrib.auth import get_user_model
from apps.orgs.models import Organization, Role
from apps.mlstore.models import AgentModel, Agent
from apps.onboarding.models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
User = get_user_model()
class OnboardingModelTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
self.org = Organization.objects.create(name='Org', owner=self.manager)
self.role = Role.objects.create(name='Engineer', organization=self.org)
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
self.agent = Agent.objects.create(model=self.model, organization=self.org)
def test_flow_str(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Welcome', description='Intro')
self.assertIn('Welcome', str(flow))
self.assertIn(self.role.name, str(flow))
def test_page_and_field_str(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
field = OnboardingField.objects.create(page=page, order=0, key='q1', label='Question 1')
self.assertIn(flow.title, str(page))
self.assertIn(field.label, str(field))
def test_field_unique_key_per_page(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
OnboardingField.objects.create(page=page, order=0, key='dup', label='Dup 1')
with self.assertRaises(Exception):
OnboardingField.objects.create(page=page, order=1, key='dup', label='Dup 2')
def test_session_str(self):
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
session = OnboardingSession.objects.create(flow=flow, user=self.user)
self.assertIn(self.user.email_address, str(session))
self.assertIn(flow.title, str(session))

451
apps/onboarding/viewsets.py Normal file
View file

@ -0,0 +1,451 @@
import json
import logging
import re
import html
from typing import Any
from django.db import transaction
from django.utils import timezone
from rest_framework import status
from rest_framework.exceptions import PermissionDenied
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from apps.mlstore.models import AgentEvent, AgentRun
from apps.mlstore import services as ml_services
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
from .serializers import (
OnboardingFlowSerializer,
OnboardingFlowDetailSerializer,
OnboardingPageSerializer,
OnboardingFieldSerializer,
OnboardingSessionSerializer,
OnboardingSubmissionSerializer,
OnboardingFeedbackSerializer,
)
logger = logging.getLogger(__name__)
def _extract_json(text: str) -> dict[str, Any]:
if not text:
return {}
try:
return json.loads(text)
except Exception:
pass
# Prefer fenced json blocks
fenced = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text, re.IGNORECASE)
if fenced:
try:
return json.loads(fenced.group(1))
except Exception:
return {}
# Fallback: find first balanced JSON object
start = text.find('{')
if start == -1:
return {}
depth = 0
for idx in range(start, len(text)):
char = text[idx]
if char == '{':
depth += 1
elif char == '}':
depth -= 1
if depth == 0:
candidate = text[start:idx + 1]
try:
return json.loads(candidate)
except Exception:
return {}
return {}
def _strip_html(text: str) -> str:
if not text:
return ""
cleaned = re.sub(r"<[^>]+>", " ", text)
cleaned = html.unescape(cleaned)
return re.sub(r"\s+", " ", cleaned).strip()
def _send_agent_progress_event(agent_run: AgentRun, content: dict):
try:
AgentEvent.objects.create(
execution=agent_run,
event_type='progress',
content=content,
)
room_group_name = f"mlstore_agent_{agent_run.agent.uuid}"
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_event",
"event_type": "progress",
"content": content,
"timestamp": timezone.now().isoformat(),
},
)
except Exception as e:
logger.warning("Failed to send progress event: %s", e)
class OnboardingFlowViewSet(ModelViewSet):
queryset = OnboardingFlow.objects.select_related('role', 'agent').all()
serializer_class = OnboardingFlowSerializer
lookup_field = 'uuid'
def get_queryset(self):
qs = super().get_queryset()
role_uuid = self.request.query_params.get('role')
status_filter = self.request.query_params.get('status')
if role_uuid:
qs = qs.filter(role__uuid=role_uuid)
if status_filter:
qs = qs.filter(status=status_filter)
return qs
def get_serializer_class(self):
if self.action in ('retrieve', 'pages'):
return OnboardingFlowDetailSerializer
return super().get_serializer_class()
@action(detail=True, methods=['get'])
def pages(self, request, pk=None, uuid=None):
flow = self.get_object()
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
return Response(serializer.data)
@action(detail=True, methods=['post'])
def generate(self, request, pk=None, uuid=None):
flow = self.get_object()
if not request.user.is_authenticated or not getattr(request.user, 'is_manager', False):
return Response({"error": "permission_denied"}, status=status.HTTP_403_FORBIDDEN)
if not flow.agent or not flow.agent.model or not flow.agent.model.path:
return Response(
{"error": "flow_agent_model_required"},
status=status.HTTP_400_BAD_REQUEST,
)
instructions = request.data.get('instructions') or ''
rag_context = ""
try:
rag_context = ml_services.get_context_for_query(
query=f"Create onboarding content for role {flow.role.name}",
role_uuid=str(flow.role.uuid),
top_k=6,
similarity_threshold=0.35,
)
except Exception as e:
logger.warning("Onboarding generation RAG lookup failed: %s", e)
prompt = (
"You are creating onboarding content as JSON. "
"Return ONLY valid JSON (no prose, no markdown, no code fences).\n"
"Do not include explanations or examples.\n"
"Do not include HTML tags. Use plain text only.\n"
"Each page body must be 3-6 paragraphs, at least 320 words total, and include 1 short list of 3-5 bullets.\n"
"Before writing the body, create a brief outline of the key points to cover and include it in meta.outline.\n"
"The outline should be a short list of 3-6 bullets, not chain-of-thought.\n"
"Do NOT ask about the learner's personal experience. Ask about what someone in the role may encounter.\n"
"Do NOT use any select or multiselect fields. Use only text, textarea, number, boolean, or date.\n"
"Use the provided context for accurate, role-specific content.\n"
"If context is insufficient, make reasonable assumptions without inventing tools or policies.\n"
"JSON shape:\n"
"{\n"
" \"title\": string,\n"
" \"description\": string,\n"
" \"pages\": [\n"
" {\n"
" \"title\": string,\n"
" \"body\": string,\n"
" \"meta\": { \"outline\": [string] },\n"
" \"fields\": [\n"
" {\n"
" \"key\": string,\n"
" \"label\": string,\n"
" \"type\": one of [text, textarea, number, boolean, date],\n"
" \"required\": boolean,\n"
" \"help_text\": string,\n"
" \"placeholder\": string,\n"
" \"options\": []\n"
" }\n"
" ]\n"
" }\n"
" ]\n"
"}\n"
f"Role: {flow.role.name}\n"
f"Role description: {flow.role.description}\n"
f"Flow title: {flow.title}\n"
f"Flow description: {flow.description}\n"
f"Extra instructions: {instructions}\n"
f"Context:\n{rag_context}\n"
)
try:
result = ml_services.infer_with_model(flow.agent.model.path, prompt, {
"max_tokens": 1800,
"temperature": 0.2,
})
except Exception as e:
logger.error("Onboarding generate inference failed: %s", e, exc_info=True)
return Response({"error": "generation_failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
response_text = ''
if isinstance(result, dict):
response_text = result.get('response') or result.get('result') or ''
payload = _extract_json(str(response_text))
if not payload or 'pages' not in payload:
return Response({"error": "invalid_generation_output", "raw": response_text}, status=status.HTTP_400_BAD_REQUEST)
with transaction.atomic():
flow.title = payload.get('title') or flow.title
# Keep existing description on regenerate unless explicitly empty
if not flow.description:
flow.description = payload.get('description') or flow.description
if flow.status != 'draft':
flow.status = 'draft'
flow.save(update_fields=['title', 'description', 'status'])
OnboardingPage.objects.filter(flow=flow).delete()
pages = payload.get('pages') or []
for page_index, page in enumerate(pages):
body_text = _strip_html(page.get('body') or '')
page_obj = OnboardingPage.objects.create(
flow=flow,
order=page_index,
title=page.get('title') or f"Page {page_index + 1}",
body=body_text,
meta=page.get('meta') or {},
)
for field_index, field in enumerate(page.get('fields') or []):
field_type = field.get('type') or 'text'
if field_type not in {"text", "textarea", "number", "boolean", "date"}:
field_type = 'text'
OnboardingField.objects.create(
page=page_obj,
order=field_index,
key=field.get('key') or f"field_{field_index + 1}",
label=field.get('label') or f"Field {field_index + 1}",
field_type=field_type,
required=bool(field.get('required')),
help_text=field.get('help_text') or '',
placeholder=field.get('placeholder') or '',
options=[],
default_value=field.get('default_value') if field.get('default_value') is not None else None,
validation=field.get('validation') or {},
)
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
return Response(serializer.data)
@action(detail=True, methods=['post'])
def publish(self, request, pk=None, uuid=None):
flow = self.get_object()
if not request.user.is_authenticated or not getattr(request.user, 'is_manager', False):
return Response({"error": "permission_denied"}, status=status.HTTP_403_FORBIDDEN)
if flow.status != 'published':
flow.status = 'published'
flow.save(update_fields=['status'])
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
return Response(serializer.data)
class OnboardingPageViewSet(ModelViewSet):
queryset = OnboardingPage.objects.select_related('flow').prefetch_related('fields').all()
serializer_class = OnboardingPageSerializer
lookup_field = 'uuid'
class OnboardingFieldViewSet(ModelViewSet):
queryset = OnboardingField.objects.select_related('page').all()
serializer_class = OnboardingFieldSerializer
lookup_field = 'uuid'
class OnboardingSessionViewSet(ModelViewSet):
queryset = OnboardingSession.objects.select_related('flow', 'user', 'agent_run', 'flow__agent').all()
serializer_class = OnboardingSessionSerializer
lookup_field = 'uuid'
def get_queryset(self):
qs = super().get_queryset()
user = self.request.user
if user.is_authenticated and not getattr(user, 'is_manager', False):
qs = qs.filter(user=user)
return qs
def perform_create(self, serializer):
if not self.request.user or not self.request.user.is_authenticated:
raise PermissionDenied("Authentication required")
flow = serializer.validated_data.get('flow')
agent_run = None
if flow and flow.agent:
agent_run = AgentRun.objects.create(
agent=flow.agent,
user=self.request.user,
input_data={
"type": "onboarding_session",
"flow_uuid": str(flow.uuid),
"role_uuid": str(flow.role.uuid),
},
)
serializer.save(user=self.request.user, agent_run=agent_run)
@action(detail=False, methods=['post'])
def get_or_create(self, request):
if not request.user or not request.user.is_authenticated:
raise PermissionDenied("Authentication required")
flow_uuid = request.data.get('flow')
if not flow_uuid:
return Response({"error": "flow_required"}, status=status.HTTP_400_BAD_REQUEST)
try:
flow = OnboardingFlow.objects.get(uuid=flow_uuid)
except OnboardingFlow.DoesNotExist:
return Response({"error": "flow_not_found"}, status=status.HTTP_404_NOT_FOUND)
session = (
OnboardingSession.objects
.filter(flow=flow, user=request.user)
.exclude(status='completed')
.order_by('-updated_at')
.first()
)
if not session:
agent_run = None
if flow.agent:
agent_run = AgentRun.objects.create(
agent=flow.agent,
user=request.user,
input_data={
"type": "onboarding_session",
"flow_uuid": str(flow.uuid),
"role_uuid": str(flow.role.uuid),
},
)
session = OnboardingSession.objects.create(
flow=flow,
user=request.user,
agent_run=agent_run,
)
return Response(OnboardingSessionSerializer(session, context={'request': request}).data)
@action(detail=True, methods=['post'])
def submit(self, request, pk=None, uuid=None):
session = self.get_object()
serializer = OnboardingSubmissionSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
page_uuid = serializer.validated_data['page_uuid']
responses = serializer.validated_data['responses']
mark_complete = serializer.validated_data.get('mark_complete')
try:
page = OnboardingPage.objects.get(flow=session.flow, uuid=page_uuid)
except OnboardingPage.DoesNotExist:
return Response({"error": "page_not_found"}, status=status.HTTP_404_NOT_FOUND)
responses_payload = dict(session.responses or {})
responses_payload[str(page.uuid)] = responses
session.responses = responses_payload
session.current_page_order = page.order
if mark_complete or page.order >= session.flow.pages.count() - 1:
session.status = 'completed'
session.completed_at = timezone.now()
session.save(update_fields=['responses', 'current_page_order', 'status', 'completed_at'])
if session.agent_run:
progress_payload = {
"flow_uuid": str(session.flow.uuid),
"session_uuid": str(session.uuid),
"page_uuid": str(page.uuid),
"page_order": page.order,
"status": session.status,
"responses": responses,
}
_send_agent_progress_event(session.agent_run, progress_payload)
session.agent_run.output_data = {
**(session.agent_run.output_data or {}),
"onboarding": session.responses,
}
session.agent_run.save(update_fields=['output_data'])
return Response(OnboardingSessionSerializer(session, context={'request': request}).data)
@action(detail=True, methods=['post'])
def feedback(self, request, pk=None, uuid=None):
session = self.get_object()
serializer = OnboardingFeedbackSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
page_uuid = serializer.validated_data['page_uuid']
responses = serializer.validated_data['responses']
question = serializer.validated_data.get('question') or ''
try:
page = OnboardingPage.objects.get(flow=session.flow, uuid=page_uuid)
except OnboardingPage.DoesNotExist:
return Response({"error": "page_not_found"}, status=status.HTTP_404_NOT_FOUND)
if not session.flow.agent or not session.flow.agent.model or not session.flow.agent.model.path:
return Response({"error": "flow_agent_model_required"}, status=status.HTTP_400_BAD_REQUEST)
prompt = (
"You are an onboarding assessor. Provide concise feedback addressed directly to the learner using second-person \"You\" statements.\n"
"Return ONLY valid JSON (no prose, no markdown, no code fences).\n"
"JSON shape:\n"
"{\n"
" \"summary\": string\n"
"}\n\n"
f"Page title: {page.title}\n"
f"Page body: {page.body}\n"
f"Responses: {json.dumps(responses)}\n"
)
if question:
prompt += f"Learner question: {question}\n"
try:
result = ml_services.infer_with_model(session.flow.agent.model.path, prompt, {
"max_tokens": 900,
"temperature": 0.2,
})
except Exception as e:
logger.error("Onboarding feedback inference failed: %s", e, exc_info=True)
return Response({"error": "feedback_failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
feedback_text = ''
if isinstance(result, dict):
feedback_text = result.get('response') or result.get('result') or ''
feedback_text = str(feedback_text).strip()
feedback_payload = _extract_json(feedback_text)
if not feedback_payload:
feedback_payload = {
"summary": feedback_text or "Feedback generated.",
}
responses_payload = dict(session.responses or {})
feedback_store = dict(responses_payload.get("__feedback__") or {})
feedback_store[str(page.uuid)] = {
"feedback": feedback_payload,
"question": question,
"updated_at": timezone.now().isoformat(),
}
responses_payload["__feedback__"] = feedback_store
session.responses = responses_payload
session.save(update_fields=['responses'])
return Response({
"feedback": feedback_payload,
"session": OnboardingSessionSerializer(session, context={'request': request}).data,
})

View file

@ -54,8 +54,8 @@ class RoleMembershipAdmin(ModelAdmin):
@register(TrainingFile) @register(TrainingFile)
class TrainingFileAdmin(ModelAdmin): class TrainingFileAdmin(ModelAdmin):
list_display = ('id', 'uuid', 'file_name', 'organization', 'uploaded_by', 'is_processed', 'created_at') list_display = ('id', 'uuid', 'file_name', 'role', 'uploaded_by', 'status', 'is_processed', 'created_at')
search_fields = ('file_name', 'organization__name', 'uploaded_by__email_address') search_fields = ('file_name', 'role__name', 'uploaded_by__email_address')
list_filter = ('is_processed', 'created_at') list_filter = ('status', 'is_processed', 'created_at')
raw_id_fields = ('organization', 'uploaded_by') raw_id_fields = ('role', 'uploaded_by')
readonly_fields = ('uuid', 'created_at', 'updated_at') readonly_fields = ('uuid', 'created_at', 'updated_at')

View file

@ -1,11 +1,8 @@
# Generated by Django 5.2.10 on 2026-01-25 11:02
import django.db.models.deletion import django.db.models.deletion
import uuid import uuid
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
initial = True initial = True
@ -118,8 +115,9 @@ class Migration(migrations.Migration):
('file_size', models.IntegerField()), ('file_size', models.IntegerField()),
('file_type', models.CharField(max_length=50)), ('file_type', models.CharField(max_length=50)),
('description', models.TextField(blank=True, default='')), ('description', models.TextField(blank=True, default='')),
('status', models.CharField(choices=[('ingesting', 'Ingesting'), ('chunked', 'Chunked'), ('embedded', 'Embedded'), ('failed', 'Failed')], default='ingesting', max_length=20)),
('is_processed', models.BooleanField(default=False)), ('is_processed', models.BooleanField(default=False)),
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='orgs.organization')), ('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='orgs.role')),
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)), ('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
], ],
options={ options={

View file

@ -3,7 +3,8 @@ from uuid import uuid4
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, ManyToManyField, Model, TextField, UUIDField, IntegerField, FileField from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, ManyToManyField, Model, TextField, UUIDField, IntegerField, FileField
from django.db.models.signals import post_delete from django.db.models.signals import post_delete, post_save
from django.db import transaction
from django.dispatch import receiver from django.dispatch import receiver
from apps.users.mixins import TimeStampMixin from apps.users.mixins import TimeStampMixin
from apps.users.models import User from apps.users.models import User
@ -104,9 +105,16 @@ class TrainingFile(TimeStampMixin, Model):
ALLOWED_EXTENSIONS = ('txt', 'pdf', 'md', 'csv', 'json', 'docx', 'doc') ALLOWED_EXTENSIONS = ('txt', 'pdf', 'md', 'csv', 'json', 'docx', 'doc')
STATUS_CHOICES = [
('ingesting', 'Ingesting'),
('chunked', 'Chunked'),
('embedded', 'Embedded'),
('failed', 'Failed'),
]
id = BigAutoField(primary_key = True) id = BigAutoField(primary_key = True)
uuid = UUIDField(default = uuid4, unique = True, editable = False) uuid = UUIDField(default = uuid4, unique = True, editable = False)
organization = ForeignKey(Organization, on_delete = CASCADE, related_name = "training_files") role = ForeignKey(Role, on_delete = CASCADE, related_name = "training_files")
uploaded_by = ForeignKey(User, on_delete = CASCADE, related_name = "uploaded_training_files") uploaded_by = ForeignKey(User, on_delete = CASCADE, related_name = "uploaded_training_files")
file = FileField(upload_to = 'training_files/%Y/%m/%d/') file = FileField(upload_to = 'training_files/%Y/%m/%d/')
@ -115,6 +123,7 @@ class TrainingFile(TimeStampMixin, Model):
file_type = CharField(max_length = 50) file_type = CharField(max_length = 50)
description = TextField(blank = True, default = '') description = TextField(blank = True, default = '')
status = CharField(max_length = 20, choices = STATUS_CHOICES, default = 'ingesting')
is_processed = BooleanField(default = False) is_processed = BooleanField(default = False)
class Meta: class Meta:
@ -123,7 +132,7 @@ class TrainingFile(TimeStampMixin, Model):
ordering = ['-created_at'] ordering = ['-created_at']
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.file_name} - {self.organization.name}" return f"{self.file_name} - {self.role.name}"
@receiver(post_delete, sender=TrainingFile) @receiver(post_delete, sender=TrainingFile)
@ -135,3 +144,15 @@ def delete_training_file_on_delete(sender, instance, **kwargs):
os.remove(instance.file.path) os.remove(instance.file.path)
except Exception: except Exception:
pass pass
@receiver(post_save, sender=TrainingFile)
def enqueue_training_file_ingestion(sender, instance, created, **kwargs):
if not created:
return
def _enqueue():
from apps.mlstore.tasks import ingest_training_file_task
ingest_training_file_task.delay(str(instance.uuid))
transaction.on_commit(_enqueue)

View file

@ -1,4 +1,5 @@
from rest_framework.serializers import ModelSerializer, SerializerMethodField, IntegerField from rest_framework.serializers import ModelSerializer, SerializerMethodField, IntegerField, UUIDField
from rest_framework.exceptions import ValidationError
from apps.orgs.models import Organization, OrganizationMembership, OrganizationInvitation, Role, RoleMembership, TrainingFile from apps.orgs.models import Organization, OrganizationMembership, OrganizationInvitation, Role, RoleMembership, TrainingFile
from apps.users.serializers import UserSerializer from apps.users.serializers import UserSerializer
@ -75,11 +76,13 @@ class RoleSerializer(ModelSerializer):
class TrainingFileSerializer(ModelSerializer): class TrainingFileSerializer(ModelSerializer):
uploaded_by = UserSerializer(read_only = True) uploaded_by = UserSerializer(read_only = True)
file_url = SerializerMethodField() file_url = SerializerMethodField()
role = RoleSerializer(read_only = True)
role_uuid = UUIDField(write_only = True, required = True)
class Meta: class Meta:
model = TrainingFile model = TrainingFile
fields = ['id', 'uuid', 'organization', 'uploaded_by', 'file', 'file_name', 'file_size', 'file_type', 'description', 'is_processed', 'file_url', 'created_at', 'updated_at'] fields = ['id', 'uuid', 'role', 'role_uuid', 'uploaded_by', 'file', 'file_name', 'file_size', 'file_type', 'description', 'status', 'is_processed', 'file_url', 'created_at', 'updated_at']
read_only_fields = ['uuid', 'uploaded_by', 'file_size', 'file_type', 'is_processed', 'created_at', 'updated_at', 'organization'] read_only_fields = ['uuid', 'uploaded_by', 'file_size', 'file_type', 'status', 'is_processed', 'created_at', 'updated_at', 'role']
def get_file_url(self, obj): def get_file_url(self, obj):
request = self.context.get('request') request = self.context.get('request')
@ -88,7 +91,6 @@ class TrainingFileSerializer(ModelSerializer):
return None return None
def validate_file(self, value): def validate_file(self, value):
"""Validate that uploaded file is a text-based file."""
if not value: if not value:
raise ValueError('File is required') raise ValueError('File is required')
@ -108,10 +110,18 @@ class TrainingFileSerializer(ModelSerializer):
return value return value
def create(self, validated_data): def create(self, validated_data):
role_uuid = validated_data.pop('role_uuid', None)
file_obj = validated_data.get('file') file_obj = validated_data.get('file')
if file_obj: if file_obj:
validated_data['file_size'] = file_obj.size validated_data['file_size'] = file_obj.size
import os import os
file_extension = os.path.splitext(file_obj.name)[1][1:].lower() file_extension = os.path.splitext(file_obj.name)[1][1:].lower()
validated_data['file_type'] = file_extension validated_data['file_type'] = file_extension
if not role_uuid:
raise ValidationError({'role_uuid': 'Role is required'})
try:
role = Role.objects.get(uuid = role_uuid)
except Role.DoesNotExist:
raise ValidationError({'role_uuid': 'Role not found'})
validated_data['role'] = role
return super().create(validated_data) return super().create(validated_data)

View file

@ -166,6 +166,12 @@ class OrganizationViewSet(ModelViewSet):
serializer = RoleSerializer(role) serializer = RoleSerializer(role)
return Response(serializer.data) return Response(serializer.data)
@action(detail=False, methods=['get'], url_path='role/mine')
def my_roles(self, request):
roles = Role.objects.filter(memberships__user=request.user).distinct()
serializer = RoleSerializer(roles, many=True)
return Response(serializer.data)
@action(detail=True, methods=['post'], url_path='role/(?P<role_uuid>[0-9a-f-]{36})/delete') @action(detail=True, methods=['post'], url_path='role/(?P<role_uuid>[0-9a-f-]{36})/delete')
def delete_role(self, request, uuid = None, role_uuid = None): def delete_role(self, request, uuid = None, role_uuid = None):
if not request.user.is_manager: if not request.user.is_manager:
@ -196,8 +202,11 @@ class OrganizationViewSet(ModelViewSet):
organization = self.get_object() organization = self.get_object()
if request.method == 'GET': if request.method == 'GET':
training_files = TrainingFile.objects.filter(organization=organization) role_uuid = request.query_params.get('role_uuid')
serializer = TrainingFileSerializer(training_files, many=True) training_files = TrainingFile.objects.filter(role__organization=organization)
if role_uuid:
training_files = training_files.filter(role__uuid=role_uuid, role__organization=organization)
serializer = TrainingFileSerializer(training_files, many=True, context={'request': request})
return Response(serializer.data) return Response(serializer.data)
if not (organization.owner == request.user or if not (organization.owner == request.user or
@ -207,9 +216,17 @@ class OrganizationViewSet(ModelViewSet):
status=HTTP_403_FORBIDDEN status=HTTP_403_FORBIDDEN
) )
serializer = TrainingFileSerializer(data=request.data) role_uuid = request.data.get('role_uuid')
if not role_uuid:
return Response({'error': 'role_uuid is required'}, status=HTTP_400_BAD_REQUEST)
try:
Role.objects.get(uuid=role_uuid, organization=organization)
except Role.DoesNotExist:
return Response({'error': 'Role not found in this organization'}, status=HTTP_404_NOT_FOUND)
serializer = TrainingFileSerializer(data=request.data, context={'request': request})
if serializer.is_valid(): if serializer.is_valid():
serializer.save(uploaded_by=request.user, organization=organization) serializer.save(uploaded_by=request.user)
return Response(serializer.data, status=201) return Response(serializer.data, status=201)
return Response(serializer.errors, status=HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=HTTP_400_BAD_REQUEST)
@ -218,17 +235,17 @@ class OrganizationViewSet(ModelViewSet):
organization = self.get_object() organization = self.get_object()
try: try:
training_file = TrainingFile.objects.get(uuid=file_uuid, organization=organization) training_file = TrainingFile.objects.get(uuid=file_uuid, role__organization=organization)
except TrainingFile.DoesNotExist: except TrainingFile.DoesNotExist:
return Response({'error': 'Training file not found'}, status=HTTP_404_NOT_FOUND) return Response({'error': 'Training file not found'}, status=HTTP_404_NOT_FOUND)
if request.method == 'GET': if request.method == 'GET':
serializer = TrainingFileSerializer(training_file) serializer = TrainingFileSerializer(training_file, context={'request': request})
return Response(serializer.data) return Response(serializer.data)
if not (training_file.uploaded_by == request.user or if not (training_file.uploaded_by == request.user or
training_file.organization.owner == request.user or training_file.role.organization.owner == request.user or
request.user.is_manager): request.user.is_manager):
return Response( return Response(
{'error': 'You do not have permission to delete this file'}, {'error': 'You do not have permission to delete this file'},
status=HTTP_403_FORBIDDEN status=HTTP_403_FORBIDDEN

View file

@ -3,11 +3,16 @@ from rest_framework.routers import DefaultRouter
from apps.orgs.viewsets import OrganizationViewSet from apps.orgs.viewsets import OrganizationViewSet
from apps.users.viewsets import UserViewSet from apps.users.viewsets import UserViewSet
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
from apps.onboarding.viewsets import OnboardingFlowViewSet, OnboardingPageViewSet, OnboardingFieldViewSet, OnboardingSessionViewSet
router = DefaultRouter() router = DefaultRouter()
router.register(r'user', UserViewSet, basename='user') router.register(r'user', UserViewSet, basename='user')
router.register(r'organization', OrganizationViewSet, basename='organization') router.register(r'organization', OrganizationViewSet, basename='organization')
router.register(r'agent', AgentViewSet, basename='agent') router.register(r'agent', AgentViewSet, basename='agent')
router.register(r'agent-run', AgentRunViewSet, basename='agent-run') router.register(r'agent-run', AgentRunViewSet, basename='agent-run')
router.register(r'onboarding/flow', OnboardingFlowViewSet, basename='onboarding-flow')
router.register(r'onboarding/page', OnboardingPageViewSet, basename='onboarding-page')
router.register(r'onboarding/field', OnboardingFieldViewSet, basename='onboarding-field')
router.register(r'onboarding/session', OnboardingSessionViewSet, basename='onboarding-session')
urlpatterns = router.urls urlpatterns = router.urls

View file

@ -71,6 +71,7 @@ LOCAL_APPS = [
'apps.users', 'apps.users',
'apps.orgs', 'apps.orgs',
'apps.mlstore', 'apps.mlstore',
'apps.onboarding',
] ]
INSTALLED_APPS = OVERRIDE_APPS + DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS INSTALLED_APPS = OVERRIDE_APPS + DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS

View file

@ -11,34 +11,36 @@ from aiohttp import web
from mcp.server import Server from mcp.server import Server
from mcp.types import Tool, TextContent from mcp.types import Tool, TextContent
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stderr),
logging.StreamHandler(sys.stdout),
]
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
model_cache_dir = os.path.join(project_root, "model", "base-model") model_cache_dir = os.path.join(project_root, "model", "base-model")
os.makedirs(model_cache_dir, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir def _init_runtime():
logger.info(f"Project root: {project_root}") if not logging.getLogger().handlers:
logger.info(f"HuggingFace model cache directory set to: {model_cache_dir}") logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stderr),
logging.StreamHandler(sys.stdout),
]
)
os.makedirs(model_cache_dir, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
logger.info(f"Project root: {project_root}")
logger.info(f"HuggingFace model cache directory set to: {model_cache_dir}")
app = Server("mlstore-mcp-server") app = Server("mlstore-mcp-server")
logger.info("MCP Server initialized: mlstore-mcp-server")
LOADED_MODELS: Dict[str, Dict[str, Any]] = {} LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
PAIR_EXTRACTOR: Dict[str, Any] = {} PAIR_EXTRACTOR: Dict[str, Any] = {}
EMBEDDING_MODEL: Dict[str, Any] = {}
BASE_MODEL_CACHE_DIR = model_cache_dir BASE_MODEL_CACHE_DIR = model_cache_dir
@app.list_tools() @app.list_tools()
async def list_tools(): async def list_tools():
logger.info("Listing available tools") logger.info("Listing available tools")
@ -93,6 +95,18 @@ async def list_tools():
"required": ["model_path", "prompt"] "required": ["model_path", "prompt"]
}, },
), ),
Tool(
name="embed",
description="Generate embeddings for a list of texts",
inputSchema={
"type": "object",
"properties": {
"texts": {"type": "array", "items": {"type": "string"}},
"model": {"type": "string"}
},
"required": ["texts"]
},
),
] ]
logger.info(f"Available tools: {[t.name for t in tools]}") logger.info(f"Available tools: {[t.name for t in tools]}")
return tools return tools
@ -110,6 +124,112 @@ def _safe_dir_name(name: str) -> str:
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".") return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
def _map_gguf_repo(model_name: str) -> tuple[str | None, str | None]:
gguf_repos = {
"Llama-3.1-8B-Instruct.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"),
"Meta-Llama-3.1-8B-Instruct.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"),
"Meta-Llama-3.1-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"),
"Llama-3.1-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"),
"Meta-Llama-3-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3-8B-Instruct-GGUF", "Meta-Llama-3-8B-Instruct-Q4_0.gguf"),
"Llama-3-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3-8B-Instruct-GGUF", "Meta-Llama-3-8B-Instruct-Q4_0.gguf"),
"mistral-7b-instruct-v0.3.Q4_0.gguf": ("bartowski/Mistral-7B-Instruct-v0.3-GGUF", "Mistral-7B-Instruct-v0.3-Q4_0.gguf"),
"Mistral-7B-Instruct-v0.3.Q4_0.gguf": ("bartowski/Mistral-7B-Instruct-v0.3-GGUF", "Mistral-7B-Instruct-v0.3-Q4_0.gguf"),
}
if model_name in gguf_repos:
return gguf_repos[model_name]
base_name = model_name.lower()
if "llama" in base_name and "3.1" in base_name and "8b" in base_name:
repo_id = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
if ".q4_0" in base_name:
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"
if ".q4_k_m" in base_name:
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
if "llama" in base_name and "3" in base_name and "8b" in base_name:
repo_id = "bartowski/Meta-Llama-3-8B-Instruct-GGUF"
if ".q4_0" in base_name:
return repo_id, "Meta-Llama-3-8B-Instruct-Q4_0.gguf"
return repo_id, "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
if "mistral" in base_name and "7b" in base_name:
repo_id = "bartowski/Mistral-7B-Instruct-v0.3-GGUF"
if ".q4_0" in base_name:
return repo_id, "Mistral-7B-Instruct-v0.3-Q4_0.gguf"
return repo_id, "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
return None, None
def _find_existing_gguf(model_name: str) -> str | None:
repo_id, filename = _map_gguf_repo(model_name)
if not filename:
return None
candidate_paths = [
os.path.join(_model_root(), filename),
os.path.join(model_cache_dir, filename),
os.path.join(os.getcwd(), "model", filename),
os.path.join(os.getcwd(), filename),
]
for path in candidate_paths:
if os.path.exists(path):
return path
return None
def _download_gguf_from_hf(model_name: str) -> str:
"""
Download a GGUF model from Hugging Face Hub.
Returns the path to the downloaded model file.
"""
logger.info(f"Attempting to download GGUF model from Hugging Face: {model_name}")
existing_path = _find_existing_gguf(model_name)
if existing_path:
logger.info(f"Found existing GGUF locally: {existing_path}")
return existing_path
try:
from huggingface_hub import hf_hub_download, list_repo_files
model_dir = _model_root()
os.makedirs(model_dir, exist_ok=True)
repo_id, filename = _map_gguf_repo(model_name)
if repo_id and filename:
logger.info(f"Found known model mapping: {repo_id}/{filename}")
if not repo_id or not filename:
logger.error(f"Could not determine Hugging Face repo for model: {model_name}")
raise ValueError(f"Unknown model: {model_name}. Please specify a known GGUF model.")
logger.info(f"Downloading {filename} from {repo_id}...")
logger.info(f"This may take several minutes depending on model size and connection speed.")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=model_cache_dir,
local_dir=model_dir,
local_dir_use_symlinks=False,
)
logger.info(f"Model downloaded successfully to: {downloaded_path}")
return downloaded_path
except ImportError as e:
logger.error(f"huggingface_hub not available: {str(e)}")
raise ImportError("huggingface_hub is required to download models. Install with: pip install huggingface_hub")
except Exception as e:
logger.error(f"Failed to download model from Hugging Face: {str(e)}", exc_info=True)
raise
def _resolve_model_path(model_path: str) -> str: def _resolve_model_path(model_path: str) -> str:
if not model_path: if not model_path:
return model_path return model_path
@ -298,14 +418,16 @@ def _prompt_based_pair_extraction(training_data: List[Any], base_model: str) ->
system_prompt = ( system_prompt = (
"You are a data extractor. Given a list of items, return a JSON array of training pairs. " "You are a data extractor. Given a list of items, return a JSON array of training pairs. "
"Each pair must have 'instruction' and 'response'. Keep answers concise. " "Each pair must have 'instruction' and 'response'. Keep answers concise. "
"If content is incomplete, still produce best-effort pairs." "If content is incomplete, still produce best-effort pairs. "
"End your answer with a complete sentence. Do not start lists or new sections."
) )
user_prompt = ( user_prompt = (
"Examples of desired output:\n" "Examples of desired output:\n"
f"{json.dumps(example_pairs, ensure_ascii=False, indent=2)}\n\n" f"{json.dumps(example_pairs, ensure_ascii=False, indent=2)}\n\n"
"Now extract training pairs from the following items. Return ONLY a JSON array, no extra text.\n" "Now extract training pairs from the following items. Return ONLY a JSON array, no extra text.\n"
f"Items:\n{data_block}" f"Items:\n{data_block}\n\n"
"Answer:"
) )
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
@ -385,7 +507,7 @@ def _extract_training_pairs(training_data: List[Any]) -> List[Tuple[str, str]]:
if not training_data: if not training_data:
return [] return []
base_model = "meta-llama/Meta-Llama-3-8B-Instruct" base_model = "meta-llama/Llama-3.1-8B-Instruct"
pairs = _prompt_based_pair_extraction(training_data, base_model) pairs = _prompt_based_pair_extraction(training_data, base_model)
if not pairs: if not pairs:
@ -402,6 +524,19 @@ def _extract_training_pairs(training_data: List[Any]) -> List[Tuple[str, str]]:
return pairs return pairs
def _ensure_embedding_model(model_name: str):
if EMBEDDING_MODEL.get("model") is not None and EMBEDDING_MODEL.get("name") == model_name:
return EMBEDDING_MODEL["model"]
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {model_name}")
model = SentenceTransformer(model_name, cache_folder=model_cache_dir)
EMBEDDING_MODEL["model"] = model
EMBEDDING_MODEL["name"] = model_name
return model
async def _fine_tune_model_impl( async def _fine_tune_model_impl(
training_files: List[str], training_files: List[str],
hyperparams: Dict[str, Any], hyperparams: Dict[str, Any],
@ -409,7 +544,7 @@ async def _fine_tune_model_impl(
version: str, version: str,
output_dir: str output_dir: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
base_model = "meta-llama/Meta-Llama-3-8B-Instruct" base_model = "meta-llama/Llama-3.1-8B-Instruct"
logger.info(f"Starting fine-tune process with base model: {base_model}") logger.info(f"Starting fine-tune process with base model: {base_model}")
try: try:
@ -793,14 +928,47 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
logger.error("model_path_required error: no model path provided") logger.error("model_path_required error: no model path provided")
return {"status": "failed", "error": "model_path_required", "timestamp": _now()} return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
original_model_path = model_path
model_path = _resolve_model_path(model_path) model_path = _resolve_model_path(model_path)
logger.debug(f"Resolved model path: {model_path}") logger.debug(f"Resolved model path: {model_path}")
if not os.path.exists(model_path): if not os.path.exists(model_path):
logger.error(f"Model not found at: {model_path}") logger.warning(f"Model not found at: {model_path}")
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} local_mapped = _find_existing_gguf(original_model_path)
if local_mapped:
logger.info(f"Using existing mapped GGUF: {local_mapped}")
model_path = local_mapped
else:
logger.info(f"Attempting to download model from Hugging Face...")
try:
model_path = _download_gguf_from_hf(original_model_path)
logger.info(f"Model downloaded successfully: {model_path}")
except Exception as download_error:
logger.error(f"Failed to download model: {str(download_error)}")
return {
"status": "failed",
"error": "model_not_found_and_download_failed",
"model_path": original_model_path,
"download_error": str(download_error),
"timestamp": _now()
}
try: try:
for loaded_path in list(LOADED_MODELS.keys()):
if loaded_path != model_path:
logger.info(f"Unloading cached model: {loaded_path}")
LOADED_MODELS.pop(loaded_path, None)
if model_path in LOADED_MODELS and "model" in LOADED_MODELS[model_path]:
logger.info(f"Model already loaded: {model_path}")
return {
"status": "completed",
"model_path": model_path,
"loaded": True,
"cached": True,
"timestamp": _now(),
}
from gpt4all import GPT4All from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path) model_dir, model_file = _resolve_model_file(model_path)
@ -866,16 +1034,38 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
logger.error("model_path_required error: no model path provided") logger.error("model_path_required error: no model path provided")
return {"status": "failed", "error": "model_path_required", "timestamp": _now()} return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
original_model_path = model_path
model_path = _resolve_model_path(model_path) model_path = _resolve_model_path(model_path)
logger.debug(f"Resolved model path: {model_path}") logger.debug(f"Resolved model path: {model_path}")
if not os.path.exists(model_path): if not os.path.exists(model_path):
logger.error(f"Model not found at: {model_path}") logger.warning(f"Model not found at: {model_path}")
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} local_mapped = _find_existing_gguf(original_model_path)
if local_mapped:
logger.info(f"Using existing mapped GGUF: {local_mapped}")
model_path = local_mapped
else:
logger.info(f"Attempting to download model from Hugging Face...")
try:
model_path = _download_gguf_from_hf(original_model_path)
logger.info(f"Model downloaded successfully: {model_path}")
except Exception as download_error:
logger.error(f"Failed to download model: {str(download_error)}")
return {
"status": "failed",
"error": "model_not_found_and_download_failed",
"model_path": original_model_path,
"download_error": str(download_error),
"timestamp": _now()
}
try: try:
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]: if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
logger.info(f"Model not in memory, loading: {model_path}") logger.info(f"Model not in memory, loading: {model_path}")
for loaded_path in list(LOADED_MODELS.keys()):
if loaded_path != model_path:
logger.info(f"Unloading cached model: {loaded_path}")
LOADED_MODELS.pop(loaded_path, None)
from gpt4all import GPT4All from gpt4all import GPT4All
model_dir, model_file = _resolve_model_file(model_path) model_dir, model_file = _resolve_model_file(model_path)
@ -952,6 +1142,34 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
"timestamp": _now(), "timestamp": _now(),
} }
if name == "embed":
texts = arguments.get("texts") or []
model_name = arguments.get("model") or "all-MiniLM-L6-v2"
if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
return {"status": "failed", "error": "texts_must_be_list_of_strings", "timestamp": _now()}
if not texts:
return {"status": "completed", "embeddings": [], "timestamp": _now()}
try:
model = _ensure_embedding_model(model_name)
embeddings = model.encode(texts).tolist()
return {
"status": "completed",
"embeddings": embeddings,
"model": model_name,
"timestamp": _now(),
}
except Exception as e:
logger.error(f"Embedding failed: {str(e)}", exc_info=True)
return {
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"timestamp": _now(),
}
raise ValueError(f"Unknown tool: {name}") raise ValueError(f"Unknown tool: {name}")
@ -1009,6 +1227,7 @@ async def handle_health(request: web.Request) -> web.Response:
async def run_http_server(): async def run_http_server():
_init_runtime()
host = os.getenv("MCP_HTTP_HOST", "0.0.0.0") host = os.getenv("MCP_HTTP_HOST", "0.0.0.0")
port = int(os.getenv("MCP_HTTP_PORT", "8001")) port = int(os.getenv("MCP_HTTP_PORT", "8001"))
logger.info(f"Starting HTTP server on {host}:{port}") logger.info(f"Starting HTTP server on {host}:{port}")
@ -1028,6 +1247,7 @@ async def run_http_server():
if __name__ == "__main__": if __name__ == "__main__":
_init_runtime()
logger.info("Starting MCP Server...") logger.info("Starting MCP Server...")
try: try:
asyncio.run(run_http_server()) asyncio.run(run_http_server())

View file

@ -5,9 +5,7 @@ import {
HomeOutlined, HomeOutlined,
InfoCircleOutlined, InfoCircleOutlined,
RocketOutlined, RocketOutlined,
TeamOutlined,
RobotOutlined, RobotOutlined,
BulbOutlined,
AppstoreOutlined, AppstoreOutlined,
DashboardOutlined, DashboardOutlined,
LoginOutlined, LoginOutlined,
@ -42,10 +40,8 @@ const navItems: NavItem[] = [
icon: BuildOutlined, icon: BuildOutlined,
path: '/organization', path: '/organization',
children: [ children: [
{ key: '/roles', label: 'Roles', icon: TeamOutlined, path: '/roles', manager: true },
{ key: '/onboarding', label: 'Onboarding', icon: RocketOutlined, path: '/onboarding' }, { key: '/onboarding', label: 'Onboarding', icon: RocketOutlined, path: '/onboarding' },
{ key: '/progress', label: 'Progress', icon: DashboardOutlined, path: '/progress' }, { key: '/progress', label: 'Progress', icon: DashboardOutlined, path: '/progress' },
{ key: '/assessments', label: 'Assessments', icon: BulbOutlined, path: '/assessments' },
{ key: '/resources', label: 'Resources', icon: AppstoreOutlined, path: '/resources' }, { key: '/resources', label: 'Resources', icon: AppstoreOutlined, path: '/resources' },
], ],
}, },

View file

@ -77,6 +77,7 @@ export const API = {
organizations: () => '/api/organization/', organizations: () => '/api/organization/',
organization: (id: string) => `/api/organization/${id}/`, organization: (id: string) => `/api/organization/${id}/`,
organizationRoles: (orgUuid: string) => `/api/organization/${orgUuid}/role/`, organizationRoles: (orgUuid: string) => `/api/organization/${orgUuid}/role/`,
organizationRolesMine: () => '/api/organization/role/mine/',
organizationRole: (orgUuid: string, roleUuid: string) => organizationRole: (orgUuid: string, roleUuid: string) =>
`/api/organization/${orgUuid}/role/${roleUuid}/`, `/api/organization/${orgUuid}/role/${roleUuid}/`,
organizationRoleMembers: (orgUuid: string, roleUuid: string) => organizationRoleMembers: (orgUuid: string, roleUuid: string) =>
@ -98,6 +99,18 @@ export const API = {
`/api/organization/${orgUuid}/training-file/${fileUuid}/`, `/api/organization/${orgUuid}/training-file/${fileUuid}/`,
agents: () => '/api/agent/', agents: () => '/api/agent/',
agent: (id: string) => `/api/agent/${id}/`, agent: (id: string) => `/api/agent/${id}/`,
onboardingFlows: () => '/api/onboarding/flow/',
onboardingFlow: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/`,
onboardingFlowPages: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/pages/`,
onboardingFlowGenerate: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/generate/`,
onboardingFlowPublish: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/publish/`,
onboardingSessions: () => '/api/onboarding/session/',
onboardingSessionGetOrCreate: () => '/api/onboarding/session/get_or_create/',
onboardingSession: (sessionUuid: string) => `/api/onboarding/session/${sessionUuid}/`,
onboardingSessionSubmit: (sessionUuid: string) =>
`/api/onboarding/session/${sessionUuid}/submit/`,
onboardingSessionFeedback: (sessionUuid: string) =>
`/api/onboarding/session/${sessionUuid}/feedback/`,
} }
export const apiClient = new ApiClient() export const apiClient = new ApiClient()

View file

@ -67,6 +67,18 @@ const router = createRouter({
component: () => import('../views/AgentDetailView.vue'), component: () => import('../views/AgentDetailView.vue'),
meta: { requiresAuth: true, requiresManager: true }, meta: { requiresAuth: true, requiresManager: true },
}, },
{
path: '/onboarding',
name: 'onboarding',
component: () => import('../views/OnboardingView.vue'),
meta: { requiresAuth: true },
},
{
path: '/onboarding/:roleId',
name: 'onboarding-role',
component: () => import('../views/OnboardingView.vue'),
meta: { requiresAuth: true },
},
], ],
}) })

View file

@ -97,15 +97,26 @@ export const useAgentStore = defineStore('agent', () => {
isConnected.value = false isConnected.value = false
} }
const startAgent = (data: { query?: string; prompt?: string; options?: Record<string, unknown> }) => { const startAgent = (data: {
query?: string
prompt?: string
role_uuid?: string
max_tokens?: number
options?: Record<string, unknown>
}) => {
if (!socket || socket.readyState !== WebSocket.OPEN) return if (!socket || socket.readyState !== WebSocket.OPEN) return
const prompt = data.query ?? data.prompt ?? '' const prompt = data.query ?? data.prompt ?? ''
const options = {
...(data.options ?? {}),
...(typeof data.max_tokens === 'number' ? { max_tokens: data.max_tokens } : {}),
}
socket.send( socket.send(
JSON.stringify({ JSON.stringify({
action: 'infer', action: 'infer',
input_data: { input_data: {
prompt, prompt,
options: data.options ?? {}, role_uuid: data.role_uuid,
options,
}, },
}), }),
) )
@ -131,6 +142,17 @@ export const useAgentStore = defineStore('agent', () => {
) )
} }
const sendOnboardingProgress = (executionId: string, content: Record<string, unknown>) => {
if (!socket || socket.readyState !== WebSocket.OPEN) return
socket.send(
JSON.stringify({
action: 'onboarding_progress',
execution_id: executionId,
content,
}),
)
}
const resetLog = () => { const resetLog = () => {
eventLog.value = [] eventLog.value = []
} }
@ -144,6 +166,7 @@ export const useAgentStore = defineStore('agent', () => {
startAgent, startAgent,
startFineTune, startFineTune,
stopAgent, stopAgent,
sendOnboardingProgress,
resetLog, resetLog,
lastExecutionId, lastExecutionId,
} }

44
src/types/onboarding.ts Normal file
View file

@ -0,0 +1,44 @@
export type OnboardingField = {
uuid: string
key: string
label: string
field_type: 'text' | 'textarea' | 'select' | 'multiselect' | 'number' | 'boolean' | 'date'
required: boolean
help_text?: string
placeholder?: string
options?: string[]
default_value?: unknown
validation?: Record<string, unknown>
}
export type OnboardingPage = {
uuid: string
order: number
title: string
body?: string
meta?: Record<string, unknown>
fields: OnboardingField[]
}
export type OnboardingFlow = {
uuid: string
role: string
agent?: string | null
title: string
description?: string
status: 'draft' | 'published' | 'archived'
pages?: OnboardingPage[]
}
export type OnboardingSession = {
uuid: string
flow: string
agent_run?: string | null
status: 'in_progress' | 'completed' | 'abandoned'
current_page_order: number
responses: Record<string, unknown>
}
export type OnboardingFeedback = {
summary?: string
}

View file

@ -37,7 +37,7 @@ export interface InviteToken {
export interface TrainingFile { export interface TrainingFile {
id: number id: number
uuid: string uuid: string
organization: string role: Role
uploaded_by: User uploaded_by: User
file: string file: string
file_name: string file_name: string
@ -45,6 +45,7 @@ export interface TrainingFile {
file_type: string file_type: string
description: string description: string
is_processed: boolean is_processed: boolean
status: 'ingesting' | 'chunked' | 'embedded' | 'failed'
file_url: string file_url: string
created_at: string created_at: string
updated_at: string updated_at: string

View file

@ -2,15 +2,15 @@
import { Card, Typography, Divider, List } from 'ant-design-vue' import { Card, Typography, Divider, List } from 'ant-design-vue'
const steps = [ const steps = [
'Register or login (demo credentials only).', 'Register or login.',
'Complete Onboarding and Training to simulate a role journey.', 'Complete onboarding and training to simulate a role journey.',
'Managers assign employees to roles and review progress reports.', 'Managers review onboarding completion and feedback.',
] ]
const features = [ const features = [
{ {
title: 'Modular Content', title: 'Modular Content',
desc: 'Compose learning journeys from small, reusable modules — mix assessments, videos and interactive checks.', desc: 'Compose learning journeys from small, reusable modules. Mix videos and interactive checks.',
img: 'https://placehold.co/600x400?text=Modular+Content', img: 'https://placehold.co/600x400?text=Modular+Content',
}, },
{ {
@ -20,7 +20,7 @@ const features = [
}, },
{ {
title: 'Reporting & Insights', title: 'Reporting & Insights',
desc: 'Lightweight reports showing completion, assessment scores and engagement metrics.', desc: 'Lightweight reports showing completion and engagement metrics.',
img: 'https://placehold.co/600x400?text=Reporting', img: 'https://placehold.co/600x400?text=Reporting',
}, },
] ]
@ -33,7 +33,7 @@ const features = [
<Typography.Paragraph type="secondary"> <Typography.Paragraph type="secondary">
Dynavera is a lightweight platform for onboarding, training, and assessing employees Dynavera is a lightweight platform for onboarding, training, and assessing employees
with modular content and agent-driven workflows. It is designed for teams that want with modular content and agent-driven workflows. It is designed for teams that want
angible learning experiences quickly without complex LMS setup. tangible learning experiences quickly without complex LMS setup.
</Typography.Paragraph> </Typography.Paragraph>
<Divider /> <Divider />
<Typography.Title :level="4">Getting started</Typography.Title> <Typography.Title :level="4">Getting started</Typography.Title>
@ -59,11 +59,11 @@ const features = [
<Divider /> <Divider />
<Typography.Title :level="4">More about Dynavera</Typography.Title> <Typography.Title :level="4">More about Dynavera</Typography.Title>
<Typography.Paragraph> <Typography.Paragraph>
Dynavera is built to be extensible plug your content sources, connect an LMS, or Dynavera is built to be extensible. Plug your content sources, connect an LMS, or
integrate third-party assessment engines. The platform focuses on flexibility and integrate third-party learning tools. The platform focuses on flexibility and ease
ease of use, so teams can get started quickly and adapt as their needs evolve. of use, so teams can get started quickly and adapt as their needs evolve. Whether
Whether youre a small startup or a growing enterprise, Dynavera aims to simplify you are a small startup or a growing enterprise, Dynavera aims to simplify the
the process of onboarding and training with modern, agent-driven approaches. process of onboarding and training with modern, agent-driven approaches.
</Typography.Paragraph> </Typography.Paragraph>
</Card> </Card>
</div> </div>

View file

@ -1,9 +1,10 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, onUnmounted, computed } from 'vue' import { ref, onMounted, onUnmounted, computed } from 'vue'
import { useRoute } from 'vue-router' import { useRoute } from 'vue-router'
import { Card, Typography, Button, List, Space, Spin, Input, message, Tag } from 'ant-design-vue' import { Card, Typography, Button, List, Space, Spin, Input, message, Tag, Select, InputNumber } from 'ant-design-vue'
import { useAgentStore } from '../stores/agentStore' import { useAgentStore } from '../stores/agentStore'
import { apiClient, isAxiosError, API } from '../router/api' import { apiClient, isAxiosError, API } from '../router/api'
import type { Role } from '../types/organization'
const route = useRoute() const route = useRoute()
const agentStore = useAgentStore() const agentStore = useAgentStore()
@ -17,6 +18,10 @@ const agent = ref<Record<string, unknown>>({
status: 'idle', status: 'idle',
}) })
const roles = ref<Role[]>([])
const selectedRoleUuid = ref('')
const maxTokens = ref<number>(256)
const queryInput = ref('') const queryInput = ref('')
const isRunning = computed(() => agentStore.executionStatus === 'running') const isRunning = computed(() => agentStore.executionStatus === 'running')
const isConnected = computed(() => agentStore.isConnected ?? false) const isConnected = computed(() => agentStore.isConnected ?? false)
@ -52,6 +57,16 @@ const fetchAgent = async () => {
try { try {
const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId)) const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId))
agent.value = response.data agent.value = response.data
const org = agent.value.organization as Record<string, unknown> | null | undefined
const orgUuid = org?.uuid as string | undefined
console.log('Agent loaded:', agent.value)
console.log('Organization:', org)
console.log('Organization UUID:', orgUuid)
if (orgUuid) {
await fetchRoles(orgUuid)
} else {
console.warn('No organization UUID found for agent')
}
} catch (error) { } catch (error) {
console.error('Failed to fetch agent:', error) console.error('Failed to fetch agent:', error)
if (isAxiosError(error)) { if (isAxiosError(error)) {
@ -65,6 +80,32 @@ const fetchAgent = async () => {
} }
} }
const fetchRoles = async (orgUuid: string) => {
try {
console.log('Fetching roles for organization:', orgUuid)
const response = await apiClient.get<Role[]>(API.organizationRoles(orgUuid))
console.log('Roles loaded:', response.data)
roles.value = response.data
if (!selectedRoleUuid.value && roles.value.length > 0) {
selectedRoleUuid.value = roles.value[0].uuid
console.log('Auto-selected role:', selectedRoleUuid.value)
}
} catch (error) {
console.error('Failed to fetch roles:', error)
if (isAxiosError(error)) {
console.error('Roles API error:', {
status: error.response?.status,
data: error.response?.data,
})
}
message.error('Failed to load roles')
}
}
const roleOptions = computed(() =>
roles.value.map((role) => ({ label: role.name, value: role.uuid })),
)
const startAgent = () => { const startAgent = () => {
if (!agentStore.isConnected) { if (!agentStore.isConnected) {
message.error('WebSocket not connected') message.error('WebSocket not connected')
@ -76,8 +117,15 @@ const startAgent = () => {
return return
} }
if (!selectedRoleUuid.value) {
message.error('Please select a role for this query')
return
}
const data = { const data = {
query: queryInput.value.trim(), query: queryInput.value.trim(),
role_uuid: selectedRoleUuid.value,
max_tokens: maxTokens.value,
} }
agentStore.startAgent(data) agentStore.startAgent(data)
@ -144,6 +192,28 @@ onUnmounted(() => {
/> />
</div> </div>
<div>
<Typography.Text>Role:</Typography.Text>
<Select
v-model:value="selectedRoleUuid"
:options="roleOptions"
:disabled="isRunning"
placeholder="Select a role"
style="width: 100%"
/>
</div>
<div>
<Typography.Text>Max Tokens:</Typography.Text>
<InputNumber
v-model:value="maxTokens"
:min="1"
:max="4096"
:disabled="isRunning"
style="width: 100%"
/>
</div>
<Space> <Space>
<Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent"> <Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent">
Run Agent Run Agent

View file

@ -0,0 +1,722 @@
<script setup lang="ts">
import { ref, computed, onMounted, onUnmounted, watch } from 'vue'
import type { Dayjs } from 'dayjs'
import { useRoute, useRouter } from 'vue-router'
import {
Card,
Typography,
Button,
Space,
Spin,
Select,
Form,
Input,
InputNumber,
Switch,
DatePicker,
Divider,
List,
message,
} from 'ant-design-vue'
import { apiClient, API } from '../router/api'
import { useUserStore } from '../stores/userStore'
import { useAgentStore } from '../stores/agentStore'
import type { Role } from '../types/organization'
import type { OnboardingFlow, OnboardingPage, OnboardingSession, OnboardingFeedback } from '../types/onboarding'
const route = useRoute()
const router = useRouter()
const userStore = useUserStore()
const agentStore = useAgentStore()
const roleId = computed(() => (route.params.roleId as string | undefined) || undefined)
const flows = ref<OnboardingFlow[]>([])
type SingleSelectValue = string | number | undefined
const selectedFlowUuid = ref<SingleSelectValue>(undefined)
const flowDetails = ref<OnboardingFlow | null>(null)
const session = ref<OnboardingSession | null>(null)
const currentPageIndex = ref(0)
const loading = ref(false)
const generating = ref(false)
const loadError = ref<string | null>(null)
const generateInstructions = ref('')
const creatingFlow = ref(false)
const resettingSession = ref(false)
const publishing = ref(false)
const agents = ref<Array<{ uuid: string; name: string }>>([])
const userRoles = ref<Role[]>([])
const roleLoading = ref(false)
const roleLoadError = ref<string | null>(null)
const createFlowForm = ref({
title: '',
description: '',
agent: undefined as SingleSelectValue,
})
const isManager = computed(() => userStore.isGeneralManager)
const hasRoleContext = computed(() => Boolean(roleId.value))
const pages = computed<OnboardingPage[]>(() => flowDetails.value?.pages ?? [])
const currentPage = computed<OnboardingPage | null>(() => pages.value[currentPageIndex.value] || null)
const hasNext = computed(() => currentPageIndex.value < pages.value.length - 1)
const hasPrev = computed(() => currentPageIndex.value > 0)
const formState = ref<Record<string, unknown>>({})
const feedbackLoading = ref(false)
const feedbackByPage = ref<Record<string, OnboardingFeedback>>({})
const getFieldValue = (key: string) => formState.value[key]
const setFieldValue = (key: string, value: unknown) => {
formState.value[key] = value
}
const toSingleValue = (value: unknown): SingleSelectValue => {
if (Array.isArray(value)) {
return value.length ? (value[0] as string | number) : undefined
}
if (typeof value === 'string' || typeof value === 'number') {
return value
}
return undefined
}
const toNumberValue = (value: unknown): number | undefined =>
typeof value === 'number' ? value : undefined
const toDateValue = (value: unknown): Dayjs | string | undefined =>
typeof value === 'string' || (value as Dayjs | undefined)?.isValid ? (value as Dayjs | string) : undefined
const fetchFlows = async () => {
loading.value = true
loadError.value = null
try {
const params: Record<string, string> = {}
if (roleId.value) params.role = roleId.value
if (!isManager.value) params.status = 'published'
const response = await apiClient.get<OnboardingFlow[]>(API.onboardingFlows(), { params })
flows.value = Array.isArray(response.data) ? response.data : []
if (flows.value.length === 1) {
selectedFlowUuid.value = flows.value[0].uuid
}
} catch (err: unknown) {
console.error('Failed to load onboarding flows', err)
loadError.value = 'Failed to load onboarding flows'
flows.value = []
} finally {
loading.value = false
}
}
const fetchAgents = async () => {
try {
const response = await apiClient.get<Array<{ uuid: string; name: string }>>(API.agents())
agents.value = Array.isArray(response.data) ? response.data : []
} catch (err: unknown) {
console.error('Failed to load agents', err)
agents.value = []
}
}
const fetchUserRoles = async () => {
roleLoading.value = true
roleLoadError.value = null
try {
const response = await apiClient.get<Role[]>(API.organizationRolesMine())
userRoles.value = Array.isArray(response.data) ? response.data : []
if (userRoles.value.length === 1) {
await router.replace(`/onboarding/${userRoles.value[0].uuid}`)
}
} catch (err: unknown) {
console.error('Failed to load user roles', err)
roleLoadError.value = 'Failed to load roles'
userRoles.value = []
} finally {
roleLoading.value = false
}
}
const fetchFlowDetails = async (flowUuid: string) => {
loading.value = true
loadError.value = null
try {
const response = await apiClient.get<OnboardingFlow>(API.onboardingFlowPages(flowUuid))
flowDetails.value = response.data
await ensureSession(flowUuid)
syncPageIndexFromSession()
hydrateFormState()
if (flowDetails.value?.agent) {
agentStore.connect(flowDetails.value.agent)
}
} catch (err: unknown) {
console.error('Failed to load onboarding flow', err)
loadError.value = 'Failed to load onboarding flow'
flowDetails.value = null
} finally {
loading.value = false
}
}
const ensureSession = async (flowUuid: string) => {
if (session.value && session.value.flow === flowUuid) return
const response = await apiClient.post<OnboardingSession>(API.onboardingSessionGetOrCreate(), {
flow: flowUuid,
})
session.value = response.data
}
const syncPageIndexFromSession = () => {
const order = session.value?.current_page_order
if (typeof order === 'number' && order >= 0) {
currentPageIndex.value = Math.min(order, Math.max(pages.value.length - 1, 0))
} else {
currentPageIndex.value = 0
}
}
const hydrateFormState = () => {
if (!currentPage.value) {
formState.value = {}
return
}
const existing = session.value?.responses?.[currentPage.value.uuid] || {}
const values: Record<string, unknown> = {}
currentPage.value.fields.forEach((field) => {
const existingValue = (existing as Record<string, unknown>)[field.key]
if (existingValue !== undefined) {
values[field.key] = existingValue
} else if (field.default_value !== undefined && field.default_value !== null) {
values[field.key] = field.default_value
} else if (field.field_type === 'boolean') {
values[field.key] = false
} else if (field.field_type === 'multiselect') {
values[field.key] = []
} else {
values[field.key] = ''
}
})
formState.value = values
const storedFeedback = (session.value?.responses as Record<string, unknown> | undefined)?.[
'__feedback__'
] as Record<string, { feedback?: OnboardingFeedback }> | undefined
if (storedFeedback && currentPage.value) {
const pageFeedback = storedFeedback[currentPage.value.uuid]?.feedback
if (pageFeedback) {
feedbackByPage.value[currentPage.value.uuid] = pageFeedback
}
}
}
const normalizeResponses = (raw: Record<string, unknown>) => {
const normalized: Record<string, unknown> = {}
Object.entries(raw).forEach(([key, value]) => {
if (value && typeof value === 'object') {
const maybeDate = value as { toISOString?: () => string; format?: (fmt: string) => string }
if (typeof maybeDate.toISOString === 'function') {
normalized[key] = maybeDate.toISOString()
return
}
if (typeof maybeDate.format === 'function') {
normalized[key] = maybeDate.format('YYYY-MM-DD')
return
}
}
normalized[key] = value
})
return normalized
}
const onSubmitPage = async () => {
if (!currentPage.value || !session.value) return
try {
const normalizedResponses = normalizeResponses(formState.value)
const payload = {
page_uuid: currentPage.value.uuid,
responses: normalizedResponses,
mark_complete: !hasNext.value,
}
const response = await apiClient.post<OnboardingSession>(
API.onboardingSessionSubmit(session.value.uuid),
payload,
)
session.value = response.data
if (session.value.agent_run) {
agentStore.sendOnboardingProgress(session.value.agent_run, {
flow_uuid: session.value.flow,
session_uuid: session.value.uuid,
page_uuid: currentPage.value.uuid,
page_order: currentPageIndex.value,
responses: normalizedResponses,
status: session.value.status,
})
}
if (hasNext.value) {
currentPageIndex.value += 1
hydrateFormState()
} else {
message.success('Onboarding completed')
}
} catch (err: unknown) {
console.error('Failed to submit onboarding page', err)
message.error('Failed to save onboarding progress')
}
}
const requestFeedback = async () => {
if (!currentPage.value || !session.value) return
feedbackLoading.value = true
try {
const normalizedResponses = normalizeResponses(formState.value)
const response = await apiClient.post<{ feedback: OnboardingFeedback; session: OnboardingSession }>(
API.onboardingSessionFeedback(session.value.uuid),
{
page_uuid: currentPage.value.uuid,
responses: normalizedResponses,
},
)
session.value = response.data.session
feedbackByPage.value[currentPage.value.uuid] = response.data.feedback
} catch (err: unknown) {
console.error('Failed to get feedback', err)
message.error('Failed to get feedback')
} finally {
feedbackLoading.value = false
}
}
const generateWithAgent = async () => {
if (!flowDetails.value) return
generating.value = true
try {
const response = await apiClient.post<OnboardingFlow>(
API.onboardingFlowGenerate(flowDetails.value.uuid),
{
instructions: generateInstructions.value,
},
)
flowDetails.value = response.data
currentPageIndex.value = 0
hydrateFormState()
message.success('Generated onboarding content')
} catch (err: unknown) {
console.error('Failed to generate onboarding flow', err)
message.error('Failed to generate onboarding content')
} finally {
generating.value = false
}
}
const publishFlow = async () => {
if (!flowDetails.value) return
publishing.value = true
try {
const response = await apiClient.post<OnboardingFlow>(
API.onboardingFlowPublish(flowDetails.value.uuid),
)
flowDetails.value = response.data
flows.value = flows.value.map((flow) =>
flow.uuid === response.data.uuid ? { ...flow, status: response.data.status } : flow,
)
message.success('Flow published')
} catch (err: unknown) {
console.error('Failed to publish onboarding flow', err)
message.error('Failed to publish onboarding flow')
} finally {
publishing.value = false
}
}
const createFlow = async () => {
if (!roleId.value) {
message.error('Role is required to create a flow')
return
}
if (!createFlowForm.value.title.trim()) {
message.error('Title is required')
return
}
creatingFlow.value = true
try {
const payload = {
role: roleId.value,
agent:
createFlowForm.value.agent !== undefined && createFlowForm.value.agent !== null
? String(createFlowForm.value.agent)
: null,
title: createFlowForm.value.title.trim(),
description: createFlowForm.value.description.trim(),
status: 'draft',
}
const response = await apiClient.post<OnboardingFlow>(API.onboardingFlows(), payload)
flows.value = [response.data, ...flows.value]
selectedFlowUuid.value = response.data.uuid
createFlowForm.value = {
title: '',
description: '',
agent: undefined,
}
message.success('Onboarding flow created')
} catch (err: unknown) {
console.error('Failed to create onboarding flow', err)
message.error('Failed to create onboarding flow')
} finally {
creatingFlow.value = false
}
}
const goPrev = () => {
if (hasPrev.value) {
currentPageIndex.value -= 1
hydrateFormState()
}
}
const resetSession = async () => {
if (!flowDetails.value) return
resettingSession.value = true
try {
const response = await apiClient.post<OnboardingSession>(API.onboardingSessions(), {
flow: flowDetails.value.uuid,
})
session.value = response.data
currentPageIndex.value = 0
feedbackByPage.value = {}
hydrateFormState()
message.success('Started a new onboarding session')
} catch (err: unknown) {
console.error('Failed to reset onboarding session', err)
message.error('Failed to start a new onboarding session')
} finally {
resettingSession.value = false
}
}
watch(selectedFlowUuid, async (value) => {
if (value) {
await fetchFlowDetails(String(value))
}
})
watch(roleId, async () => {
selectedFlowUuid.value = undefined
flowDetails.value = null
session.value = null
userRoles.value = []
if (!roleId.value) {
await fetchUserRoles()
return
}
await fetchFlows()
if (selectedFlowUuid.value) {
await fetchFlowDetails(String(selectedFlowUuid.value))
}
})
watch(currentPageIndex, hydrateFormState)
onUnmounted(() => {
agentStore.disconnect()
})
onMounted(async () => {
await userStore.fetchSession()
if (isManager.value) {
await fetchAgents()
}
if (!roleId.value) {
await fetchUserRoles()
return
}
await fetchFlows()
if (selectedFlowUuid.value) {
await fetchFlowDetails(String(selectedFlowUuid.value))
}
})
</script>
<template>
<div class="page">
<Spin :spinning="loading" tip="Loading onboarding...">
<Card class="panel" :bordered="false">
<div class="header">
<Typography.Title :level="2">Onboarding</Typography.Title>
<Space v-if="flowDetails && isManager">
<Input
v-model:value="generateInstructions"
placeholder="Optional generation instructions"
style="width: 320px"
/>
<Button
type="primary"
:loading="generating"
:disabled="!flowDetails?.agent"
@click="generateWithAgent"
>
Generate with Agent
</Button>
<Button
:loading="publishing"
:disabled="flowDetails?.status === 'published'"
@click="publishFlow"
>
{{ flowDetails?.status === 'published' ? 'Published' : 'Publish' }}
</Button>
</Space>
</div>
<Typography.Paragraph v-if="loadError" type="danger">{{ loadError }}</Typography.Paragraph>
<div v-if="!hasRoleContext">
<Typography.Paragraph type="secondary">
Choose a role to continue onboarding.
</Typography.Paragraph>
<Typography.Paragraph v-if="roleLoadError" type="danger">
{{ roleLoadError }}
</Typography.Paragraph>
<Spin :spinning="roleLoading">
<div v-if="userRoles.length > 0" class="role-list">
<List :data-source="userRoles" :bordered="false">
<template #renderItem="{ item }">
<List.Item class="role-item">
<List.Item.Meta
:title="item.name"
:description="item.organization?.name || 'Organization'"
/>
<Button type="primary" @click="router.push(`/onboarding/${item.uuid}`)">
Open
</Button>
</List.Item>
</template>
</List>
</div>
<Typography.Paragraph v-else type="secondary">
You are not a member of any roles yet.
</Typography.Paragraph>
</Spin>
</div>
<div v-else-if="flows.length > 1" class="flow-select">
<Typography.Text strong>Select a flow</Typography.Text>
<Select
:value="selectedFlowUuid"
@update:value="(val) => (selectedFlowUuid = toSingleValue(val))"
placeholder="Choose a flow"
style="width: 320px"
:options="flows.map((f) => ({ value: f.uuid, label: f.title }))"
/>
</div>
<div v-else-if="flows.length === 0">
<Typography.Paragraph type="secondary">
No onboarding flows available yet.
</Typography.Paragraph>
<div v-if="isManager" class="create-flow">
<Typography.Text strong>Create a flow</Typography.Text>
<Typography.Paragraph type="secondary">
Create a draft flow and then generate pages with an agent.
</Typography.Paragraph>
<Typography.Paragraph v-if="!roleId" type="warning">
Open this page with a role id to create a flow.
</Typography.Paragraph>
<Form layout="vertical" :model="createFlowForm" @finish="createFlow">
<Form.Item
label="Title"
name="title"
:rules="[{ required: true, message: 'Title is required' }]"
>
<Input v-model:value="createFlowForm.title" placeholder="Flow title" />
</Form.Item>
<Form.Item label="Description" name="description">
<Input
v-model:value="createFlowForm.description"
placeholder="Short description"
/>
</Form.Item>
<Form.Item label="Agent (optional)" name="agent">
<Select
:value="createFlowForm.agent"
@update:value="(val) => (createFlowForm.agent = toSingleValue(val))"
allow-clear
placeholder="Select an agent"
:options="agents.map((a) => ({ value: a.uuid, label: a.name }))"
/>
</Form.Item>
<Button
type="primary"
html-type="submit"
:loading="creatingFlow"
:disabled="!roleId"
>
Create draft flow
</Button>
</Form>
<Typography.Paragraph v-if="!agents.length" type="secondary">
No agents available yet. You can create a flow now and assign an agent later.
</Typography.Paragraph>
</div>
</div>
<div v-if="flowDetails" class="flow-body">
<Typography.Title :level="3">{{ flowDetails.title }}</Typography.Title>
<Typography.Paragraph type="secondary">
{{ flowDetails.description || 'No description provided.' }}
</Typography.Paragraph>
<Space v-if="session" class="session-actions">
<Button danger :loading="resettingSession" @click="resetSession">
Start New Session
</Button>
</Space>
<Divider />
<div v-if="currentPage" class="page-body">
<Typography.Title :level="4">
{{ currentPage.title }}
</Typography.Title>
<Typography.Paragraph>
{{ currentPage.body }}
</Typography.Paragraph>
<Form layout="vertical" :model="formState" @finish="onSubmitPage">
<Form.Item
v-for="field in currentPage.fields"
:key="field.uuid"
:label="field.label"
:name="field.key"
:rules="field.required ? [{ required: true, message: 'Required' }] : []"
>
<Input
v-if="field.field_type === 'text'"
:value="(getFieldValue(field.key) as string | number | undefined)"
@update:value="(val) => setFieldValue(field.key, val)"
:placeholder="field.placeholder"
/>
<Input.TextArea
v-else-if="field.field_type === 'textarea'"
:value="(getFieldValue(field.key) as string | undefined)"
@update:value="(val) => setFieldValue(field.key, val)"
:placeholder="field.placeholder"
:rows="4"
/>
<InputNumber
v-else-if="field.field_type === 'number'"
:value="toNumberValue(getFieldValue(field.key))"
@update:value="(val) => setFieldValue(field.key, val)"
style="width: 100%"
/>
<Select
v-else-if="field.field_type === 'select'"
:value="(getFieldValue(field.key) as string | undefined)"
@update:value="(val) => setFieldValue(field.key, val)"
:options="(field.options || []).map((opt) => ({ value: opt, label: opt }))"
/>
<Select
v-else-if="field.field_type === 'multiselect'"
:value="(getFieldValue(field.key) as string[] | undefined)"
@update:value="(val) => setFieldValue(field.key, val)"
mode="multiple"
:options="(field.options || []).map((opt) => ({ value: opt, label: opt }))"
/>
<Switch
v-else-if="field.field_type === 'boolean'"
:checked="Boolean(getFieldValue(field.key))"
@update:checked="(val) => setFieldValue(field.key, val)"
/>
<DatePicker
v-else-if="field.field_type === 'date'"
:value="toDateValue(getFieldValue(field.key))"
@update:value="(val) => setFieldValue(field.key, val)"
style="width: 100%"
/>
<Input
v-else
:value="(getFieldValue(field.key) as string | number | undefined)"
@update:value="(val) => setFieldValue(field.key, val)"
:placeholder="field.placeholder"
/>
</Form.Item>
<Space>
<Button :disabled="!hasPrev" @click="goPrev">Back</Button>
<Button :loading="feedbackLoading" @click="requestFeedback">
Get Feedback
</Button>
<Button type="primary" html-type="submit">
{{ hasNext ? 'Next' : 'Finish' }}
</Button>
</Space>
<div class="feedback">
<div v-if="feedbackByPage[currentPage.uuid]" class="feedback-panel">
<Typography.Text strong>Feedback</Typography.Text>
<Typography.Paragraph>
{{ feedbackByPage[currentPage.uuid]?.summary || '' }}
</Typography.Paragraph>
</div>
</div>
</Form>
</div>
</div>
</Card>
</Spin>
</div>
</template>
<style scoped>
.page {
max-width: 1100px;
margin: 0 auto;
padding: 1rem;
}
.panel {
background: #0f172a;
border: 1px solid #1f2937;
color: #e5e7eb;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
gap: 1rem;
margin-bottom: 1rem;
}
.flow-select {
margin: 1rem 0;
}
.role-list {
margin: 1rem 0;
}
.role-item :deep(.ant-list-item-meta-title),
.role-item :deep(.ant-list-item-meta-description) {
background: #0f172a;
border: 1px solid #1f2937;
}
.flow-body {
margin-top: 1rem;
}
.session-actions {
margin-top: 0.5rem;
}
.create-flow {
margin-top: 1rem;
padding: 1rem;
border: 1px dashed #1f2937;
border-radius: 6px;
background: #0b1220;
}
.page-body {
margin-top: 1rem;
}
.feedback {
margin-top: 1rem;
display: grid;
gap: 0.75rem;
}
.feedback-panel {
padding: 0.75rem 1rem;
border: 1px solid #1f2937;
border-radius: 6px;
background: #0b1220;
}
.feedback-panel ul {
margin: 0.5rem 0 0.75rem 1.25rem;
}
</style>

View file

@ -1,7 +1,7 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted, computed, h } from 'vue' import { ref, onMounted, computed, h } from 'vue'
import { useRouter, useRoute } from 'vue-router' import { useRouter, useRoute } from 'vue-router'
import { Card, Typography, Button, List, Space, Spin, message, Tag, Divider, Upload, Modal, Table } from 'ant-design-vue' import { Card, Typography, Button, List, Space, Spin, message, Tag, Divider, Upload, Modal, Table, Select } from 'ant-design-vue'
import { apiClient, isAxiosError, API } from '../router/api' import { apiClient, isAxiosError, API } from '../router/api'
import { useUserStore } from '../stores/userStore' import { useUserStore } from '../stores/userStore'
import { InboxOutlined, DeleteOutlined } from '@ant-design/icons-vue' import { InboxOutlined, DeleteOutlined } from '@ant-design/icons-vue'
@ -52,28 +52,13 @@ const fetchRoles = async () => {
} }
const fetchUserRoleMemberships = async () => { const fetchUserRoleMemberships = async () => {
const userRoleUuids: string[] = [] if (!organization.value?.uuid) return
const userId = auth.user?.id
if (!userId || !organization.value?.uuid) return
try { try {
const checks = roles.value.map(async (r) => { const response = await apiClient.get<Role[]>(API.organizationRolesMine())
try { const mine = Array.isArray(response.data) ? response.data : []
const resp = await apiClient.get<Array<{ user: { id: number } }>>( const orgUuid = organization.value.uuid
API.organizationRoleMembers(organization.value!.uuid, r.uuid), const joinedRoles = mine.filter((role) => role.organization?.uuid === orgUuid)
) auth.setJoinedRoles(joinedRoles)
const found =
Array.isArray(resp.data) && resp.data.some((m) => m.user?.id === userId)
if (found && r.uuid) userRoleUuids.push(r.uuid)
} catch {
// ignore individual role errors
}
})
await Promise.all(checks)
// update the global store with actual Role objects the user has joined
const joinedRoles = roles.value.filter((r) => userRoleUuids.includes(r.uuid))
if (joinedRoles.length) {
auth.setJoinedRoles(joinedRoles)
}
} catch (err) { } catch (err) {
console.error('Failed to fetch user role memberships', err) console.error('Failed to fetch user role memberships', err)
} }
@ -169,6 +154,7 @@ const beforeUpload = (file: File) => {
const selectedFile = ref<File | null>(null) const selectedFile = ref<File | null>(null)
const fileDescription = ref('') const fileDescription = ref('')
const selectedRoleUuid = ref<string>('')
const handleFileSelected = (file: File) => { const handleFileSelected = (file: File) => {
selectedFile.value = file selectedFile.value = file
@ -179,10 +165,15 @@ const handleFileUploadClick = async () => {
message.error('Please select a file to upload') message.error('Please select a file to upload')
return return
} }
if (!selectedRoleUuid.value) {
message.error('Please select a role for this training file')
return
}
await handleFileUpload(selectedFile.value, fileDescription.value) await handleFileUpload(selectedFile.value, fileDescription.value)
selectedFile.value = null selectedFile.value = null
fileDescription.value = '' fileDescription.value = ''
selectedRoleUuid.value = ''
} }
const handleFileUpload = async (file: File, description: string = '') => { const handleFileUpload = async (file: File, description: string = '') => {
@ -197,6 +188,9 @@ const handleFileUpload = async (file: File, description: string = '') => {
formData.append('file', file) formData.append('file', file)
formData.append('file_name', file.name) formData.append('file_name', file.name)
formData.append('description', description) formData.append('description', description)
if (selectedRoleUuid.value) {
formData.append('role_uuid', selectedRoleUuid.value)
}
const response = await apiClient.post<TrainingFile>( const response = await apiClient.post<TrainingFile>(
API.organizationTrainingFiles(organization.value.uuid), API.organizationTrainingFiles(organization.value.uuid),
@ -278,12 +272,25 @@ const trainingFileColumns = [
key: 'file_size', key: 'file_size',
customRender: ({ value }: { value: number }) => formatFileSize(value || 0), customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
}, },
{
title: 'Role',
key: 'role',
customRender: ({ record }: { record: TrainingFile }) => record.role?.name || '-',
},
{ {
title: 'Status', title: 'Status',
dataIndex: 'is_processed', dataIndex: 'status',
key: 'is_processed', key: 'status',
customRender: ({ value }: { value: boolean }) => customRender: ({ value }: { value: string }) => {
h(Tag, { color: value ? 'success' : 'processing' }, () => value ? 'Processed' : 'Processing'), const statusMap: Record<string, { color: string; label: string }> = {
ingesting: { color: 'processing', label: 'Ingesting' },
chunked: { color: 'blue', label: 'Chunked' },
embedded: { color: 'success', label: 'Embedded' },
failed: { color: 'error', label: 'Failed' },
}
const status = statusMap[value] || { color: 'default', label: value }
return h(Tag, { color: status.color }, () => status.label)
},
}, },
{ {
title: 'Uploaded', title: 'Uploaded',
@ -312,13 +319,13 @@ const trainingFileColumns = [
}, },
] ]
onMounted(() => { onMounted(async () => {
fetchOrganization().then(async () => { await auth.fetchSession(true)
await fetchMembers() await fetchOrganization()
await fetchRoles() await fetchMembers()
await fetchUserRoleMemberships() await fetchRoles()
await fetchTrainingFiles() await fetchUserRoleMemberships()
}) await fetchTrainingFiles()
}) })
</script> </script>
@ -404,7 +411,6 @@ onMounted(() => {
<Button <Button
type="default" type="default"
size="small" size="small"
:disabled="isManager"
@click="router.push(`/onboarding/${item.uuid}`)" @click="router.push(`/onboarding/${item.uuid}`)"
> >
Start Onboarding Start Onboarding
@ -418,7 +424,9 @@ onMounted(() => {
> >
Join Role Join Role
</Button> </Button>
<Tag v-else color="success">Joined</Tag> <Button v-else size="small" disabled>
Joined
</Button>
</Space> </Space>
</List.Item> </List.Item>
</template> </template>
@ -436,7 +444,7 @@ onMounted(() => {
width="600px" width="600px"
ok-text="Upload" ok-text="Upload"
cancel-text="Cancel" cancel-text="Cancel"
:ok-button-props="{ loading: uploading, disabled: !selectedFile }" :ok-button-props="{ loading: uploading, disabled: !selectedFile || !selectedRoleUuid }"
@ok="handleFileUploadClick" @ok="handleFileUploadClick"
@cancel="showUploadModal = false" @cancel="showUploadModal = false"
> >
@ -444,6 +452,15 @@ onMounted(() => {
<Typography.Text> <Typography.Text>
Supported formats: <strong>txt, pdf, md, csv, json, docx, doc</strong> (Max 50MB) Supported formats: <strong>txt, pdf, md, csv, json, docx, doc</strong> (Max 50MB)
</Typography.Text> </Typography.Text>
<div>
<Typography.Text strong>Role</Typography.Text>
<Select
v-model:value="selectedRoleUuid"
placeholder="Select a role"
style="width: 100%"
:options="roles.map((role) => ({ label: role.name, value: role.uuid }))"
/>
</div>
<Upload.Dragger <Upload.Dragger
accept=".txt,.pdf,.md,.csv,.json,.docx,.doc" accept=".txt,.pdf,.md,.csv,.json,.docx,.doc"
:before-upload="(file) => { :before-upload="(file) => {