Added role agnostic training files
This commit is contained in:
parent
6ccb7822c9
commit
a9ba16c76d
11 changed files with 171 additions and 67 deletions
|
|
@ -5,22 +5,22 @@ from apps.knowledge.models import RoleRagDocument, TrainingFile
|
||||||
|
|
||||||
@admin.register(TrainingFile)
|
@admin.register(TrainingFile)
|
||||||
class TrainingFileAdmin(admin.ModelAdmin):
|
class TrainingFileAdmin(admin.ModelAdmin):
|
||||||
list_display = ('file_name', 'role', 'status', 'is_processed', 'uploaded_by', 'created_at')
|
list_display = ('file_name', 'organization', 'role', 'status', 'is_processed', 'uploaded_by', 'created_at')
|
||||||
list_filter = ('status', 'is_processed', 'role__organization', 'created_at')
|
list_filter = ('status', 'is_processed', 'organization', 'created_at')
|
||||||
search_fields = ('file_name', 'role__name', 'uploaded_by__email_address')
|
search_fields = ('file_name', 'organization__name', 'role__name', 'uploaded_by__email_address')
|
||||||
raw_id_fields = ('role', 'uploaded_by')
|
raw_id_fields = ('organization', 'role', 'uploaded_by')
|
||||||
readonly_fields = ('uuid', 'file_size', 'file_type', 'created_at', 'updated_at')
|
readonly_fields = ('uuid', 'file_size', 'file_type', 'created_at', 'updated_at')
|
||||||
ordering = ('-created_at',)
|
ordering = ('-created_at',)
|
||||||
|
|
||||||
@admin.register(RoleRagDocument)
|
@admin.register(RoleRagDocument)
|
||||||
class RoleRagDocumentAdmin(admin.ModelAdmin):
|
class RoleRagDocumentAdmin(admin.ModelAdmin):
|
||||||
list_display = ('role', 'chunk_index', 'training_file', 'is_active', 'created_at')
|
list_display = ('organization', 'role', 'chunk_index', 'training_file', 'is_active', 'created_at')
|
||||||
list_filter = ('is_active', 'role__organization', 'created_at')
|
list_filter = ('is_active', 'organization', 'created_at')
|
||||||
search_fields = ('content', 'role__name', 'training_file__file_name')
|
search_fields = ('content', 'organization__name', 'role__name', 'training_file__file_name')
|
||||||
raw_id_fields = ('role', 'training_file')
|
raw_id_fields = ('organization', 'role', 'training_file')
|
||||||
|
|
||||||
readonly_fields = ('uuid', 'content_hash', 'display_embedding', 'created_at', 'updated_at')
|
readonly_fields = ('uuid', 'content_hash', 'display_embedding', 'created_at', 'updated_at')
|
||||||
ordering = ('role', 'chunk_index')
|
ordering = ('organization', 'role', 'chunk_index')
|
||||||
|
|
||||||
def get_fields(self, request, obj=None):
|
def get_fields(self, request, obj=None):
|
||||||
fields = super().get_fields(request, obj)
|
fields = super().get_fields(request, obj)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ class Migration(migrations.Migration):
|
||||||
('description', models.TextField(blank=True, default='')),
|
('description', models.TextField(blank=True, default='')),
|
||||||
('status', models.CharField(choices=[('ingesting', 'Ingesting'), ('chunked', 'Chunked'), ('embedded', 'Embedded'), ('failed', 'Failed')], default='ingesting', max_length=20)),
|
('status', models.CharField(choices=[('ingesting', 'Ingesting'), ('chunked', 'Chunked'), ('embedded', 'Embedded'), ('failed', 'Failed')], default='ingesting', max_length=20)),
|
||||||
('is_processed', models.BooleanField(default=False)),
|
('is_processed', models.BooleanField(default=False)),
|
||||||
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.role')),
|
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.organization')),
|
||||||
|
('role', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='training_files', to='accounts.role')),
|
||||||
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
|
('uploaded_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='uploaded_training_files', to=settings.AUTH_USER_MODEL)),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
|
|
@ -52,7 +53,8 @@ class Migration(migrations.Migration):
|
||||||
('metadata', models.JSONField(blank=True, default=dict)),
|
('metadata', models.JSONField(blank=True, default=dict)),
|
||||||
('chunk_index', models.IntegerField(default=0)),
|
('chunk_index', models.IntegerField(default=0)),
|
||||||
('is_active', models.BooleanField(default=True)),
|
('is_active', models.BooleanField(default=True)),
|
||||||
('role', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='accounts.role')),
|
('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='rag_documents', to='accounts.organization')),
|
||||||
|
('role', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='rag_documents', to='accounts.role')),
|
||||||
('training_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='chunks', to='knowledge.trainingfile')),
|
('training_file', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='chunks', to='knowledge.trainingfile')),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,14 @@ import os
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
from django.db.models import CASCADE, BooleanField, CharField, FileField, ForeignKey, IntegerField, JSONField, Model, TextField
|
from django.db.models import CASCADE, SET_NULL, BooleanField, CharField, FileField, ForeignKey, IntegerField, JSONField, Model, TextField
|
||||||
from django.db.models.signals import post_delete, post_save
|
from django.db.models.signals import post_delete, post_save
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from pgvector.django import VectorField
|
from pgvector.django import VectorField
|
||||||
|
|
||||||
from apps.accounts.mixins import IdentifierMixin, TimeStampMixin
|
from apps.accounts.mixins import IdentifierMixin, TimeStampMixin
|
||||||
from apps.accounts.models import Role, User
|
from apps.accounts.models import Organization, Role, User
|
||||||
|
|
||||||
class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
|
class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
|
||||||
STATUS_CHOICES = [
|
STATUS_CHOICES = [
|
||||||
|
|
@ -19,7 +19,8 @@ class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
|
||||||
('failed', 'Failed'),
|
('failed', 'Failed'),
|
||||||
]
|
]
|
||||||
|
|
||||||
role = ForeignKey(Role, on_delete=CASCADE, related_name="training_files")
|
organization = ForeignKey(Organization, on_delete=CASCADE, related_name="training_files")
|
||||||
|
role = ForeignKey(Role, on_delete=CASCADE, related_name="training_files", null=True, blank=True)
|
||||||
uploaded_by = ForeignKey(User, on_delete=CASCADE, related_name="uploaded_training_files")
|
uploaded_by = ForeignKey(User, on_delete=CASCADE, related_name="uploaded_training_files")
|
||||||
|
|
||||||
file = FileField(upload_to='training_files/%Y/%m/%d/')
|
file = FileField(upload_to='training_files/%Y/%m/%d/')
|
||||||
|
|
@ -37,11 +38,14 @@ class TrainingFile(IdentifierMixin, TimeStampMixin, Model):
|
||||||
ordering = ['-created_at']
|
ordering = ['-created_at']
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.file_name} ({self.role.name})"
|
if self.role_id:
|
||||||
|
return f"{self.file_name} ({self.role.name})"
|
||||||
|
return f"{self.file_name} ({self.organization.name} - Organization-wide)"
|
||||||
|
|
||||||
class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model):
|
class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model):
|
||||||
|
|
||||||
role = ForeignKey(Role, on_delete=CASCADE, related_name='rag_documents')
|
organization = ForeignKey(Organization, on_delete=CASCADE, related_name='rag_documents')
|
||||||
|
role = ForeignKey(Role, on_delete=SET_NULL, related_name='rag_documents', null=True, blank=True)
|
||||||
training_file = ForeignKey(TrainingFile, on_delete=CASCADE, related_name='chunks', null=True, blank=True)
|
training_file = ForeignKey(TrainingFile, on_delete=CASCADE, related_name='chunks', null=True, blank=True)
|
||||||
|
|
||||||
content = TextField()
|
content = TextField()
|
||||||
|
|
@ -58,7 +62,9 @@ class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model):
|
||||||
verbose_name_plural = _("Role RAG Documents")
|
verbose_name_plural = _("Role RAG Documents")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.role.name} - Chunk {self.chunk_index}"
|
if self.role_id:
|
||||||
|
return f"{self.role.name} - Chunk {self.chunk_index}"
|
||||||
|
return f"{self.organization.name} (Organization-wide) - Chunk {self.chunk_index}"
|
||||||
|
|
||||||
@receiver(post_delete, sender=TrainingFile)
|
@receiver(post_delete, sender=TrainingFile)
|
||||||
def delete_physical_file(sender, instance, **kwargs):
|
def delete_physical_file(sender, instance, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,37 @@
|
||||||
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
from rest_framework.serializers import ModelSerializer, SerializerMethodField
|
||||||
|
|
||||||
from apps.accounts.serializers import RoleSerializer, UserSerializer
|
from apps.accounts.serializers import OrganizationSerializer, RoleSerializer, UserSerializer
|
||||||
from apps.knowledge.models import RoleRagDocument, TrainingFile
|
from apps.knowledge.models import RoleRagDocument, TrainingFile
|
||||||
|
|
||||||
class TrainingFileSerializer(ModelSerializer):
|
class TrainingFileSerializer(ModelSerializer):
|
||||||
uploaded_by = UserSerializer(read_only=True)
|
uploaded_by = UserSerializer(read_only=True)
|
||||||
|
organization = OrganizationSerializer(read_only=True)
|
||||||
role = RoleSerializer(read_only=True)
|
role = RoleSerializer(read_only=True)
|
||||||
file_url = SerializerMethodField()
|
file_url = SerializerMethodField()
|
||||||
|
scope = SerializerMethodField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = TrainingFile
|
model = TrainingFile
|
||||||
fields = [
|
fields = [
|
||||||
'id', 'uuid', 'role', 'uploaded_by', 'file', 'file_url',
|
'id', 'uuid', 'organization', 'role', 'scope', 'uploaded_by', 'file', 'file_url',
|
||||||
'file_name', 'file_size', 'file_type', 'description',
|
'file_name', 'file_size', 'file_type', 'description',
|
||||||
'status', 'is_processed', 'created_at', 'updated_at'
|
'status', 'is_processed', 'created_at', 'updated_at'
|
||||||
]
|
]
|
||||||
read_only_fields = [
|
read_only_fields = [
|
||||||
'id', 'uuid', 'uploaded_by', 'file_size', 'file_type',
|
'id', 'uuid', 'uploaded_by', 'file_size', 'file_type',
|
||||||
'status', 'is_processed', 'created_at', 'updated_at',
|
'status', 'is_processed', 'created_at', 'updated_at',
|
||||||
'role'
|
'organization', 'role', 'scope'
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_file_url(self, obj: TrainingFile) -> str:
|
def get_file_url(self, obj: TrainingFile):
|
||||||
request = self.context.get('request')
|
request = self.context.get('request')
|
||||||
if obj.file and request:
|
if obj.file and request:
|
||||||
return request.build_absolute_uri(obj.file.url)
|
return request.build_absolute_uri(obj.file.url)
|
||||||
return obj.file.url if obj.file else None
|
return obj.file.url if obj.file else None
|
||||||
|
|
||||||
|
def get_scope(self, obj: TrainingFile) -> str:
|
||||||
|
return 'role' if obj.role_id else 'organization'
|
||||||
|
|
||||||
class RoleRagDocumentSerializer(ModelSerializer):
|
class RoleRagDocumentSerializer(ModelSerializer):
|
||||||
training_file_name = SerializerMethodField()
|
training_file_name = SerializerMethodField()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,13 +77,18 @@ def ingest_training_file_task(self, file_uuid):
|
||||||
|
|
||||||
for chunk_text, embedding in zip(chunks, embeddings):
|
for chunk_text, embedding in zip(chunks, embeddings):
|
||||||
all_documents.append(RoleRagDocument(
|
all_documents.append(RoleRagDocument(
|
||||||
|
organization=file_obj.organization,
|
||||||
role=file_obj.role,
|
role=file_obj.role,
|
||||||
training_file=file_obj,
|
training_file=file_obj,
|
||||||
content=chunk_text,
|
content=chunk_text,
|
||||||
content_hash=hashlib.sha256(chunk_text.encode('utf-8')).hexdigest(),
|
content_hash=hashlib.sha256(chunk_text.encode('utf-8')).hexdigest(),
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
chunk_index=chunk_counter,
|
chunk_index=chunk_counter,
|
||||||
metadata={"source": file_obj.file_name}
|
metadata={
|
||||||
|
"source": file_obj.file_name,
|
||||||
|
"file_name": file_obj.file_name,
|
||||||
|
"scope": "role" if file_obj.role_id else "organization",
|
||||||
|
},
|
||||||
))
|
))
|
||||||
chunk_counter += 1
|
chunk_counter += 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ class KnowledgeApiTests(TestCase):
|
||||||
self.role = Role.objects.create(name='Researcher', organization=self.org)
|
self.role = Role.objects.create(name='Researcher', organization=self.org)
|
||||||
|
|
||||||
self.training_file = TrainingFile.objects.create(
|
self.training_file = TrainingFile.objects.create(
|
||||||
|
organization=self.org,
|
||||||
role=self.role,
|
role=self.role,
|
||||||
uploaded_by=self.owner,
|
uploaded_by=self.owner,
|
||||||
file=SimpleUploadedFile('doc.txt', b'content', content_type='text/plain'),
|
file=SimpleUploadedFile('doc.txt', b'content', content_type='text/plain'),
|
||||||
|
|
@ -52,6 +53,7 @@ class KnowledgeApiTests(TestCase):
|
||||||
file_type='text/plain',
|
file_type='text/plain',
|
||||||
)
|
)
|
||||||
self.rag_doc = RoleRagDocument.objects.create(
|
self.rag_doc = RoleRagDocument.objects.create(
|
||||||
|
organization=self.org,
|
||||||
role=self.role,
|
role=self.role,
|
||||||
training_file=self.training_file,
|
training_file=self.training_file,
|
||||||
content='chunk body',
|
content='chunk body',
|
||||||
|
|
@ -136,7 +138,7 @@ class KnowledgeApiTests(TestCase):
|
||||||
'file_name': 'new.txt',
|
'file_name': 'new.txt',
|
||||||
})
|
})
|
||||||
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST)
|
||||||
self.assertIn('role_uuid', response.json())
|
self.assertIn('organization_uuid', response.json())
|
||||||
|
|
||||||
def test_training_file_create_by_owner_succeeds(self):
|
def test_training_file_create_by_owner_succeeds(self):
|
||||||
self.client.force_authenticate(self.owner)
|
self.client.force_authenticate(self.owner)
|
||||||
|
|
@ -148,6 +150,17 @@ class KnowledgeApiTests(TestCase):
|
||||||
})
|
})
|
||||||
self.assertEqual(response.status_code, HTTP_201_CREATED)
|
self.assertEqual(response.status_code, HTTP_201_CREATED)
|
||||||
|
|
||||||
|
def test_training_file_create_org_wide_by_owner_succeeds(self):
|
||||||
|
self.client.force_authenticate(self.owner)
|
||||||
|
uploaded = SimpleUploadedFile('org-wide.txt', b'org policy', content_type='text/plain')
|
||||||
|
response = self.client.post('/api/training-file/', {
|
||||||
|
'organization_uuid': str(self.org.uuid),
|
||||||
|
'file': uploaded,
|
||||||
|
'file_name': 'org-wide.txt',
|
||||||
|
})
|
||||||
|
self.assertEqual(response.status_code, HTTP_201_CREATED)
|
||||||
|
self.assertIsNone(response.json().get('role'))
|
||||||
|
|
||||||
def test_training_file_destroy_forbidden_for_regular_member(self):
|
def test_training_file_destroy_forbidden_for_regular_member(self):
|
||||||
self.client.force_authenticate(self.member)
|
self.client.force_authenticate(self.member)
|
||||||
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')
|
response = self.client.delete(f'/api/training-file/{self.training_file.uuid}/')
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class KnowledgeModelTests(TestCase):
|
||||||
def test_training_file_fields_and_defaults(self):
|
def test_training_file_fields_and_defaults(self):
|
||||||
uploaded = SimpleUploadedFile('training.txt', b'hello world', content_type='text/plain')
|
uploaded = SimpleUploadedFile('training.txt', b'hello world', content_type='text/plain')
|
||||||
training_file = TrainingFile.objects.create(
|
training_file = TrainingFile.objects.create(
|
||||||
|
organization=self.org,
|
||||||
role=self.role,
|
role=self.role,
|
||||||
uploaded_by=self.user,
|
uploaded_by=self.user,
|
||||||
file=uploaded,
|
file=uploaded,
|
||||||
|
|
@ -44,6 +45,7 @@ class KnowledgeModelTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(training_file.role, self.role)
|
self.assertEqual(training_file.role, self.role)
|
||||||
|
self.assertEqual(training_file.organization, self.org)
|
||||||
self.assertEqual(training_file.uploaded_by, self.user)
|
self.assertEqual(training_file.uploaded_by, self.user)
|
||||||
self.assertEqual(training_file.file_name, 'training.txt')
|
self.assertEqual(training_file.file_name, 'training.txt')
|
||||||
self.assertEqual(training_file.file_size, 11)
|
self.assertEqual(training_file.file_size, 11)
|
||||||
|
|
@ -62,6 +64,7 @@ class KnowledgeModelTests(TestCase):
|
||||||
def test_role_rag_document_fields_and_defaults(self):
|
def test_role_rag_document_fields_and_defaults(self):
|
||||||
uploaded = SimpleUploadedFile('base.txt', b'base', content_type='text/plain')
|
uploaded = SimpleUploadedFile('base.txt', b'base', content_type='text/plain')
|
||||||
training_file = TrainingFile.objects.create(
|
training_file = TrainingFile.objects.create(
|
||||||
|
organization=self.org,
|
||||||
role=self.role,
|
role=self.role,
|
||||||
uploaded_by=self.user,
|
uploaded_by=self.user,
|
||||||
file=uploaded,
|
file=uploaded,
|
||||||
|
|
@ -70,6 +73,7 @@ class KnowledgeModelTests(TestCase):
|
||||||
file_type='text/plain',
|
file_type='text/plain',
|
||||||
)
|
)
|
||||||
document = RoleRagDocument.objects.create(
|
document = RoleRagDocument.objects.create(
|
||||||
|
organization=self.org,
|
||||||
role=self.role,
|
role=self.role,
|
||||||
training_file=training_file,
|
training_file=training_file,
|
||||||
content='Chunk content',
|
content='Chunk content',
|
||||||
|
|
@ -80,6 +84,7 @@ class KnowledgeModelTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(document.role, self.role)
|
self.assertEqual(document.role, self.role)
|
||||||
|
self.assertEqual(document.organization, self.org)
|
||||||
self.assertEqual(document.training_file, training_file)
|
self.assertEqual(document.training_file, training_file)
|
||||||
self.assertEqual(document.content, 'Chunk content')
|
self.assertEqual(document.content, 'Chunk content')
|
||||||
self.assertEqual(document.content_hash, 'a' * 64)
|
self.assertEqual(document.content_hash, 'a' * 64)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from rest_framework.parsers import FormParser, MultiPartParser
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
from rest_framework.viewsets import ModelViewSet, ReadOnlyModelViewSet
|
||||||
|
|
||||||
from apps.accounts.models import Role
|
from apps.accounts.models import Organization, Role
|
||||||
from apps.accounts.permissions import can_manage_organization
|
from apps.accounts.permissions import can_manage_organization
|
||||||
from apps.knowledge.models import RoleRagDocument, TrainingFile
|
from apps.knowledge.models import RoleRagDocument, TrainingFile
|
||||||
from apps.knowledge.serializers import RoleRagDocumentSerializer, TrainingFileSerializer
|
from apps.knowledge.serializers import RoleRagDocumentSerializer, TrainingFileSerializer
|
||||||
|
|
@ -19,35 +19,51 @@ class TrainingFileViewSet(ModelViewSet):
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
queryset = TrainingFile.objects.filter(
|
queryset = TrainingFile.objects.filter(
|
||||||
Q(role__organization__owner=user) |
|
Q(organization__owner=user) |
|
||||||
Q(role__organization__members=user)
|
Q(organization__members=user)
|
||||||
).distinct()
|
).distinct()
|
||||||
|
|
||||||
organization_uuid = self.request.query_params.get('organization_uuid')
|
organization_uuid = self.request.query_params.get('organization_uuid')
|
||||||
if organization_uuid in (None, ''):
|
if organization_uuid in (None, ''):
|
||||||
organization_uuid = self.request.data.get('organization_uuid')
|
organization_uuid = self.request.data.get('organization_uuid')
|
||||||
if organization_uuid:
|
if organization_uuid:
|
||||||
queryset = queryset.filter(role__organization__uuid=organization_uuid)
|
queryset = queryset.filter(organization__uuid=organization_uuid)
|
||||||
|
|
||||||
role_uuid = self.request.query_params.get('role_uuid')
|
role_uuid = self.request.query_params.get('role_uuid')
|
||||||
if role_uuid in (None, ''):
|
if role_uuid in (None, ''):
|
||||||
role_uuid = self.request.data.get('role_uuid')
|
role_uuid = self.request.data.get('role_uuid')
|
||||||
if role_uuid:
|
if role_uuid:
|
||||||
queryset = queryset.filter(role__uuid=role_uuid)
|
queryset = queryset.filter(Q(role__uuid=role_uuid) | Q(role__isnull=True))
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
def perform_create(self, serializer):
|
def perform_create(self, serializer):
|
||||||
role_uuid = self.request.data.get('role_uuid')
|
role_uuid = self.request.data.get('role_uuid')
|
||||||
if not role_uuid:
|
organization_uuid = self.request.data.get('organization_uuid')
|
||||||
raise ValidationError({'role_uuid': 'role_uuid is required.'})
|
|
||||||
|
|
||||||
try:
|
role = None
|
||||||
role = Role.objects.get(uuid=role_uuid)
|
organization = None
|
||||||
except Role.DoesNotExist:
|
|
||||||
raise NotFound('Role not found')
|
|
||||||
|
|
||||||
if not can_manage_organization(self.request.user, role.organization):
|
if role_uuid:
|
||||||
|
try:
|
||||||
|
role = Role.objects.select_related('organization').get(uuid=role_uuid)
|
||||||
|
except Role.DoesNotExist:
|
||||||
|
raise NotFound('Role not found')
|
||||||
|
|
||||||
|
organization = role.organization
|
||||||
|
|
||||||
|
if organization_uuid and str(organization.uuid) != str(organization_uuid):
|
||||||
|
raise ValidationError({'organization_uuid': 'organization_uuid does not match role organization.'})
|
||||||
|
else:
|
||||||
|
if not organization_uuid:
|
||||||
|
raise ValidationError({'organization_uuid': 'organization_uuid is required when role_uuid is not provided.'})
|
||||||
|
|
||||||
|
try:
|
||||||
|
organization = Organization.objects.get(uuid=organization_uuid)
|
||||||
|
except Organization.DoesNotExist:
|
||||||
|
raise NotFound('Organization not found')
|
||||||
|
|
||||||
|
if not can_manage_organization(self.request.user, organization):
|
||||||
raise PermissionDenied('Permission denied')
|
raise PermissionDenied('Permission denied')
|
||||||
|
|
||||||
uploaded_file = self.request.FILES.get('file')
|
uploaded_file = self.request.FILES.get('file')
|
||||||
|
|
@ -56,6 +72,7 @@ class TrainingFileViewSet(ModelViewSet):
|
||||||
|
|
||||||
serializer.save(
|
serializer.save(
|
||||||
uploaded_by=self.request.user,
|
uploaded_by=self.request.user,
|
||||||
|
organization=organization,
|
||||||
role=role,
|
role=role,
|
||||||
file_name=uploaded_file.name,
|
file_name=uploaded_file.name,
|
||||||
file_size=uploaded_file.size,
|
file_size=uploaded_file.size,
|
||||||
|
|
@ -66,8 +83,8 @@ class TrainingFileViewSet(ModelViewSet):
|
||||||
instance = self.get_object()
|
instance = self.get_object()
|
||||||
|
|
||||||
is_uploader = instance.uploaded_by == request.user
|
is_uploader = instance.uploaded_by == request.user
|
||||||
is_org_owner = instance.role.organization.owner == request.user
|
is_org_owner = instance.organization.owner == request.user
|
||||||
is_org_manager = bool(request.user.is_manager) and instance.role.organization.members.filter(id=request.user.id).exists()
|
is_org_manager = bool(request.user.is_manager) and instance.organization.members.filter(id=request.user.id).exists()
|
||||||
|
|
||||||
if not (is_uploader or is_org_owner or is_org_manager):
|
if not (is_uploader or is_org_owner or is_org_manager):
|
||||||
raise PermissionDenied('Permission denied')
|
raise PermissionDenied('Permission denied')
|
||||||
|
|
@ -83,15 +100,15 @@ class RoleRagDocumentViewSet(ReadOnlyModelViewSet):
|
||||||
def get_queryset(self):
|
def get_queryset(self):
|
||||||
user = self.request.user
|
user = self.request.user
|
||||||
queryset = RoleRagDocument.objects.filter(
|
queryset = RoleRagDocument.objects.filter(
|
||||||
Q(role__organization__owner=user) |
|
Q(organization__owner=user) |
|
||||||
Q(role__organization__members=user)
|
Q(organization__members=user)
|
||||||
).distinct()
|
).distinct()
|
||||||
|
|
||||||
organization_uuid = self.request.query_params.get('organization_uuid')
|
organization_uuid = self.request.query_params.get('organization_uuid')
|
||||||
if organization_uuid in (None, ''):
|
if organization_uuid in (None, ''):
|
||||||
organization_uuid = self.request.data.get('organization_uuid')
|
organization_uuid = self.request.data.get('organization_uuid')
|
||||||
if organization_uuid:
|
if organization_uuid:
|
||||||
queryset = queryset.filter(role__organization__uuid=organization_uuid)
|
queryset = queryset.filter(organization__uuid=organization_uuid)
|
||||||
|
|
||||||
role_uuid = self.request.query_params.get('role_uuid')
|
role_uuid = self.request.query_params.get('role_uuid')
|
||||||
if role_uuid in (None, ''):
|
if role_uuid in (None, ''):
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ import random
|
||||||
|
|
||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.db.models import Q
|
||||||
from pgvector.django import CosineDistance
|
from pgvector.django import CosineDistance
|
||||||
|
|
||||||
|
from apps.accounts.models import Role
|
||||||
from apps.knowledge.models import RoleRagDocument
|
from apps.knowledge.models import RoleRagDocument
|
||||||
from apps.onboarding.models import OnboardingSession
|
from apps.onboarding.models import OnboardingSession
|
||||||
|
|
||||||
|
|
@ -105,9 +107,17 @@ class MCPRouter:
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def _search_knowledge_documents(self, role_uuid, query_vector):
|
def _search_knowledge_documents(self, role_uuid, query_vector):
|
||||||
|
role = Role.objects.select_related('organization').filter(uuid=role_uuid).first()
|
||||||
|
if role is None:
|
||||||
|
logger.warning('MCP search_knowledge_documents role not found: role_uuid=%s', role_uuid)
|
||||||
|
return []
|
||||||
|
|
||||||
docs = RoleRagDocument.objects.filter(
|
docs = RoleRagDocument.objects.filter(
|
||||||
role__uuid=role_uuid,
|
organization=role.organization,
|
||||||
|
embedding__isnull=False,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
|
).filter(
|
||||||
|
Q(role__uuid=role_uuid) | Q(role__isnull=True),
|
||||||
).annotate(
|
).annotate(
|
||||||
distance=CosineDistance('embedding', query_vector)
|
distance=CosineDistance('embedding', query_vector)
|
||||||
).order_by('distance')[:5]
|
).order_by('distance')[:5]
|
||||||
|
|
@ -115,7 +125,7 @@ class MCPRouter:
|
||||||
results = [
|
results = [
|
||||||
{
|
{
|
||||||
'content': d.content,
|
'content': d.content,
|
||||||
'source': d.metadata.get('file_name', 'Unknown Source'),
|
'source': d.metadata.get('file_name') or d.metadata.get('source', 'Unknown Source'),
|
||||||
'relevance': round(1 - d.distance, 4),
|
'relevance': round(1 - d.distance, 4),
|
||||||
}
|
}
|
||||||
for d in docs
|
for d in docs
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ export interface InviteToken {
|
||||||
}
|
}
|
||||||
export interface TrainingFile {
|
export interface TrainingFile {
|
||||||
uuid: string
|
uuid: string
|
||||||
role: Role
|
organization: Organization
|
||||||
|
role: Role | null
|
||||||
|
scope?: 'role' | 'organization'
|
||||||
uploaded_by: User
|
uploaded_by: User
|
||||||
file: string
|
file: string
|
||||||
file_name: string
|
file_name: string
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,12 @@ const inviteModalVisible = ref(false)
|
||||||
const newInviteUrl = ref('')
|
const newInviteUrl = ref('')
|
||||||
const editingDescription = ref(false)
|
const editingDescription = ref(false)
|
||||||
const newDescription = ref('')
|
const newDescription = ref('')
|
||||||
|
const ORGANIZATION_WIDE_SCOPE = '__organization_wide__'
|
||||||
|
|
||||||
|
const uploadRoleOptions = computed(() => [
|
||||||
|
{ label: 'Organization-wide (all roles)', value: ORGANIZATION_WIDE_SCOPE },
|
||||||
|
...Roles.value.map((role) => ({ label: role.name, value: role.uuid })),
|
||||||
|
])
|
||||||
|
|
||||||
const filteredMembers = computed(() => {
|
const filteredMembers = computed(() => {
|
||||||
const query = memberSearch.value.trim().toLowerCase()
|
const query = memberSearch.value.trim().toLowerCase()
|
||||||
|
|
@ -151,6 +157,8 @@ const fetchTrainingFiles = async () => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getScopeLabel = (file: TrainingFile) => (file.role?.name ? file.role.name : 'Organization-wide')
|
||||||
|
|
||||||
|
|
||||||
const resetRoleWizard = () => {
|
const resetRoleWizard = () => {
|
||||||
roleWizardStep.value = 0
|
roleWizardStep.value = 0
|
||||||
|
|
@ -200,7 +208,7 @@ const validateUploadFile = (file: File): boolean => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const uploadTrainingFile = async (
|
const uploadTrainingFile = async (
|
||||||
roleUuid: string,
|
roleUuid: string | null,
|
||||||
file: File,
|
file: File,
|
||||||
description: string,
|
description: string,
|
||||||
): Promise<TrainingFile | null> => {
|
): Promise<TrainingFile | null> => {
|
||||||
|
|
@ -208,7 +216,10 @@ const uploadTrainingFile = async (
|
||||||
formData.append('file', file)
|
formData.append('file', file)
|
||||||
formData.append('file_name', file.name)
|
formData.append('file_name', file.name)
|
||||||
formData.append('description', description)
|
formData.append('description', description)
|
||||||
formData.append('role_uuid', roleUuid)
|
formData.append('organization_uuid', organizationUuid)
|
||||||
|
if (roleUuid) {
|
||||||
|
formData.append('role_uuid', roleUuid)
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await apiClient.post<TrainingFile>(API.knowledge.trainingFiles.list(), formData, {
|
const response = await apiClient.post<TrainingFile>(API.knowledge.trainingFiles.list(), formData, {
|
||||||
|
|
@ -234,6 +245,10 @@ const uploadTrainingFile = async (
|
||||||
const getTrainingFilesByRole = (roleUuid: string): TrainingFile[] =>
|
const getTrainingFilesByRole = (roleUuid: string): TrainingFile[] =>
|
||||||
trainingFiles.value.filter((file) => file.role?.uuid === roleUuid)
|
trainingFiles.value.filter((file) => file.role?.uuid === roleUuid)
|
||||||
|
|
||||||
|
const organizationWideTrainingFiles = computed(() =>
|
||||||
|
trainingFiles.value.filter((file) => !file.role?.uuid),
|
||||||
|
)
|
||||||
|
|
||||||
const deleteTrainingFile = async (uuid: string, fileName: string) => {
|
const deleteTrainingFile = async (uuid: string, fileName: string) => {
|
||||||
Modal.confirm({
|
Modal.confirm({
|
||||||
title: 'Delete File',
|
title: 'Delete File',
|
||||||
|
|
@ -289,9 +304,9 @@ const trainingFileColumns = [
|
||||||
customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
|
customRender: ({ value }: { value: number }) => formatFileSize(value || 0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: 'Role',
|
title: 'Scope',
|
||||||
key: 'role',
|
key: 'role',
|
||||||
customRender: ({ record }: { record: TrainingFile }) => record.role?.name || '-',
|
customRender: ({ record }: { record: TrainingFile }) => getScopeLabel(record),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
title: 'Status',
|
title: 'Status',
|
||||||
|
|
@ -426,7 +441,7 @@ const uploadFileFromWizard = async () => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const openUploadModal = (role?: Role) => {
|
const openUploadModal = (role?: Role) => {
|
||||||
uploadRoleUuid.value = role?.uuid || ''
|
uploadRoleUuid.value = role?.uuid || ORGANIZATION_WIDE_SCOPE
|
||||||
uploadSelectedFile.value = null
|
uploadSelectedFile.value = null
|
||||||
uploadFileDescription.value = ''
|
uploadFileDescription.value = ''
|
||||||
uploadModalVisible.value = true
|
uploadModalVisible.value = true
|
||||||
|
|
@ -441,20 +456,18 @@ const handleUploadModalFileSelected = (file: File) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleUploadModalOk = async () => {
|
const handleUploadModalOk = async () => {
|
||||||
if (!uploadRoleUuid.value) {
|
|
||||||
message.error('Please select a role for this training file')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!uploadSelectedFile.value) {
|
if (!uploadSelectedFile.value) {
|
||||||
message.error('Please select a file to upload')
|
message.error('Please select a file to upload')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const selectedRoleUuid =
|
||||||
|
uploadRoleUuid.value === ORGANIZATION_WIDE_SCOPE ? null : uploadRoleUuid.value
|
||||||
|
|
||||||
uploadingFile.value = true
|
uploadingFile.value = true
|
||||||
try {
|
try {
|
||||||
const uploaded = await uploadTrainingFile(
|
const uploaded = await uploadTrainingFile(
|
||||||
uploadRoleUuid.value,
|
selectedRoleUuid,
|
||||||
uploadSelectedFile.value,
|
uploadSelectedFile.value,
|
||||||
uploadFileDescription.value,
|
uploadFileDescription.value,
|
||||||
)
|
)
|
||||||
|
|
@ -848,7 +861,29 @@ onMounted(async () => {
|
||||||
</List.Item>
|
</List.Item>
|
||||||
</template>
|
</template>
|
||||||
</List>
|
</List>
|
||||||
<Typography.Paragraph v-else type="secondary">
|
<div v-if="organizationWideTrainingFiles.length > 0" class="role-files" style="margin-top: 1rem">
|
||||||
|
<Typography.Text strong>
|
||||||
|
Organization-wide training files (applies to all roles)
|
||||||
|
</Typography.Text>
|
||||||
|
<List
|
||||||
|
:data-source="organizationWideTrainingFiles"
|
||||||
|
size="small"
|
||||||
|
:bordered="false"
|
||||||
|
>
|
||||||
|
<template #renderItem="{ item: file }">
|
||||||
|
<List.Item>
|
||||||
|
<Space style="display: flex; justify-content: space-between; width: 100%">
|
||||||
|
<Typography.Text>{{ file.file_name }}</Typography.Text>
|
||||||
|
<Tag color="geekblue">Organization-wide</Tag>
|
||||||
|
</Space>
|
||||||
|
</List.Item>
|
||||||
|
</template>
|
||||||
|
</List>
|
||||||
|
</div>
|
||||||
|
<Typography.Paragraph
|
||||||
|
v-if="filteredRoles.length === 0 && organizationWideTrainingFiles.length === 0"
|
||||||
|
type="secondary"
|
||||||
|
>
|
||||||
{{ roleEmptyMessage }}
|
{{ roleEmptyMessage }}
|
||||||
</Typography.Paragraph>
|
</Typography.Paragraph>
|
||||||
</div>
|
</div>
|
||||||
|
|
@ -913,7 +948,8 @@ onMounted(async () => {
|
||||||
<Typography.Paragraph type="secondary" style="margin-bottom: 0">
|
<Typography.Paragraph type="secondary" style="margin-bottom: 0">
|
||||||
Upload optional training files for
|
Upload optional training files for
|
||||||
<strong>{{ createdRoleForWizard?.name }}</strong>
|
<strong>{{ createdRoleForWizard?.name }}</strong>
|
||||||
. You can also do this later.
|
. You can also do this later. Use the main Upload Training File modal for
|
||||||
|
organization-wide files.
|
||||||
</Typography.Paragraph>
|
</Typography.Paragraph>
|
||||||
|
|
||||||
<Input.TextArea
|
<Input.TextArea
|
||||||
|
|
@ -970,7 +1006,7 @@ onMounted(async () => {
|
||||||
title="Upload Training File"
|
title="Upload Training File"
|
||||||
ok-text="Upload"
|
ok-text="Upload"
|
||||||
cancel-text="Cancel"
|
cancel-text="Cancel"
|
||||||
:ok-button-props="{ loading: uploadingFile, disabled: !uploadRoleUuid || !uploadSelectedFile }"
|
:ok-button-props="{ loading: uploadingFile, disabled: !uploadSelectedFile }"
|
||||||
@ok="handleUploadModalOk"
|
@ok="handleUploadModalOk"
|
||||||
@cancel="uploadModalVisible = false"
|
@cancel="uploadModalVisible = false"
|
||||||
>
|
>
|
||||||
|
|
@ -982,13 +1018,16 @@ onMounted(async () => {
|
||||||
</Typography.Text>
|
</Typography.Text>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<Typography.Text strong>Role</Typography.Text>
|
<Typography.Text strong>Scope</Typography.Text>
|
||||||
<Select
|
<Select
|
||||||
v-model:value="uploadRoleUuid"
|
v-model:value="uploadRoleUuid"
|
||||||
placeholder="Select a role"
|
placeholder="Select training scope"
|
||||||
style="width: 100%"
|
style="width: 100%"
|
||||||
:options="Roles.map((role) => ({ label: role.name, value: role.uuid }))"
|
:options="uploadRoleOptions"
|
||||||
/>
|
/>
|
||||||
|
<Typography.Paragraph type="secondary" style="margin: 0.5rem 0 0">
|
||||||
|
Organization-wide files apply to every role in this organization.
|
||||||
|
</Typography.Paragraph>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<Input.TextArea
|
<Input.TextArea
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue