Dynavera/gpu_server.py
Viswamedha Nalabotu 95fc6dccf8 Added extra logging
2026-03-11 16:12:05 +00:00

268 lines
No EOL
9.2 KiB
Python

import logging
import os
import json
from contextlib import asynccontextmanager
from typing import Dict, Any
import numpy as np
from torch import cuda, no_grad, Tensor
import torch.nn.functional as F
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from llama_cpp import Llama
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("gpu-node")
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
state: Dict[str, Any] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
device = "cuda" if cuda.is_available() else "cpu"
logger.info(f"--- Initializing GPU Node on {device} ---")
if device == "cpu":
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
try:
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
state["embed_model"] = SentenceTransformer(
EMBED_MODEL_NAME,
trust_remote_code=True,
device=device
)
if not os.path.exists(LLM_MODEL_PATH):
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
else:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
state["llm"] = Llama(
model_path=LLM_MODEL_PATH,
n_gpu_layers=-1,
n_ctx=8192,
n_batch=512,
verbose=False
)
logger.info("--- GPU Node Ready ---")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
yield
state.clear()
if cuda.is_available():
cuda.empty_cache()
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan)
@app.get("/health")
async def health():
return {
"status": "ok",
"embedding_ready": state.get("embed_model") is not None,
"llm_ready": state.get("llm") is not None,
}
def _resolve_target_dimensions(payload: Dict[str, Any]) -> int:
raw_target = payload.get("target_dimensions")
if raw_target in (None, ""):
raise HTTPException(status_code=400, detail="'target_dimensions' is required and must be a positive integer")
try:
target = int(raw_target)
except (TypeError, ValueError) as exc:
logger.warning("Invalid target_dimensions value: %s", raw_target)
raise HTTPException(status_code=400, detail="'target_dimensions' must be an integer") from exc
if target <= 0:
logger.warning("Non-positive target_dimensions value: %s", target)
raise HTTPException(status_code=400, detail="'target_dimensions' must be > 0")
return target
def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor:
curr_dim = embeddings.shape[1]
if curr_dim < target_dimensions:
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
elif curr_dim > target_dimensions:
embeddings = embeddings[:, :target_dimensions]
return F.normalize(embeddings, p=2, dim=1)
@app.post("/v1/embeddings")
async def embeddings(request: Request):
data = await request.json()
input_data = data.get("input", "")
input_kind = type(input_data).__name__
input_count = len(input_data) if isinstance(input_data, list) else (1 if isinstance(input_data, str) else 0)
logger.info("/v1/embeddings request received: input_kind=%s input_count=%s", input_kind, input_count)
target_dimensions = _resolve_target_dimensions(data)
logger.info("/v1/embeddings resolved target_dimensions=%s", target_dimensions)
if isinstance(input_data, str):
inputs = [input_data]
elif isinstance(input_data, list):
inputs = [str(item) for item in input_data if str(item).strip()]
else:
logger.warning("/v1/embeddings bad input type: %s", input_kind)
raise HTTPException(status_code=400, detail="'input' must be a string or list of strings")
if not inputs:
return {
"object": "list",
"data": [],
"model": EMBED_MODEL_NAME,
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
model = state.get("embed_model")
if model is None:
raise HTTPException(status_code=503, detail="Embedding model not initialized")
prefixed_inputs = [
text if text.startswith("search_") else f"search_query: {text}"
for text in inputs
]
with no_grad():
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions)
vector_list = vectors.cpu().tolist()
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
for idx, embedding in enumerate(vector_list)
],
"model": EMBED_MODEL_NAME,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in inputs),
"total_tokens": sum(len(text.split()) for text in inputs),
},
}
@app.post("/v1/semantic-chunk")
async def semantic_chunk(request: Request):
data = await request.json()
raw_text = data.get("text", "")
threshold_percentile = data.get("threshold", 95)
raw_text_len = len(raw_text) if isinstance(raw_text, str) else -1
logger.info("/v1/semantic-chunk request received: text_len=%s threshold=%s", raw_text_len, threshold_percentile,)
target_dimensions = _resolve_target_dimensions(data)
logger.info("/v1/semantic-chunk resolved target_dimensions=%s", target_dimensions)
if not raw_text:
logger.info("/v1/semantic-chunk empty text payload")
return {"chunks": [], "embeddings": []}
if len(raw_text) > 50000:
logger.warning("/v1/semantic-chunk payload too large: text_len=%s", len(raw_text))
raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.")
model = state.get("embed_model")
if model is None:
logger.error("/v1/semantic-chunk embedding model not initialized")
raise HTTPException(status_code=503, detail="Embedding model not initialized")
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
if len(sentences) < 2:
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
single = pad_and_normalize(single, target_dimensions=target_dimensions)
return {
"chunks": [raw_text],
"embeddings": single.cpu().tolist(),
}
s_embeddings = model.encode(sentences, convert_to_tensor=True)
distances = [
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
for i in range(len(s_embeddings) - 1)
]
breakpoint_threshold = np.percentile(distances, threshold_percentile)
indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold]
chunks = []
start = 0
for idx in indices:
chunks.append(". ".join(sentences[start : idx + 1]) + ".")
start = idx + 1
chunks.append(". ".join(sentences[start:]) + ".")
with no_grad():
final_embeddings = model.encode(
[f"search_document: {c}" for c in chunks],
convert_to_tensor=True
)
final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=target_dimensions)
return {
"chunks": chunks,
"embeddings": final_embeddings.cpu().tolist()
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
try:
data = await request.json()
except Exception as e:
raw_body = await request.body()
preview = raw_body[:500].decode("utf-8", errors="replace")
logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
messages = data.get("messages", [])
stream = data.get("stream", False)
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
llm = state.get("llm")
if not llm:
raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.")
try:
response = llm.create_chat_completion(
messages=messages,
stream=stream,
temperature=data.get("temperature", 0.7),
max_tokens=data.get("max_tokens", 1024),
stop=["<|eot_id|>", "<|end_of_text|>"]
)
if stream:
return StreamingResponse(
llm_streamer(response),
media_type="text/event-stream"
)
return response
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def llm_streamer(response_iterator):
for chunk in response_iterator:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
if __name__ == "__main__":
import uvicorn
uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)