Fixed formatting and output
This commit is contained in:
parent
4e548fdefd
commit
40bade6e1e
1 changed files with 8 additions and 20 deletions
|
|
@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
from torch import cuda, no_grad, Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
@ -25,15 +25,13 @@ state: Dict[str, Any] = {}
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Handles GPU model loading and cleanup."""
|
device = "cuda" if cuda.is_available() else "cpu"
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
logger.info(f"--- Initializing GPU Node on {device} ---")
|
logger.info(f"--- Initializing GPU Node on {device} ---")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
|
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load Embedding Model (Nomic)
|
|
||||||
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
|
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
|
||||||
state["embed_model"] = SentenceTransformer(
|
state["embed_model"] = SentenceTransformer(
|
||||||
EMBED_MODEL_NAME,
|
EMBED_MODEL_NAME,
|
||||||
|
|
@ -41,14 +39,13 @@ async def lifespan(app: FastAPI):
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load Llama Model (GGUF)
|
|
||||||
if not os.path.exists(LLM_MODEL_PATH):
|
if not os.path.exists(LLM_MODEL_PATH):
|
||||||
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
|
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
|
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
|
||||||
state["llm"] = Llama(
|
state["llm"] = Llama(
|
||||||
model_path=LLM_MODEL_PATH,
|
model_path=LLM_MODEL_PATH,
|
||||||
n_gpu_layers=-1, # Offload all layers to GPU
|
n_gpu_layers=-1,
|
||||||
n_ctx=8192,
|
n_ctx=8192,
|
||||||
n_batch=512,
|
n_batch=512,
|
||||||
verbose=False
|
verbose=False
|
||||||
|
|
@ -61,10 +58,9 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
state.clear()
|
state.clear()
|
||||||
if torch.cuda.is_available():
|
if cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
||||||
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan)
|
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
@ -94,8 +90,7 @@ def _resolve_target_dimensions(payload: Dict[str, Any]) -> int:
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch.Tensor:
|
def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor:
|
||||||
"""Dimension standardization plus L2 normalization."""
|
|
||||||
curr_dim = embeddings.shape[1]
|
curr_dim = embeddings.shape[1]
|
||||||
if curr_dim < target_dimensions:
|
if curr_dim < target_dimensions:
|
||||||
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
|
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
|
||||||
|
|
@ -106,7 +101,6 @@ def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch
|
||||||
|
|
||||||
@app.post("/v1/embeddings")
|
@app.post("/v1/embeddings")
|
||||||
async def embeddings(request: Request):
|
async def embeddings(request: Request):
|
||||||
"""Generates text embeddings compatible with OpenAI API format."""
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
input_data = data.get("input", "")
|
input_data = data.get("input", "")
|
||||||
target_dimensions = _resolve_target_dimensions(data)
|
target_dimensions = _resolve_target_dimensions(data)
|
||||||
|
|
@ -135,7 +129,7 @@ async def embeddings(request: Request):
|
||||||
for text in inputs
|
for text in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
|
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
|
||||||
vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions)
|
vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions)
|
||||||
|
|
||||||
|
|
@ -160,7 +154,6 @@ async def embeddings(request: Request):
|
||||||
|
|
||||||
@app.post("/v1/semantic-chunk")
|
@app.post("/v1/semantic-chunk")
|
||||||
async def semantic_chunk(request: Request):
|
async def semantic_chunk(request: Request):
|
||||||
"""Processes raw text into semantically cohesive blocks."""
|
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
raw_text = data.get("text", "")
|
raw_text = data.get("text", "")
|
||||||
threshold_percentile = data.get("threshold", 95)
|
threshold_percentile = data.get("threshold", 95)
|
||||||
|
|
@ -176,7 +169,6 @@ async def semantic_chunk(request: Request):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(status_code=503, detail="Embedding model not initialized")
|
raise HTTPException(status_code=503, detail="Embedding model not initialized")
|
||||||
|
|
||||||
# Split by sentences
|
|
||||||
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
|
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
|
||||||
if len(sentences) < 2:
|
if len(sentences) < 2:
|
||||||
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
|
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
|
||||||
|
|
@ -186,7 +178,6 @@ async def semantic_chunk(request: Request):
|
||||||
"embeddings": single.cpu().tolist(),
|
"embeddings": single.cpu().tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate sentence embeddings to find breakpoints via cosine distance
|
|
||||||
s_embeddings = model.encode(sentences, convert_to_tensor=True)
|
s_embeddings = model.encode(sentences, convert_to_tensor=True)
|
||||||
distances = [
|
distances = [
|
||||||
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
|
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
|
||||||
|
|
@ -203,7 +194,7 @@ async def semantic_chunk(request: Request):
|
||||||
start = idx + 1
|
start = idx + 1
|
||||||
chunks.append(". ".join(sentences[start:]) + ".")
|
chunks.append(". ".join(sentences[start:]) + ".")
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
final_embeddings = model.encode(
|
final_embeddings = model.encode(
|
||||||
[f"search_document: {c}" for c in chunks],
|
[f"search_document: {c}" for c in chunks],
|
||||||
convert_to_tensor=True
|
convert_to_tensor=True
|
||||||
|
|
@ -217,7 +208,6 @@ async def semantic_chunk(request: Request):
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completions(request: Request):
|
async def chat_completions(request: Request):
|
||||||
"""Unified LLM completion endpoint compatible with OpenAI-style requests."""
|
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -229,7 +219,6 @@ async def chat_completions(request: Request):
|
||||||
messages = data.get("messages", [])
|
messages = data.get("messages", [])
|
||||||
stream = data.get("stream", False)
|
stream = data.get("stream", False)
|
||||||
|
|
||||||
# Log incoming request details
|
|
||||||
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
|
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
|
||||||
|
|
||||||
llm = state.get("llm")
|
llm = state.get("llm")
|
||||||
|
|
@ -257,7 +246,6 @@ async def chat_completions(request: Request):
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
async def llm_streamer(response_iterator):
|
async def llm_streamer(response_iterator):
|
||||||
"""Iterates through llama-cpp generator and yields SSE chunks."""
|
|
||||||
for chunk in response_iterator:
|
for chunk in response_iterator:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue