{ "cells": [ { "cell_type": "markdown", "id": "45d62106", "metadata": {}, "source": [ "# Basic RAG Implementation with a local LLM" ] }, { "cell_type": "code", "execution_count": 1, "id": "4c312410", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\nalab\\University\\vxn217\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from gpt4all import GPT4All\n", "from sentence_transformers import SentenceTransformer\n", "from chromadb import PersistentClient\n", "from docx import Document\n", "\n", "MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n", "CONTEXT_SIZE = 8192\n", "EMBEDDER = \"all-MiniLM-L6-v2\"\n", "RAG_PATH = \"./build/rag_db\"\n", "DOCS_PATH = \"./documents/fNIRS_Glossary_Hardware.docx\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "90bae527", "metadata": {}, "outputs": [], "source": [ "\n", "model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True, device = \"cuda\")\n", "embedder = SentenceTransformer(EMBEDDER)\n", "client = PersistentClient(path = RAG_PATH)\n", "\n", "\n", "class EmbeddingFunctionWrapper:\n", " def __init__(self, model):\n", " self.model = model\n", "\n", " def name(self):\n", " return \"sentence-transformers\"\n", "\n", " def __call__(self, input):\n", " if isinstance(input, str):\n", " texts = [input]\n", " embs = self.model.encode(texts).tolist()\n", " return embs[0]\n", " else:\n", " texts = list(input)\n", " return self.model.encode(texts).tolist()\n", "\n", "embedding_fn = EmbeddingFunctionWrapper(embedder)" ] }, { "cell_type": "code", "execution_count": 3, "id": "34efbc7c", "metadata": {}, "outputs": [], "source": [ "doc = Document(DOCS_PATH)\n", "docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n", "chunk_size = 1000\n", "documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n", "embeddings = embedder.encode(documents).tolist()\n", "collection = client.get_or_create_collection(\n", " name = \"knowledge_base\",\n", " embedding_function = embedding_fn,\n", ")\n", "collection.add(\n", " documents=documents,\n", " embeddings=embeddings,\n", " ids=[f\"doc{i}\" for i in range(len(documents))]\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed2cc1ff", "metadata": {}, "outputs": [], "source": [ "def retrieve(query, top_k = 1):\n", " query_embedding = embedder.encode([query]).tolist()[0]\n", " try:\n", " results = collection.query(query_texts=[query], n_results=top_k)\n", " return results[\"documents\"][0]\n", " except Exception:\n", " results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n", " return results[\"documents\"][0]\n", "\n", "def rag_answer(query):\n", " retrieved_docs = retrieve(query)\n", " context = \"\\n\\n\".join(retrieved_docs)\n", " max_context_length = 500\n", " if len(context) > max_context_length:\n", " context = context[:max_context_length] + \"...\"\n", "\n", " prompt = f\"\"\"\n", "Use the context to answer the question.\n", "Context:\n", "{context}\n", "Question:\n", "{query}\n", "Answer:\n", "\"\"\"\n", " print(f\"Prompt length: {len(prompt)}\")\n", " return model.generate(prompt, max_tokens=200)" ] }, { "cell_type": "code", "execution_count": 5, "id": "6fa9fd10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of documents: 68\n", "Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n", "Retrieved docs length: 1\n", "Prompt length: 627\n" ] } ], "source": [ "query = \"What can Frequency domain multidistance NIRS estimate?\"\n", "print(f\"Number of documents: {len(documents)}\")\n", "print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n", "retrieved = retrieve(query)\n", "print(f\"Retrieved docs length: {len(retrieved)}\")\n", "response = rag_answer(query)" ] }, { "cell_type": "code", "execution_count": 6, "id": "5a82353e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations. This may involve one or more modulation frequencies.\\n\\nExplanation:\\nThe frequency-domain multidistance NIRS method is a powerful tool for estimating the optical properties of biological tissues in-vivo. By capturing changes in intensity and phase at multiple source-detector separations/distances, this technique can provide absolute values of absorption (μa) and scattering (μs) coefficients. These estimates are crucial for understanding tissue physiology and pathophysiology.\\n\\nThe ability to estimate chromophore concentrations is particularly important as it allows researchers to monitor changes in biomarkers associated with various diseases or physiological processes. This information can be used to develop novel diagnostic tools, track disease progression, and evaluate the effectiveness of therapeutic interventions.\\n\\nIn summary, frequency-domain multidistance NIRS offers a unique combination of sensitivity, specificity, and spatial resolution for non-invasive optical imaging applications. Its ability to estimate absolute'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.9" } }, "nbformat": 4, "nbformat_minor": 5 }