Added mcp rag, onboarding app, frontend for learning and agent inference
This commit is contained in:
parent
a12d5f906c
commit
2f7b2001d4
41 changed files with 3305 additions and 134 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.contrib.admin import ModelAdmin, TabularInline
|
from django.contrib.admin import ModelAdmin, TabularInline
|
||||||
from apps.mlstore.models import AgentModel, AgentRun, Agent, AgentEvent
|
from apps.mlstore.models import AgentModel, AgentRun, Agent, AgentEvent, RoleRagDocument
|
||||||
|
|
||||||
|
|
||||||
class AgentInline(TabularInline):
|
class AgentInline(TabularInline):
|
||||||
|
|
@ -55,3 +55,12 @@ class AgentEventAdmin(ModelAdmin):
|
||||||
search_fields = ('event_type', 'execution__uuid', 'execution__agent__model__name')
|
search_fields = ('event_type', 'execution__uuid', 'execution__agent__model__name')
|
||||||
list_filter = ('event_type',)
|
list_filter = ('event_type',)
|
||||||
raw_id_fields = ('execution',)
|
raw_id_fields = ('execution',)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(RoleRagDocument)
|
||||||
|
class RoleRagDocumentAdmin(ModelAdmin):
|
||||||
|
list_display = ('id', 'uuid', 'role', 'training_file', 'chunk_index', 'is_active', 'created_at')
|
||||||
|
search_fields = ('role__name', 'training_file__file_name')
|
||||||
|
list_filter = ('is_active', 'created_at')
|
||||||
|
raw_id_fields = ('role', 'training_file')
|
||||||
|
readonly_fields = ('uuid', 'created_at', 'updated_at')
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
|
||||||
await self.handle_fine_tune(data)
|
await self.handle_fine_tune(data)
|
||||||
elif action == "infer":
|
elif action == "infer":
|
||||||
await self.handle_infer(data)
|
await self.handle_infer(data)
|
||||||
|
elif action == "onboarding_progress":
|
||||||
|
await self.handle_onboarding_progress(data)
|
||||||
elif action in ("stop_agent", "stop"):
|
elif action in ("stop_agent", "stop"):
|
||||||
await self.handle_stop(data)
|
await self.handle_stop(data)
|
||||||
else:
|
else:
|
||||||
|
|
@ -90,6 +92,16 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
|
||||||
return
|
return
|
||||||
|
|
||||||
input_data = data.get("input_data") or {}
|
input_data = data.get("input_data") or {}
|
||||||
|
role_uuid = input_data.get("role_uuid")
|
||||||
|
if not role_uuid:
|
||||||
|
options = input_data.get("options") or {}
|
||||||
|
role_uuid = options.get("role_uuid")
|
||||||
|
if not role_uuid:
|
||||||
|
await self.send(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "role_uuid is required for inference to enable RAG"
|
||||||
|
}))
|
||||||
|
return
|
||||||
execution = await self.create_run(agent, self.user, input_data)
|
execution = await self.create_run(agent, self.user, input_data)
|
||||||
|
|
||||||
await self.send(json.dumps({
|
await self.send(json.dumps({
|
||||||
|
|
@ -125,6 +137,35 @@ class MLStoreConsumer(AsyncWebsocketConsumer):
|
||||||
"error_message": "Execution stopped by user"
|
"error_message": "Execution stopped by user"
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
async def handle_onboarding_progress(self, data):
|
||||||
|
execution_id = data.get("execution_id")
|
||||||
|
if not execution_id:
|
||||||
|
await self.send(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "execution_id required for onboarding_progress"
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
execution = await self.get_execution(execution_id)
|
||||||
|
if not execution:
|
||||||
|
await self.send(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Execution not found"
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
content = data.get("content") or data.get("progress") or {}
|
||||||
|
await self.create_event(execution, "progress", content)
|
||||||
|
await self.channel_layer.group_send(
|
||||||
|
self.room_group_name,
|
||||||
|
{
|
||||||
|
"type": "mlstore_event",
|
||||||
|
"event_type": "progress",
|
||||||
|
"content": content,
|
||||||
|
"timestamp": timezone.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def mlstore_event(self, event):
|
async def mlstore_event(self, event):
|
||||||
await self.send(json.dumps({
|
await self.send(json.dumps({
|
||||||
"type": "mlstore_event",
|
"type": "mlstore_event",
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,22 @@ import django.db.models.deletion
|
||||||
import uuid
|
import uuid
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
from pgvector.django import VectorField
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vector_extension(apps, schema_editor):
|
||||||
|
if schema_editor.connection.vendor != 'postgresql':
|
||||||
|
return
|
||||||
|
with schema_editor.connection.cursor() as cursor:
|
||||||
|
cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_vector_extension(apps, schema_editor):
|
||||||
|
if schema_editor.connection.vendor != 'postgresql':
|
||||||
|
return
|
||||||
|
with schema_editor.connection.cursor() as cursor:
|
||||||
|
cursor.execute('DROP EXTENSION IF EXISTS vector')
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
|
@ -13,6 +29,10 @@ class Migration(migrations.Migration):
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
|
migrations.RunPython(
|
||||||
|
code=_create_vector_extension,
|
||||||
|
reverse_code=_drop_vector_extension,
|
||||||
|
),
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
name='AgentModel',
|
name='AgentModel',
|
||||||
fields=[
|
fields=[
|
||||||
|
|
@ -54,7 +74,7 @@ class Migration(migrations.Migration):
|
||||||
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
('status', models.CharField(choices=[('queued', 'Queued'), ('running', 'Running'), ('completed', 'Completed'), ('failed', 'Failed')], default='queued', max_length=20)),
|
('status', models.CharField(choices=[('queued', 'Queued'), ('running', 'Running'), ('completed', 'Completed'), ('failed', 'Failed')], default='queued', max_length=20)),
|
||||||
('input_data', models.JSONField(default=dict)),
|
('input_data', models.JSONField(blank=True, default=dict)),
|
||||||
('output_data', models.JSONField(blank=True, default=dict)),
|
('output_data', models.JSONField(blank=True, default=dict)),
|
||||||
('error_message', models.TextField(blank=True, default='')),
|
('error_message', models.TextField(blank=True, default='')),
|
||||||
('started_at', models.DateTimeField(blank=True, null=True)),
|
('started_at', models.DateTimeField(blank=True, null=True)),
|
||||||
|
|
@ -83,4 +103,25 @@ class Migration(migrations.Migration):
|
||||||
'ordering': ['timestamp'],
|
'ordering': ['timestamp'],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='RoleRagDocument',
|
||||||
|
fields=[
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
|
||||||
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
|
('content', models.TextField()),
|
||||||
|
('content_hash', models.CharField(db_index=True, max_length=64)),
|
||||||
|
('embedding', VectorField(blank=True, dimensions=1536, null=True)),
|
||||||
|
('metadata', models.JSONField(blank=True, default=dict)),
|
||||||
|
('chunk_index', models.IntegerField(default=0)),
|
||||||
|
('is_active', models.BooleanField(default=True)),
|
||||||
|
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='orgs.role')),
|
||||||
|
('training_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='orgs.trainingfile')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Role RAG Document',
|
||||||
|
'verbose_name_plural': 'Role RAG Documents',
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
('mlstore', '0001_initial'),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.AlterField(
|
|
||||||
model_name='agentrun',
|
|
||||||
name='input_data',
|
|
||||||
field=models.JSONField(blank=True, default=dict),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from django.db.models import BigAutoField, CASCADE, CharField, DateTimeField, ForeignKey, JSONField, Model, TextField, UUIDField
|
from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, JSONField, Model, TextField, UUIDField, IntegerField
|
||||||
|
from pgvector.django import VectorField
|
||||||
from apps.users.mixins import TimeStampMixin
|
from apps.users.mixins import TimeStampMixin
|
||||||
from apps.users.models import User
|
from apps.users.models import User
|
||||||
from apps.orgs.models import Organization
|
from apps.orgs.models import Organization, Role, TrainingFile
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
class AgentModel(Model):
|
class AgentModel(Model):
|
||||||
|
|
@ -100,3 +101,25 @@ class AgentEvent(Model):
|
||||||
ordering = ['timestamp']
|
ordering = ['timestamp']
|
||||||
verbose_name = "Agent Event"
|
verbose_name = "Agent Event"
|
||||||
verbose_name_plural = "Agent Events"
|
verbose_name_plural = "Agent Events"
|
||||||
|
|
||||||
|
|
||||||
|
class RoleRagDocument(TimeStampMixin, Model):
|
||||||
|
|
||||||
|
id = BigAutoField(primary_key = True)
|
||||||
|
uuid = UUIDField(default = uuid4, editable = False, unique = True)
|
||||||
|
role = ForeignKey(Role, on_delete = CASCADE, related_name = 'rag_documents')
|
||||||
|
training_file = ForeignKey(TrainingFile, on_delete = CASCADE, related_name = 'rag_documents', null = True, blank = True)
|
||||||
|
|
||||||
|
content = TextField()
|
||||||
|
content_hash = CharField(max_length = 64, db_index = True)
|
||||||
|
embedding = VectorField(dimensions = 1536, null = True, blank = True)
|
||||||
|
metadata = JSONField(default = dict, blank = True)
|
||||||
|
chunk_index = IntegerField(default = 0)
|
||||||
|
is_active = BooleanField(default = True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = "Role RAG Document"
|
||||||
|
verbose_name_plural = "Role RAG Documents"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.role.name} - chunk {self.chunk_index}"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from rest_framework.serializers import ModelSerializer
|
from rest_framework.serializers import ModelSerializer
|
||||||
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
from .models import AgentModel, Agent, AgentRun, AgentEvent
|
||||||
|
from apps.orgs.serializers import OrganizationSerializer
|
||||||
|
|
||||||
|
|
||||||
class AgentModelSerializer(ModelSerializer):
|
class AgentModelSerializer(ModelSerializer):
|
||||||
|
|
@ -11,6 +12,7 @@ class AgentModelSerializer(ModelSerializer):
|
||||||
|
|
||||||
class AgentSerializer(ModelSerializer):
|
class AgentSerializer(ModelSerializer):
|
||||||
model = AgentModelSerializer(read_only=True)
|
model = AgentModelSerializer(read_only=True)
|
||||||
|
organization = OrganizationSerializer(read_only=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Agent
|
model = Agent
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from mcp_agent.mcp_client import MCPClient
|
from mcp_agent.mcp_client import MCPClient
|
||||||
from .models import AgentModel
|
from .models import AgentModel, RoleRagDocument
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Get reference to the base model cache directory
|
|
||||||
try:
|
try:
|
||||||
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
|
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
|
||||||
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
|
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback: construct the path manually
|
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
|
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
|
||||||
|
|
||||||
|
|
@ -67,7 +66,6 @@ def fine_tune_model(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
|
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
|
||||||
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
|
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
|
||||||
# Return a failed response instead of raising
|
|
||||||
return {
|
return {
|
||||||
"status": "failed",
|
"status": "failed",
|
||||||
"error": error_msg,
|
"error": error_msg,
|
||||||
|
|
@ -114,3 +112,294 @@ def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel
|
||||||
NOTE: migrations are required after the model field change prior to using this in production.
|
NOTE: migrations are required after the model field change prior to using this in production.
|
||||||
"""
|
"""
|
||||||
return AgentModel.objects.create(name=name, version=version, path=model_path)
|
return AgentModel.objects.create(name=name, version=version, path=model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for texts using the MCP embedding service.
|
||||||
|
|
||||||
|
Falls back to local sentence-transformers if MCP unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text strings to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors (list of floats).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If both MCP and local embedding fail.
|
||||||
|
"""
|
||||||
|
logger.info(f"Embedding {len(texts)} texts")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = asyncio.run(_call_mcp("embed", {"texts": texts}))
|
||||||
|
embeddings = result.get("embeddings", [])
|
||||||
|
if embeddings and len(embeddings) == len(texts):
|
||||||
|
logger.info(f"Successfully embedded {len(texts)} texts via MCP")
|
||||||
|
return embeddings
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"MCP embedding failed, trying local fallback: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
embeddings = model.encode(texts).tolist()
|
||||||
|
logger.info(f"Successfully embedded {len(texts)} texts via local model")
|
||||||
|
return embeddings
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Local embedding also failed: {e}")
|
||||||
|
raise RuntimeError(f"Failed to embed texts: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def embed_text(text: str) -> List[float]:
|
||||||
|
"""Generate embedding for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text string to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector (list of floats).
|
||||||
|
"""
|
||||||
|
return embed_texts([text])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar_documents(
|
||||||
|
query: str,
|
||||||
|
role_uuid: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
similarity_threshold: float = 0.0,
|
||||||
|
) -> List[Tuple[RoleRagDocument, float]]:
|
||||||
|
"""Search for documents similar to the query using vector similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query text to embed and search for.
|
||||||
|
role_uuid: UUID of role to scope search.
|
||||||
|
top_k: Number of top results to return.
|
||||||
|
similarity_threshold: Minimum similarity score (0-1) to include results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (RoleRagDocument, similarity_score) tuples, ordered by similarity DESC.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If role not found or embedding fails.
|
||||||
|
"""
|
||||||
|
from apps.orgs.models import Role
|
||||||
|
|
||||||
|
try:
|
||||||
|
query_embedding = embed_text(query)
|
||||||
|
logger.info(f"Embedded query: '{query[:50]}...' to {len(query_embedding)}D vector")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to embed query: {e}")
|
||||||
|
raise ValueError(f"Failed to embed query: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
role = Role.objects.get(uuid=role_uuid)
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
raise ValueError(f"Role with UUID {role_uuid} not found")
|
||||||
|
|
||||||
|
queryset = RoleRagDocument.objects.filter(
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not queryset.exists():
|
||||||
|
logger.warning(f"No documents with embeddings found for role {role.uuid}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
query_sql = """
|
||||||
|
SELECT id, 1 - (embedding <=> %s::vector) as similarity
|
||||||
|
FROM mlstore_roleragdocument
|
||||||
|
WHERE role_id = %s AND embedding IS NOT NULL
|
||||||
|
ORDER BY similarity DESC
|
||||||
|
LIMIT %s
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
query_sql,
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_ids_with_scores = cursor.fetchall()
|
||||||
|
|
||||||
|
if not doc_ids_with_scores:
|
||||||
|
logger.info(f"No similar documents found for query in role {role.uuid}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
filtered_docs = [
|
||||||
|
(doc_id, score)
|
||||||
|
for doc_id, score in doc_ids_with_scores
|
||||||
|
if score >= similarity_threshold
|
||||||
|
][:top_k]
|
||||||
|
|
||||||
|
if not filtered_docs:
|
||||||
|
logger.info(
|
||||||
|
f"No documents met similarity threshold {similarity_threshold}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc_ids = [doc_id for doc_id, _ in filtered_docs]
|
||||||
|
doc_scores = {doc_id: score for doc_id, score in filtered_docs}
|
||||||
|
|
||||||
|
documents = RoleRagDocument.objects.filter(id__in=doc_ids)
|
||||||
|
results = [
|
||||||
|
(doc, doc_scores[doc.id])
|
||||||
|
for doc in documents
|
||||||
|
if doc.id in doc_scores
|
||||||
|
]
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Found {len(results)} similar documents for query "
|
||||||
|
f"(threshold={similarity_threshold}, top_k={top_k})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def batch_embed_documents(
|
||||||
|
documents: List[RoleRagDocument],
|
||||||
|
batch_size: int = 32,
|
||||||
|
force_reembed: bool = False,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Batch embed documents that don't have embeddings yet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of RoleRagDocument instances to embed.
|
||||||
|
batch_size: Number of documents to embed per API call.
|
||||||
|
force_reembed: If True, re-embed documents that already have embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (num_embedded, num_failed).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Updates documents in-place with embedding values.
|
||||||
|
"""
|
||||||
|
to_embed = [
|
||||||
|
doc for doc in documents
|
||||||
|
if force_reembed or not doc.embedding
|
||||||
|
]
|
||||||
|
|
||||||
|
if not to_embed:
|
||||||
|
logger.info("No documents to embed")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
num_embedded = 0
|
||||||
|
num_failed = 0
|
||||||
|
|
||||||
|
for i in range(0, len(to_embed), batch_size):
|
||||||
|
batch = to_embed[i : i + batch_size]
|
||||||
|
logger.info(
|
||||||
|
f"Embedding batch {i // batch_size + 1} "
|
||||||
|
f"({len(batch)} documents)"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
texts = [doc.content for doc in batch]
|
||||||
|
embeddings = embed_texts(texts)
|
||||||
|
|
||||||
|
for doc, embedding in zip(batch, embeddings):
|
||||||
|
doc.embedding = embedding
|
||||||
|
num_embedded += 1
|
||||||
|
|
||||||
|
RoleRagDocument.objects.bulk_update(batch, ["embedding"], batch_size=500)
|
||||||
|
logger.info(f"Successfully embedded {len(batch)} documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to embed batch: {e}")
|
||||||
|
num_failed += len(batch)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Embedding complete: {num_embedded} embedded, {num_failed} failed"
|
||||||
|
)
|
||||||
|
return num_embedded, num_failed
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_for_query(
|
||||||
|
query: str,
|
||||||
|
role_uuid: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
similarity_threshold: float = 0.5,
|
||||||
|
) -> str:
|
||||||
|
"""Get context string from similar documents for a query.
|
||||||
|
|
||||||
|
Useful for augmenting prompts with retrieved context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query text.
|
||||||
|
role_uuid: UUID of role to search within.
|
||||||
|
top_k: Number of top results to include.
|
||||||
|
similarity_threshold: Minimum similarity score.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string with source attribution.
|
||||||
|
"""
|
||||||
|
def _clean_chunk_text(text: str) -> str:
|
||||||
|
"""Strip junk and deduplicate paragraphs to keep context lean."""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = re.sub(r"\[\s*Answer\s*:.*?\]", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for raw_line in text.splitlines():
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line:
|
||||||
|
lines.append("")
|
||||||
|
continue
|
||||||
|
|
||||||
|
lower = line.lower()
|
||||||
|
|
||||||
|
if line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "do you have any questions" in lower:
|
||||||
|
continue
|
||||||
|
if "feel free to ask" in lower:
|
||||||
|
continue
|
||||||
|
if "references" in lower or "sources" in lower or "wikipedia" in lower:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
|
cleaned = "\n".join(lines)
|
||||||
|
|
||||||
|
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", cleaned) if p.strip()]
|
||||||
|
seen = set()
|
||||||
|
unique_paragraphs: List[str] = []
|
||||||
|
for para in paragraphs:
|
||||||
|
if para in seen:
|
||||||
|
continue
|
||||||
|
seen.add(para)
|
||||||
|
unique_paragraphs.append(para)
|
||||||
|
|
||||||
|
return "\n\n".join(unique_paragraphs)
|
||||||
|
try:
|
||||||
|
results = search_similar_documents(
|
||||||
|
query=query,
|
||||||
|
role_uuid=role_uuid,
|
||||||
|
top_k=top_k,
|
||||||
|
similarity_threshold=similarity_threshold,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to retrieve context: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context_parts = []
|
||||||
|
for doc, similarity in results:
|
||||||
|
cleaned = _clean_chunk_text(doc.content)
|
||||||
|
if not cleaned:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source = "unknown"
|
||||||
|
if doc.training_file:
|
||||||
|
source = doc.training_file.file_name
|
||||||
|
|
||||||
|
context_parts.append(
|
||||||
|
f"[Source: {source}, Similarity: {similarity:.2%}]\n{cleaned}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
context = "\n---\n".join(context_parts)
|
||||||
|
return context
|
||||||
|
|
@ -1,16 +1,128 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from hashlib import sha256
|
||||||
from asgiref.sync import async_to_sync
|
from asgiref.sync import async_to_sync
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from channels.layers import get_channel_layer
|
from channels.layers import get_channel_layer
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from django.db import transaction
|
||||||
|
|
||||||
from apps.orgs.models import TrainingFile
|
from apps.orgs.models import TrainingFile, Role
|
||||||
from . import services
|
from . import services
|
||||||
from .models import Agent, AgentEvent, AgentModel, AgentRun
|
from .models import Agent, AgentEvent, AgentModel, AgentRun, RoleRagDocument
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mem_info() -> str:
|
||||||
|
try:
|
||||||
|
with open('/proc/self/status', 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
mem = {line.split(':', 1)[0]: line.split(':', 1)[1].strip() for line in lines if ':' in line}
|
||||||
|
return f"VmRSS={mem.get('VmRSS','?')}, VmHWM={mem.get('VmHWM','?')}, VmSize={mem.get('VmSize','?')}"
|
||||||
|
except Exception:
|
||||||
|
return "mem_info_unavailable"
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str) -> int:
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
return len(re.findall(r"\w+|[^\s\w]", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _split_semantic_units(text: str) -> list[str]:
|
||||||
|
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
|
||||||
|
units: list[str] = []
|
||||||
|
for para in paragraphs:
|
||||||
|
sentences = re.split(r"(?<=[.!?])\s+", para)
|
||||||
|
for sent in sentences:
|
||||||
|
sent = sent.strip()
|
||||||
|
if sent:
|
||||||
|
units.append(sent)
|
||||||
|
return units or paragraphs
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_text(text: str, max_tokens: int = 400, overlap_tokens: int = 50) -> list[str]:
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
units = _split_semantic_units(text)
|
||||||
|
logger.info(
|
||||||
|
"Semantic chunking units=%s max_tokens=%s overlap_tokens=%s mem=%s",
|
||||||
|
len(units),
|
||||||
|
max_tokens,
|
||||||
|
overlap_tokens,
|
||||||
|
_get_mem_info(),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
current: list[str] = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for unit in units:
|
||||||
|
unit_tokens = _estimate_tokens(unit)
|
||||||
|
if unit_tokens == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_tokens + unit_tokens > max_tokens and current:
|
||||||
|
chunk = " ".join(current).strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
if overlap_tokens > 0:
|
||||||
|
overlap: list[str] = []
|
||||||
|
overlap_count = 0
|
||||||
|
for prev in reversed(current):
|
||||||
|
prev_tokens = _estimate_tokens(prev)
|
||||||
|
if overlap_count + prev_tokens > overlap_tokens:
|
||||||
|
break
|
||||||
|
overlap.insert(0, prev)
|
||||||
|
overlap_count += prev_tokens
|
||||||
|
current = overlap
|
||||||
|
current_tokens = overlap_count
|
||||||
|
else:
|
||||||
|
current = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
current.append(unit)
|
||||||
|
current_tokens += unit_tokens
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunk = " ".join(current).strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_file(file_path: str, file_type: str | None) -> str:
|
||||||
|
file_type = (file_type or '').lower()
|
||||||
|
if file_type in {'txt', 'md', 'csv', 'json'}:
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
if file_type == 'pdf':
|
||||||
|
try:
|
||||||
|
from PyPDF2 import PdfReader
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError('PyPDF2 is required to parse PDF files') from e
|
||||||
|
reader = PdfReader(file_path)
|
||||||
|
return "\n".join(page.extract_text() or "" for page in reader.pages)
|
||||||
|
|
||||||
|
if file_type in {'docx', 'doc'}:
|
||||||
|
try:
|
||||||
|
import docx
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError('python-docx is required to parse DOCX files') from e
|
||||||
|
doc = docx.Document(file_path)
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs)
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
||||||
channel_layer = get_channel_layer()
|
channel_layer = get_channel_layer()
|
||||||
async_to_sync(channel_layer.group_send)(
|
async_to_sync(channel_layer.group_send)(
|
||||||
|
|
@ -68,11 +180,21 @@ def start_fine_tune_run_task(execution_id: str):
|
||||||
|
|
||||||
training_files = input_data.get("training_files") or []
|
training_files = input_data.get("training_files") or []
|
||||||
org_training_files = []
|
org_training_files = []
|
||||||
|
role_uuid = input_data.get("role_uuid")
|
||||||
if not training_files and agent.organization:
|
if not training_files and agent.organization:
|
||||||
org_training_files = list(TrainingFile.objects.filter(
|
training_files_qs = TrainingFile.objects.filter(
|
||||||
organization=agent.organization,
|
role__organization=agent.organization,
|
||||||
is_processed=False
|
is_processed=False
|
||||||
).select_related('uploaded_by'))
|
).select_related('uploaded_by', 'role')
|
||||||
|
|
||||||
|
if role_uuid:
|
||||||
|
try:
|
||||||
|
role = Role.objects.get(uuid=role_uuid, organization=agent.organization)
|
||||||
|
training_files_qs = training_files_qs.filter(role=role)
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
logger.warning(f"Role {role_uuid} not found for organization {agent.organization.name}")
|
||||||
|
|
||||||
|
org_training_files = list(training_files_qs)
|
||||||
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
||||||
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
||||||
|
|
||||||
|
|
@ -186,6 +308,197 @@ def start_fine_tune_run_task(execution_id: str):
|
||||||
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def ingest_training_file_task(training_file_uuid: str):
|
||||||
|
logger.info(f"Ingest task started for training_file_uuid={training_file_uuid}")
|
||||||
|
started_at = time.time()
|
||||||
|
try:
|
||||||
|
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
|
||||||
|
except TrainingFile.DoesNotExist:
|
||||||
|
logger.error(f"Training file not found: {training_file_uuid}")
|
||||||
|
return {"status": "error", "error": "training_file_not_found"}
|
||||||
|
|
||||||
|
if training_file.is_processed:
|
||||||
|
logger.info(f"Training file already processed: {training_file_uuid}")
|
||||||
|
return {"status": "skipped", "reason": "already_processed"}
|
||||||
|
|
||||||
|
if not training_file.file:
|
||||||
|
logger.error(f"Training file has no file attached: {training_file_uuid}")
|
||||||
|
return {"status": "error", "error": "file_missing"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_path = training_file.file.path
|
||||||
|
file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0
|
||||||
|
logger.info(
|
||||||
|
"Ingesting file: name=%s type=%s size_bytes=%s role=%s path=%s",
|
||||||
|
training_file.file_name,
|
||||||
|
training_file.file_type,
|
||||||
|
file_size,
|
||||||
|
training_file.role_id,
|
||||||
|
file_path,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to stat training file for {training_file_uuid}: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_file.status = 'ingesting'
|
||||||
|
training_file.save(update_fields=['status'])
|
||||||
|
|
||||||
|
extract_started = time.time()
|
||||||
|
text = _extract_text_from_file(training_file.file.path, training_file.file_type)
|
||||||
|
logger.info(
|
||||||
|
"Extracted text length=%s for training_file_uuid=%s in %.2fs mem=%s",
|
||||||
|
len(text),
|
||||||
|
training_file_uuid,
|
||||||
|
time.time() - extract_started,
|
||||||
|
_get_mem_info(),
|
||||||
|
)
|
||||||
|
chunk_started = time.time()
|
||||||
|
chunks = _chunk_text(text)
|
||||||
|
logger.info(
|
||||||
|
"Chunked text into %s chunks in %.2fs (sample lengths: %s) mem=%s",
|
||||||
|
len(chunks),
|
||||||
|
time.time() - chunk_started,
|
||||||
|
[len(c) for c in chunks[:5]],
|
||||||
|
_get_mem_info(),
|
||||||
|
)
|
||||||
|
if not chunks:
|
||||||
|
raise RuntimeError("No text extracted from file")
|
||||||
|
|
||||||
|
with transaction.atomic():
|
||||||
|
logger.info("Clearing existing RAG docs for training_file_uuid=%s mem=%s", training_file_uuid, _get_mem_info())
|
||||||
|
RoleRagDocument.objects.filter(training_file=training_file).delete()
|
||||||
|
logger.info("Preparing %s RAG docs for bulk_create mem=%s", len(chunks), _get_mem_info())
|
||||||
|
existing_hashes = set(
|
||||||
|
RoleRagDocument.objects.filter(role=training_file.role)
|
||||||
|
.values_list('content_hash', flat=True)
|
||||||
|
)
|
||||||
|
documents = []
|
||||||
|
skipped = 0
|
||||||
|
for index, chunk in enumerate(chunks):
|
||||||
|
content_hash = sha256(chunk.encode('utf-8')).hexdigest()
|
||||||
|
if content_hash in existing_hashes:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
documents.append(
|
||||||
|
RoleRagDocument(
|
||||||
|
role=training_file.role,
|
||||||
|
training_file=training_file,
|
||||||
|
content=chunk,
|
||||||
|
embedding=None,
|
||||||
|
content_hash=content_hash,
|
||||||
|
metadata={
|
||||||
|
"file_name": training_file.file_name,
|
||||||
|
"file_type": training_file.file_type,
|
||||||
|
"chunk_size": len(chunk),
|
||||||
|
"source": "training_file",
|
||||||
|
},
|
||||||
|
chunk_index=index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("Skipped %s duplicate chunks based on content_hash", skipped)
|
||||||
|
logger.info("Bulk creating RAG docs count=%s mem=%s", len(documents), _get_mem_info())
|
||||||
|
RoleRagDocument.objects.bulk_create(documents, batch_size=500)
|
||||||
|
training_file.status = 'chunked'
|
||||||
|
training_file.is_processed = True
|
||||||
|
training_file.save(update_fields=['status', 'is_processed'])
|
||||||
|
|
||||||
|
elapsed = time.time() - started_at
|
||||||
|
logger.info(
|
||||||
|
"Ingested training file %s into %s RAG chunks in %.2fs",
|
||||||
|
training_file_uuid,
|
||||||
|
len(chunks),
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Enqueueing embedding task for training_file_uuid={training_file_uuid}")
|
||||||
|
embed_training_file_task.delay(training_file_uuid)
|
||||||
|
|
||||||
|
return {"status": "completed", "chunks": len(chunks)}
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time.time() - started_at
|
||||||
|
logger.error(f"Failed to ingest training file {training_file_uuid}: {str(e)}", exc_info=True)
|
||||||
|
logger.error(f"Ingest task failed after {elapsed:.2f}s for training_file_uuid={training_file_uuid}")
|
||||||
|
try:
|
||||||
|
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def embed_training_file_task(training_file_uuid: str):
|
||||||
|
"""Generate embeddings for all documents of a training file.
|
||||||
|
|
||||||
|
This task is called after chunking to embed the document chunks
|
||||||
|
using the configured embedding provider (OpenAI, Google, or local).
|
||||||
|
"""
|
||||||
|
logger.info(f"Embedding task started for training_file_uuid={training_file_uuid}")
|
||||||
|
started_at = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
|
||||||
|
except TrainingFile.DoesNotExist:
|
||||||
|
logger.error(f"Training file not found: {training_file_uuid}")
|
||||||
|
return {"status": "error", "error": "training_file_not_found"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = list(RoleRagDocument.objects.filter(training_file=training_file))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
logger.warning(f"No RAG documents found for training_file_uuid={training_file_uuid}")
|
||||||
|
return {"status": "skipped", "reason": "no_documents"}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Starting to embed {len(documents)} documents for training_file_uuid={training_file_uuid} "
|
||||||
|
f"mem={_get_mem_info()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_embedded, num_failed = services.batch_embed_documents(documents, batch_size=32)
|
||||||
|
|
||||||
|
if num_failed == 0:
|
||||||
|
training_file.status = 'embedded'
|
||||||
|
training_file.save(update_fields=['status'])
|
||||||
|
logger.info(f"Successfully embedded all documents for training_file_uuid={training_file_uuid}")
|
||||||
|
elif num_embedded > 0:
|
||||||
|
training_file.status = 'embedded'
|
||||||
|
training_file.save(update_fields=['status'])
|
||||||
|
logger.warning(
|
||||||
|
f"Partially embedded {num_embedded} documents, {num_failed} failed "
|
||||||
|
f"for training_file_uuid={training_file_uuid}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_file.status = 'failed'
|
||||||
|
training_file.save(update_fields=['status'])
|
||||||
|
logger.error(f"Failed to embed any documents for training_file_uuid={training_file_uuid}")
|
||||||
|
return {"status": "error", "error": "embedding_failed", "num_failed": num_failed}
|
||||||
|
|
||||||
|
elapsed = time.time() - started_at
|
||||||
|
logger.info(
|
||||||
|
f"Embedding task completed for {training_file_uuid}: "
|
||||||
|
f"embedded={num_embedded}, failed={num_failed}, time={elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"num_embedded": num_embedded,
|
||||||
|
"num_failed": num_failed,
|
||||||
|
"elapsed": elapsed,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time.time() - started_at
|
||||||
|
logger.error(
|
||||||
|
f"Failed to embed training file {training_file_uuid}: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"status": "error", "error": str(e), "elapsed": elapsed}
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def infer_run_task(execution_id: str):
|
def infer_run_task(execution_id: str):
|
||||||
logger.info(f"Inference run task started for execution: {execution_id}")
|
logger.info(f"Inference run task started for execution: {execution_id}")
|
||||||
|
|
@ -207,9 +520,70 @@ def infer_run_task(execution_id: str):
|
||||||
|
|
||||||
input_data = execution.input_data or {}
|
input_data = execution.input_data or {}
|
||||||
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
||||||
options = input_data.get("options") or {}
|
options = dict(input_data.get("options") or {})
|
||||||
|
role_uuid = input_data.get("role_uuid") or options.get("role_uuid")
|
||||||
|
rag_top_k = int(input_data.get("rag_top_k", 5))
|
||||||
|
rag_similarity_threshold = float(input_data.get("rag_similarity_threshold", 0.5))
|
||||||
|
|
||||||
|
options.setdefault("temperature", 0.2)
|
||||||
|
options.setdefault("top_p", 0.9)
|
||||||
|
options.setdefault("max_tokens", 200)
|
||||||
|
options.setdefault("stop", ["\n\n", "References:", "Sources:"])
|
||||||
|
|
||||||
logger.info(f"Prompt length: {len(prompt)} characters")
|
logger.info(f"Prompt length: {len(prompt)} characters")
|
||||||
|
|
||||||
|
if not role_uuid:
|
||||||
|
logger.warning(f"No role_uuid provided for inference run {execution_id}")
|
||||||
|
execution.status = "failed"
|
||||||
|
execution.error_message = "role_uuid_required"
|
||||||
|
execution.completed_at = timezone.now()
|
||||||
|
execution.save()
|
||||||
|
_update_agent_status(agent, "failed")
|
||||||
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
|
||||||
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
|
||||||
|
async_to_sync(get_channel_layer().group_send)(
|
||||||
|
room_group_name,
|
||||||
|
{
|
||||||
|
"type": "mlstore_error",
|
||||||
|
"execution_id": str(execution.uuid),
|
||||||
|
"error_message": "role_uuid_required",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return {"status": "failed", "execution_id": execution_id, "error": "role_uuid_required"}
|
||||||
|
|
||||||
|
if role_uuid and prompt:
|
||||||
|
try:
|
||||||
|
context = services.get_context_for_query(
|
||||||
|
query=prompt,
|
||||||
|
role_uuid=str(role_uuid),
|
||||||
|
top_k=rag_top_k,
|
||||||
|
similarity_threshold=rag_similarity_threshold,
|
||||||
|
)
|
||||||
|
if context:
|
||||||
|
logger.info(f"RAG context retrieved for role={role_uuid} (top_k={rag_top_k})")
|
||||||
|
prompt = (
|
||||||
|
"You are a technical assistant.\n\n"
|
||||||
|
"Answer the question using ONLY the information in the context.\n"
|
||||||
|
"Do NOT:\n"
|
||||||
|
"- ask follow-up questions\n"
|
||||||
|
"- include hashtags\n"
|
||||||
|
"- include references or sources\n"
|
||||||
|
"- repeat the question\n"
|
||||||
|
"- add headings or sections\n"
|
||||||
|
"- add information not present in the context\n\n"
|
||||||
|
"Answer in 3-6 concise sentences.\n"
|
||||||
|
"If the context is insufficient, say: \"The context does not provide enough information.\"\n\n"
|
||||||
|
"Context:\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
"Question:\n"
|
||||||
|
f"{prompt}\n\n"
|
||||||
|
"Answer:"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"No RAG context found for role={role_uuid}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"RAG context retrieval failed for role={role_uuid}: {e}")
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning(f"No prompt provided for inference run {execution_id}")
|
logger.warning(f"No prompt provided for inference run {execution_id}")
|
||||||
execution.status = "failed"
|
execution.status = "failed"
|
||||||
|
|
|
||||||
0
apps/mlstore/tests/__init__.py
Normal file
0
apps/mlstore/tests/__init__.py
Normal file
91
apps/mlstore/tests/test_api.py
Normal file
91
apps/mlstore/tests/test_api.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
from unittest.mock import patch
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.test import TestCase
|
||||||
|
from rest_framework.test import APIRequestFactory, force_authenticate
|
||||||
|
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
from apps.orgs.models import Organization, Role
|
||||||
|
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
|
||||||
|
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
|
class MLStoreAPITests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = APIRequestFactory()
|
||||||
|
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
|
||||||
|
self.other = User.objects.create_user(email_address='other@example.com', password='pass')
|
||||||
|
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
|
||||||
|
self.org = Organization.objects.create(name='Org', owner=self.manager)
|
||||||
|
self.role = Role.objects.create(name='Engineer', organization=self.org)
|
||||||
|
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
|
||||||
|
self.agent = Agent.objects.create(model=self.model, organization=self.org)
|
||||||
|
|
||||||
|
def test_agents_list_requires_auth(self):
|
||||||
|
view = AgentViewSet.as_view({'get': 'list'})
|
||||||
|
request = self.factory.get('/')
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
def test_agents_list_authenticated(self):
|
||||||
|
view = AgentViewSet.as_view({'get': 'list'})
|
||||||
|
request = self.factory.get('/')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
|
||||||
|
def test_agent_runs_scoped_to_user(self):
|
||||||
|
AgentRun.objects.create(agent=self.agent, user=self.user)
|
||||||
|
AgentRun.objects.create(agent=self.agent, user=self.other)
|
||||||
|
view = AgentRunViewSet.as_view({'get': 'list'})
|
||||||
|
request = self.factory.get('/')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
|
||||||
|
def test_agent_run_events(self):
|
||||||
|
run = AgentRun.objects.create(agent=self.agent, user=self.user)
|
||||||
|
AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
|
||||||
|
view = AgentRunViewSet.as_view({'get': 'events'})
|
||||||
|
request = self.factory.get('/')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request, uuid=str(run.uuid))
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
|
||||||
|
def test_retrieve_context_missing_params(self):
|
||||||
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||||||
|
request = self.factory.post('/', {})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
def test_retrieve_context_role_not_found(self):
|
||||||
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||||||
|
request = self.factory.post('/', {'query': 'q', 'role_uuid': '00000000-0000-0000-0000-000000000000'})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
@patch('apps.mlstore.viewsets.services.search_similar_documents')
|
||||||
|
@patch('apps.mlstore.viewsets.services.get_context_for_query')
|
||||||
|
def test_retrieve_context_success(self, mock_context, mock_search):
|
||||||
|
doc = RoleRagDocument.objects.create(
|
||||||
|
role=self.role,
|
||||||
|
content='chunk',
|
||||||
|
content_hash='hash',
|
||||||
|
chunk_index=0,
|
||||||
|
)
|
||||||
|
mock_search.return_value = [(doc, 0.9)]
|
||||||
|
mock_context.return_value = 'context text'
|
||||||
|
|
||||||
|
view = AgentRunViewSet.as_view({'post': 'retrieve_context'})
|
||||||
|
payload = {'query': 'hello', 'role_uuid': str(self.role.uuid)}
|
||||||
|
request = self.factory.post('/', payload, format='json')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data.get('num_results'), 1)
|
||||||
|
self.assertEqual(response.data.get('context'), 'context text')
|
||||||
41
apps/mlstore/tests/test_models.py
Normal file
41
apps/mlstore/tests/test_models.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from apps.orgs.models import Organization, Role
|
||||||
|
from apps.mlstore.models import AgentModel, Agent, AgentRun, AgentEvent, RoleRagDocument
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
|
class MLStoreModelTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
|
||||||
|
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
|
||||||
|
self.org = Organization.objects.create(name='Org', owner=self.manager)
|
||||||
|
self.role = Role.objects.create(name='Engineer', organization=self.org)
|
||||||
|
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
|
||||||
|
self.agent = Agent.objects.create(model=self.model, organization=self.org)
|
||||||
|
|
||||||
|
def test_agent_model_str(self):
|
||||||
|
self.assertEqual(str(self.model), 'test-model')
|
||||||
|
|
||||||
|
def test_agent_str(self):
|
||||||
|
self.assertIn(self.model.name, str(self.agent))
|
||||||
|
|
||||||
|
def test_agent_run_str(self):
|
||||||
|
run = AgentRun.objects.create(agent=self.agent, user=self.user)
|
||||||
|
self.assertIn(str(run.uuid), str(run))
|
||||||
|
self.assertIn(str(self.agent), str(run))
|
||||||
|
|
||||||
|
def test_agent_event_str(self):
|
||||||
|
run = AgentRun.objects.create(agent=self.agent, user=self.user)
|
||||||
|
evt = AgentEvent.objects.create(execution=run, event_type='message', content={'msg': 'hi'})
|
||||||
|
self.assertIn('message', str(evt))
|
||||||
|
|
||||||
|
def test_role_rag_document_str(self):
|
||||||
|
doc = RoleRagDocument.objects.create(
|
||||||
|
role=self.role,
|
||||||
|
content='chunk',
|
||||||
|
content_hash='hash',
|
||||||
|
chunk_index=0,
|
||||||
|
)
|
||||||
|
self.assertIn(self.role.name, str(doc))
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
from .models import Agent, AgentRun, AgentEvent
|
|
||||||
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
|
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
from rest_framework import status
|
||||||
|
from .models import Agent, AgentRun, AgentEvent
|
||||||
|
from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer
|
||||||
|
from . import services
|
||||||
|
from apps.orgs.models import Role
|
||||||
|
|
||||||
|
|
||||||
class AgentViewSet(ModelViewSet):
|
class AgentViewSet(ModelViewSet):
|
||||||
queryset = Agent.objects.all()
|
queryset = Agent.objects.all()
|
||||||
|
|
@ -27,3 +31,109 @@ class AgentRunViewSet(ModelViewSet):
|
||||||
events = AgentEvent.objects.filter(execution=run).order_by('timestamp')
|
events = AgentEvent.objects.filter(execution=run).order_by('timestamp')
|
||||||
serializer = AgentEventSerializer(events, many=True)
|
serializer = AgentEventSerializer(events, many=True)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
@action(detail=False, methods=['post'], url_path='retrieve-context')
|
||||||
|
def retrieve_context(self, request):
|
||||||
|
"""Retrieve context documents from RAG using semantic similarity.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"query": "search query text",
|
||||||
|
"role_uuid": "role-uuid",
|
||||||
|
"top_k": 5, # optional, default 5
|
||||||
|
"similarity_threshold": 0.5 # optional, default 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"query": "search query text",
|
||||||
|
"context": "formatted context string with sources",
|
||||||
|
"documents": [
|
||||||
|
{
|
||||||
|
"id": 123,
|
||||||
|
"content": "chunk text",
|
||||||
|
"similarity": 0.87,
|
||||||
|
"source": "filename.pdf",
|
||||||
|
"chunk_index": 0
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
query = request.data.get('query', '').strip()
|
||||||
|
role_uuid = request.data.get('role_uuid', '').strip()
|
||||||
|
top_k = request.data.get('top_k', 5)
|
||||||
|
similarity_threshold = request.data.get('similarity_threshold', 0.5)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return Response(
|
||||||
|
{"error": "query is required"},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
if not role_uuid:
|
||||||
|
return Response(
|
||||||
|
{"error": "role_uuid is required"},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate role exists and user has access
|
||||||
|
role = Role.objects.get(uuid=role_uuid)
|
||||||
|
# You can add additional permission checks here if needed
|
||||||
|
|
||||||
|
# Search for similar documents
|
||||||
|
results = services.search_similar_documents(
|
||||||
|
query=query,
|
||||||
|
role_uuid=role_uuid,
|
||||||
|
top_k=top_k,
|
||||||
|
similarity_threshold=similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format response
|
||||||
|
documents = []
|
||||||
|
for doc, similarity in results:
|
||||||
|
documents.append({
|
||||||
|
"id": doc.id,
|
||||||
|
"uuid": str(doc.uuid),
|
||||||
|
"content": doc.content,
|
||||||
|
"similarity": float(similarity),
|
||||||
|
"source": doc.training_file.file_name if doc.training_file else "unknown",
|
||||||
|
"chunk_index": doc.chunk_index,
|
||||||
|
"metadata": doc.metadata,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get formatted context string
|
||||||
|
context = services.get_context_for_query(
|
||||||
|
query=query,
|
||||||
|
role_uuid=role_uuid,
|
||||||
|
top_k=top_k,
|
||||||
|
similarity_threshold=similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
"query": query,
|
||||||
|
"role_uuid": role_uuid,
|
||||||
|
"num_results": len(documents),
|
||||||
|
"context": context,
|
||||||
|
"documents": documents,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
return Response(
|
||||||
|
{"error": f"Role with UUID {role_uuid} not found"},
|
||||||
|
status=status.HTTP_404_NOT_FOUND
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return Response(
|
||||||
|
{"error": str(e)},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
logging.exception("Error retrieving context")
|
||||||
|
return Response(
|
||||||
|
{"error": "Failed to retrieve context", "detail": str(e)},
|
||||||
|
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
|
|
||||||
0
apps/onboarding/__init__.py
Normal file
0
apps/onboarding/__init__.py
Normal file
47
apps/onboarding/admin.py
Normal file
47
apps/onboarding/admin.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
from django.contrib import admin
|
||||||
|
from django.contrib.admin import ModelAdmin, TabularInline
|
||||||
|
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingPageInline(TabularInline):
|
||||||
|
model = OnboardingPage
|
||||||
|
extra = 0
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFieldInline(TabularInline):
|
||||||
|
model = OnboardingField
|
||||||
|
extra = 0
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(OnboardingFlow)
|
||||||
|
class OnboardingFlowAdmin(ModelAdmin):
|
||||||
|
list_display = ('id', 'uuid', 'title', 'role', 'status')
|
||||||
|
search_fields = ('title', 'role__name')
|
||||||
|
list_filter = ('status',)
|
||||||
|
inlines = (OnboardingPageInline,)
|
||||||
|
readonly_fields = ('uuid',)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(OnboardingPage)
|
||||||
|
class OnboardingPageAdmin(ModelAdmin):
|
||||||
|
list_display = ('id', 'uuid', 'title', 'flow', 'order')
|
||||||
|
search_fields = ('title', 'flow__title')
|
||||||
|
list_filter = ('flow',)
|
||||||
|
inlines = (OnboardingFieldInline,)
|
||||||
|
readonly_fields = ('uuid',)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(OnboardingField)
|
||||||
|
class OnboardingFieldAdmin(ModelAdmin):
|
||||||
|
list_display = ('id', 'uuid', 'label', 'page', 'field_type', 'required')
|
||||||
|
search_fields = ('label', 'page__title')
|
||||||
|
list_filter = ('field_type',)
|
||||||
|
readonly_fields = ('uuid',)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(OnboardingSession)
|
||||||
|
class OnboardingSessionAdmin(ModelAdmin):
|
||||||
|
list_display = ('id', 'uuid', 'flow', 'user', 'status', 'current_page_order')
|
||||||
|
search_fields = ('flow__title', 'user__email_address')
|
||||||
|
list_filter = ('status',)
|
||||||
|
readonly_fields = ('uuid',)
|
||||||
6
apps/onboarding/apps.py
Normal file
6
apps/onboarding/apps.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingConfig(AppConfig):
|
||||||
|
default_auto_field = 'django.db.models.BigAutoField'
|
||||||
|
name = 'apps.onboarding'
|
||||||
100
apps/onboarding/migrations/0001_initial.py
Normal file
100
apps/onboarding/migrations/0001_initial.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
initial = True
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('orgs', '0001_initial'),
|
||||||
|
('mlstore', '0001_initial'),
|
||||||
|
('users', '0001_initial'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='OnboardingFlow',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
|
||||||
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
|
('title', models.CharField(max_length=255)),
|
||||||
|
('description', models.TextField(blank=True, default='')),
|
||||||
|
('status', models.CharField(choices=[('draft', 'Draft'), ('published', 'Published'), ('archived', 'Archived')], default='draft', max_length=20)),
|
||||||
|
('agent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_flows', to='mlstore.agent')),
|
||||||
|
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_flows', to='orgs.role')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Onboarding Flow',
|
||||||
|
'verbose_name_plural': 'Onboarding Flows',
|
||||||
|
'ordering': ['-created_at'],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='OnboardingPage',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
|
||||||
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
|
('order', models.IntegerField(default=0)),
|
||||||
|
('title', models.CharField(max_length=255)),
|
||||||
|
('body', models.TextField(blank=True, default='')),
|
||||||
|
('meta', models.JSONField(blank=True, default=dict)),
|
||||||
|
('flow', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='pages', to='onboarding.onboardingflow')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Onboarding Page',
|
||||||
|
'verbose_name_plural': 'Onboarding Pages',
|
||||||
|
'ordering': ['order', 'created_at'],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='OnboardingField',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
|
||||||
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
|
('order', models.IntegerField(default=0)),
|
||||||
|
('key', models.CharField(max_length=120)),
|
||||||
|
('label', models.CharField(max_length=255)),
|
||||||
|
('field_type', models.CharField(choices=[('text', 'Text'), ('textarea', 'Textarea'), ('select', 'Select'), ('multiselect', 'Multi Select'), ('number', 'Number'), ('boolean', 'Boolean'), ('date', 'Date')], default='text', max_length=30)),
|
||||||
|
('required', models.BooleanField(default=False)),
|
||||||
|
('help_text', models.TextField(blank=True, default='')),
|
||||||
|
('placeholder', models.CharField(blank=True, default='', max_length=255)),
|
||||||
|
('options', models.JSONField(blank=True, default=list)),
|
||||||
|
('default_value', models.JSONField(blank=True, null=True, default=None)),
|
||||||
|
('validation', models.JSONField(blank=True, default=dict)),
|
||||||
|
('page', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='fields', to='onboarding.onboardingpage')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Onboarding Field',
|
||||||
|
'verbose_name_plural': 'Onboarding Fields',
|
||||||
|
'ordering': ['order', 'created_at'],
|
||||||
|
'unique_together': {('page', 'key')},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='OnboardingSession',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='Created At')),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')),
|
||||||
|
('uuid', models.UUIDField(default=uuid.uuid4, editable=False, unique=True)),
|
||||||
|
('status', models.CharField(choices=[('in_progress', 'In Progress'), ('completed', 'Completed'), ('abandoned', 'Abandoned')], default='in_progress', max_length=20)),
|
||||||
|
('current_page_order', models.IntegerField(default=0)),
|
||||||
|
('responses', models.JSONField(blank=True, default=dict)),
|
||||||
|
('completed_at', models.DateTimeField(blank=True, null=True)),
|
||||||
|
('agent_run', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_sessions', to='mlstore.agentrun')),
|
||||||
|
('flow', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sessions', to='onboarding.onboardingflow')),
|
||||||
|
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='onboarding_sessions', to='users.user')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'verbose_name': 'Onboarding Session',
|
||||||
|
'verbose_name_plural': 'Onboarding Sessions',
|
||||||
|
'ordering': ['-created_at'],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
0
apps/onboarding/migrations/__init__.py
Normal file
0
apps/onboarding/migrations/__init__.py
Normal file
121
apps/onboarding/models.py
Normal file
121
apps/onboarding/models.py
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
from uuid import uuid4
|
||||||
|
from django.db.models import (
|
||||||
|
BigAutoField,
|
||||||
|
BooleanField,
|
||||||
|
CASCADE,
|
||||||
|
CharField,
|
||||||
|
DateTimeField,
|
||||||
|
ForeignKey,
|
||||||
|
IntegerField,
|
||||||
|
JSONField,
|
||||||
|
Model,
|
||||||
|
TextField,
|
||||||
|
UUIDField,
|
||||||
|
)
|
||||||
|
from apps.users.mixins import TimeStampMixin
|
||||||
|
from apps.users.models import User
|
||||||
|
from apps.orgs.models import Role
|
||||||
|
from apps.mlstore.models import Agent, AgentRun
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFlow(TimeStampMixin, Model):
|
||||||
|
STATUS_CHOICES = [
|
||||||
|
('draft', 'Draft'),
|
||||||
|
('published', 'Published'),
|
||||||
|
('archived', 'Archived'),
|
||||||
|
]
|
||||||
|
|
||||||
|
id = BigAutoField(primary_key=True)
|
||||||
|
uuid = UUIDField(default=uuid4, editable=False, unique=True)
|
||||||
|
role = ForeignKey(Role, on_delete=CASCADE, related_name='onboarding_flows')
|
||||||
|
agent = ForeignKey(Agent, on_delete=CASCADE, related_name='onboarding_flows', null=True, blank=True)
|
||||||
|
title = CharField(max_length=255)
|
||||||
|
description = TextField(blank=True, default='')
|
||||||
|
status = CharField(max_length=20, choices=STATUS_CHOICES, default='draft')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = 'Onboarding Flow'
|
||||||
|
verbose_name_plural = 'Onboarding Flows'
|
||||||
|
ordering = ['-created_at']
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'{self.title} ({self.role.name})'
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingPage(TimeStampMixin, Model):
|
||||||
|
id = BigAutoField(primary_key=True)
|
||||||
|
uuid = UUIDField(default=uuid4, editable=False, unique=True)
|
||||||
|
flow = ForeignKey(OnboardingFlow, on_delete=CASCADE, related_name='pages')
|
||||||
|
order = IntegerField(default=0)
|
||||||
|
title = CharField(max_length=255)
|
||||||
|
body = TextField(blank=True, default='')
|
||||||
|
meta = JSONField(default=dict, blank=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = 'Onboarding Page'
|
||||||
|
verbose_name_plural = 'Onboarding Pages'
|
||||||
|
ordering = ['order', 'created_at']
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'{self.flow.title} - {self.title}'
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingField(TimeStampMixin, Model):
|
||||||
|
FIELD_TYPES = [
|
||||||
|
('text', 'Text'),
|
||||||
|
('textarea', 'Textarea'),
|
||||||
|
('select', 'Select'),
|
||||||
|
('multiselect', 'Multi Select'),
|
||||||
|
('number', 'Number'),
|
||||||
|
('boolean', 'Boolean'),
|
||||||
|
('date', 'Date'),
|
||||||
|
]
|
||||||
|
|
||||||
|
id = BigAutoField(primary_key=True)
|
||||||
|
uuid = UUIDField(default=uuid4, editable=False, unique=True)
|
||||||
|
page = ForeignKey(OnboardingPage, on_delete=CASCADE, related_name='fields')
|
||||||
|
order = IntegerField(default=0)
|
||||||
|
key = CharField(max_length=120)
|
||||||
|
label = CharField(max_length=255)
|
||||||
|
field_type = CharField(max_length=30, choices=FIELD_TYPES, default='text')
|
||||||
|
required = BooleanField(default=False)
|
||||||
|
help_text = TextField(blank=True, default='')
|
||||||
|
placeholder = CharField(max_length=255, blank=True, default='')
|
||||||
|
options = JSONField(default=list, blank=True)
|
||||||
|
default_value = JSONField(null=True, blank=True, default=None)
|
||||||
|
validation = JSONField(default=dict, blank=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = 'Onboarding Field'
|
||||||
|
verbose_name_plural = 'Onboarding Fields'
|
||||||
|
ordering = ['order', 'created_at']
|
||||||
|
unique_together = [['page', 'key']]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'{self.page.title} - {self.label}'
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingSession(TimeStampMixin, Model):
|
||||||
|
STATUS_CHOICES = [
|
||||||
|
('in_progress', 'In Progress'),
|
||||||
|
('completed', 'Completed'),
|
||||||
|
('abandoned', 'Abandoned'),
|
||||||
|
]
|
||||||
|
|
||||||
|
id = BigAutoField(primary_key=True)
|
||||||
|
uuid = UUIDField(default=uuid4, editable=False, unique=True)
|
||||||
|
flow = ForeignKey(OnboardingFlow, on_delete=CASCADE, related_name='sessions')
|
||||||
|
user = ForeignKey(User, on_delete=CASCADE, related_name='onboarding_sessions')
|
||||||
|
agent_run = ForeignKey(AgentRun, on_delete=CASCADE, related_name='onboarding_sessions', null=True, blank=True)
|
||||||
|
status = CharField(max_length=20, choices=STATUS_CHOICES, default='in_progress')
|
||||||
|
current_page_order = IntegerField(default=0)
|
||||||
|
responses = JSONField(default=dict, blank=True)
|
||||||
|
completed_at = DateTimeField(null=True, blank=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
verbose_name = 'Onboarding Session'
|
||||||
|
verbose_name_plural = 'Onboarding Sessions'
|
||||||
|
ordering = ['-created_at']
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'{self.user.email_address} - {self.flow.title}'
|
||||||
105
apps/onboarding/serializers.py
Normal file
105
apps/onboarding/serializers.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
from rest_framework import serializers
|
||||||
|
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
|
||||||
|
from apps.orgs.models import Role
|
||||||
|
from apps.mlstore.models import Agent
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFieldSerializer(serializers.ModelSerializer):
|
||||||
|
page = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingPage.objects.all())
|
||||||
|
class Meta:
|
||||||
|
model = OnboardingField
|
||||||
|
fields = [
|
||||||
|
'id',
|
||||||
|
'uuid',
|
||||||
|
'page',
|
||||||
|
'order',
|
||||||
|
'key',
|
||||||
|
'label',
|
||||||
|
'field_type',
|
||||||
|
'required',
|
||||||
|
'help_text',
|
||||||
|
'placeholder',
|
||||||
|
'options',
|
||||||
|
'default_value',
|
||||||
|
'validation',
|
||||||
|
]
|
||||||
|
read_only_fields = ['id', 'uuid']
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingPageSerializer(serializers.ModelSerializer):
|
||||||
|
fields = OnboardingFieldSerializer(many=True, read_only=True)
|
||||||
|
flow = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingFlow.objects.all())
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = OnboardingPage
|
||||||
|
fields = [
|
||||||
|
'id',
|
||||||
|
'uuid',
|
||||||
|
'flow',
|
||||||
|
'order',
|
||||||
|
'title',
|
||||||
|
'body',
|
||||||
|
'meta',
|
||||||
|
'fields',
|
||||||
|
]
|
||||||
|
read_only_fields = ['id', 'uuid']
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFlowSerializer(serializers.ModelSerializer):
|
||||||
|
role = serializers.SlugRelatedField(slug_field='uuid', queryset=Role.objects.all())
|
||||||
|
agent = serializers.SlugRelatedField(slug_field='uuid', queryset=Agent.objects.all(), allow_null=True, required=False)
|
||||||
|
class Meta:
|
||||||
|
model = OnboardingFlow
|
||||||
|
fields = [
|
||||||
|
'id',
|
||||||
|
'uuid',
|
||||||
|
'role',
|
||||||
|
'agent',
|
||||||
|
'title',
|
||||||
|
'description',
|
||||||
|
'status',
|
||||||
|
'created_at',
|
||||||
|
'updated_at',
|
||||||
|
]
|
||||||
|
read_only_fields = ['id', 'uuid', 'created_at', 'updated_at']
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFlowDetailSerializer(OnboardingFlowSerializer):
|
||||||
|
pages = OnboardingPageSerializer(many=True, read_only=True)
|
||||||
|
|
||||||
|
class Meta(OnboardingFlowSerializer.Meta):
|
||||||
|
fields = OnboardingFlowSerializer.Meta.fields + ['pages']
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingSessionSerializer(serializers.ModelSerializer):
|
||||||
|
flow = serializers.SlugRelatedField(slug_field='uuid', queryset=OnboardingFlow.objects.all())
|
||||||
|
user = serializers.SlugRelatedField(slug_field='uuid', read_only=True)
|
||||||
|
agent_run = serializers.SlugRelatedField(slug_field='uuid', read_only=True)
|
||||||
|
class Meta:
|
||||||
|
model = OnboardingSession
|
||||||
|
fields = [
|
||||||
|
'id',
|
||||||
|
'uuid',
|
||||||
|
'flow',
|
||||||
|
'user',
|
||||||
|
'agent_run',
|
||||||
|
'status',
|
||||||
|
'current_page_order',
|
||||||
|
'responses',
|
||||||
|
'created_at',
|
||||||
|
'updated_at',
|
||||||
|
'completed_at',
|
||||||
|
]
|
||||||
|
read_only_fields = ['id', 'uuid', 'user', 'agent_run', 'created_at', 'updated_at', 'completed_at']
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingSubmissionSerializer(serializers.Serializer):
|
||||||
|
page_uuid = serializers.CharField()
|
||||||
|
responses = serializers.DictField()
|
||||||
|
mark_complete = serializers.BooleanField(required=False, default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFeedbackSerializer(serializers.Serializer):
|
||||||
|
page_uuid = serializers.CharField()
|
||||||
|
responses = serializers.DictField()
|
||||||
|
question = serializers.CharField(required=False, allow_blank=True, default='')
|
||||||
0
apps/onboarding/tests/__init__.py
Normal file
0
apps/onboarding/tests/__init__.py
Normal file
124
apps/onboarding/tests/test_api.py
Normal file
124
apps/onboarding/tests/test_api.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.test import TestCase
|
||||||
|
from rest_framework.test import APIRequestFactory, force_authenticate
|
||||||
|
from rest_framework.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
from apps.orgs.models import Organization, Role
|
||||||
|
from apps.mlstore.models import AgentModel, Agent
|
||||||
|
from apps.onboarding.models import OnboardingFlow, OnboardingPage, OnboardingSession
|
||||||
|
from apps.onboarding.viewsets import OnboardingFlowViewSet, OnboardingSessionViewSet
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingAPITests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = APIRequestFactory()
|
||||||
|
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
|
||||||
|
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
|
||||||
|
self.org = Organization.objects.create(name='Org', owner=self.manager)
|
||||||
|
self.role = Role.objects.create(name='Engineer', organization=self.org)
|
||||||
|
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
|
||||||
|
self.agent = Agent.objects.create(model=self.model, organization=self.org)
|
||||||
|
|
||||||
|
def test_create_flow(self):
|
||||||
|
view = OnboardingFlowViewSet.as_view({'post': 'create'})
|
||||||
|
data = {
|
||||||
|
'role': str(self.role.uuid),
|
||||||
|
'agent': str(self.agent.uuid),
|
||||||
|
'title': 'Flow',
|
||||||
|
'description': 'Desc',
|
||||||
|
'status': 'draft',
|
||||||
|
}
|
||||||
|
request = self.factory.post('/', data)
|
||||||
|
force_authenticate(request, user=self.manager)
|
||||||
|
response = view(request)
|
||||||
|
self.assertIn(response.status_code, (HTTP_200_OK, HTTP_201_CREATED))
|
||||||
|
self.assertTrue(OnboardingFlow.objects.filter(title='Flow').exists())
|
||||||
|
|
||||||
|
def test_pages_action(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
|
||||||
|
view = OnboardingFlowViewSet.as_view({'get': 'pages'})
|
||||||
|
request = self.factory.get('/')
|
||||||
|
force_authenticate(request, user=self.manager)
|
||||||
|
response = view(request, uuid=str(flow.uuid))
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertEqual(len(response.data.get('pages', [])), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
view = OnboardingSessionViewSet.as_view({'post': 'create'})
|
||||||
|
request = self.factory.post('/', {'flow': str(flow.uuid)})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertIn(response.status_code, (HTTP_200_OK, HTTP_201_CREATED))
|
||||||
|
self.assertTrue(OnboardingSession.objects.filter(flow=flow, user=self.user).exists())
|
||||||
|
|
||||||
|
def test_submit_updates_session(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
|
||||||
|
session = OnboardingSession.objects.create(flow=flow, user=self.user)
|
||||||
|
|
||||||
|
view = OnboardingSessionViewSet.as_view({'post': 'submit'})
|
||||||
|
payload = {'page_uuid': str(page.uuid), 'responses': {'q1': 'a'}, 'mark_complete': True}
|
||||||
|
request = self.factory.post('/', payload, format='json')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request, uuid=str(session.uuid))
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
session.refresh_from_db()
|
||||||
|
self.assertEqual(session.status, 'completed')
|
||||||
|
self.assertIn(str(page.uuid), session.responses)
|
||||||
|
|
||||||
|
def test_publish_flow_as_manager(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
self.assertEqual(flow.status, 'draft')
|
||||||
|
view = OnboardingFlowViewSet.as_view({'post': 'publish'})
|
||||||
|
request = self.factory.post('/')
|
||||||
|
force_authenticate(request, user=self.manager)
|
||||||
|
response = view(request, uuid=str(flow.uuid))
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
flow.refresh_from_db()
|
||||||
|
self.assertEqual(flow.status, 'published')
|
||||||
|
|
||||||
|
def test_publish_flow_requires_manager(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
view = OnboardingFlowViewSet.as_view({'post': 'publish'})
|
||||||
|
request = self.factory.post('/')
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request, uuid=str(flow.uuid))
|
||||||
|
self.assertEqual(response.status_code, HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
def test_get_or_create_session_creates_when_missing(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
|
||||||
|
request = self.factory.post('/', {'flow': str(flow.uuid)})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertTrue(OnboardingSession.objects.filter(flow=flow, user=self.user).exists())
|
||||||
|
|
||||||
|
def test_get_or_create_session_reuses_active(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
existing = OnboardingSession.objects.create(flow=flow, user=self.user, current_page_order=1)
|
||||||
|
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
|
||||||
|
request = self.factory.post('/', {'flow': str(flow.uuid)})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data.get('uuid'), str(existing.uuid))
|
||||||
|
self.assertEqual(response.data.get('current_page_order'), 1)
|
||||||
|
|
||||||
|
def test_get_or_create_session_creates_after_completion(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
completed = OnboardingSession.objects.create(flow=flow, user=self.user, status='completed')
|
||||||
|
view = OnboardingSessionViewSet.as_view({'post': 'get_or_create'})
|
||||||
|
request = self.factory.post('/', {'flow': str(flow.uuid)})
|
||||||
|
force_authenticate(request, user=self.user)
|
||||||
|
response = view(request)
|
||||||
|
self.assertEqual(response.status_code, HTTP_200_OK)
|
||||||
|
self.assertNotEqual(response.data.get('uuid'), str(completed.uuid))
|
||||||
|
|
||||||
|
|
||||||
41
apps/onboarding/tests/test_models.py
Normal file
41
apps/onboarding/tests/test_models.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from apps.orgs.models import Organization, Role
|
||||||
|
from apps.mlstore.models import AgentModel, Agent
|
||||||
|
from apps.onboarding.models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
class OnboardingModelTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.user = User.objects.create_user(email_address='user@example.com', password='pass')
|
||||||
|
self.manager = User.objects.create_user(email_address='manager@example.com', password='pass', is_manager=True)
|
||||||
|
self.org = Organization.objects.create(name='Org', owner=self.manager)
|
||||||
|
self.role = Role.objects.create(name='Engineer', organization=self.org)
|
||||||
|
self.model = AgentModel.objects.create(name='test-model', version='v1', path='model.gguf')
|
||||||
|
self.agent = Agent.objects.create(model=self.model, organization=self.org)
|
||||||
|
|
||||||
|
def test_flow_str(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Welcome', description='Intro')
|
||||||
|
self.assertIn('Welcome', str(flow))
|
||||||
|
self.assertIn(self.role.name, str(flow))
|
||||||
|
|
||||||
|
def test_page_and_field_str(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
|
||||||
|
field = OnboardingField.objects.create(page=page, order=0, key='q1', label='Question 1')
|
||||||
|
self.assertIn(flow.title, str(page))
|
||||||
|
self.assertIn(field.label, str(field))
|
||||||
|
|
||||||
|
def test_field_unique_key_per_page(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
page = OnboardingPage.objects.create(flow=flow, order=0, title='Page 1', body='Body')
|
||||||
|
OnboardingField.objects.create(page=page, order=0, key='dup', label='Dup 1')
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
OnboardingField.objects.create(page=page, order=1, key='dup', label='Dup 2')
|
||||||
|
|
||||||
|
def test_session_str(self):
|
||||||
|
flow = OnboardingFlow.objects.create(role=self.role, agent=self.agent, title='Flow', description='')
|
||||||
|
session = OnboardingSession.objects.create(flow=flow, user=self.user)
|
||||||
|
self.assertIn(self.user.email_address, str(session))
|
||||||
|
self.assertIn(flow.title, str(session))
|
||||||
451
apps/onboarding/viewsets.py
Normal file
451
apps/onboarding/viewsets.py
Normal file
|
|
@ -0,0 +1,451 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import html
|
||||||
|
from typing import Any
|
||||||
|
from django.db import transaction
|
||||||
|
from django.utils import timezone
|
||||||
|
from rest_framework import status
|
||||||
|
from rest_framework.exceptions import PermissionDenied
|
||||||
|
from rest_framework.decorators import action
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
|
from channels.layers import get_channel_layer
|
||||||
|
|
||||||
|
from apps.mlstore.models import AgentEvent, AgentRun
|
||||||
|
from apps.mlstore import services as ml_services
|
||||||
|
from .models import OnboardingFlow, OnboardingPage, OnboardingField, OnboardingSession
|
||||||
|
from .serializers import (
|
||||||
|
OnboardingFlowSerializer,
|
||||||
|
OnboardingFlowDetailSerializer,
|
||||||
|
OnboardingPageSerializer,
|
||||||
|
OnboardingFieldSerializer,
|
||||||
|
OnboardingSessionSerializer,
|
||||||
|
OnboardingSubmissionSerializer,
|
||||||
|
OnboardingFeedbackSerializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json(text: str) -> dict[str, Any]:
|
||||||
|
if not text:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Prefer fenced json blocks
|
||||||
|
fenced = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", text, re.IGNORECASE)
|
||||||
|
if fenced:
|
||||||
|
try:
|
||||||
|
return json.loads(fenced.group(1))
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Fallback: find first balanced JSON object
|
||||||
|
start = text.find('{')
|
||||||
|
if start == -1:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
depth = 0
|
||||||
|
for idx in range(start, len(text)):
|
||||||
|
char = text[idx]
|
||||||
|
if char == '{':
|
||||||
|
depth += 1
|
||||||
|
elif char == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
candidate = text[start:idx + 1]
|
||||||
|
try:
|
||||||
|
return json.loads(candidate)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
cleaned = re.sub(r"<[^>]+>", " ", text)
|
||||||
|
cleaned = html.unescape(cleaned)
|
||||||
|
return re.sub(r"\s+", " ", cleaned).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _send_agent_progress_event(agent_run: AgentRun, content: dict):
|
||||||
|
try:
|
||||||
|
AgentEvent.objects.create(
|
||||||
|
execution=agent_run,
|
||||||
|
event_type='progress',
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
room_group_name = f"mlstore_agent_{agent_run.agent.uuid}"
|
||||||
|
async_to_sync(get_channel_layer().group_send)(
|
||||||
|
room_group_name,
|
||||||
|
{
|
||||||
|
"type": "mlstore_event",
|
||||||
|
"event_type": "progress",
|
||||||
|
"content": content,
|
||||||
|
"timestamp": timezone.now().isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to send progress event: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFlowViewSet(ModelViewSet):
|
||||||
|
queryset = OnboardingFlow.objects.select_related('role', 'agent').all()
|
||||||
|
serializer_class = OnboardingFlowSerializer
|
||||||
|
lookup_field = 'uuid'
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
qs = super().get_queryset()
|
||||||
|
role_uuid = self.request.query_params.get('role')
|
||||||
|
status_filter = self.request.query_params.get('status')
|
||||||
|
if role_uuid:
|
||||||
|
qs = qs.filter(role__uuid=role_uuid)
|
||||||
|
if status_filter:
|
||||||
|
qs = qs.filter(status=status_filter)
|
||||||
|
return qs
|
||||||
|
|
||||||
|
def get_serializer_class(self):
|
||||||
|
if self.action in ('retrieve', 'pages'):
|
||||||
|
return OnboardingFlowDetailSerializer
|
||||||
|
return super().get_serializer_class()
|
||||||
|
|
||||||
|
@action(detail=True, methods=['get'])
|
||||||
|
def pages(self, request, pk=None, uuid=None):
|
||||||
|
flow = self.get_object()
|
||||||
|
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
@action(detail=True, methods=['post'])
|
||||||
|
def generate(self, request, pk=None, uuid=None):
|
||||||
|
flow = self.get_object()
|
||||||
|
if not request.user.is_authenticated or not getattr(request.user, 'is_manager', False):
|
||||||
|
return Response({"error": "permission_denied"}, status=status.HTTP_403_FORBIDDEN)
|
||||||
|
if not flow.agent or not flow.agent.model or not flow.agent.model.path:
|
||||||
|
return Response(
|
||||||
|
{"error": "flow_agent_model_required"},
|
||||||
|
status=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
instructions = request.data.get('instructions') or ''
|
||||||
|
rag_context = ""
|
||||||
|
try:
|
||||||
|
rag_context = ml_services.get_context_for_query(
|
||||||
|
query=f"Create onboarding content for role {flow.role.name}",
|
||||||
|
role_uuid=str(flow.role.uuid),
|
||||||
|
top_k=6,
|
||||||
|
similarity_threshold=0.35,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Onboarding generation RAG lookup failed: %s", e)
|
||||||
|
prompt = (
|
||||||
|
"You are creating onboarding content as JSON. "
|
||||||
|
"Return ONLY valid JSON (no prose, no markdown, no code fences).\n"
|
||||||
|
"Do not include explanations or examples.\n"
|
||||||
|
"Do not include HTML tags. Use plain text only.\n"
|
||||||
|
"Each page body must be 3-6 paragraphs, at least 320 words total, and include 1 short list of 3-5 bullets.\n"
|
||||||
|
"Before writing the body, create a brief outline of the key points to cover and include it in meta.outline.\n"
|
||||||
|
"The outline should be a short list of 3-6 bullets, not chain-of-thought.\n"
|
||||||
|
"Do NOT ask about the learner's personal experience. Ask about what someone in the role may encounter.\n"
|
||||||
|
"Do NOT use any select or multiselect fields. Use only text, textarea, number, boolean, or date.\n"
|
||||||
|
"Use the provided context for accurate, role-specific content.\n"
|
||||||
|
"If context is insufficient, make reasonable assumptions without inventing tools or policies.\n"
|
||||||
|
"JSON shape:\n"
|
||||||
|
"{\n"
|
||||||
|
" \"title\": string,\n"
|
||||||
|
" \"description\": string,\n"
|
||||||
|
" \"pages\": [\n"
|
||||||
|
" {\n"
|
||||||
|
" \"title\": string,\n"
|
||||||
|
" \"body\": string,\n"
|
||||||
|
" \"meta\": { \"outline\": [string] },\n"
|
||||||
|
" \"fields\": [\n"
|
||||||
|
" {\n"
|
||||||
|
" \"key\": string,\n"
|
||||||
|
" \"label\": string,\n"
|
||||||
|
" \"type\": one of [text, textarea, number, boolean, date],\n"
|
||||||
|
" \"required\": boolean,\n"
|
||||||
|
" \"help_text\": string,\n"
|
||||||
|
" \"placeholder\": string,\n"
|
||||||
|
" \"options\": []\n"
|
||||||
|
" }\n"
|
||||||
|
" ]\n"
|
||||||
|
" }\n"
|
||||||
|
" ]\n"
|
||||||
|
"}\n"
|
||||||
|
f"Role: {flow.role.name}\n"
|
||||||
|
f"Role description: {flow.role.description}\n"
|
||||||
|
f"Flow title: {flow.title}\n"
|
||||||
|
f"Flow description: {flow.description}\n"
|
||||||
|
f"Extra instructions: {instructions}\n"
|
||||||
|
f"Context:\n{rag_context}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ml_services.infer_with_model(flow.agent.model.path, prompt, {
|
||||||
|
"max_tokens": 1800,
|
||||||
|
"temperature": 0.2,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Onboarding generate inference failed: %s", e, exc_info=True)
|
||||||
|
return Response({"error": "generation_failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
response_text = ''
|
||||||
|
if isinstance(result, dict):
|
||||||
|
response_text = result.get('response') or result.get('result') or ''
|
||||||
|
payload = _extract_json(str(response_text))
|
||||||
|
if not payload or 'pages' not in payload:
|
||||||
|
return Response({"error": "invalid_generation_output", "raw": response_text}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
with transaction.atomic():
|
||||||
|
flow.title = payload.get('title') or flow.title
|
||||||
|
# Keep existing description on regenerate unless explicitly empty
|
||||||
|
if not flow.description:
|
||||||
|
flow.description = payload.get('description') or flow.description
|
||||||
|
if flow.status != 'draft':
|
||||||
|
flow.status = 'draft'
|
||||||
|
flow.save(update_fields=['title', 'description', 'status'])
|
||||||
|
|
||||||
|
OnboardingPage.objects.filter(flow=flow).delete()
|
||||||
|
|
||||||
|
pages = payload.get('pages') or []
|
||||||
|
for page_index, page in enumerate(pages):
|
||||||
|
body_text = _strip_html(page.get('body') or '')
|
||||||
|
page_obj = OnboardingPage.objects.create(
|
||||||
|
flow=flow,
|
||||||
|
order=page_index,
|
||||||
|
title=page.get('title') or f"Page {page_index + 1}",
|
||||||
|
body=body_text,
|
||||||
|
meta=page.get('meta') or {},
|
||||||
|
)
|
||||||
|
for field_index, field in enumerate(page.get('fields') or []):
|
||||||
|
field_type = field.get('type') or 'text'
|
||||||
|
if field_type not in {"text", "textarea", "number", "boolean", "date"}:
|
||||||
|
field_type = 'text'
|
||||||
|
OnboardingField.objects.create(
|
||||||
|
page=page_obj,
|
||||||
|
order=field_index,
|
||||||
|
key=field.get('key') or f"field_{field_index + 1}",
|
||||||
|
label=field.get('label') or f"Field {field_index + 1}",
|
||||||
|
field_type=field_type,
|
||||||
|
required=bool(field.get('required')),
|
||||||
|
help_text=field.get('help_text') or '',
|
||||||
|
placeholder=field.get('placeholder') or '',
|
||||||
|
options=[],
|
||||||
|
default_value=field.get('default_value') if field.get('default_value') is not None else None,
|
||||||
|
validation=field.get('validation') or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
@action(detail=True, methods=['post'])
|
||||||
|
def publish(self, request, pk=None, uuid=None):
|
||||||
|
flow = self.get_object()
|
||||||
|
if not request.user.is_authenticated or not getattr(request.user, 'is_manager', False):
|
||||||
|
return Response({"error": "permission_denied"}, status=status.HTTP_403_FORBIDDEN)
|
||||||
|
if flow.status != 'published':
|
||||||
|
flow.status = 'published'
|
||||||
|
flow.save(update_fields=['status'])
|
||||||
|
serializer = OnboardingFlowDetailSerializer(flow, context={'request': request})
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingPageViewSet(ModelViewSet):
|
||||||
|
queryset = OnboardingPage.objects.select_related('flow').prefetch_related('fields').all()
|
||||||
|
serializer_class = OnboardingPageSerializer
|
||||||
|
lookup_field = 'uuid'
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingFieldViewSet(ModelViewSet):
|
||||||
|
queryset = OnboardingField.objects.select_related('page').all()
|
||||||
|
serializer_class = OnboardingFieldSerializer
|
||||||
|
lookup_field = 'uuid'
|
||||||
|
|
||||||
|
|
||||||
|
class OnboardingSessionViewSet(ModelViewSet):
|
||||||
|
queryset = OnboardingSession.objects.select_related('flow', 'user', 'agent_run', 'flow__agent').all()
|
||||||
|
serializer_class = OnboardingSessionSerializer
|
||||||
|
lookup_field = 'uuid'
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
qs = super().get_queryset()
|
||||||
|
user = self.request.user
|
||||||
|
if user.is_authenticated and not getattr(user, 'is_manager', False):
|
||||||
|
qs = qs.filter(user=user)
|
||||||
|
return qs
|
||||||
|
|
||||||
|
def perform_create(self, serializer):
|
||||||
|
if not self.request.user or not self.request.user.is_authenticated:
|
||||||
|
raise PermissionDenied("Authentication required")
|
||||||
|
flow = serializer.validated_data.get('flow')
|
||||||
|
agent_run = None
|
||||||
|
if flow and flow.agent:
|
||||||
|
agent_run = AgentRun.objects.create(
|
||||||
|
agent=flow.agent,
|
||||||
|
user=self.request.user,
|
||||||
|
input_data={
|
||||||
|
"type": "onboarding_session",
|
||||||
|
"flow_uuid": str(flow.uuid),
|
||||||
|
"role_uuid": str(flow.role.uuid),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
serializer.save(user=self.request.user, agent_run=agent_run)
|
||||||
|
|
||||||
|
@action(detail=False, methods=['post'])
|
||||||
|
def get_or_create(self, request):
|
||||||
|
if not request.user or not request.user.is_authenticated:
|
||||||
|
raise PermissionDenied("Authentication required")
|
||||||
|
|
||||||
|
flow_uuid = request.data.get('flow')
|
||||||
|
if not flow_uuid:
|
||||||
|
return Response({"error": "flow_required"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
try:
|
||||||
|
flow = OnboardingFlow.objects.get(uuid=flow_uuid)
|
||||||
|
except OnboardingFlow.DoesNotExist:
|
||||||
|
return Response({"error": "flow_not_found"}, status=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
session = (
|
||||||
|
OnboardingSession.objects
|
||||||
|
.filter(flow=flow, user=request.user)
|
||||||
|
.exclude(status='completed')
|
||||||
|
.order_by('-updated_at')
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
agent_run = None
|
||||||
|
if flow.agent:
|
||||||
|
agent_run = AgentRun.objects.create(
|
||||||
|
agent=flow.agent,
|
||||||
|
user=request.user,
|
||||||
|
input_data={
|
||||||
|
"type": "onboarding_session",
|
||||||
|
"flow_uuid": str(flow.uuid),
|
||||||
|
"role_uuid": str(flow.role.uuid),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session = OnboardingSession.objects.create(
|
||||||
|
flow=flow,
|
||||||
|
user=request.user,
|
||||||
|
agent_run=agent_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(OnboardingSessionSerializer(session, context={'request': request}).data)
|
||||||
|
|
||||||
|
@action(detail=True, methods=['post'])
|
||||||
|
def submit(self, request, pk=None, uuid=None):
|
||||||
|
session = self.get_object()
|
||||||
|
serializer = OnboardingSubmissionSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
page_uuid = serializer.validated_data['page_uuid']
|
||||||
|
responses = serializer.validated_data['responses']
|
||||||
|
mark_complete = serializer.validated_data.get('mark_complete')
|
||||||
|
|
||||||
|
try:
|
||||||
|
page = OnboardingPage.objects.get(flow=session.flow, uuid=page_uuid)
|
||||||
|
except OnboardingPage.DoesNotExist:
|
||||||
|
return Response({"error": "page_not_found"}, status=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
responses_payload = dict(session.responses or {})
|
||||||
|
responses_payload[str(page.uuid)] = responses
|
||||||
|
session.responses = responses_payload
|
||||||
|
session.current_page_order = page.order
|
||||||
|
|
||||||
|
if mark_complete or page.order >= session.flow.pages.count() - 1:
|
||||||
|
session.status = 'completed'
|
||||||
|
session.completed_at = timezone.now()
|
||||||
|
session.save(update_fields=['responses', 'current_page_order', 'status', 'completed_at'])
|
||||||
|
|
||||||
|
if session.agent_run:
|
||||||
|
progress_payload = {
|
||||||
|
"flow_uuid": str(session.flow.uuid),
|
||||||
|
"session_uuid": str(session.uuid),
|
||||||
|
"page_uuid": str(page.uuid),
|
||||||
|
"page_order": page.order,
|
||||||
|
"status": session.status,
|
||||||
|
"responses": responses,
|
||||||
|
}
|
||||||
|
_send_agent_progress_event(session.agent_run, progress_payload)
|
||||||
|
session.agent_run.output_data = {
|
||||||
|
**(session.agent_run.output_data or {}),
|
||||||
|
"onboarding": session.responses,
|
||||||
|
}
|
||||||
|
session.agent_run.save(update_fields=['output_data'])
|
||||||
|
|
||||||
|
return Response(OnboardingSessionSerializer(session, context={'request': request}).data)
|
||||||
|
|
||||||
|
@action(detail=True, methods=['post'])
|
||||||
|
def feedback(self, request, pk=None, uuid=None):
|
||||||
|
session = self.get_object()
|
||||||
|
serializer = OnboardingFeedbackSerializer(data=request.data)
|
||||||
|
serializer.is_valid(raise_exception=True)
|
||||||
|
page_uuid = serializer.validated_data['page_uuid']
|
||||||
|
responses = serializer.validated_data['responses']
|
||||||
|
question = serializer.validated_data.get('question') or ''
|
||||||
|
|
||||||
|
try:
|
||||||
|
page = OnboardingPage.objects.get(flow=session.flow, uuid=page_uuid)
|
||||||
|
except OnboardingPage.DoesNotExist:
|
||||||
|
return Response({"error": "page_not_found"}, status=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
if not session.flow.agent or not session.flow.agent.model or not session.flow.agent.model.path:
|
||||||
|
return Response({"error": "flow_agent_model_required"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"You are an onboarding assessor. Provide concise feedback addressed directly to the learner using second-person \"You\" statements.\n"
|
||||||
|
"Return ONLY valid JSON (no prose, no markdown, no code fences).\n"
|
||||||
|
"JSON shape:\n"
|
||||||
|
"{\n"
|
||||||
|
" \"summary\": string\n"
|
||||||
|
"}\n\n"
|
||||||
|
f"Page title: {page.title}\n"
|
||||||
|
f"Page body: {page.body}\n"
|
||||||
|
f"Responses: {json.dumps(responses)}\n"
|
||||||
|
)
|
||||||
|
if question:
|
||||||
|
prompt += f"Learner question: {question}\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ml_services.infer_with_model(session.flow.agent.model.path, prompt, {
|
||||||
|
"max_tokens": 900,
|
||||||
|
"temperature": 0.2,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Onboarding feedback inference failed: %s", e, exc_info=True)
|
||||||
|
return Response({"error": "feedback_failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
feedback_text = ''
|
||||||
|
if isinstance(result, dict):
|
||||||
|
feedback_text = result.get('response') or result.get('result') or ''
|
||||||
|
feedback_text = str(feedback_text).strip()
|
||||||
|
|
||||||
|
feedback_payload = _extract_json(feedback_text)
|
||||||
|
if not feedback_payload:
|
||||||
|
feedback_payload = {
|
||||||
|
"summary": feedback_text or "Feedback generated.",
|
||||||
|
}
|
||||||
|
|
||||||
|
responses_payload = dict(session.responses or {})
|
||||||
|
feedback_store = dict(responses_payload.get("__feedback__") or {})
|
||||||
|
feedback_store[str(page.uuid)] = {
|
||||||
|
"feedback": feedback_payload,
|
||||||
|
"question": question,
|
||||||
|
"updated_at": timezone.now().isoformat(),
|
||||||
|
}
|
||||||
|
responses_payload["__feedback__"] = feedback_store
|
||||||
|
session.responses = responses_payload
|
||||||
|
session.save(update_fields=['responses'])
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
"feedback": feedback_payload,
|
||||||
|
"session": OnboardingSessionSerializer(session, context={'request': request}).data,
|
||||||
|
})
|
||||||
|
|
@ -54,8 +54,8 @@ class RoleMembershipAdmin(ModelAdmin):
|
||||||
|
|
||||||
@register(TrainingFile)
|
@register(TrainingFile)
|
||||||
class TrainingFileAdmin(ModelAdmin):
|
class TrainingFileAdmin(ModelAdmin):
|
||||||
list_display = ('id', 'uuid', 'file_name', 'organization', 'uploaded_by', 'is_processed', 'created_at')
|
list_display = ('id', 'uuid', 'file_name', 'role', 'uploaded_by', 'status', 'is_processed', 'created_at')
|
||||||
search_fields = ('file_name', 'organization__name', 'uploaded_by__email_address')
|
search_fields = ('file_name', 'role__name', 'uploaded_by__email_address')
|
||||||
list_filter = ('is_processed', 'created_at')
|
list_filter = ('status', 'is_processed', 'created_at')
|
||||||
raw_id_fields = ('organization', 'uploaded_by')
|
raw_id_fields = ('role', 'uploaded_by')
|
||||||
readonly_fields = ('uuid', 'created_at', 'updated_at')
|
readonly_fields = ('uuid', 'created_at', 'updated_at')
|
||||||
|
|
@ -1,11 +1,8 @@
|
||||||
# Generated by Django 5.2.10 on 2026-01-25 11:02
|
|
||||||
|
|
||||||
import django.db.models.deletion
|
import django.db.models.deletion
|
||||||
import uuid
|
import uuid
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
initial = True
|
initial = True
|
||||||
|
|
@ -118,8 +115,9 @@ class Migration(migrations.Migration):
|
||||||
('file_size', models.IntegerField()),
|
('file_size', models.IntegerField()),
|
||||||
('file_type', models.CharField(max_length=50)),
|
('file_type', models.CharField(max_length=50)),
|
||||||
('description', models.TextField(blank=True, default='')),
|
('description', models.TextField(blank=True, default='')),
|
||||||
|
('status', models.CharField(choices=[('ingesting', 'Ingesting'), ('chunked', 'Chunked'), ('embedded', 'Embedded'), ('failed', 'Failed')], default='ingesting', max_length=20)),
|
||||||
('is_processed', models.BooleanField(default=False)),
|
('is_processed', models.BooleanField(default=False)),
|
||||||
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='orgs.organization')),
|
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='orgs.role')),
|
||||||
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
|
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ from uuid import uuid4
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, ManyToManyField, Model, TextField, UUIDField, IntegerField, FileField
|
from django.db.models import BigAutoField, BooleanField, CASCADE, CharField, DateTimeField, ForeignKey, ManyToManyField, Model, TextField, UUIDField, IntegerField, FileField
|
||||||
from django.db.models.signals import post_delete
|
from django.db.models.signals import post_delete, post_save
|
||||||
|
from django.db import transaction
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from apps.users.mixins import TimeStampMixin
|
from apps.users.mixins import TimeStampMixin
|
||||||
from apps.users.models import User
|
from apps.users.models import User
|
||||||
|
|
@ -104,9 +105,16 @@ class TrainingFile(TimeStampMixin, Model):
|
||||||
|
|
||||||
ALLOWED_EXTENSIONS = ('txt', 'pdf', 'md', 'csv', 'json', 'docx', 'doc')
|
ALLOWED_EXTENSIONS = ('txt', 'pdf', 'md', 'csv', 'json', 'docx', 'doc')
|
||||||
|
|
||||||
|
STATUS_CHOICES = [
|
||||||
|
('ingesting', 'Ingesting'),
|
||||||
|
('chunked', 'Chunked'),
|
||||||
|
('embedded', 'Embedded'),
|
||||||
|
('failed', 'Failed'),
|
||||||
|
]
|
||||||
|
|
||||||
id = BigAutoField(primary_key = True)
|
id = BigAutoField(primary_key = True)
|
||||||
uuid = UUIDField(default = uuid4, unique = True, editable = False)
|
uuid = UUIDField(default = uuid4, unique = True, editable = False)
|
||||||
organization = ForeignKey(Organization, on_delete = CASCADE, related_name = "training_files")
|
role = ForeignKey(Role, on_delete = CASCADE, related_name = "training_files")
|
||||||
uploaded_by = ForeignKey(User, on_delete = CASCADE, related_name = "uploaded_training_files")
|
uploaded_by = ForeignKey(User, on_delete = CASCADE, related_name = "uploaded_training_files")
|
||||||
|
|
||||||
file = FileField(upload_to = 'training_files/%Y/%m/%d/')
|
file = FileField(upload_to = 'training_files/%Y/%m/%d/')
|
||||||
|
|
@ -115,6 +123,7 @@ class TrainingFile(TimeStampMixin, Model):
|
||||||
file_type = CharField(max_length = 50)
|
file_type = CharField(max_length = 50)
|
||||||
|
|
||||||
description = TextField(blank = True, default = '')
|
description = TextField(blank = True, default = '')
|
||||||
|
status = CharField(max_length = 20, choices = STATUS_CHOICES, default = 'ingesting')
|
||||||
is_processed = BooleanField(default = False)
|
is_processed = BooleanField(default = False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
@ -123,7 +132,7 @@ class TrainingFile(TimeStampMixin, Model):
|
||||||
ordering = ['-created_at']
|
ordering = ['-created_at']
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.file_name} - {self.organization.name}"
|
return f"{self.file_name} - {self.role.name}"
|
||||||
|
|
||||||
|
|
||||||
@receiver(post_delete, sender=TrainingFile)
|
@receiver(post_delete, sender=TrainingFile)
|
||||||
|
|
@ -135,3 +144,15 @@ def delete_training_file_on_delete(sender, instance, **kwargs):
|
||||||
os.remove(instance.file.path)
|
os.remove(instance.file.path)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(post_save, sender=TrainingFile)
|
||||||
|
def enqueue_training_file_ingestion(sender, instance, created, **kwargs):
|
||||||
|
if not created:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _enqueue():
|
||||||
|
from apps.mlstore.tasks import ingest_training_file_task
|
||||||
|
ingest_training_file_task.delay(str(instance.uuid))
|
||||||
|
|
||||||
|
transaction.on_commit(_enqueue)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField, IntegerField
|
from rest_framework.serializers import ModelSerializer, SerializerMethodField, IntegerField, UUIDField
|
||||||
|
from rest_framework.exceptions import ValidationError
|
||||||
from apps.orgs.models import Organization, OrganizationMembership, OrganizationInvitation, Role, RoleMembership, TrainingFile
|
from apps.orgs.models import Organization, OrganizationMembership, OrganizationInvitation, Role, RoleMembership, TrainingFile
|
||||||
from apps.users.serializers import UserSerializer
|
from apps.users.serializers import UserSerializer
|
||||||
|
|
||||||
|
|
@ -75,11 +76,13 @@ class RoleSerializer(ModelSerializer):
|
||||||
class TrainingFileSerializer(ModelSerializer):
|
class TrainingFileSerializer(ModelSerializer):
|
||||||
uploaded_by = UserSerializer(read_only = True)
|
uploaded_by = UserSerializer(read_only = True)
|
||||||
file_url = SerializerMethodField()
|
file_url = SerializerMethodField()
|
||||||
|
role = RoleSerializer(read_only = True)
|
||||||
|
role_uuid = UUIDField(write_only = True, required = True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = TrainingFile
|
model = TrainingFile
|
||||||
fields = ['id', 'uuid', 'organization', 'uploaded_by', 'file', 'file_name', 'file_size', 'file_type', 'description', 'is_processed', 'file_url', 'created_at', 'updated_at']
|
fields = ['id', 'uuid', 'role', 'role_uuid', 'uploaded_by', 'file', 'file_name', 'file_size', 'file_type', 'description', 'status', 'is_processed', 'file_url', 'created_at', 'updated_at']
|
||||||
read_only_fields = ['uuid', 'uploaded_by', 'file_size', 'file_type', 'is_processed', 'created_at', 'updated_at', 'organization']
|
read_only_fields = ['uuid', 'uploaded_by', 'file_size', 'file_type', 'status', 'is_processed', 'created_at', 'updated_at', 'role']
|
||||||
|
|
||||||
def get_file_url(self, obj):
|
def get_file_url(self, obj):
|
||||||
request = self.context.get('request')
|
request = self.context.get('request')
|
||||||
|
|
@ -88,7 +91,6 @@ class TrainingFileSerializer(ModelSerializer):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def validate_file(self, value):
|
def validate_file(self, value):
|
||||||
"""Validate that uploaded file is a text-based file."""
|
|
||||||
if not value:
|
if not value:
|
||||||
raise ValueError('File is required')
|
raise ValueError('File is required')
|
||||||
|
|
||||||
|
|
@ -108,10 +110,18 @@ class TrainingFileSerializer(ModelSerializer):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def create(self, validated_data):
|
def create(self, validated_data):
|
||||||
|
role_uuid = validated_data.pop('role_uuid', None)
|
||||||
file_obj = validated_data.get('file')
|
file_obj = validated_data.get('file')
|
||||||
if file_obj:
|
if file_obj:
|
||||||
validated_data['file_size'] = file_obj.size
|
validated_data['file_size'] = file_obj.size
|
||||||
import os
|
import os
|
||||||
file_extension = os.path.splitext(file_obj.name)[1][1:].lower()
|
file_extension = os.path.splitext(file_obj.name)[1][1:].lower()
|
||||||
validated_data['file_type'] = file_extension
|
validated_data['file_type'] = file_extension
|
||||||
|
if not role_uuid:
|
||||||
|
raise ValidationError({'role_uuid': 'Role is required'})
|
||||||
|
try:
|
||||||
|
role = Role.objects.get(uuid = role_uuid)
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
raise ValidationError({'role_uuid': 'Role not found'})
|
||||||
|
validated_data['role'] = role
|
||||||
return super().create(validated_data)
|
return super().create(validated_data)
|
||||||
|
|
@ -166,6 +166,12 @@ class OrganizationViewSet(ModelViewSet):
|
||||||
serializer = RoleSerializer(role)
|
serializer = RoleSerializer(role)
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
@action(detail=False, methods=['get'], url_path='role/mine')
|
||||||
|
def my_roles(self, request):
|
||||||
|
roles = Role.objects.filter(memberships__user=request.user).distinct()
|
||||||
|
serializer = RoleSerializer(roles, many=True)
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
@action(detail=True, methods=['post'], url_path='role/(?P<role_uuid>[0-9a-f-]{36})/delete')
|
@action(detail=True, methods=['post'], url_path='role/(?P<role_uuid>[0-9a-f-]{36})/delete')
|
||||||
def delete_role(self, request, uuid = None, role_uuid = None):
|
def delete_role(self, request, uuid = None, role_uuid = None):
|
||||||
if not request.user.is_manager:
|
if not request.user.is_manager:
|
||||||
|
|
@ -196,8 +202,11 @@ class OrganizationViewSet(ModelViewSet):
|
||||||
organization = self.get_object()
|
organization = self.get_object()
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
training_files = TrainingFile.objects.filter(organization=organization)
|
role_uuid = request.query_params.get('role_uuid')
|
||||||
serializer = TrainingFileSerializer(training_files, many=True)
|
training_files = TrainingFile.objects.filter(role__organization=organization)
|
||||||
|
if role_uuid:
|
||||||
|
training_files = training_files.filter(role__uuid=role_uuid, role__organization=organization)
|
||||||
|
serializer = TrainingFileSerializer(training_files, many=True, context={'request': request})
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
if not (organization.owner == request.user or
|
if not (organization.owner == request.user or
|
||||||
|
|
@ -207,9 +216,17 @@ class OrganizationViewSet(ModelViewSet):
|
||||||
status=HTTP_403_FORBIDDEN
|
status=HTTP_403_FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
serializer = TrainingFileSerializer(data=request.data)
|
role_uuid = request.data.get('role_uuid')
|
||||||
|
if not role_uuid:
|
||||||
|
return Response({'error': 'role_uuid is required'}, status=HTTP_400_BAD_REQUEST)
|
||||||
|
try:
|
||||||
|
Role.objects.get(uuid=role_uuid, organization=organization)
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
return Response({'error': 'Role not found in this organization'}, status=HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
serializer = TrainingFileSerializer(data=request.data, context={'request': request})
|
||||||
if serializer.is_valid():
|
if serializer.is_valid():
|
||||||
serializer.save(uploaded_by=request.user, organization=organization)
|
serializer.save(uploaded_by=request.user)
|
||||||
return Response(serializer.data, status=201)
|
return Response(serializer.data, status=201)
|
||||||
return Response(serializer.errors, status=HTTP_400_BAD_REQUEST)
|
return Response(serializer.errors, status=HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|
@ -218,16 +235,16 @@ class OrganizationViewSet(ModelViewSet):
|
||||||
organization = self.get_object()
|
organization = self.get_object()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
training_file = TrainingFile.objects.get(uuid=file_uuid, organization=organization)
|
training_file = TrainingFile.objects.get(uuid=file_uuid, role__organization=organization)
|
||||||
except TrainingFile.DoesNotExist:
|
except TrainingFile.DoesNotExist:
|
||||||
return Response({'error': 'Training file not found'}, status=HTTP_404_NOT_FOUND)
|
return Response({'error': 'Training file not found'}, status=HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
serializer = TrainingFileSerializer(training_file)
|
serializer = TrainingFileSerializer(training_file, context={'request': request})
|
||||||
return Response(serializer.data)
|
return Response(serializer.data)
|
||||||
|
|
||||||
if not (training_file.uploaded_by == request.user or
|
if not (training_file.uploaded_by == request.user or
|
||||||
training_file.organization.owner == request.user or
|
training_file.role.organization.owner == request.user or
|
||||||
request.user.is_manager):
|
request.user.is_manager):
|
||||||
return Response(
|
return Response(
|
||||||
{'error': 'You do not have permission to delete this file'},
|
{'error': 'You do not have permission to delete this file'},
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,16 @@ from rest_framework.routers import DefaultRouter
|
||||||
from apps.orgs.viewsets import OrganizationViewSet
|
from apps.orgs.viewsets import OrganizationViewSet
|
||||||
from apps.users.viewsets import UserViewSet
|
from apps.users.viewsets import UserViewSet
|
||||||
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
|
from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet
|
||||||
|
from apps.onboarding.viewsets import OnboardingFlowViewSet, OnboardingPageViewSet, OnboardingFieldViewSet, OnboardingSessionViewSet
|
||||||
|
|
||||||
router = DefaultRouter()
|
router = DefaultRouter()
|
||||||
router.register(r'user', UserViewSet, basename='user')
|
router.register(r'user', UserViewSet, basename='user')
|
||||||
router.register(r'organization', OrganizationViewSet, basename='organization')
|
router.register(r'organization', OrganizationViewSet, basename='organization')
|
||||||
router.register(r'agent', AgentViewSet, basename='agent')
|
router.register(r'agent', AgentViewSet, basename='agent')
|
||||||
router.register(r'agent-run', AgentRunViewSet, basename='agent-run')
|
router.register(r'agent-run', AgentRunViewSet, basename='agent-run')
|
||||||
|
router.register(r'onboarding/flow', OnboardingFlowViewSet, basename='onboarding-flow')
|
||||||
|
router.register(r'onboarding/page', OnboardingPageViewSet, basename='onboarding-page')
|
||||||
|
router.register(r'onboarding/field', OnboardingFieldViewSet, basename='onboarding-field')
|
||||||
|
router.register(r'onboarding/session', OnboardingSessionViewSet, basename='onboarding-session')
|
||||||
|
|
||||||
urlpatterns = router.urls
|
urlpatterns = router.urls
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ LOCAL_APPS = [
|
||||||
'apps.users',
|
'apps.users',
|
||||||
'apps.orgs',
|
'apps.orgs',
|
||||||
'apps.mlstore',
|
'apps.mlstore',
|
||||||
|
'apps.onboarding',
|
||||||
]
|
]
|
||||||
INSTALLED_APPS = OVERRIDE_APPS + DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS
|
INSTALLED_APPS = OVERRIDE_APPS + DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,34 +11,36 @@ from aiohttp import web
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import Tool, TextContent
|
from mcp.types import Tool, TextContent
|
||||||
|
|
||||||
logging.basicConfig(
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
model_cache_dir = os.path.join(project_root, "model", "base-model")
|
||||||
|
|
||||||
|
def _init_runtime():
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.StreamHandler(sys.stderr),
|
logging.StreamHandler(sys.stderr),
|
||||||
logging.StreamHandler(sys.stdout),
|
logging.StreamHandler(sys.stdout),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
os.makedirs(model_cache_dir, exist_ok=True)
|
||||||
|
os.environ["HF_HOME"] = model_cache_dir
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
logger.info(f"Project root: {project_root}")
|
||||||
model_cache_dir = os.path.join(project_root, "model", "base-model")
|
logger.info(f"HuggingFace model cache directory set to: {model_cache_dir}")
|
||||||
os.makedirs(model_cache_dir, exist_ok=True)
|
|
||||||
os.environ["HF_HOME"] = model_cache_dir
|
|
||||||
logger.info(f"Project root: {project_root}")
|
|
||||||
logger.info(f"HuggingFace model cache directory set to: {model_cache_dir}")
|
|
||||||
|
|
||||||
app = Server("mlstore-mcp-server")
|
app = Server("mlstore-mcp-server")
|
||||||
logger.info("MCP Server initialized: mlstore-mcp-server")
|
|
||||||
|
|
||||||
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
|
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
PAIR_EXTRACTOR: Dict[str, Any] = {}
|
PAIR_EXTRACTOR: Dict[str, Any] = {}
|
||||||
|
EMBEDDING_MODEL: Dict[str, Any] = {}
|
||||||
|
|
||||||
BASE_MODEL_CACHE_DIR = model_cache_dir
|
BASE_MODEL_CACHE_DIR = model_cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.list_tools()
|
@app.list_tools()
|
||||||
async def list_tools():
|
async def list_tools():
|
||||||
logger.info("Listing available tools")
|
logger.info("Listing available tools")
|
||||||
|
|
@ -93,6 +95,18 @@ async def list_tools():
|
||||||
"required": ["model_path", "prompt"]
|
"required": ["model_path", "prompt"]
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Tool(
|
||||||
|
name="embed",
|
||||||
|
description="Generate embeddings for a list of texts",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"texts": {"type": "array", "items": {"type": "string"}},
|
||||||
|
"model": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["texts"]
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
logger.info(f"Available tools: {[t.name for t in tools]}")
|
logger.info(f"Available tools: {[t.name for t in tools]}")
|
||||||
return tools
|
return tools
|
||||||
|
|
@ -110,6 +124,112 @@ def _safe_dir_name(name: str) -> str:
|
||||||
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
|
return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".")
|
||||||
|
|
||||||
|
|
||||||
|
def _map_gguf_repo(model_name: str) -> tuple[str | None, str | None]:
|
||||||
|
gguf_repos = {
|
||||||
|
"Llama-3.1-8B-Instruct.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"),
|
||||||
|
"Meta-Llama-3.1-8B-Instruct.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"),
|
||||||
|
"Meta-Llama-3.1-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"),
|
||||||
|
"Llama-3.1-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"),
|
||||||
|
"Meta-Llama-3-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3-8B-Instruct-GGUF", "Meta-Llama-3-8B-Instruct-Q4_0.gguf"),
|
||||||
|
"Llama-3-8B-Instruct.Q4_0.gguf": ("bartowski/Meta-Llama-3-8B-Instruct-GGUF", "Meta-Llama-3-8B-Instruct-Q4_0.gguf"),
|
||||||
|
"mistral-7b-instruct-v0.3.Q4_0.gguf": ("bartowski/Mistral-7B-Instruct-v0.3-GGUF", "Mistral-7B-Instruct-v0.3-Q4_0.gguf"),
|
||||||
|
"Mistral-7B-Instruct-v0.3.Q4_0.gguf": ("bartowski/Mistral-7B-Instruct-v0.3-GGUF", "Mistral-7B-Instruct-v0.3-Q4_0.gguf"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_name in gguf_repos:
|
||||||
|
return gguf_repos[model_name]
|
||||||
|
|
||||||
|
base_name = model_name.lower()
|
||||||
|
if "llama" in base_name and "3.1" in base_name and "8b" in base_name:
|
||||||
|
repo_id = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||||
|
if ".q4_0" in base_name:
|
||||||
|
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_0.gguf"
|
||||||
|
if ".q4_k_m" in base_name:
|
||||||
|
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
|
||||||
|
return repo_id, "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
|
||||||
|
|
||||||
|
if "llama" in base_name and "3" in base_name and "8b" in base_name:
|
||||||
|
repo_id = "bartowski/Meta-Llama-3-8B-Instruct-GGUF"
|
||||||
|
if ".q4_0" in base_name:
|
||||||
|
return repo_id, "Meta-Llama-3-8B-Instruct-Q4_0.gguf"
|
||||||
|
return repo_id, "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
|
||||||
|
|
||||||
|
if "mistral" in base_name and "7b" in base_name:
|
||||||
|
repo_id = "bartowski/Mistral-7B-Instruct-v0.3-GGUF"
|
||||||
|
if ".q4_0" in base_name:
|
||||||
|
return repo_id, "Mistral-7B-Instruct-v0.3-Q4_0.gguf"
|
||||||
|
return repo_id, "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_existing_gguf(model_name: str) -> str | None:
|
||||||
|
repo_id, filename = _map_gguf_repo(model_name)
|
||||||
|
if not filename:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidate_paths = [
|
||||||
|
os.path.join(_model_root(), filename),
|
||||||
|
os.path.join(model_cache_dir, filename),
|
||||||
|
os.path.join(os.getcwd(), "model", filename),
|
||||||
|
os.path.join(os.getcwd(), filename),
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in candidate_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
return path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _download_gguf_from_hf(model_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Download a GGUF model from Hugging Face Hub.
|
||||||
|
Returns the path to the downloaded model file.
|
||||||
|
"""
|
||||||
|
logger.info(f"Attempting to download GGUF model from Hugging Face: {model_name}")
|
||||||
|
|
||||||
|
existing_path = _find_existing_gguf(model_name)
|
||||||
|
if existing_path:
|
||||||
|
logger.info(f"Found existing GGUF locally: {existing_path}")
|
||||||
|
return existing_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download, list_repo_files
|
||||||
|
|
||||||
|
model_dir = _model_root()
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
repo_id, filename = _map_gguf_repo(model_name)
|
||||||
|
if repo_id and filename:
|
||||||
|
logger.info(f"Found known model mapping: {repo_id}/{filename}")
|
||||||
|
|
||||||
|
if not repo_id or not filename:
|
||||||
|
logger.error(f"Could not determine Hugging Face repo for model: {model_name}")
|
||||||
|
raise ValueError(f"Unknown model: {model_name}. Please specify a known GGUF model.")
|
||||||
|
|
||||||
|
logger.info(f"Downloading {filename} from {repo_id}...")
|
||||||
|
logger.info(f"This may take several minutes depending on model size and connection speed.")
|
||||||
|
|
||||||
|
downloaded_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
cache_dir=model_cache_dir,
|
||||||
|
local_dir=model_dir,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Model downloaded successfully to: {downloaded_path}")
|
||||||
|
return downloaded_path
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"huggingface_hub not available: {str(e)}")
|
||||||
|
raise ImportError("huggingface_hub is required to download models. Install with: pip install huggingface_hub")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download model from Hugging Face: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model_path(model_path: str) -> str:
|
def _resolve_model_path(model_path: str) -> str:
|
||||||
if not model_path:
|
if not model_path:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
@ -298,14 +418,16 @@ def _prompt_based_pair_extraction(training_data: List[Any], base_model: str) ->
|
||||||
system_prompt = (
|
system_prompt = (
|
||||||
"You are a data extractor. Given a list of items, return a JSON array of training pairs. "
|
"You are a data extractor. Given a list of items, return a JSON array of training pairs. "
|
||||||
"Each pair must have 'instruction' and 'response'. Keep answers concise. "
|
"Each pair must have 'instruction' and 'response'. Keep answers concise. "
|
||||||
"If content is incomplete, still produce best-effort pairs."
|
"If content is incomplete, still produce best-effort pairs. "
|
||||||
|
"End your answer with a complete sentence. Do not start lists or new sections."
|
||||||
)
|
)
|
||||||
|
|
||||||
user_prompt = (
|
user_prompt = (
|
||||||
"Examples of desired output:\n"
|
"Examples of desired output:\n"
|
||||||
f"{json.dumps(example_pairs, ensure_ascii=False, indent=2)}\n\n"
|
f"{json.dumps(example_pairs, ensure_ascii=False, indent=2)}\n\n"
|
||||||
"Now extract training pairs from the following items. Return ONLY a JSON array, no extra text.\n"
|
"Now extract training pairs from the following items. Return ONLY a JSON array, no extra text.\n"
|
||||||
f"Items:\n{data_block}"
|
f"Items:\n{data_block}\n\n"
|
||||||
|
"Answer:"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
@ -385,7 +507,7 @@ def _extract_training_pairs(training_data: List[Any]) -> List[Tuple[str, str]]:
|
||||||
if not training_data:
|
if not training_data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
base_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
pairs = _prompt_based_pair_extraction(training_data, base_model)
|
pairs = _prompt_based_pair_extraction(training_data, base_model)
|
||||||
if not pairs:
|
if not pairs:
|
||||||
|
|
@ -402,6 +524,19 @@ def _extract_training_pairs(training_data: List[Any]) -> List[Tuple[str, str]]:
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_embedding_model(model_name: str):
|
||||||
|
if EMBEDDING_MODEL.get("model") is not None and EMBEDDING_MODEL.get("name") == model_name:
|
||||||
|
return EMBEDDING_MODEL["model"]
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger.info(f"Loading embedding model: {model_name}")
|
||||||
|
model = SentenceTransformer(model_name, cache_folder=model_cache_dir)
|
||||||
|
EMBEDDING_MODEL["model"] = model
|
||||||
|
EMBEDDING_MODEL["name"] = model_name
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def _fine_tune_model_impl(
|
async def _fine_tune_model_impl(
|
||||||
training_files: List[str],
|
training_files: List[str],
|
||||||
hyperparams: Dict[str, Any],
|
hyperparams: Dict[str, Any],
|
||||||
|
|
@ -409,7 +544,7 @@ async def _fine_tune_model_impl(
|
||||||
version: str,
|
version: str,
|
||||||
output_dir: str
|
output_dir: str
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
base_model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
logger.info(f"Starting fine-tune process with base model: {base_model}")
|
logger.info(f"Starting fine-tune process with base model: {base_model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -793,14 +928,47 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
||||||
logger.error("model_path_required error: no model path provided")
|
logger.error("model_path_required error: no model path provided")
|
||||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||||
|
|
||||||
|
original_model_path = model_path
|
||||||
model_path = _resolve_model_path(model_path)
|
model_path = _resolve_model_path(model_path)
|
||||||
logger.debug(f"Resolved model path: {model_path}")
|
logger.debug(f"Resolved model path: {model_path}")
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
logger.error(f"Model not found at: {model_path}")
|
logger.warning(f"Model not found at: {model_path}")
|
||||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
local_mapped = _find_existing_gguf(original_model_path)
|
||||||
|
if local_mapped:
|
||||||
|
logger.info(f"Using existing mapped GGUF: {local_mapped}")
|
||||||
|
model_path = local_mapped
|
||||||
|
else:
|
||||||
|
logger.info(f"Attempting to download model from Hugging Face...")
|
||||||
|
try:
|
||||||
|
model_path = _download_gguf_from_hf(original_model_path)
|
||||||
|
logger.info(f"Model downloaded successfully: {model_path}")
|
||||||
|
except Exception as download_error:
|
||||||
|
logger.error(f"Failed to download model: {str(download_error)}")
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": "model_not_found_and_download_failed",
|
||||||
|
"model_path": original_model_path,
|
||||||
|
"download_error": str(download_error),
|
||||||
|
"timestamp": _now()
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
for loaded_path in list(LOADED_MODELS.keys()):
|
||||||
|
if loaded_path != model_path:
|
||||||
|
logger.info(f"Unloading cached model: {loaded_path}")
|
||||||
|
LOADED_MODELS.pop(loaded_path, None)
|
||||||
|
|
||||||
|
if model_path in LOADED_MODELS and "model" in LOADED_MODELS[model_path]:
|
||||||
|
logger.info(f"Model already loaded: {model_path}")
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"model_path": model_path,
|
||||||
|
"loaded": True,
|
||||||
|
"cached": True,
|
||||||
|
"timestamp": _now(),
|
||||||
|
}
|
||||||
|
|
||||||
from gpt4all import GPT4All
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
model_dir, model_file = _resolve_model_file(model_path)
|
model_dir, model_file = _resolve_model_file(model_path)
|
||||||
|
|
@ -866,16 +1034,38 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
||||||
logger.error("model_path_required error: no model path provided")
|
logger.error("model_path_required error: no model path provided")
|
||||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||||
|
|
||||||
|
original_model_path = model_path
|
||||||
model_path = _resolve_model_path(model_path)
|
model_path = _resolve_model_path(model_path)
|
||||||
logger.debug(f"Resolved model path: {model_path}")
|
logger.debug(f"Resolved model path: {model_path}")
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
logger.error(f"Model not found at: {model_path}")
|
logger.warning(f"Model not found at: {model_path}")
|
||||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
local_mapped = _find_existing_gguf(original_model_path)
|
||||||
|
if local_mapped:
|
||||||
|
logger.info(f"Using existing mapped GGUF: {local_mapped}")
|
||||||
|
model_path = local_mapped
|
||||||
|
else:
|
||||||
|
logger.info(f"Attempting to download model from Hugging Face...")
|
||||||
|
try:
|
||||||
|
model_path = _download_gguf_from_hf(original_model_path)
|
||||||
|
logger.info(f"Model downloaded successfully: {model_path}")
|
||||||
|
except Exception as download_error:
|
||||||
|
logger.error(f"Failed to download model: {str(download_error)}")
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": "model_not_found_and_download_failed",
|
||||||
|
"model_path": original_model_path,
|
||||||
|
"download_error": str(download_error),
|
||||||
|
"timestamp": _now()
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
|
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
|
||||||
logger.info(f"Model not in memory, loading: {model_path}")
|
logger.info(f"Model not in memory, loading: {model_path}")
|
||||||
|
for loaded_path in list(LOADED_MODELS.keys()):
|
||||||
|
if loaded_path != model_path:
|
||||||
|
logger.info(f"Unloading cached model: {loaded_path}")
|
||||||
|
LOADED_MODELS.pop(loaded_path, None)
|
||||||
from gpt4all import GPT4All
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
model_dir, model_file = _resolve_model_file(model_path)
|
model_dir, model_file = _resolve_model_file(model_path)
|
||||||
|
|
@ -952,6 +1142,34 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
||||||
"timestamp": _now(),
|
"timestamp": _now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if name == "embed":
|
||||||
|
texts = arguments.get("texts") or []
|
||||||
|
model_name = arguments.get("model") or "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
|
||||||
|
return {"status": "failed", "error": "texts_must_be_list_of_strings", "timestamp": _now()}
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return {"status": "completed", "embeddings": [], "timestamp": _now()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = _ensure_embedding_model(model_name)
|
||||||
|
embeddings = model.encode(texts).tolist()
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"embeddings": embeddings,
|
||||||
|
"model": model_name,
|
||||||
|
"timestamp": _now(),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding failed: {str(e)}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"timestamp": _now(),
|
||||||
|
}
|
||||||
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1009,6 +1227,7 @@ async def handle_health(request: web.Request) -> web.Response:
|
||||||
|
|
||||||
|
|
||||||
async def run_http_server():
|
async def run_http_server():
|
||||||
|
_init_runtime()
|
||||||
host = os.getenv("MCP_HTTP_HOST", "0.0.0.0")
|
host = os.getenv("MCP_HTTP_HOST", "0.0.0.0")
|
||||||
port = int(os.getenv("MCP_HTTP_PORT", "8001"))
|
port = int(os.getenv("MCP_HTTP_PORT", "8001"))
|
||||||
logger.info(f"Starting HTTP server on {host}:{port}")
|
logger.info(f"Starting HTTP server on {host}:{port}")
|
||||||
|
|
@ -1028,6 +1247,7 @@ async def run_http_server():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
_init_runtime()
|
||||||
logger.info("Starting MCP Server...")
|
logger.info("Starting MCP Server...")
|
||||||
try:
|
try:
|
||||||
asyncio.run(run_http_server())
|
asyncio.run(run_http_server())
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,7 @@ import {
|
||||||
HomeOutlined,
|
HomeOutlined,
|
||||||
InfoCircleOutlined,
|
InfoCircleOutlined,
|
||||||
RocketOutlined,
|
RocketOutlined,
|
||||||
TeamOutlined,
|
|
||||||
RobotOutlined,
|
RobotOutlined,
|
||||||
BulbOutlined,
|
|
||||||
AppstoreOutlined,
|
AppstoreOutlined,
|
||||||
DashboardOutlined,
|
DashboardOutlined,
|
||||||
LoginOutlined,
|
LoginOutlined,
|
||||||
|
|
@ -42,10 +40,8 @@ const navItems: NavItem[] = [
|
||||||
icon: BuildOutlined,
|
icon: BuildOutlined,
|
||||||
path: '/organization',
|
path: '/organization',
|
||||||
children: [
|
children: [
|
||||||
{ key: '/roles', label: 'Roles', icon: TeamOutlined, path: '/roles', manager: true },
|
|
||||||
{ key: '/onboarding', label: 'Onboarding', icon: RocketOutlined, path: '/onboarding' },
|
{ key: '/onboarding', label: 'Onboarding', icon: RocketOutlined, path: '/onboarding' },
|
||||||
{ key: '/progress', label: 'Progress', icon: DashboardOutlined, path: '/progress' },
|
{ key: '/progress', label: 'Progress', icon: DashboardOutlined, path: '/progress' },
|
||||||
{ key: '/assessments', label: 'Assessments', icon: BulbOutlined, path: '/assessments' },
|
|
||||||
{ key: '/resources', label: 'Resources', icon: AppstoreOutlined, path: '/resources' },
|
{ key: '/resources', label: 'Resources', icon: AppstoreOutlined, path: '/resources' },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ export const API = {
|
||||||
organizations: () => '/api/organization/',
|
organizations: () => '/api/organization/',
|
||||||
organization: (id: string) => `/api/organization/${id}/`,
|
organization: (id: string) => `/api/organization/${id}/`,
|
||||||
organizationRoles: (orgUuid: string) => `/api/organization/${orgUuid}/role/`,
|
organizationRoles: (orgUuid: string) => `/api/organization/${orgUuid}/role/`,
|
||||||
|
organizationRolesMine: () => '/api/organization/role/mine/',
|
||||||
organizationRole: (orgUuid: string, roleUuid: string) =>
|
organizationRole: (orgUuid: string, roleUuid: string) =>
|
||||||
`/api/organization/${orgUuid}/role/${roleUuid}/`,
|
`/api/organization/${orgUuid}/role/${roleUuid}/`,
|
||||||
organizationRoleMembers: (orgUuid: string, roleUuid: string) =>
|
organizationRoleMembers: (orgUuid: string, roleUuid: string) =>
|
||||||
|
|
@ -98,6 +99,18 @@ export const API = {
|
||||||
`/api/organization/${orgUuid}/training-file/${fileUuid}/`,
|
`/api/organization/${orgUuid}/training-file/${fileUuid}/`,
|
||||||
agents: () => '/api/agent/',
|
agents: () => '/api/agent/',
|
||||||
agent: (id: string) => `/api/agent/${id}/`,
|
agent: (id: string) => `/api/agent/${id}/`,
|
||||||
|
onboardingFlows: () => '/api/onboarding/flow/',
|
||||||
|
onboardingFlow: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/`,
|
||||||
|
onboardingFlowPages: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/pages/`,
|
||||||
|
onboardingFlowGenerate: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/generate/`,
|
||||||
|
onboardingFlowPublish: (flowUuid: string) => `/api/onboarding/flow/${flowUuid}/publish/`,
|
||||||
|
onboardingSessions: () => '/api/onboarding/session/',
|
||||||
|
onboardingSessionGetOrCreate: () => '/api/onboarding/session/get_or_create/',
|
||||||
|
onboardingSession: (sessionUuid: string) => `/api/onboarding/session/${sessionUuid}/`,
|
||||||
|
onboardingSessionSubmit: (sessionUuid: string) =>
|
||||||
|
`/api/onboarding/session/${sessionUuid}/submit/`,
|
||||||
|
onboardingSessionFeedback: (sessionUuid: string) =>
|
||||||
|
`/api/onboarding/session/${sessionUuid}/feedback/`,
|
||||||
}
|
}
|
||||||
|
|
||||||
export const apiClient = new ApiClient()
|
export const apiClient = new ApiClient()
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,18 @@ const router = createRouter({
|
||||||
component: () => import('../views/AgentDetailView.vue'),
|
component: () => import('../views/AgentDetailView.vue'),
|
||||||
meta: { requiresAuth: true, requiresManager: true },
|
meta: { requiresAuth: true, requiresManager: true },
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: '/onboarding',
|
||||||
|
name: 'onboarding',
|
||||||
|
component: () => import('../views/OnboardingView.vue'),
|
||||||
|
meta: { requiresAuth: true },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
path: '/onboarding/:roleId',
|
||||||
|
name: 'onboarding-role',
|
||||||
|
component: () => import('../views/OnboardingView.vue'),
|
||||||
|
meta: { requiresAuth: true },
|
||||||
|
},
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,15 +97,26 @@ export const useAgentStore = defineStore('agent', () => {
|
||||||
isConnected.value = false
|
isConnected.value = false
|
||||||
}
|
}
|
||||||
|
|
||||||
const startAgent = (data: { query?: string; prompt?: string; options?: Record<string, unknown> }) => {
|
const startAgent = (data: {
|
||||||
|
query?: string
|
||||||
|
prompt?: string
|
||||||
|
role_uuid?: string
|
||||||
|
max_tokens?: number
|
||||||
|
options?: Record<string, unknown>
|
||||||
|
}) => {
|
||||||
if (!socket || socket.readyState !== WebSocket.OPEN) return
|
if (!socket || socket.readyState !== WebSocket.OPEN) return
|
||||||
const prompt = data.query ?? data.prompt ?? ''
|
const prompt = data.query ?? data.prompt ?? ''
|
||||||
|
const options = {
|
||||||
|
...(data.options ?? {}),
|
||||||
|
...(typeof data.max_tokens === 'number' ? { max_tokens: data.max_tokens } : {}),
|
||||||
|
}
|
||||||
socket.send(
|
socket.send(
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
action: 'infer',
|
action: 'infer',
|
||||||
input_data: {
|
input_data: {
|
||||||
prompt,
|
prompt,
|
||||||
options: data.options ?? {},
|
role_uuid: data.role_uuid,
|
||||||
|
options,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
@ -131,6 +142,17 @@ export const useAgentStore = defineStore('agent', () => {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const sendOnboardingProgress = (executionId: string, content: Record<string, unknown>) => {
|
||||||
|
if (!socket || socket.readyState !== WebSocket.OPEN) return
|
||||||
|
socket.send(
|
||||||
|
JSON.stringify({
|
||||||
|
action: 'onboarding_progress',
|
||||||
|
execution_id: executionId,
|
||||||
|
content,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const resetLog = () => {
|
const resetLog = () => {
|
||||||
eventLog.value = []
|
eventLog.value = []
|
||||||
}
|
}
|
||||||
|
|
@ -144,6 +166,7 @@ export const useAgentStore = defineStore('agent', () => {
|
||||||
startAgent,
|
startAgent,
|
||||||
startFineTune,
|
startFineTune,
|
||||||
stopAgent,
|
stopAgent,
|
||||||
|
sendOnboardingProgress,
|
||||||
resetLog,
|
resetLog,
|
||||||
lastExecutionId,
|
lastExecutionId,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
44
src/types/onboarding.ts
Normal file
44
src/types/onboarding.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
export type OnboardingField = {
|
||||||
|
uuid: string
|
||||||
|
key: string
|
||||||
|
label: string
|
||||||
|
field_type: 'text' | 'textarea' | 'select' | 'multiselect' | 'number' | 'boolean' | 'date'
|
||||||
|
required: boolean
|
||||||
|
help_text?: string
|
||||||
|
placeholder?: string
|
||||||
|
options?: string[]
|
||||||
|
default_value?: unknown
|
||||||
|
validation?: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export type OnboardingPage = {
|
||||||
|
uuid: string
|
||||||
|
order: number
|
||||||
|
title: string
|
||||||
|
body?: string
|
||||||
|
meta?: Record<string, unknown>
|
||||||
|
fields: OnboardingField[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export type OnboardingFlow = {
|
||||||
|
uuid: string
|
||||||
|
role: string
|
||||||
|
agent?: string | null
|
||||||
|
title: string
|
||||||
|
description?: string
|
||||||
|
status: 'draft' | 'published' | 'archived'
|
||||||
|
pages?: OnboardingPage[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export type OnboardingSession = {
|
||||||
|
uuid: string
|
||||||
|
flow: string
|
||||||
|
agent_run?: string | null
|
||||||
|
status: 'in_progress' | 'completed' | 'abandoned'
|
||||||
|
current_page_order: number
|
||||||
|
responses: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
export type OnboardingFeedback = {
|
||||||
|
summary?: string
|
||||||
|
}
|
||||||
|
|
@ -37,7 +37,7 @@ export interface InviteToken {
|
||||||
export interface TrainingFile {
|
export interface TrainingFile {
|
||||||
id: number
|
id: number
|
||||||
uuid: string
|
uuid: string
|
||||||
organization: string
|
role: Role
|
||||||
uploaded_by: User
|
uploaded_by: User
|
||||||
file: string
|
file: string
|
||||||
file_name: string
|
file_name: string
|
||||||
|
|
@ -45,6 +45,7 @@ export interface TrainingFile {
|
||||||
file_type: string
|
file_type: string
|
||||||
description: string
|
description: string
|
||||||
is_processed: boolean
|
is_processed: boolean
|
||||||
|
status: 'ingesting' | 'chunked' | 'embedded' | 'failed'
|
||||||
file_url: string
|
file_url: string
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,15 @@
|
||||||
import { Card, Typography, Divider, List } from 'ant-design-vue'
|
import { Card, Typography, Divider, List } from 'ant-design-vue'
|
||||||
|
|
||||||
const steps = [
|
const steps = [
|
||||||
'Register or login (demo credentials only).',
|
'Register or login.',
|
||||||
'Complete Onboarding and Training to simulate a role journey.',
|
'Complete onboarding and training to simulate a role journey.',
|
||||||
'Managers assign employees to roles and review progress reports.',
|
'Managers review onboarding completion and feedback.',
|
||||||
]
|
]
|
||||||
|
|
||||||
const features = [
|
const features = [
|
||||||
{
|
{
|
||||||
title: 'Modular Content',
|
title: 'Modular Content',
|
||||||
desc: 'Compose learning journeys from small, reusable modules — mix assessments, videos and interactive checks.',
|
desc: 'Compose learning journeys from small, reusable modules. Mix videos and interactive checks.',
|
||||||
img: 'https://placehold.co/600x400?text=Modular+Content',
|
img: 'https://placehold.co/600x400?text=Modular+Content',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -20,7 +20,7 @@ const features = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: 'Reporting & Insights',
|
title: 'Reporting & Insights',
|
||||||
desc: 'Lightweight reports showing completion, assessment scores and engagement metrics.',
|
desc: 'Lightweight reports showing completion and engagement metrics.',
|
||||||
img: 'https://placehold.co/600x400?text=Reporting',
|
img: 'https://placehold.co/600x400?text=Reporting',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
@ -33,7 +33,7 @@ const features = [
|
||||||
<Typography.Paragraph type="secondary">
|
<Typography.Paragraph type="secondary">
|
||||||
Dynavera is a lightweight platform for onboarding, training, and assessing employees
|
Dynavera is a lightweight platform for onboarding, training, and assessing employees
|
||||||
with modular content and agent-driven workflows. It is designed for teams that want
|
with modular content and agent-driven workflows. It is designed for teams that want
|
||||||
angible learning experiences quickly without complex LMS setup.
|
tangible learning experiences quickly without complex LMS setup.
|
||||||
</Typography.Paragraph>
|
</Typography.Paragraph>
|
||||||
<Divider />
|
<Divider />
|
||||||
<Typography.Title :level="4">Getting started</Typography.Title>
|
<Typography.Title :level="4">Getting started</Typography.Title>
|
||||||
|
|
@ -59,11 +59,11 @@ const features = [
|
||||||
<Divider />
|
<Divider />
|
||||||
<Typography.Title :level="4">More about Dynavera</Typography.Title>
|
<Typography.Title :level="4">More about Dynavera</Typography.Title>
|
||||||
<Typography.Paragraph>
|
<Typography.Paragraph>
|
||||||
Dynavera is built to be extensible — plug your content sources, connect an LMS, or
|
Dynavera is built to be extensible. Plug your content sources, connect an LMS, or
|
||||||
integrate third-party assessment engines. The platform focuses on flexibility and
|
integrate third-party learning tools. The platform focuses on flexibility and ease
|
||||||
ease of use, so teams can get started quickly and adapt as their needs evolve.
|
of use, so teams can get started quickly and adapt as their needs evolve. Whether
|
||||||
Whether you’re a small startup or a growing enterprise, Dynavera aims to simplify
|
you are a small startup or a growing enterprise, Dynavera aims to simplify the
|
||||||
the process of onboarding and training with modern, agent-driven approaches.
|
process of onboarding and training with modern, agent-driven approaches.
|
||||||
</Typography.Paragraph>
|
</Typography.Paragraph>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
import { ref, onMounted, onUnmounted, computed } from 'vue'
|
||||||
import { useRoute } from 'vue-router'
|
import { useRoute } from 'vue-router'
|
||||||
import { Card, Typography, Button, List, Space, Spin, Input, message, Tag } from 'ant-design-vue'
|
import { Card, Typography, Button, List, Space, Spin, Input, message, Tag, Select, InputNumber } from 'ant-design-vue'
|
||||||
import { useAgentStore } from '../stores/agentStore'
|
import { useAgentStore } from '../stores/agentStore'
|
||||||
import { apiClient, isAxiosError, API } from '../router/api'
|
import { apiClient, isAxiosError, API } from '../router/api'
|
||||||
|
import type { Role } from '../types/organization'
|
||||||
|
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
const agentStore = useAgentStore()
|
const agentStore = useAgentStore()
|
||||||
|
|
@ -17,6 +18,10 @@ const agent = ref<Record<string, unknown>>({
|
||||||
status: 'idle',
|
status: 'idle',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const roles = ref<Role[]>([])
|
||||||
|
const selectedRoleUuid = ref('')
|
||||||
|
const maxTokens = ref<number>(256)
|
||||||
|
|
||||||
const queryInput = ref('')
|
const queryInput = ref('')
|
||||||
const isRunning = computed(() => agentStore.executionStatus === 'running')
|
const isRunning = computed(() => agentStore.executionStatus === 'running')
|
||||||
const isConnected = computed(() => agentStore.isConnected ?? false)
|
const isConnected = computed(() => agentStore.isConnected ?? false)
|
||||||
|
|
@ -52,6 +57,16 @@ const fetchAgent = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId))
|
const response = await apiClient.get<Record<string, unknown>>(API.agent(agentId))
|
||||||
agent.value = response.data
|
agent.value = response.data
|
||||||
|
const org = agent.value.organization as Record<string, unknown> | null | undefined
|
||||||
|
const orgUuid = org?.uuid as string | undefined
|
||||||
|
console.log('Agent loaded:', agent.value)
|
||||||
|
console.log('Organization:', org)
|
||||||
|
console.log('Organization UUID:', orgUuid)
|
||||||
|
if (orgUuid) {
|
||||||
|
await fetchRoles(orgUuid)
|
||||||
|
} else {
|
||||||
|
console.warn('No organization UUID found for agent')
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to fetch agent:', error)
|
console.error('Failed to fetch agent:', error)
|
||||||
if (isAxiosError(error)) {
|
if (isAxiosError(error)) {
|
||||||
|
|
@ -65,6 +80,32 @@ const fetchAgent = async () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fetchRoles = async (orgUuid: string) => {
|
||||||
|
try {
|
||||||
|
console.log('Fetching roles for organization:', orgUuid)
|
||||||
|
const response = await apiClient.get<Role[]>(API.organizationRoles(orgUuid))
|
||||||
|
console.log('Roles loaded:', response.data)
|
||||||
|
roles.value = response.data
|
||||||
|
if (!selectedRoleUuid.value && roles.value.length > 0) {
|
||||||
|
selectedRoleUuid.value = roles.value[0].uuid
|
||||||
|
console.log('Auto-selected role:', selectedRoleUuid.value)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to fetch roles:', error)
|
||||||
|
if (isAxiosError(error)) {
|
||||||
|
console.error('Roles API error:', {
|
||||||
|
status: error.response?.status,
|
||||||
|
data: error.response?.data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
message.error('Failed to load roles')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const roleOptions = computed(() =>
|
||||||
|
roles.value.map((role) => ({ label: role.name, value: role.uuid })),
|
||||||
|
)
|
||||||
|
|
||||||
const startAgent = () => {
|
const startAgent = () => {
|
||||||
if (!agentStore.isConnected) {
|
if (!agentStore.isConnected) {
|
||||||
message.error('WebSocket not connected')
|
message.error('WebSocket not connected')
|
||||||
|
|
@ -76,8 +117,15 @@ const startAgent = () => {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!selectedRoleUuid.value) {
|
||||||
|
message.error('Please select a role for this query')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const data = {
|
const data = {
|
||||||
query: queryInput.value.trim(),
|
query: queryInput.value.trim(),
|
||||||
|
role_uuid: selectedRoleUuid.value,
|
||||||
|
max_tokens: maxTokens.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
agentStore.startAgent(data)
|
agentStore.startAgent(data)
|
||||||
|
|
@ -144,6 +192,28 @@ onUnmounted(() => {
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Typography.Text>Role:</Typography.Text>
|
||||||
|
<Select
|
||||||
|
v-model:value="selectedRoleUuid"
|
||||||
|
:options="roleOptions"
|
||||||
|
:disabled="isRunning"
|
||||||
|
placeholder="Select a role"
|
||||||
|
style="width: 100%"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Typography.Text>Max Tokens:</Typography.Text>
|
||||||
|
<InputNumber
|
||||||
|
v-model:value="maxTokens"
|
||||||
|
:min="1"
|
||||||
|
:max="4096"
|
||||||
|
:disabled="isRunning"
|
||||||
|
style="width: 100%"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<Space>
|
<Space>
|
||||||
<Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent">
|
<Button type="primary" :disabled="isRunning || !isConnected" @click="startAgent">
|
||||||
Run Agent
|
Run Agent
|
||||||
|
|
|
||||||
722
src/views/OnboardingView.vue
Normal file
722
src/views/OnboardingView.vue
Normal file
|
|
@ -0,0 +1,722 @@
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, onMounted, onUnmounted, watch } from 'vue'
|
||||||
|
import type { Dayjs } from 'dayjs'
|
||||||
|
import { useRoute, useRouter } from 'vue-router'
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
Typography,
|
||||||
|
Button,
|
||||||
|
Space,
|
||||||
|
Spin,
|
||||||
|
Select,
|
||||||
|
Form,
|
||||||
|
Input,
|
||||||
|
InputNumber,
|
||||||
|
Switch,
|
||||||
|
DatePicker,
|
||||||
|
Divider,
|
||||||
|
List,
|
||||||
|
message,
|
||||||
|
} from 'ant-design-vue'
|
||||||
|
import { apiClient, API } from '../router/api'
|
||||||
|
import { useUserStore } from '../stores/userStore'
|
||||||
|
import { useAgentStore } from '../stores/agentStore'
|
||||||
|
import type { Role } from '../types/organization'
|
||||||
|
import type { OnboardingFlow, OnboardingPage, OnboardingSession, OnboardingFeedback } from '../types/onboarding'
|
||||||
|
|
||||||
|
const route = useRoute()
|
||||||
|
const router = useRouter()
|
||||||
|
const userStore = useUserStore()
|
||||||
|
const agentStore = useAgentStore()
|
||||||
|
|
||||||
|
const roleId = computed(() => (route.params.roleId as string | undefined) || undefined)
|
||||||
|
const flows = ref<OnboardingFlow[]>([])
|
||||||
|
type SingleSelectValue = string | number | undefined
|
||||||
|
const selectedFlowUuid = ref<SingleSelectValue>(undefined)
|
||||||
|
const flowDetails = ref<OnboardingFlow | null>(null)
|
||||||
|
const session = ref<OnboardingSession | null>(null)
|
||||||
|
const currentPageIndex = ref(0)
|
||||||
|
const loading = ref(false)
|
||||||
|
const generating = ref(false)
|
||||||
|
const loadError = ref<string | null>(null)
|
||||||
|
const generateInstructions = ref('')
|
||||||
|
const creatingFlow = ref(false)
|
||||||
|
const resettingSession = ref(false)
|
||||||
|
const publishing = ref(false)
|
||||||
|
const agents = ref<Array<{ uuid: string; name: string }>>([])
|
||||||
|
const userRoles = ref<Role[]>([])
|
||||||
|
const roleLoading = ref(false)
|
||||||
|
const roleLoadError = ref<string | null>(null)
|
||||||
|
const createFlowForm = ref({
|
||||||
|
title: '',
|
||||||
|
description: '',
|
||||||
|
agent: undefined as SingleSelectValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
const isManager = computed(() => userStore.isGeneralManager)
|
||||||
|
const hasRoleContext = computed(() => Boolean(roleId.value))
|
||||||
|
const pages = computed<OnboardingPage[]>(() => flowDetails.value?.pages ?? [])
|
||||||
|
const currentPage = computed<OnboardingPage | null>(() => pages.value[currentPageIndex.value] || null)
|
||||||
|
const hasNext = computed(() => currentPageIndex.value < pages.value.length - 1)
|
||||||
|
const hasPrev = computed(() => currentPageIndex.value > 0)
|
||||||
|
|
||||||
|
const formState = ref<Record<string, unknown>>({})
|
||||||
|
const feedbackLoading = ref(false)
|
||||||
|
const feedbackByPage = ref<Record<string, OnboardingFeedback>>({})
|
||||||
|
|
||||||
|
const getFieldValue = (key: string) => formState.value[key]
|
||||||
|
const setFieldValue = (key: string, value: unknown) => {
|
||||||
|
formState.value[key] = value
|
||||||
|
}
|
||||||
|
const toSingleValue = (value: unknown): SingleSelectValue => {
|
||||||
|
if (Array.isArray(value)) {
|
||||||
|
return value.length ? (value[0] as string | number) : undefined
|
||||||
|
}
|
||||||
|
if (typeof value === 'string' || typeof value === 'number') {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
const toNumberValue = (value: unknown): number | undefined =>
|
||||||
|
typeof value === 'number' ? value : undefined
|
||||||
|
const toDateValue = (value: unknown): Dayjs | string | undefined =>
|
||||||
|
typeof value === 'string' || (value as Dayjs | undefined)?.isValid ? (value as Dayjs | string) : undefined
|
||||||
|
|
||||||
|
const fetchFlows = async () => {
|
||||||
|
loading.value = true
|
||||||
|
loadError.value = null
|
||||||
|
try {
|
||||||
|
const params: Record<string, string> = {}
|
||||||
|
if (roleId.value) params.role = roleId.value
|
||||||
|
if (!isManager.value) params.status = 'published'
|
||||||
|
const response = await apiClient.get<OnboardingFlow[]>(API.onboardingFlows(), { params })
|
||||||
|
flows.value = Array.isArray(response.data) ? response.data : []
|
||||||
|
if (flows.value.length === 1) {
|
||||||
|
selectedFlowUuid.value = flows.value[0].uuid
|
||||||
|
}
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to load onboarding flows', err)
|
||||||
|
loadError.value = 'Failed to load onboarding flows'
|
||||||
|
flows.value = []
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetchAgents = async () => {
|
||||||
|
try {
|
||||||
|
const response = await apiClient.get<Array<{ uuid: string; name: string }>>(API.agents())
|
||||||
|
agents.value = Array.isArray(response.data) ? response.data : []
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to load agents', err)
|
||||||
|
agents.value = []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetchUserRoles = async () => {
|
||||||
|
roleLoading.value = true
|
||||||
|
roleLoadError.value = null
|
||||||
|
try {
|
||||||
|
const response = await apiClient.get<Role[]>(API.organizationRolesMine())
|
||||||
|
userRoles.value = Array.isArray(response.data) ? response.data : []
|
||||||
|
if (userRoles.value.length === 1) {
|
||||||
|
await router.replace(`/onboarding/${userRoles.value[0].uuid}`)
|
||||||
|
}
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to load user roles', err)
|
||||||
|
roleLoadError.value = 'Failed to load roles'
|
||||||
|
userRoles.value = []
|
||||||
|
} finally {
|
||||||
|
roleLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetchFlowDetails = async (flowUuid: string) => {
|
||||||
|
loading.value = true
|
||||||
|
loadError.value = null
|
||||||
|
try {
|
||||||
|
const response = await apiClient.get<OnboardingFlow>(API.onboardingFlowPages(flowUuid))
|
||||||
|
flowDetails.value = response.data
|
||||||
|
await ensureSession(flowUuid)
|
||||||
|
syncPageIndexFromSession()
|
||||||
|
hydrateFormState()
|
||||||
|
if (flowDetails.value?.agent) {
|
||||||
|
agentStore.connect(flowDetails.value.agent)
|
||||||
|
}
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to load onboarding flow', err)
|
||||||
|
loadError.value = 'Failed to load onboarding flow'
|
||||||
|
flowDetails.value = null
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ensureSession = async (flowUuid: string) => {
|
||||||
|
if (session.value && session.value.flow === flowUuid) return
|
||||||
|
const response = await apiClient.post<OnboardingSession>(API.onboardingSessionGetOrCreate(), {
|
||||||
|
flow: flowUuid,
|
||||||
|
})
|
||||||
|
session.value = response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
const syncPageIndexFromSession = () => {
|
||||||
|
const order = session.value?.current_page_order
|
||||||
|
if (typeof order === 'number' && order >= 0) {
|
||||||
|
currentPageIndex.value = Math.min(order, Math.max(pages.value.length - 1, 0))
|
||||||
|
} else {
|
||||||
|
currentPageIndex.value = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const hydrateFormState = () => {
|
||||||
|
if (!currentPage.value) {
|
||||||
|
formState.value = {}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const existing = session.value?.responses?.[currentPage.value.uuid] || {}
|
||||||
|
const values: Record<string, unknown> = {}
|
||||||
|
currentPage.value.fields.forEach((field) => {
|
||||||
|
const existingValue = (existing as Record<string, unknown>)[field.key]
|
||||||
|
if (existingValue !== undefined) {
|
||||||
|
values[field.key] = existingValue
|
||||||
|
} else if (field.default_value !== undefined && field.default_value !== null) {
|
||||||
|
values[field.key] = field.default_value
|
||||||
|
} else if (field.field_type === 'boolean') {
|
||||||
|
values[field.key] = false
|
||||||
|
} else if (field.field_type === 'multiselect') {
|
||||||
|
values[field.key] = []
|
||||||
|
} else {
|
||||||
|
values[field.key] = ''
|
||||||
|
}
|
||||||
|
})
|
||||||
|
formState.value = values
|
||||||
|
|
||||||
|
const storedFeedback = (session.value?.responses as Record<string, unknown> | undefined)?.[
|
||||||
|
'__feedback__'
|
||||||
|
] as Record<string, { feedback?: OnboardingFeedback }> | undefined
|
||||||
|
if (storedFeedback && currentPage.value) {
|
||||||
|
const pageFeedback = storedFeedback[currentPage.value.uuid]?.feedback
|
||||||
|
if (pageFeedback) {
|
||||||
|
feedbackByPage.value[currentPage.value.uuid] = pageFeedback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizeResponses = (raw: Record<string, unknown>) => {
|
||||||
|
const normalized: Record<string, unknown> = {}
|
||||||
|
Object.entries(raw).forEach(([key, value]) => {
|
||||||
|
if (value && typeof value === 'object') {
|
||||||
|
const maybeDate = value as { toISOString?: () => string; format?: (fmt: string) => string }
|
||||||
|
if (typeof maybeDate.toISOString === 'function') {
|
||||||
|
normalized[key] = maybeDate.toISOString()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (typeof maybeDate.format === 'function') {
|
||||||
|
normalized[key] = maybeDate.format('YYYY-MM-DD')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized[key] = value
|
||||||
|
})
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
const onSubmitPage = async () => {
|
||||||
|
if (!currentPage.value || !session.value) return
|
||||||
|
try {
|
||||||
|
const normalizedResponses = normalizeResponses(formState.value)
|
||||||
|
const payload = {
|
||||||
|
page_uuid: currentPage.value.uuid,
|
||||||
|
responses: normalizedResponses,
|
||||||
|
mark_complete: !hasNext.value,
|
||||||
|
}
|
||||||
|
const response = await apiClient.post<OnboardingSession>(
|
||||||
|
API.onboardingSessionSubmit(session.value.uuid),
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
session.value = response.data
|
||||||
|
|
||||||
|
if (session.value.agent_run) {
|
||||||
|
agentStore.sendOnboardingProgress(session.value.agent_run, {
|
||||||
|
flow_uuid: session.value.flow,
|
||||||
|
session_uuid: session.value.uuid,
|
||||||
|
page_uuid: currentPage.value.uuid,
|
||||||
|
page_order: currentPageIndex.value,
|
||||||
|
responses: normalizedResponses,
|
||||||
|
status: session.value.status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasNext.value) {
|
||||||
|
currentPageIndex.value += 1
|
||||||
|
hydrateFormState()
|
||||||
|
} else {
|
||||||
|
message.success('Onboarding completed')
|
||||||
|
}
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to submit onboarding page', err)
|
||||||
|
message.error('Failed to save onboarding progress')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const requestFeedback = async () => {
|
||||||
|
if (!currentPage.value || !session.value) return
|
||||||
|
feedbackLoading.value = true
|
||||||
|
try {
|
||||||
|
const normalizedResponses = normalizeResponses(formState.value)
|
||||||
|
const response = await apiClient.post<{ feedback: OnboardingFeedback; session: OnboardingSession }>(
|
||||||
|
API.onboardingSessionFeedback(session.value.uuid),
|
||||||
|
{
|
||||||
|
page_uuid: currentPage.value.uuid,
|
||||||
|
responses: normalizedResponses,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session.value = response.data.session
|
||||||
|
feedbackByPage.value[currentPage.value.uuid] = response.data.feedback
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to get feedback', err)
|
||||||
|
message.error('Failed to get feedback')
|
||||||
|
} finally {
|
||||||
|
feedbackLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const generateWithAgent = async () => {
|
||||||
|
if (!flowDetails.value) return
|
||||||
|
generating.value = true
|
||||||
|
try {
|
||||||
|
const response = await apiClient.post<OnboardingFlow>(
|
||||||
|
API.onboardingFlowGenerate(flowDetails.value.uuid),
|
||||||
|
{
|
||||||
|
instructions: generateInstructions.value,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
flowDetails.value = response.data
|
||||||
|
currentPageIndex.value = 0
|
||||||
|
hydrateFormState()
|
||||||
|
message.success('Generated onboarding content')
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to generate onboarding flow', err)
|
||||||
|
message.error('Failed to generate onboarding content')
|
||||||
|
} finally {
|
||||||
|
generating.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const publishFlow = async () => {
|
||||||
|
if (!flowDetails.value) return
|
||||||
|
publishing.value = true
|
||||||
|
try {
|
||||||
|
const response = await apiClient.post<OnboardingFlow>(
|
||||||
|
API.onboardingFlowPublish(flowDetails.value.uuid),
|
||||||
|
)
|
||||||
|
flowDetails.value = response.data
|
||||||
|
flows.value = flows.value.map((flow) =>
|
||||||
|
flow.uuid === response.data.uuid ? { ...flow, status: response.data.status } : flow,
|
||||||
|
)
|
||||||
|
message.success('Flow published')
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to publish onboarding flow', err)
|
||||||
|
message.error('Failed to publish onboarding flow')
|
||||||
|
} finally {
|
||||||
|
publishing.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const createFlow = async () => {
|
||||||
|
if (!roleId.value) {
|
||||||
|
message.error('Role is required to create a flow')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (!createFlowForm.value.title.trim()) {
|
||||||
|
message.error('Title is required')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
creatingFlow.value = true
|
||||||
|
try {
|
||||||
|
const payload = {
|
||||||
|
role: roleId.value,
|
||||||
|
agent:
|
||||||
|
createFlowForm.value.agent !== undefined && createFlowForm.value.agent !== null
|
||||||
|
? String(createFlowForm.value.agent)
|
||||||
|
: null,
|
||||||
|
title: createFlowForm.value.title.trim(),
|
||||||
|
description: createFlowForm.value.description.trim(),
|
||||||
|
status: 'draft',
|
||||||
|
}
|
||||||
|
const response = await apiClient.post<OnboardingFlow>(API.onboardingFlows(), payload)
|
||||||
|
flows.value = [response.data, ...flows.value]
|
||||||
|
selectedFlowUuid.value = response.data.uuid
|
||||||
|
createFlowForm.value = {
|
||||||
|
title: '',
|
||||||
|
description: '',
|
||||||
|
agent: undefined,
|
||||||
|
}
|
||||||
|
message.success('Onboarding flow created')
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to create onboarding flow', err)
|
||||||
|
message.error('Failed to create onboarding flow')
|
||||||
|
} finally {
|
||||||
|
creatingFlow.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const goPrev = () => {
|
||||||
|
if (hasPrev.value) {
|
||||||
|
currentPageIndex.value -= 1
|
||||||
|
hydrateFormState()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const resetSession = async () => {
|
||||||
|
if (!flowDetails.value) return
|
||||||
|
resettingSession.value = true
|
||||||
|
try {
|
||||||
|
const response = await apiClient.post<OnboardingSession>(API.onboardingSessions(), {
|
||||||
|
flow: flowDetails.value.uuid,
|
||||||
|
})
|
||||||
|
session.value = response.data
|
||||||
|
currentPageIndex.value = 0
|
||||||
|
feedbackByPage.value = {}
|
||||||
|
hydrateFormState()
|
||||||
|
message.success('Started a new onboarding session')
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error('Failed to reset onboarding session', err)
|
||||||
|
message.error('Failed to start a new onboarding session')
|
||||||
|
} finally {
|
||||||
|
resettingSession.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
watch(selectedFlowUuid, async (value) => {
|
||||||
|
if (value) {
|
||||||
|
await fetchFlowDetails(String(value))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
watch(roleId, async () => {
|
||||||
|
selectedFlowUuid.value = undefined
|
||||||
|
flowDetails.value = null
|
||||||
|
session.value = null
|
||||||
|
userRoles.value = []
|
||||||
|
if (!roleId.value) {
|
||||||
|
await fetchUserRoles()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
await fetchFlows()
|
||||||
|
if (selectedFlowUuid.value) {
|
||||||
|
await fetchFlowDetails(String(selectedFlowUuid.value))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
watch(currentPageIndex, hydrateFormState)
|
||||||
|
|
||||||
|
onUnmounted(() => {
|
||||||
|
agentStore.disconnect()
|
||||||
|
})
|
||||||
|
|
||||||
|
onMounted(async () => {
|
||||||
|
await userStore.fetchSession()
|
||||||
|
if (isManager.value) {
|
||||||
|
await fetchAgents()
|
||||||
|
}
|
||||||
|
if (!roleId.value) {
|
||||||
|
await fetchUserRoles()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
await fetchFlows()
|
||||||
|
if (selectedFlowUuid.value) {
|
||||||
|
await fetchFlowDetails(String(selectedFlowUuid.value))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<div class="page">
|
||||||
|
<Spin :spinning="loading" tip="Loading onboarding...">
|
||||||
|
<Card class="panel" :bordered="false">
|
||||||
|
<div class="header">
|
||||||
|
<Typography.Title :level="2">Onboarding</Typography.Title>
|
||||||
|
<Space v-if="flowDetails && isManager">
|
||||||
|
<Input
|
||||||
|
v-model:value="generateInstructions"
|
||||||
|
placeholder="Optional generation instructions"
|
||||||
|
style="width: 320px"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
:loading="generating"
|
||||||
|
:disabled="!flowDetails?.agent"
|
||||||
|
@click="generateWithAgent"
|
||||||
|
>
|
||||||
|
Generate with Agent
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
:loading="publishing"
|
||||||
|
:disabled="flowDetails?.status === 'published'"
|
||||||
|
@click="publishFlow"
|
||||||
|
>
|
||||||
|
{{ flowDetails?.status === 'published' ? 'Published' : 'Publish' }}
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Typography.Paragraph v-if="loadError" type="danger">{{ loadError }}</Typography.Paragraph>
|
||||||
|
|
||||||
|
<div v-if="!hasRoleContext">
|
||||||
|
<Typography.Paragraph type="secondary">
|
||||||
|
Choose a role to continue onboarding.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<Typography.Paragraph v-if="roleLoadError" type="danger">
|
||||||
|
{{ roleLoadError }}
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<Spin :spinning="roleLoading">
|
||||||
|
<div v-if="userRoles.length > 0" class="role-list">
|
||||||
|
<List :data-source="userRoles" :bordered="false">
|
||||||
|
<template #renderItem="{ item }">
|
||||||
|
<List.Item class="role-item">
|
||||||
|
<List.Item.Meta
|
||||||
|
:title="item.name"
|
||||||
|
:description="item.organization?.name || 'Organization'"
|
||||||
|
/>
|
||||||
|
<Button type="primary" @click="router.push(`/onboarding/${item.uuid}`)">
|
||||||
|
Open
|
||||||
|
</Button>
|
||||||
|
</List.Item>
|
||||||
|
</template>
|
||||||
|
</List>
|
||||||
|
</div>
|
||||||
|
<Typography.Paragraph v-else type="secondary">
|
||||||
|
You are not a member of any roles yet.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
</Spin>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="flows.length > 1" class="flow-select">
|
||||||
|
<Typography.Text strong>Select a flow</Typography.Text>
|
||||||
|
<Select
|
||||||
|
:value="selectedFlowUuid"
|
||||||
|
@update:value="(val) => (selectedFlowUuid = toSingleValue(val))"
|
||||||
|
placeholder="Choose a flow"
|
||||||
|
style="width: 320px"
|
||||||
|
:options="flows.map((f) => ({ value: f.uuid, label: f.title }))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-else-if="flows.length === 0">
|
||||||
|
<Typography.Paragraph type="secondary">
|
||||||
|
No onboarding flows available yet.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<div v-if="isManager" class="create-flow">
|
||||||
|
<Typography.Text strong>Create a flow</Typography.Text>
|
||||||
|
<Typography.Paragraph type="secondary">
|
||||||
|
Create a draft flow and then generate pages with an agent.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<Typography.Paragraph v-if="!roleId" type="warning">
|
||||||
|
Open this page with a role id to create a flow.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<Form layout="vertical" :model="createFlowForm" @finish="createFlow">
|
||||||
|
<Form.Item
|
||||||
|
label="Title"
|
||||||
|
name="title"
|
||||||
|
:rules="[{ required: true, message: 'Title is required' }]"
|
||||||
|
>
|
||||||
|
<Input v-model:value="createFlowForm.title" placeholder="Flow title" />
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item label="Description" name="description">
|
||||||
|
<Input
|
||||||
|
v-model:value="createFlowForm.description"
|
||||||
|
placeholder="Short description"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Form.Item label="Agent (optional)" name="agent">
|
||||||
|
<Select
|
||||||
|
:value="createFlowForm.agent"
|
||||||
|
@update:value="(val) => (createFlowForm.agent = toSingleValue(val))"
|
||||||
|
allow-clear
|
||||||
|
placeholder="Select an agent"
|
||||||
|
:options="agents.map((a) => ({ value: a.uuid, label: a.name }))"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
html-type="submit"
|
||||||
|
:loading="creatingFlow"
|
||||||
|
:disabled="!roleId"
|
||||||
|
>
|
||||||
|
Create draft flow
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
<Typography.Paragraph v-if="!agents.length" type="secondary">
|
||||||
|
No agents available yet. You can create a flow now and assign an agent later.
|
||||||
|
</Typography.Paragraph>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="flowDetails" class="flow-body">
|
||||||
|
<Typography.Title :level="3">{{ flowDetails.title }}</Typography.Title>
|
||||||
|
<Typography.Paragraph type="secondary">
|
||||||
|
{{ flowDetails.description || 'No description provided.' }}
|
||||||
|
</Typography.Paragraph>
|
||||||
|
<Space v-if="session" class="session-actions">
|
||||||
|
<Button danger :loading="resettingSession" @click="resetSession">
|
||||||
|
Start New Session
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<div v-if="currentPage" class="page-body">
|
||||||
|
<Typography.Title :level="4">
|
||||||
|
{{ currentPage.title }}
|
||||||
|
</Typography.Title>
|
||||||
|
<Typography.Paragraph>
|
||||||
|
{{ currentPage.body }}
|
||||||
|
</Typography.Paragraph>
|
||||||
|
|
||||||
|
<Form layout="vertical" :model="formState" @finish="onSubmitPage">
|
||||||
|
<Form.Item
|
||||||
|
v-for="field in currentPage.fields"
|
||||||
|
:key="field.uuid"
|
||||||
|
:label="field.label"
|
||||||
|
:name="field.key"
|
||||||
|
:rules="field.required ? [{ required: true, message: 'Required' }] : []"
|
||||||
|
>
|
||||||
|
<Input
|
||||||
|
v-if="field.field_type === 'text'"
|
||||||
|
:value="(getFieldValue(field.key) as string | number | undefined)"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
:placeholder="field.placeholder"
|
||||||
|
/>
|
||||||
|
<Input.TextArea
|
||||||
|
v-else-if="field.field_type === 'textarea'"
|
||||||
|
:value="(getFieldValue(field.key) as string | undefined)"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
:placeholder="field.placeholder"
|
||||||
|
:rows="4"
|
||||||
|
/>
|
||||||
|
<InputNumber
|
||||||
|
v-else-if="field.field_type === 'number'"
|
||||||
|
:value="toNumberValue(getFieldValue(field.key))"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
style="width: 100%"
|
||||||
|
/>
|
||||||
|
<Select
|
||||||
|
v-else-if="field.field_type === 'select'"
|
||||||
|
:value="(getFieldValue(field.key) as string | undefined)"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
:options="(field.options || []).map((opt) => ({ value: opt, label: opt }))"
|
||||||
|
/>
|
||||||
|
<Select
|
||||||
|
v-else-if="field.field_type === 'multiselect'"
|
||||||
|
:value="(getFieldValue(field.key) as string[] | undefined)"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
mode="multiple"
|
||||||
|
:options="(field.options || []).map((opt) => ({ value: opt, label: opt }))"
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
v-else-if="field.field_type === 'boolean'"
|
||||||
|
:checked="Boolean(getFieldValue(field.key))"
|
||||||
|
@update:checked="(val) => setFieldValue(field.key, val)"
|
||||||
|
/>
|
||||||
|
<DatePicker
|
||||||
|
v-else-if="field.field_type === 'date'"
|
||||||
|
:value="toDateValue(getFieldValue(field.key))"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
style="width: 100%"
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
v-else
|
||||||
|
:value="(getFieldValue(field.key) as string | number | undefined)"
|
||||||
|
@update:value="(val) => setFieldValue(field.key, val)"
|
||||||
|
:placeholder="field.placeholder"
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Space>
|
||||||
|
<Button :disabled="!hasPrev" @click="goPrev">Back</Button>
|
||||||
|
<Button :loading="feedbackLoading" @click="requestFeedback">
|
||||||
|
Get Feedback
|
||||||
|
</Button>
|
||||||
|
<Button type="primary" html-type="submit">
|
||||||
|
{{ hasNext ? 'Next' : 'Finish' }}
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
<div class="feedback">
|
||||||
|
<div v-if="feedbackByPage[currentPage.uuid]" class="feedback-panel">
|
||||||
|
<Typography.Text strong>Feedback</Typography.Text>
|
||||||
|
<Typography.Paragraph>
|
||||||
|
{{ feedbackByPage[currentPage.uuid]?.summary || '' }}
|
||||||
|
</Typography.Paragraph>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
</Spin>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<style scoped>
|
||||||
|
.page {
|
||||||
|
max-width: 1100px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
.panel {
|
||||||
|
background: #0f172a;
|
||||||
|
border: 1px solid #1f2937;
|
||||||
|
color: #e5e7eb;
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
gap: 1rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
.flow-select {
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
.role-list {
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
.role-item :deep(.ant-list-item-meta-title),
|
||||||
|
.role-item :deep(.ant-list-item-meta-description) {
|
||||||
|
background: #0f172a;
|
||||||
|
border: 1px solid #1f2937;
|
||||||
|
}
|
||||||
|
.flow-body {
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
.session-actions {
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
}
|
||||||
|
.create-flow {
|
||||||
|
margin-top: 1rem;
|
||||||
|
padding: 1rem;
|
||||||
|
border: 1px dashed #1f2937;
|
||||||
|
border-radius: 6px;
|
||||||
|
background: #0b1220;
|
||||||
|
}
|
||||||
|
.page-body {
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
.feedback {
|
||||||
|
margin-top: 1rem;
|
||||||
|
display: grid;
|
||||||
|
gap: 0.75rem;
|
||||||
|
}
|
||||||
|
.feedback-panel {
|
||||||
|
padding: 0.75rem 1rem;
|
||||||
|
border: 1px solid #1f2937;
|
||||||
|
border-radius: 6px;
|
||||||
|
background: #0b1220;
|
||||||
|
}
|
||||||
|
.feedback-panel ul {
|
||||||
|
margin: 0.5rem 0 0.75rem 1.25rem;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, onMounted, computed, h } from 'vue'
|
import { ref, onMounted, computed, h } from 'vue'
|
||||||
import { useRouter, useRoute } from 'vue-router'
|
import { useRouter, useRoute } from 'vue-router'
|
||||||
import { Card, Typography, Button, List, Space, Spin, message, Tag, Divider, Upload, Modal, Table } from 'ant-design-vue'
|
import { Card, Typography, Button, List, Space, Spin, message, Tag, Divider, Upload, Modal, Table, Select } from 'ant-design-vue'
|
||||||
import { apiClient, isAxiosError, API } from '../router/api'
|
import { apiClient, isAxiosError, API } from '../router/api'
|
||||||
import { useUserStore } from '../stores/userStore'
|
import { useUserStore } from '../stores/userStore'
|
||||||
import { InboxOutlined, DeleteOutlined } from '@ant-design/icons-vue'
|
import { InboxOutlined, DeleteOutlined } from '@ant-design/icons-vue'
|
||||||
|
|
@ -52,28 +52,13 @@ const fetchRoles = async () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const fetchUserRoleMemberships = async () => {
|
const fetchUserRoleMemberships = async () => {
|
||||||
const userRoleUuids: string[] = []
|
if (!organization.value?.uuid) return
|
||||||
const userId = auth.user?.id
|
|
||||||
if (!userId || !organization.value?.uuid) return
|
|
||||||
try {
|
try {
|
||||||
const checks = roles.value.map(async (r) => {
|
const response = await apiClient.get<Role[]>(API.organizationRolesMine())
|
||||||
try {
|
const mine = Array.isArray(response.data) ? response.data : []
|
||||||
const resp = await apiClient.get<Array<{ user: { id: number } }>>(
|
const orgUuid = organization.value.uuid
|
||||||
API.organizationRoleMembers(organization.value!.uuid, r.uuid),
|
const joinedRoles = mine.filter((role) => role.organization?.uuid === orgUuid)
|
||||||
)
|
|
||||||
const found =
|
|
||||||
Array.isArray(resp.data) && resp.data.some((m) => m.user?.id === userId)
|
|
||||||
if (found && r.uuid) userRoleUuids.push(r.uuid)
|
|
||||||
} catch {
|
|
||||||
// ignore individual role errors
|
|
||||||
}
|
|
||||||
})
|
|
||||||
await Promise.all(checks)
|
|
||||||
// update the global store with actual Role objects the user has joined
|
|
||||||
const joinedRoles = roles.value.filter((r) => userRoleUuids.includes(r.uuid))
|
|
||||||
if (joinedRoles.length) {
|
|
||||||
auth.setJoinedRoles(joinedRoles)
|
auth.setJoinedRoles(joinedRoles)
|
||||||
}
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to fetch user role memberships', err)
|
console.error('Failed to fetch user role memberships', err)
|
||||||
}
|
}
|
||||||
|
|
@ -169,6 +154,7 @@ const beforeUpload = (file: File) => {
|
||||||
|
|
||||||
const selectedFile = ref<File | null>(null)
|
const selectedFile = ref<File | null>(null)
|
||||||
const fileDescription = ref('')
|
const fileDescription = ref('')
|
||||||
|
const selectedRoleUuid = ref<string>('')
|
||||||
|
|
||||||
const handleFileSelected = (file: File) => {
|
const handleFileSelected = (file: File) => {
|
||||||
selectedFile.value = file
|
selectedFile.value = file
|
||||||
|
|
@ -179,10 +165,15 @@ const handleFileUploadClick = async () => {
|
||||||
message.error('Please select a file to upload')
|
message.error('Please select a file to upload')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if (!selectedRoleUuid.value) {
|
||||||
|
message.error('Please select a role for this training file')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
await handleFileUpload(selectedFile.value, fileDescription.value)
|
await handleFileUpload(selectedFile.value, fileDescription.value)
|
||||||
selectedFile.value = null
|
selectedFile.value = null
|
||||||
fileDescription.value = ''
|
fileDescription.value = ''
|
||||||
|
selectedRoleUuid.value = ''
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleFileUpload = async (file: File, description: string = '') => {
|
const handleFileUpload = async (file: File, description: string = '') => {
|
||||||
|
|
@ -197,6 +188,9 @@ const handleFileUpload = async (file: File, description: string = '') => {
|
||||||
formData.append('file', file)
|
formData.append('file', file)
|
||||||
formData.append('file_name', file.name)
|
formData.append('file_name', file.name)
|
||||||
formData.append('description', description)
|
formData.append('description', description)
|
||||||
|
if (selectedRoleUuid.value) {
|
||||||
|
formData.append('role_uuid', selectedRoleUuid.value)
|
||||||
|
}
|
||||||
|
|
||||||
const response = await apiClient.post<TrainingFile>(
|
const response = await apiClient.post<TrainingFile>(
|
||||||
API.organizationTrainingFiles(organization.value.uuid),
|
API.organizationTrainingFiles(organization.value.uuid),
|
||||||
|
|
@ -278,12 +272,25 @@ const trainingFileColumns = [
|
||||||
key: 'file_size',
|
key: 'file_size',
|
||||||
customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
|
customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: 'Role',
|
||||||
|
key: 'role',
|
||||||
|
customRender: ({ record }: { record: TrainingFile }) => record.role?.name || '-',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: 'Status',
|
title: 'Status',
|
||||||
dataIndex: 'is_processed',
|
dataIndex: 'status',
|
||||||
key: 'is_processed',
|
key: 'status',
|
||||||
customRender: ({ value }: { value: boolean }) =>
|
customRender: ({ value }: { value: string }) => {
|
||||||
h(Tag, { color: value ? 'success' : 'processing' }, () => value ? 'Processed' : 'Processing'),
|
const statusMap: Record<string, { color: string; label: string }> = {
|
||||||
|
ingesting: { color: 'processing', label: 'Ingesting' },
|
||||||
|
chunked: { color: 'blue', label: 'Chunked' },
|
||||||
|
embedded: { color: 'success', label: 'Embedded' },
|
||||||
|
failed: { color: 'error', label: 'Failed' },
|
||||||
|
}
|
||||||
|
const status = statusMap[value] || { color: 'default', label: value }
|
||||||
|
return h(Tag, { color: status.color }, () => status.label)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: 'Uploaded',
|
title: 'Uploaded',
|
||||||
|
|
@ -312,13 +319,13 @@ const trainingFileColumns = [
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(async () => {
|
||||||
fetchOrganization().then(async () => {
|
await auth.fetchSession(true)
|
||||||
|
await fetchOrganization()
|
||||||
await fetchMembers()
|
await fetchMembers()
|
||||||
await fetchRoles()
|
await fetchRoles()
|
||||||
await fetchUserRoleMemberships()
|
await fetchUserRoleMemberships()
|
||||||
await fetchTrainingFiles()
|
await fetchTrainingFiles()
|
||||||
})
|
|
||||||
})
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
|
@ -404,7 +411,6 @@ onMounted(() => {
|
||||||
<Button
|
<Button
|
||||||
type="default"
|
type="default"
|
||||||
size="small"
|
size="small"
|
||||||
:disabled="isManager"
|
|
||||||
@click="router.push(`/onboarding/${item.uuid}`)"
|
@click="router.push(`/onboarding/${item.uuid}`)"
|
||||||
>
|
>
|
||||||
Start Onboarding
|
Start Onboarding
|
||||||
|
|
@ -418,7 +424,9 @@ onMounted(() => {
|
||||||
>
|
>
|
||||||
Join Role
|
Join Role
|
||||||
</Button>
|
</Button>
|
||||||
<Tag v-else color="success">Joined</Tag>
|
<Button v-else size="small" disabled>
|
||||||
|
Joined
|
||||||
|
</Button>
|
||||||
</Space>
|
</Space>
|
||||||
</List.Item>
|
</List.Item>
|
||||||
</template>
|
</template>
|
||||||
|
|
@ -436,7 +444,7 @@ onMounted(() => {
|
||||||
width="600px"
|
width="600px"
|
||||||
ok-text="Upload"
|
ok-text="Upload"
|
||||||
cancel-text="Cancel"
|
cancel-text="Cancel"
|
||||||
:ok-button-props="{ loading: uploading, disabled: !selectedFile }"
|
:ok-button-props="{ loading: uploading, disabled: !selectedFile || !selectedRoleUuid }"
|
||||||
@ok="handleFileUploadClick"
|
@ok="handleFileUploadClick"
|
||||||
@cancel="showUploadModal = false"
|
@cancel="showUploadModal = false"
|
||||||
>
|
>
|
||||||
|
|
@ -444,6 +452,15 @@ onMounted(() => {
|
||||||
<Typography.Text>
|
<Typography.Text>
|
||||||
Supported formats: <strong>txt, pdf, md, csv, json, docx, doc</strong> (Max 50MB)
|
Supported formats: <strong>txt, pdf, md, csv, json, docx, doc</strong> (Max 50MB)
|
||||||
</Typography.Text>
|
</Typography.Text>
|
||||||
|
<div>
|
||||||
|
<Typography.Text strong>Role</Typography.Text>
|
||||||
|
<Select
|
||||||
|
v-model:value="selectedRoleUuid"
|
||||||
|
placeholder="Select a role"
|
||||||
|
style="width: 100%"
|
||||||
|
:options="roles.map((role) => ({ label: role.name, value: role.uuid }))"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
<Upload.Dragger
|
<Upload.Dragger
|
||||||
accept=".txt,.pdf,.md,.csv,.json,.docx,.doc"
|
accept=".txt,.pdf,.md,.csv,.json,.docx,.doc"
|
||||||
:before-upload="(file) => {
|
:before-upload="(file) => {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue