From 12e0f141febcaa66594d466335013b26c458d9cc Mon Sep 17 00:00:00 2001 From: Viswamedha Nalabotu Date: Wed, 19 Nov 2025 21:44:18 +0000 Subject: [PATCH] Added rag implementation for testing of a local model --- apps/agents/langgraph_adapter.py | 29 --- apps/agents/llm.py | 63 ------ apps/agents/service.py | 16 -- apps/domains/migrations/0001_initial.py | 34 ++++ apps/domains/models.py | 21 +- .../local-model-rag-implementation.ipynb | 189 ++++++++++++++++++ 6 files changed, 243 insertions(+), 109 deletions(-) delete mode 100644 apps/agents/langgraph_adapter.py delete mode 100644 apps/agents/llm.py delete mode 100644 apps/agents/service.py create mode 100644 apps/domains/migrations/0001_initial.py create mode 100644 notebooks/local-model-rag-implementation.ipynb diff --git a/apps/agents/langgraph_adapter.py b/apps/agents/langgraph_adapter.py deleted file mode 100644 index 1342dd8..0000000 --- a/apps/agents/langgraph_adapter.py +++ /dev/null @@ -1,29 +0,0 @@ - -from typing import Any -import logging - -from .llm import get_llm_for_domain - -logger = logging.getLogger(__name__) - - -class SimpleAgent: - """Minimal agent abstraction that calls a local LLM and returns responses.""" - - def __init__(self, name: str, domain: str, system_message: str | None = None): - self.name = name - self.domain = domain - self.system_message = system_message or "You are an assistant." - self._llm = get_llm_for_domain(domain) - - def run(self, prompt: str, **kwargs: Any) -> str: - full_prompt = f"{self.system_message}\n\nUser: {prompt}" - logger.debug("Agent %s running prompt: %s", self.name, prompt) - return self._llm.generate(full_prompt) - - -def build_agents_for_domains(domains: list[str]) -> dict[str, SimpleAgent]: - agents = {} - for d in domains: - agents[d] = SimpleAgent(name=f"agent-{d}", domain=d, system_message=f"You are a tutor for {d}.") - return agents diff --git a/apps/agents/llm.py b/apps/agents/llm.py deleted file mode 100644 index ce1428c..0000000 --- a/apps/agents/llm.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Lightweight local LLM wrappers. - -This file provides simple wrappers for `llama_cpp` and `transformers` backends. -They are intentionally minimal — adapt to your runtime and model formats. -""" -from typing import Optional -import logging - - -logger = logging.getLogger(__name__) - - -class BaseLLM: - def generate(self, prompt: str) -> str: - raise NotImplementedError() - - -class LlamaCPPWrapper(BaseLLM): - def __init__(self, model_path: str): - try: - from llama_cpp import Llama - - self._llm = Llama(model_path=model_path) - except Exception: - logger.exception("llama_cpp is unavailable or model failed to load") - self._llm = None - - def generate(self, prompt: str) -> str: - if self._llm is None: - raise RuntimeError("Llama model not available") - resp = self._llm(prompt) - return resp.get("text") if isinstance(resp, dict) else str(resp) - - -class TransformersWrapper(BaseLLM): - def __init__(self, model_name_or_path: str): - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - import torch - - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto") - except Exception: - logger.exception("transformers not available or model failed to load") - self.model = None - self.tokenizer = None - - def generate(self, prompt: str) -> str: - if self.model is None or self.tokenizer is None: - raise RuntimeError("Transformers model not available") - inputs = self.tokenizer(prompt, return_tensors="pt") - outputs = self.model.generate(**inputs, max_new_tokens=256) - return self.tokenizer.decode(outputs[0], skip_special_tokens=True) - - -def get_llm_for_domain(domain: str, prefer: str | None = None) -> BaseLLM: - # Basic loader: choose Llama (gguf) if file exists, else fall back to transformers - model_dir = "models" / domain - gguf = model_dir / "model.gguf" - if gguf.exists(): - return LlamaCPPWrapper(str(gguf)) - # fallback: try transformers - return TransformersWrapper(str(model_dir)) diff --git a/apps/agents/service.py b/apps/agents/service.py deleted file mode 100644 index 8b6755d..0000000 --- a/apps/agents/service.py +++ /dev/null @@ -1,16 +0,0 @@ -import logging - -from .models import AgentRun -from .langgraph_adapter import SimpleAgent - -logger = logging.getLogger(__name__) - - -def run_agent(agent: SimpleAgent, prompt: str) -> str: - """Run the agent and store an AgentRun record using the Django ORM.""" - out = agent.run(prompt) - try: - AgentRun.objects.create(agent_name=agent.name, input_text=prompt, output_text=out) - except Exception: - logger.exception("Failed to persist agent run via Django ORM") - return out diff --git a/apps/domains/migrations/0001_initial.py b/apps/domains/migrations/0001_initial.py new file mode 100644 index 0000000..231535b --- /dev/null +++ b/apps/domains/migrations/0001_initial.py @@ -0,0 +1,34 @@ +# Generated by Django 5.2.8 on 2025-11-19 14:22 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='Domain', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255, unique=True)), + ('description', models.TextField(blank=True, default='')), + ], + ), + migrations.CreateModel( + name='Dataset', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ('description', models.TextField(blank=True, default='')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('domain', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='datasets', to='domains.domain')), + ], + ), + ] diff --git a/apps/domains/models.py b/apps/domains/models.py index 71a8362..5fd772b 100644 --- a/apps/domains/models.py +++ b/apps/domains/models.py @@ -1,3 +1,22 @@ from django.db import models -# Create your models here. + +class Domain(models.Model): + + name = models.CharField(max_length = 255, unique = True) + description = models.TextField(blank = True, default = "") + + def __str__(self) -> str: # pragma: no cover - trivial + return self.name + + +class Dataset(models.Model): + + domain = models.ForeignKey(Domain, on_delete = models.CASCADE, related_name = "datasets") + name = models.CharField(max_length = 255) + description = models.TextField(blank = True, default = "") + created_at = models.DateTimeField(auto_now_add = True) + updated_at = models.DateTimeField(auto_now = True) + + def __str__(self) -> str: # pragma: no cover - trivial + return f"{self.name} ({self.domain.name})" \ No newline at end of file diff --git a/notebooks/local-model-rag-implementation.ipynb b/notebooks/local-model-rag-implementation.ipynb new file mode 100644 index 0000000..3ff0146 --- /dev/null +++ b/notebooks/local-model-rag-implementation.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "45d62106", + "metadata": {}, + "source": [ + "# Basic RAG Implementation with a local LLM" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "4c312410", + "metadata": {}, + "outputs": [], + "source": [ + "from gpt4all import GPT4All\n", + "from sentence_transformers import SentenceTransformer\n", + "from chromadb import PersistentClient\n", + "from docx import Document\n", + "\n", + "MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n", + "CONTEXT_SIZE = 8192\n", + "EMBEDDER = \"all-MiniLM-L6-v2\"\n", + "RAG_PATH = \"./build/rag_db\"\n", + "DOCS_PATH = \"C:\\\\Users\\\\nalab\\\\Downloads\\\\fNIRS_Glossary_Hardware.docx\"" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "90bae527", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True)\n", + "embedder = SentenceTransformer(EMBEDDER)\n", + "client = PersistentClient(path = RAG_PATH)\n", + "\n", + "\n", + "class EmbeddingFunctionWrapper:\n", + " def __init__(self, model):\n", + " self.model = model\n", + "\n", + " def name(self):\n", + " return \"sentence-transformers\"\n", + "\n", + " def __call__(self, input):\n", + " if isinstance(input, str):\n", + " texts = [input]\n", + " embs = self.model.encode(texts).tolist()\n", + " return embs[0]\n", + " else:\n", + " texts = list(input)\n", + " return self.model.encode(texts).tolist()\n", + "\n", + "embedding_fn = EmbeddingFunctionWrapper(embedder)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34efbc7c", + "metadata": {}, + "outputs": [], + "source": [ + "doc = Document(DOCS_PATH)\n", + "docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n", + "chunk_size = 1000\n", + "documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n", + "embeddings = embedder.encode(documents).tolist()\n", + "collection = client.get_or_create_collection(\n", + " name = \"knowledge_base\",\n", + " embedding_function = embedding_fn,\n", + ")\n", + "collection.add(\n", + " documents=documents,\n", + " embeddings=embeddings,\n", + " ids=[f\"doc{i}\" for i in range(len(documents))]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed2cc1ff", + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve(query, top_k = 1):\n", + " query_embedding = embedder.encode([query]).tolist()[0]\n", + " try:\n", + " results = collection.query(query_texts=[query], n_results=top_k)\n", + " return results[\"documents\"][0]\n", + " except Exception:\n", + " results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n", + " return results[\"documents\"][0]\n", + "\n", + "def rag_answer(query):\n", + " retrieved_docs = retrieve(query)\n", + " context = \"\\n\\n\".join(retrieved_docs)\n", + " max_context_length = 500\n", + " if len(context) > max_context_length:\n", + " context = context[:max_context_length] + \"...\"\n", + "\n", + " prompt = f\"\"\"\n", + "Use the context to answer the question.\n", + "Context:\n", + "{context}\n", + "Question:\n", + "{query}\n", + "Answer:\n", + "\"\"\"\n", + " print(f\"Prompt length: {len(prompt)}\")\n", + " return model.generate(prompt, max_tokens=200)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fa9fd10", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of documents: 68\n", + "Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n", + "Retrieved docs length: 1\n", + "Prompt length: 630\n" + ] + } + ], + "source": [ + "query = \"What can Frequency domain multidistance NIRS estimate?\"\n", + "print(f\"Number of documents: {len(documents)}\")\n", + "print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n", + "retrieved = retrieve(query)\n", + "print(f\"Retrieved docs length: {len(retrieved)}\")\n", + "response = rag_answer(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "5a82353e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}