353 lines
10 KiB
Text
353 lines
10 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45d62106",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Basic RAG Implementation with a local LLM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "4c312410",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from gpt4all import GPT4All\n",
|
|
"from sentence_transformers import SentenceTransformer\n",
|
|
"from chromadb import PersistentClient\n",
|
|
"from docx import Document\n",
|
|
"\n",
|
|
"MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n",
|
|
"CONTEXT_SIZE = 8192\n",
|
|
"EMBEDDER = \"all-MiniLM-L6-v2\"\n",
|
|
"RAG_PATH = \"./build/rag_db\"\n",
|
|
"DOCS_PATH = \"./build/documents/fNIRS_Glossary_Hardware.docx\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "90bae527",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "104f2001edc34aa5aff82734b3388041",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"modules.json: 0%| | 0.00/349 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"c:\\Users\\nalab\\University\\vxn217\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:143: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\nalab\\.cache\\huggingface\\hub\\models--sentence-transformers--all-MiniLM-L6-v2. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
|
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
|
" warnings.warn(message)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "7bf16ea40d964be19217eadc81f5674e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"config_sentence_transformers.json: 0%| | 0.00/116 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "32962e77048440908808689c5dc386e0",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"README.md: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "bf08ffecdfa94eaca2841e2b6b88eea5",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"sentence_bert_config.json: 0%| | 0.00/53.0 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6079ecdd0e464623a1d7e20999213213",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"config.json: 0%| | 0.00/612 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "60b2de9bec5c4237827d910660389db1",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"model.safetensors: 0%| | 0.00/90.9M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "05f352a112fb4ccd8968a7ffe335c80f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"tokenizer_config.json: 0%| | 0.00/350 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "b5f7aa6547c0455eb55863ad8ec6c84f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"vocab.txt: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "43605d598a604c10a85effee5869939e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"tokenizer.json: 0.00B [00:00, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "bd1a21fcccee4a92a50dcca08c858565",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"special_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "6d409c5032674774bfe157e1ec21eb3a",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"config.json: 0%| | 0.00/190 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True, device = \"cuda\")\n",
|
|
"embedder = SentenceTransformer(EMBEDDER)\n",
|
|
"client = PersistentClient(path = RAG_PATH)\n",
|
|
"\n",
|
|
"\n",
|
|
"class EmbeddingFunctionWrapper:\n",
|
|
" def __init__(self, model):\n",
|
|
" self.model = model\n",
|
|
"\n",
|
|
" def name(self):\n",
|
|
" return \"sentence-transformers\"\n",
|
|
"\n",
|
|
" def __call__(self, input):\n",
|
|
" if isinstance(input, str):\n",
|
|
" texts = [input]\n",
|
|
" embs = self.model.encode(texts).tolist()\n",
|
|
" return embs[0]\n",
|
|
" else:\n",
|
|
" texts = list(input)\n",
|
|
" return self.model.encode(texts).tolist()\n",
|
|
"\n",
|
|
"embedding_fn = EmbeddingFunctionWrapper(embedder)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "34efbc7c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc = Document(DOCS_PATH)\n",
|
|
"docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n",
|
|
"chunk_size = 1000\n",
|
|
"documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n",
|
|
"embeddings = embedder.encode(documents).tolist()\n",
|
|
"collection = client.get_or_create_collection(\n",
|
|
" name = \"knowledge_base\",\n",
|
|
" embedding_function = embedding_fn,\n",
|
|
")\n",
|
|
"collection.add(\n",
|
|
" documents=documents,\n",
|
|
" embeddings=embeddings,\n",
|
|
" ids=[f\"doc{i}\" for i in range(len(documents))]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "ed2cc1ff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def retrieve(query, top_k = 1):\n",
|
|
" query_embedding = embedder.encode([query]).tolist()[0]\n",
|
|
" try:\n",
|
|
" results = collection.query(query_texts=[query], n_results=top_k)\n",
|
|
" return results[\"documents\"][0]\n",
|
|
" except Exception:\n",
|
|
" results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n",
|
|
" return results[\"documents\"][0]\n",
|
|
"\n",
|
|
"def rag_answer(query):\n",
|
|
" retrieved_docs = retrieve(query)\n",
|
|
" context = \"\\n\\n\".join(retrieved_docs)\n",
|
|
" max_context_length = 500\n",
|
|
" if len(context) > max_context_length:\n",
|
|
" context = context[:max_context_length] + \"...\"\n",
|
|
"\n",
|
|
" prompt = f\"\"\"\n",
|
|
"Use the context to answer the question.\n",
|
|
"Context:\n",
|
|
"{context}\n",
|
|
"Question:\n",
|
|
"{query}\n",
|
|
"Answer:\n",
|
|
"\"\"\"\n",
|
|
" print(f\"Prompt length: {len(prompt)}\")\n",
|
|
" return model.generate(prompt, max_tokens=200)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "6fa9fd10",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of documents: 68\n",
|
|
"Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n",
|
|
"Retrieved docs length: 1\n",
|
|
"Prompt length: 627\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"query = \"What can Frequency domain multidistance NIRS estimate?\"\n",
|
|
"print(f\"Number of documents: {len(documents)}\")\n",
|
|
"print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n",
|
|
"retrieved = retrieve(query)\n",
|
|
"print(f\"Retrieved docs length: {len(retrieved)}\")\n",
|
|
"response = rag_answer(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "5a82353e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"response"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|