import logging import os import json from contextlib import asynccontextmanager from typing import Dict, Any import numpy as np from torch import cuda, no_grad, Tensor import torch.nn.functional as F from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse from llama_cpp import Llama from sentence_transformers import SentenceTransformer logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) logger = logging.getLogger("gpu-node") EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf") state: Dict[str, Any] = {} @asynccontextmanager async def lifespan(app: FastAPI): device = "cuda" if cuda.is_available() else "cpu" logger.info(f"--- Initializing GPU Node on {device} ---") if device == "cpu": logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.") try: logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}") state["embed_model"] = SentenceTransformer( EMBED_MODEL_NAME, trust_remote_code=True, device=device ) if not os.path.exists(LLM_MODEL_PATH): logger.error(f"LLM File not found at {LLM_MODEL_PATH}") else: logger.info(f"Loading LLM: {LLM_MODEL_PATH}") state["llm"] = Llama( model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False ) logger.info("--- GPU Node Ready ---") except Exception as e: logger.error(f"Failed to load models: {e}") raise e yield state.clear() if cuda.is_available(): cuda.empty_cache() app = FastAPI(title="Agentic GPU Node", lifespan=lifespan) @app.get("/health") async def health(): return { "status": "ok", "embedding_ready": state.get("embed_model") is not None, "llm_ready": state.get("llm") is not None, } def _resolve_target_dimensions(payload: Dict[str, Any]) -> int: raw_target = payload.get("target_dimensions") if raw_target in (None, ""): raise HTTPException(status_code=400, detail="'target_dimensions' is required and must be a positive integer") try: target = int(raw_target) except (TypeError, ValueError) as exc: logger.warning("Invalid target_dimensions value: %s", raw_target) raise HTTPException(status_code=400, detail="'target_dimensions' must be an integer") from exc if target <= 0: logger.warning("Non-positive target_dimensions value: %s", target) raise HTTPException(status_code=400, detail="'target_dimensions' must be > 0") return target def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor: curr_dim = embeddings.shape[1] if curr_dim < target_dimensions: embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0) elif curr_dim > target_dimensions: embeddings = embeddings[:, :target_dimensions] return F.normalize(embeddings, p=2, dim=1) @app.post("/v1/embeddings") async def embeddings(request: Request): data = await request.json() input_data = data.get("input", "") input_kind = type(input_data).__name__ input_count = len(input_data) if isinstance(input_data, list) else (1 if isinstance(input_data, str) else 0) logger.info("/v1/embeddings request received: input_kind=%s input_count=%s", input_kind, input_count) target_dimensions = _resolve_target_dimensions(data) logger.info("/v1/embeddings resolved target_dimensions=%s", target_dimensions) if isinstance(input_data, str): inputs = [input_data] elif isinstance(input_data, list): inputs = [str(item) for item in input_data if str(item).strip()] else: logger.warning("/v1/embeddings bad input type: %s", input_kind) raise HTTPException(status_code=400, detail="'input' must be a string or list of strings") if not inputs: return { "object": "list", "data": [], "model": EMBED_MODEL_NAME, "usage": {"prompt_tokens": 0, "total_tokens": 0}, } model = state.get("embed_model") if model is None: raise HTTPException(status_code=503, detail="Embedding model not initialized") prefixed_inputs = [ text if text.startswith("search_") else f"search_query: {text}" for text in inputs ] with no_grad(): vectors = model.encode(prefixed_inputs, convert_to_tensor=True) vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions) vector_list = vectors.cpu().tolist() return { "object": "list", "data": [ { "object": "embedding", "index": idx, "embedding": embedding, } for idx, embedding in enumerate(vector_list) ], "model": EMBED_MODEL_NAME, "usage": { "prompt_tokens": sum(len(text.split()) for text in inputs), "total_tokens": sum(len(text.split()) for text in inputs), }, } @app.post("/v1/semantic-chunk") async def semantic_chunk(request: Request): data = await request.json() raw_text = data.get("text", "") threshold_percentile = data.get("threshold", 95) raw_text_len = len(raw_text) if isinstance(raw_text, str) else -1 logger.info("/v1/semantic-chunk request received: text_len=%s threshold=%s", raw_text_len, threshold_percentile,) target_dimensions = _resolve_target_dimensions(data) logger.info("/v1/semantic-chunk resolved target_dimensions=%s", target_dimensions) if not raw_text: logger.info("/v1/semantic-chunk empty text payload") return {"chunks": [], "embeddings": []} if len(raw_text) > 50000: logger.warning("/v1/semantic-chunk payload too large: text_len=%s", len(raw_text)) raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.") model = state.get("embed_model") if model is None: logger.error("/v1/semantic-chunk embedding model not initialized") raise HTTPException(status_code=503, detail="Embedding model not initialized") sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()] if len(sentences) < 2: single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) single = pad_and_normalize(single, target_dimensions=target_dimensions) return { "chunks": [raw_text], "embeddings": single.cpu().tolist(), } s_embeddings = model.encode(sentences, convert_to_tensor=True) distances = [ 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() for i in range(len(s_embeddings) - 1) ] breakpoint_threshold = np.percentile(distances, threshold_percentile) indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold] chunks = [] start = 0 for idx in indices: chunks.append(". ".join(sentences[start : idx + 1]) + ".") start = idx + 1 chunks.append(". ".join(sentences[start:]) + ".") with no_grad(): final_embeddings = model.encode( [f"search_document: {c}" for c in chunks], convert_to_tensor=True ) final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=target_dimensions) return { "chunks": chunks, "embeddings": final_embeddings.cpu().tolist() } @app.post("/v1/chat/completions") async def chat_completions(request: Request): try: data = await request.json() except Exception as e: raw_body = await request.body() preview = raw_body[:500].decode("utf-8", errors="replace") logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}") raise HTTPException(status_code=400, detail="Invalid JSON payload") messages = data.get("messages", []) stream = data.get("stream", False) logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}") llm = state.get("llm") if not llm: raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.") try: response = llm.create_chat_completion( messages=messages, stream=stream, temperature=data.get("temperature", 0.7), max_tokens=data.get("max_tokens", 1024), stop=["<|eot_id|>", "<|end_of_text|>"] ) if stream: return StreamingResponse( llm_streamer(response), media_type="text/event-stream" ) return response except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) async def llm_streamer(response_iterator): for chunk in response_iterator: yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" if __name__ == "__main__": import uvicorn uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)