From f9073d53b602ee93bc203c6a2678d7beefee5c05 Mon Sep 17 00:00:00 2001 From: Viswamedha Nalabotu Date: Sun, 22 Mar 2026 08:19:57 +0000 Subject: [PATCH] Added llm api basic auth and disabled docs --- .env.example | 2 ++ .env.template | 2 ++ config/settings.py | 4 +++- gpu_server.py | 24 ++++++++++++++++++------ 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/.env.example b/.env.example index 25e80cf..92ad612 100644 --- a/.env.example +++ b/.env.example @@ -32,3 +32,5 @@ POSTGRES_PORT=5432 INFERENCE_PROTOCOL=http INFERENCE_HOST=fyp-inference-dev INFERENCE_PORT=8001 +INFERENCE_USERNAME=admin +INFERENCE_PASSWORD=changeme diff --git a/.env.template b/.env.template index 08c961f..56792e7 100644 --- a/.env.template +++ b/.env.template @@ -34,6 +34,8 @@ POSTGRES_PORT=5432 INFERENCE_PROTOCOL=http INFERENCE_HOST=localhost INFERENCE_PORT=8001 +INFERENCE_USERNAME=admin +INFERENCE_PASSWORD=change_this_to_a_secure_password # Production YAML (Ignore if you're setting up locally) FYP_DJANGO_IMAGE=dynavera-django:prod diff --git a/config/settings.py b/config/settings.py index fb208a2..7be80b1 100644 --- a/config/settings.py +++ b/config/settings.py @@ -28,7 +28,9 @@ DJANGO_CELERY_BROKER_URL = os.getenv('DJANGO_CELERY_BROKER_URL', 'redis://localh INFERENCE_HOST = os.getenv('INFERENCE_HOST', 'localhost') INFERENCE_PORT = os.getenv('INFERENCE_PORT', '8001') INFERENCE_PROTOCOL = os.getenv('INFERENCE_PROTOCOL', 'http') -INFERENCE_URL = f"{INFERENCE_PROTOCOL}://{INFERENCE_HOST}:{INFERENCE_PORT}" +INFERENCE_USERNAME = os.getenv('INFERENCE_USERNAME', 'admin') +INFERENCE_PASSWORD = os.getenv('INFERENCE_PASSWORD', 'changeme') +INFERENCE_URL = f"{INFERENCE_PROTOCOL}://{INFERENCE_USERNAME}:{INFERENCE_PASSWORD}@{INFERENCE_HOST}:{INFERENCE_PORT}" INFERENCE_SEMANTIC_CHUNK_ENDPOINT = f"{INFERENCE_URL}/v1/semantic-chunk" INFERENCE_EMBEDDINGS_ENDPOINT = f"{INFERENCE_URL}/v1/embeddings" INFERENCE_CHAT_COMPLETIONS_ENDPOINT = f"{INFERENCE_URL}/v1/chat/completions" diff --git a/gpu_server.py b/gpu_server.py index 6eb6623..fe1c117 100644 --- a/gpu_server.py +++ b/gpu_server.py @@ -8,8 +8,10 @@ from typing import Dict, Any import numpy as np from torch import cuda, no_grad, Tensor import torch.nn.functional as F -from fastapi import FastAPI, Request, HTTPException +import secrets +from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import StreamingResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials from llama_cpp import Llama from sentence_transformers import SentenceTransformer @@ -65,10 +67,20 @@ async def lifespan(app: FastAPI): if cuda.is_available(): cuda.empty_cache() -app = FastAPI(title="Agentic GPU Node", lifespan=lifespan) +app = FastAPI(title="Agentic GPU Node", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None) + +_security = HTTPBasic() +_API_USER = os.getenv("INFERENCE_USERNAME", "admin") +_API_PASS = os.getenv("INFERENCE_PASSWORD", "changeme") + +def require_auth(credentials: HTTPBasicCredentials = Depends(_security)): + valid_user = secrets.compare_digest(credentials.username.encode(), _API_USER.encode()) + valid_pass = secrets.compare_digest(credentials.password.encode(), _API_PASS.encode()) + if not (valid_user and valid_pass): + raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}) -@app.get("/health") +@app.get("/health", dependencies=[Depends(require_auth)]) async def health(): return { "status": "ok", @@ -85,7 +97,7 @@ def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor: return F.normalize(embeddings, p=2, dim=1) -@app.post("/v1/embeddings") +@app.post("/v1/embeddings", dependencies=[Depends(require_auth)]) async def embeddings(request: Request): data = await request.json() input_data = data.get("input", "") @@ -148,7 +160,7 @@ async def embeddings(request: Request): }, } -@app.post("/v1/semantic-chunk") +@app.post("/v1/semantic-chunk", dependencies=[Depends(require_auth)]) async def semantic_chunk(request: Request): data = await request.json() raw_text = data.get("text", "") @@ -208,7 +220,7 @@ async def semantic_chunk(request: Request): result = await loop.run_in_executor(None, _chunk_and_embed) return result -@app.post("/v1/chat/completions") +@app.post("/v1/chat/completions", dependencies=[Depends(require_auth)]) async def chat_completions(request: Request): try: data = await request.json()