189 lines
5.6 KiB
Text
189 lines
5.6 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45d62106",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Basic RAG Implementation with a local LLM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "4c312410",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from gpt4all import GPT4All\n",
|
|
"from sentence_transformers import SentenceTransformer\n",
|
|
"from chromadb import PersistentClient\n",
|
|
"from docx import Document\n",
|
|
"\n",
|
|
"MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n",
|
|
"CONTEXT_SIZE = 8192\n",
|
|
"EMBEDDER = \"all-MiniLM-L6-v2\"\n",
|
|
"RAG_PATH = \"./build/rag_db\"\n",
|
|
"DOCS_PATH = \"./documents/fNIRS_Glossary_Hardware.docx\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "90bae527",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True, device = \"cuda\")\n",
|
|
"embedder = SentenceTransformer(EMBEDDER)\n",
|
|
"client = PersistentClient(path = RAG_PATH)\n",
|
|
"\n",
|
|
"\n",
|
|
"class EmbeddingFunctionWrapper:\n",
|
|
" def __init__(self, model):\n",
|
|
" self.model = model\n",
|
|
"\n",
|
|
" def name(self):\n",
|
|
" return \"sentence-transformers\"\n",
|
|
"\n",
|
|
" def __call__(self, input):\n",
|
|
" if isinstance(input, str):\n",
|
|
" texts = [input]\n",
|
|
" embs = self.model.encode(texts).tolist()\n",
|
|
" return embs[0]\n",
|
|
" else:\n",
|
|
" texts = list(input)\n",
|
|
" return self.model.encode(texts).tolist()\n",
|
|
"\n",
|
|
"embedding_fn = EmbeddingFunctionWrapper(embedder)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "34efbc7c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc = Document(DOCS_PATH)\n",
|
|
"docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n",
|
|
"chunk_size = 1000\n",
|
|
"documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n",
|
|
"embeddings = embedder.encode(documents).tolist()\n",
|
|
"collection = client.get_or_create_collection(\n",
|
|
" name = \"knowledge_base\",\n",
|
|
" embedding_function = embedding_fn,\n",
|
|
")\n",
|
|
"collection.add(\n",
|
|
" documents=documents,\n",
|
|
" embeddings=embeddings,\n",
|
|
" ids=[f\"doc{i}\" for i in range(len(documents))]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "ed2cc1ff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def retrieve(query, top_k = 1):\n",
|
|
" query_embedding = embedder.encode([query]).tolist()[0]\n",
|
|
" try:\n",
|
|
" results = collection.query(query_texts=[query], n_results=top_k)\n",
|
|
" return results[\"documents\"][0]\n",
|
|
" except Exception:\n",
|
|
" results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n",
|
|
" return results[\"documents\"][0]\n",
|
|
"\n",
|
|
"def rag_answer(query):\n",
|
|
" retrieved_docs = retrieve(query)\n",
|
|
" context = \"\\n\\n\".join(retrieved_docs)\n",
|
|
" max_context_length = 500\n",
|
|
" if len(context) > max_context_length:\n",
|
|
" context = context[:max_context_length] + \"...\"\n",
|
|
"\n",
|
|
" prompt = f\"\"\"\n",
|
|
"Use the context to answer the question.\n",
|
|
"Context:\n",
|
|
"{context}\n",
|
|
"Question:\n",
|
|
"{query}\n",
|
|
"Answer:\n",
|
|
"\"\"\"\n",
|
|
" print(f\"Prompt length: {len(prompt)}\")\n",
|
|
" return model.generate(prompt, max_tokens=200)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "6fa9fd10",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of documents: 68\n",
|
|
"Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n",
|
|
"Retrieved docs length: 1\n",
|
|
"Prompt length: 627\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"query = \"What can Frequency domain multidistance NIRS estimate?\"\n",
|
|
"print(f\"Number of documents: {len(documents)}\")\n",
|
|
"print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n",
|
|
"retrieved = retrieve(query)\n",
|
|
"print(f\"Retrieved docs length: {len(retrieved)}\")\n",
|
|
"response = rag_answer(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "5a82353e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"response"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|