diff --git a/gpu_server.py b/gpu_server.py index 44e8640..39200b9 100644 --- a/gpu_server.py +++ b/gpu_server.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager from typing import Dict, Any import numpy as np -import torch +from torch import cuda, no_grad, Tensor import torch.nn.functional as F from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse @@ -25,15 +25,13 @@ state: Dict[str, Any] = {} @asynccontextmanager async def lifespan(app: FastAPI): - """Handles GPU model loading and cleanup.""" - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda" if cuda.is_available() else "cpu" logger.info(f"--- Initializing GPU Node on {device} ---") if device == "cpu": logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.") try: - # Load Embedding Model (Nomic) logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}") state["embed_model"] = SentenceTransformer( EMBED_MODEL_NAME, @@ -41,14 +39,13 @@ async def lifespan(app: FastAPI): device=device ) - # Load Llama Model (GGUF) if not os.path.exists(LLM_MODEL_PATH): logger.error(f"LLM File not found at {LLM_MODEL_PATH}") else: logger.info(f"Loading LLM: {LLM_MODEL_PATH}") state["llm"] = Llama( model_path=LLM_MODEL_PATH, - n_gpu_layers=-1, # Offload all layers to GPU + n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False @@ -61,10 +58,9 @@ async def lifespan(app: FastAPI): yield - # Cleanup state.clear() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if cuda.is_available(): + cuda.empty_cache() app = FastAPI(title="Agentic GPU Node", lifespan=lifespan) @@ -94,8 +90,7 @@ def _resolve_target_dimensions(payload: Dict[str, Any]) -> int: return target -def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch.Tensor: - """Dimension standardization plus L2 normalization.""" +def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor: curr_dim = embeddings.shape[1] if curr_dim < target_dimensions: embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0) @@ -106,7 +101,6 @@ def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch @app.post("/v1/embeddings") async def embeddings(request: Request): - """Generates text embeddings compatible with OpenAI API format.""" data = await request.json() input_data = data.get("input", "") target_dimensions = _resolve_target_dimensions(data) @@ -135,7 +129,7 @@ async def embeddings(request: Request): for text in inputs ] - with torch.no_grad(): + with no_grad(): vectors = model.encode(prefixed_inputs, convert_to_tensor=True) vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions) @@ -160,7 +154,6 @@ async def embeddings(request: Request): @app.post("/v1/semantic-chunk") async def semantic_chunk(request: Request): - """Processes raw text into semantically cohesive blocks.""" data = await request.json() raw_text = data.get("text", "") threshold_percentile = data.get("threshold", 95) @@ -176,7 +169,6 @@ async def semantic_chunk(request: Request): if model is None: raise HTTPException(status_code=503, detail="Embedding model not initialized") - # Split by sentences sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()] if len(sentences) < 2: single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) @@ -186,7 +178,6 @@ async def semantic_chunk(request: Request): "embeddings": single.cpu().tolist(), } - # Generate sentence embeddings to find breakpoints via cosine distance s_embeddings = model.encode(sentences, convert_to_tensor=True) distances = [ 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() @@ -203,7 +194,7 @@ async def semantic_chunk(request: Request): start = idx + 1 chunks.append(". ".join(sentences[start:]) + ".") - with torch.no_grad(): + with no_grad(): final_embeddings = model.encode( [f"search_document: {c}" for c in chunks], convert_to_tensor=True @@ -217,7 +208,6 @@ async def semantic_chunk(request: Request): @app.post("/v1/chat/completions") async def chat_completions(request: Request): - """Unified LLM completion endpoint compatible with OpenAI-style requests.""" try: data = await request.json() except Exception as e: @@ -229,7 +219,6 @@ async def chat_completions(request: Request): messages = data.get("messages", []) stream = data.get("stream", False) - # Log incoming request details logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}") llm = state.get("llm") @@ -257,7 +246,6 @@ async def chat_completions(request: Request): raise HTTPException(status_code=500, detail=str(e)) async def llm_streamer(response_iterator): - """Iterates through llama-cpp generator and yields SSE chunks.""" for chunk in response_iterator: yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n"