Added rag implementation for testing of a local model

This commit is contained in:
Viswamedha Nalabotu 2025-11-19 21:44:18 +00:00
parent 42cd79662d
commit 12e0f141fe
6 changed files with 243 additions and 109 deletions

View file

@ -1,29 +0,0 @@
from typing import Any
import logging
from .llm import get_llm_for_domain
logger = logging.getLogger(__name__)
class SimpleAgent:
"""Minimal agent abstraction that calls a local LLM and returns responses."""
def __init__(self, name: str, domain: str, system_message: str | None = None):
self.name = name
self.domain = domain
self.system_message = system_message or "You are an assistant."
self._llm = get_llm_for_domain(domain)
def run(self, prompt: str, **kwargs: Any) -> str:
full_prompt = f"{self.system_message}\n\nUser: {prompt}"
logger.debug("Agent %s running prompt: %s", self.name, prompt)
return self._llm.generate(full_prompt)
def build_agents_for_domains(domains: list[str]) -> dict[str, SimpleAgent]:
agents = {}
for d in domains:
agents[d] = SimpleAgent(name=f"agent-{d}", domain=d, system_message=f"You are a tutor for {d}.")
return agents

View file

@ -1,63 +0,0 @@
"""Lightweight local LLM wrappers.
This file provides simple wrappers for `llama_cpp` and `transformers` backends.
They are intentionally minimal adapt to your runtime and model formats.
"""
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class BaseLLM:
def generate(self, prompt: str) -> str:
raise NotImplementedError()
class LlamaCPPWrapper(BaseLLM):
def __init__(self, model_path: str):
try:
from llama_cpp import Llama
self._llm = Llama(model_path=model_path)
except Exception:
logger.exception("llama_cpp is unavailable or model failed to load")
self._llm = None
def generate(self, prompt: str) -> str:
if self._llm is None:
raise RuntimeError("Llama model not available")
resp = self._llm(prompt)
return resp.get("text") if isinstance(resp, dict) else str(resp)
class TransformersWrapper(BaseLLM):
def __init__(self, model_name_or_path: str):
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto")
except Exception:
logger.exception("transformers not available or model failed to load")
self.model = None
self.tokenizer = None
def generate(self, prompt: str) -> str:
if self.model is None or self.tokenizer is None:
raise RuntimeError("Transformers model not available")
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_new_tokens=256)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_llm_for_domain(domain: str, prefer: str | None = None) -> BaseLLM:
# Basic loader: choose Llama (gguf) if file exists, else fall back to transformers
model_dir = "models" / domain
gguf = model_dir / "model.gguf"
if gguf.exists():
return LlamaCPPWrapper(str(gguf))
# fallback: try transformers
return TransformersWrapper(str(model_dir))

View file

@ -1,16 +0,0 @@
import logging
from .models import AgentRun
from .langgraph_adapter import SimpleAgent
logger = logging.getLogger(__name__)
def run_agent(agent: SimpleAgent, prompt: str) -> str:
"""Run the agent and store an AgentRun record using the Django ORM."""
out = agent.run(prompt)
try:
AgentRun.objects.create(agent_name=agent.name, input_text=prompt, output_text=out)
except Exception:
logger.exception("Failed to persist agent run via Django ORM")
return out

View file

@ -0,0 +1,34 @@
# Generated by Django 5.2.8 on 2025-11-19 14:22
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='Domain',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=255, unique=True)),
('description', models.TextField(blank=True, default='')),
],
),
migrations.CreateModel(
name='Dataset',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=255)),
('description', models.TextField(blank=True, default='')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('domain', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='datasets', to='domains.domain')),
],
),
]

View file

@ -1,3 +1,22 @@
from django.db import models
# Create your models here.
class Domain(models.Model):
name = models.CharField(max_length = 255, unique = True)
description = models.TextField(blank = True, default = "")
def __str__(self) -> str: # pragma: no cover - trivial
return self.name
class Dataset(models.Model):
domain = models.ForeignKey(Domain, on_delete = models.CASCADE, related_name = "datasets")
name = models.CharField(max_length = 255)
description = models.TextField(blank = True, default = "")
created_at = models.DateTimeField(auto_now_add = True)
updated_at = models.DateTimeField(auto_now = True)
def __str__(self) -> str: # pragma: no cover - trivial
return f"{self.name} ({self.domain.name})"

View file

@ -0,0 +1,189 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "45d62106",
"metadata": {},
"source": [
"# Basic RAG Implementation with a local LLM"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "4c312410",
"metadata": {},
"outputs": [],
"source": [
"from gpt4all import GPT4All\n",
"from sentence_transformers import SentenceTransformer\n",
"from chromadb import PersistentClient\n",
"from docx import Document\n",
"\n",
"MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n",
"CONTEXT_SIZE = 8192\n",
"EMBEDDER = \"all-MiniLM-L6-v2\"\n",
"RAG_PATH = \"./build/rag_db\"\n",
"DOCS_PATH = \"C:\\\\Users\\\\nalab\\\\Downloads\\\\fNIRS_Glossary_Hardware.docx\""
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "90bae527",
"metadata": {},
"outputs": [],
"source": [
"\n",
"model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True)\n",
"embedder = SentenceTransformer(EMBEDDER)\n",
"client = PersistentClient(path = RAG_PATH)\n",
"\n",
"\n",
"class EmbeddingFunctionWrapper:\n",
" def __init__(self, model):\n",
" self.model = model\n",
"\n",
" def name(self):\n",
" return \"sentence-transformers\"\n",
"\n",
" def __call__(self, input):\n",
" if isinstance(input, str):\n",
" texts = [input]\n",
" embs = self.model.encode(texts).tolist()\n",
" return embs[0]\n",
" else:\n",
" texts = list(input)\n",
" return self.model.encode(texts).tolist()\n",
"\n",
"embedding_fn = EmbeddingFunctionWrapper(embedder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34efbc7c",
"metadata": {},
"outputs": [],
"source": [
"doc = Document(DOCS_PATH)\n",
"docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n",
"chunk_size = 1000\n",
"documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n",
"embeddings = embedder.encode(documents).tolist()\n",
"collection = client.get_or_create_collection(\n",
" name = \"knowledge_base\",\n",
" embedding_function = embedding_fn,\n",
")\n",
"collection.add(\n",
" documents=documents,\n",
" embeddings=embeddings,\n",
" ids=[f\"doc{i}\" for i in range(len(documents))]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed2cc1ff",
"metadata": {},
"outputs": [],
"source": [
"def retrieve(query, top_k = 1):\n",
" query_embedding = embedder.encode([query]).tolist()[0]\n",
" try:\n",
" results = collection.query(query_texts=[query], n_results=top_k)\n",
" return results[\"documents\"][0]\n",
" except Exception:\n",
" results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n",
" return results[\"documents\"][0]\n",
"\n",
"def rag_answer(query):\n",
" retrieved_docs = retrieve(query)\n",
" context = \"\\n\\n\".join(retrieved_docs)\n",
" max_context_length = 500\n",
" if len(context) > max_context_length:\n",
" context = context[:max_context_length] + \"...\"\n",
"\n",
" prompt = f\"\"\"\n",
"Use the context to answer the question.\n",
"Context:\n",
"{context}\n",
"Question:\n",
"{query}\n",
"Answer:\n",
"\"\"\"\n",
" print(f\"Prompt length: {len(prompt)}\")\n",
" return model.generate(prompt, max_tokens=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fa9fd10",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of documents: 68\n",
"Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n",
"Retrieved docs length: 1\n",
"Prompt length: 630\n"
]
}
],
"source": [
"query = \"What can Frequency domain multidistance NIRS estimate?\"\n",
"print(f\"Number of documents: {len(documents)}\")\n",
"print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n",
"retrieved = retrieve(query)\n",
"print(f\"Retrieved docs length: {len(retrieved)}\")\n",
"response = rag_answer(query)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "5a82353e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}