diff --git a/apps/knowledge/admin.py b/apps/knowledge/admin.py index 12d0015..c63be62 100644 --- a/apps/knowledge/admin.py +++ b/apps/knowledge/admin.py @@ -28,7 +28,7 @@ class RoleRagDocumentAdmin(admin.ModelAdmin): fields.remove('embedding') return fields - @admin.display(description=_("Embedding Preview (1536d)")) + @admin.display(description=_("Embedding Preview")) def display_embedding(self, obj): if obj.embedding is not None: preview = list(obj.embedding[:5]) diff --git a/apps/knowledge/migrations/0001_initial.py b/apps/knowledge/migrations/0001_initial.py index 27378a2..9c30bfe 100644 --- a/apps/knowledge/migrations/0001_initial.py +++ b/apps/knowledge/migrations/0001_initial.py @@ -48,7 +48,7 @@ class Migration(migrations.Migration): ('updated_at', models.DateTimeField(auto_now=True, verbose_name='Updated At')), ('content', models.TextField()), ('content_hash', models.CharField(db_index=True, max_length=64)), - ('embedding', pgvector.django.vector.VectorField(blank=True, dimensions=1536, null=True)), + ('embedding', pgvector.django.vector.VectorField(blank=True, dimensions=getattr(settings, 'EMBEDDING_DIMENSIONS', 768), null=True)), ('metadata', models.JSONField(blank=True, default=dict)), ('chunk_index', models.IntegerField(default=0)), ('is_active', models.BooleanField(default=True)), diff --git a/apps/knowledge/models.py b/apps/knowledge/models.py index acb08b4..ed0ab3e 100644 --- a/apps/knowledge/models.py +++ b/apps/knowledge/models.py @@ -1,5 +1,6 @@ import os +from django.conf import settings from django.db import transaction from django.db.models import CASCADE, BooleanField, CharField, FileField, ForeignKey, IntegerField, JSONField, Model, TextField from django.db.models.signals import post_delete, post_save @@ -46,7 +47,7 @@ class RoleRagDocument(IdentifierMixin, TimeStampMixin, Model): content = TextField() content_hash = CharField(max_length=64, db_index=True) - embedding = VectorField(dimensions=1536, null=True, blank=True) + embedding = VectorField(dimensions=settings.EMBEDDING_DIMENSIONS, null=True, blank=True) metadata = JSONField(default=dict, blank=True) chunk_index = IntegerField(default=0) diff --git a/apps/knowledge/tasks.py b/apps/knowledge/tasks.py index 3889724..7a548f5 100644 --- a/apps/knowledge/tasks.py +++ b/apps/knowledge/tasks.py @@ -50,6 +50,8 @@ def ingest_training_file_task(self, file_uuid): file_obj.status = 'ingesting' file_obj.save() + target_dimensions = RoleRagDocument._meta.get_field('embedding').dimensions + try: raw_text = _extract_text_from_training_file(file_obj) if not raw_text: @@ -65,7 +67,11 @@ def ingest_training_file_task(self, file_uuid): for text_segment in _get_text_chunks(raw_text): response = client.post( settings.INFERENCE_SEMANTIC_CHUNK_ENDPOINT, - json={"text": text_segment, "threshold": 95} + json={ + "text": text_segment, + "threshold": 95, + "target_dimensions": target_dimensions, + }, ) response.raise_for_status() result = response.json() diff --git a/apps/onboarding/mcp.py b/apps/onboarding/mcp.py index 23df4e5..9227cba 100644 --- a/apps/onboarding/mcp.py +++ b/apps/onboarding/mcp.py @@ -68,10 +68,14 @@ class MCPRouter: async def _get_embedding(self, text): logger.info('MCP embedding request started') + target_dimensions = RoleRagDocument._meta.get_field('embedding').dimensions async with httpx.AsyncClient() as client: response = await client.post( settings.INFERENCE_EMBEDDINGS_ENDPOINT, - json={'input': text}, + json={ + 'input': text, + 'target_dimensions': target_dimensions, + }, ) response.raise_for_status() embedding = response.json()['data'][0]['embedding'] diff --git a/config/settings.py b/config/settings.py index 049d1db..05bd3dd 100644 --- a/config/settings.py +++ b/config/settings.py @@ -32,6 +32,7 @@ INFERENCE_SEMANTIC_CHUNK_ENDPOINT = f"{INFERENCE_URL}/v1/semantic-chunk" INFERENCE_EMBEDDINGS_ENDPOINT = f"{INFERENCE_URL}/v1/embeddings" INFERENCE_CHAT_COMPLETIONS_ENDPOINT = f"{INFERENCE_URL}/v1/chat/completions" INFERENCE_INGEST_TIMEOUT = float(os.getenv('INFERENCE_INGEST_TIMEOUT', '600')) +EMBEDDING_DIMENSIONS = int(os.getenv('EMBEDDING_DIMENSIONS', '768')) STATIC_URL = os.getenv('DJANGO_STATIC_URL', '/static/') MEDIA_URL = os.getenv('DJANGO_MEDIA_URL', '/media/') diff --git a/gpu_server.py b/gpu_server.py index 36f33e2..44e8640 100644 --- a/gpu_server.py +++ b/gpu_server.py @@ -20,7 +20,6 @@ logger = logging.getLogger("gpu-node") EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf") -TARGET_DIMENSIONS = 1536 state: Dict[str, Any] = {} @@ -79,13 +78,29 @@ async def health(): } -def pad_and_normalize(embeddings: torch.Tensor) -> torch.Tensor: - """Standardizes vector dimensions to 1536 for pgvector compatibility.""" +def _resolve_target_dimensions(payload: Dict[str, Any]) -> int: + raw_target = payload.get("target_dimensions") + if raw_target in (None, ""): + raise HTTPException(status_code=400, detail="'target_dimensions' is required") + + try: + target = int(raw_target) + except (TypeError, ValueError) as exc: + raise HTTPException(status_code=400, detail="'target_dimensions' must be an integer") from exc + + if target <= 0: + raise HTTPException(status_code=400, detail="'target_dimensions' must be > 0") + + return target + + +def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch.Tensor: + """Dimension standardization plus L2 normalization.""" curr_dim = embeddings.shape[1] - if curr_dim < TARGET_DIMENSIONS: - embeddings = F.pad(embeddings, (0, TARGET_DIMENSIONS - curr_dim), "constant", 0) - elif curr_dim > TARGET_DIMENSIONS: - embeddings = embeddings[:, :TARGET_DIMENSIONS] + if curr_dim < target_dimensions: + embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0) + elif curr_dim > target_dimensions: + embeddings = embeddings[:, :target_dimensions] return F.normalize(embeddings, p=2, dim=1) @@ -94,6 +109,7 @@ async def embeddings(request: Request): """Generates text embeddings compatible with OpenAI API format.""" data = await request.json() input_data = data.get("input", "") + target_dimensions = _resolve_target_dimensions(data) if isinstance(input_data, str): inputs = [input_data] @@ -121,7 +137,7 @@ async def embeddings(request: Request): with torch.no_grad(): vectors = model.encode(prefixed_inputs, convert_to_tensor=True) - vectors = pad_and_normalize(vectors) + vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions) vector_list = vectors.cpu().tolist() @@ -148,6 +164,7 @@ async def semantic_chunk(request: Request): data = await request.json() raw_text = data.get("text", "") threshold_percentile = data.get("threshold", 95) + target_dimensions = _resolve_target_dimensions(data) if not raw_text: return {"chunks": [], "embeddings": []} @@ -162,9 +179,11 @@ async def semantic_chunk(request: Request): # Split by sentences sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()] if len(sentences) < 2: + single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) + single = pad_and_normalize(single, target_dimensions=target_dimensions) return { "chunks": [raw_text], - "embeddings": model.encode([f"search_document: {raw_text}"]).tolist() + "embeddings": single.cpu().tolist(), } # Generate sentence embeddings to find breakpoints via cosine distance @@ -189,7 +208,7 @@ async def semantic_chunk(request: Request): [f"search_document: {c}" for c in chunks], convert_to_tensor=True ) - final_embeddings = pad_and_normalize(final_embeddings) + final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=target_dimensions) return { "chunks": chunks, diff --git a/report/report.tex b/report/report.tex index 76de114..603a549 100644 --- a/report/report.tex +++ b/report/report.tex @@ -461,7 +461,7 @@ embeddings. This avoids naive fixed-size splits that can break context mid-concept. \underline{Vector storage and retrieval with pgvector}\\ -Returned chunk embeddings are stored in RoleRagDocument.embedding (1536 +Returned chunk embeddings are stored in RoleRagDocument.embedding (768 dimensions) in PostgreSQL using pgvector, linked relationally to role and source file metadata. Retrieval is performed in SQL using cosine-distance ranking and top-k selection, allowing role filtering and