Dynavera/notebooks/local-model-rag-implementation.ipynb

190 lines
5.6 KiB
Text
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "45d62106",
"metadata": {},
"source": [
"# Basic RAG Implementation with a local LLM"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "4c312410",
"metadata": {},
"outputs": [],
"source": [
"from gpt4all import GPT4All\n",
"from sentence_transformers import SentenceTransformer\n",
"from chromadb import PersistentClient\n",
"from docx import Document\n",
"\n",
"MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n",
"CONTEXT_SIZE = 8192\n",
"EMBEDDER = \"all-MiniLM-L6-v2\"\n",
"RAG_PATH = \"./build/rag_db\"\n",
"DOCS_PATH = \"C:\\\\Users\\\\nalab\\\\Downloads\\\\fNIRS_Glossary_Hardware.docx\""
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "90bae527",
"metadata": {},
"outputs": [],
"source": [
"\n",
"model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True)\n",
"embedder = SentenceTransformer(EMBEDDER)\n",
"client = PersistentClient(path = RAG_PATH)\n",
"\n",
"\n",
"class EmbeddingFunctionWrapper:\n",
" def __init__(self, model):\n",
" self.model = model\n",
"\n",
" def name(self):\n",
" return \"sentence-transformers\"\n",
"\n",
" def __call__(self, input):\n",
" if isinstance(input, str):\n",
" texts = [input]\n",
" embs = self.model.encode(texts).tolist()\n",
" return embs[0]\n",
" else:\n",
" texts = list(input)\n",
" return self.model.encode(texts).tolist()\n",
"\n",
"embedding_fn = EmbeddingFunctionWrapper(embedder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34efbc7c",
"metadata": {},
"outputs": [],
"source": [
"doc = Document(DOCS_PATH)\n",
"docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n",
"chunk_size = 1000\n",
"documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n",
"embeddings = embedder.encode(documents).tolist()\n",
"collection = client.get_or_create_collection(\n",
" name = \"knowledge_base\",\n",
" embedding_function = embedding_fn,\n",
")\n",
"collection.add(\n",
" documents=documents,\n",
" embeddings=embeddings,\n",
" ids=[f\"doc{i}\" for i in range(len(documents))]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed2cc1ff",
"metadata": {},
"outputs": [],
"source": [
"def retrieve(query, top_k = 1):\n",
" query_embedding = embedder.encode([query]).tolist()[0]\n",
" try:\n",
" results = collection.query(query_texts=[query], n_results=top_k)\n",
" return results[\"documents\"][0]\n",
" except Exception:\n",
" results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n",
" return results[\"documents\"][0]\n",
"\n",
"def rag_answer(query):\n",
" retrieved_docs = retrieve(query)\n",
" context = \"\\n\\n\".join(retrieved_docs)\n",
" max_context_length = 500\n",
" if len(context) > max_context_length:\n",
" context = context[:max_context_length] + \"...\"\n",
"\n",
" prompt = f\"\"\"\n",
"Use the context to answer the question.\n",
"Context:\n",
"{context}\n",
"Question:\n",
"{query}\n",
"Answer:\n",
"\"\"\"\n",
" print(f\"Prompt length: {len(prompt)}\")\n",
" return model.generate(prompt, max_tokens=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fa9fd10",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of documents: 68\n",
"Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n",
"Retrieved docs length: 1\n",
"Prompt length: 630\n"
]
}
],
"source": [
"query = \"What can Frequency domain multidistance NIRS estimate?\"\n",
"print(f\"Number of documents: {len(documents)}\")\n",
"print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n",
"retrieved = retrieve(query)\n",
"print(f\"Retrieved docs length: {len(retrieved)}\")\n",
"response = rag_answer(query)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "5a82353e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}