Added VRAM release mechanism

This commit is contained in:
Viswamedha Nalabotu 2026-03-22 15:34:06 +00:00
parent 43618ee3f4
commit c6f7f8917a

View file

@ -2,6 +2,7 @@ import asyncio
import logging import logging
import os import os
import json import json
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Dict, Any from typing import Dict, Any
@ -23,11 +24,47 @@ logger = logging.getLogger("gpu-node")
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf") LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
LLM_IDLE_TIMEOUT = int(os.getenv("LLM_IDLE_TIMEOUT", "1800"))
TARGET_DIMENSIONS = 768 TARGET_DIMENSIONS = 768
state: Dict[str, Any] = {} state: Dict[str, Any] = {}
gpu_semaphore = asyncio.Semaphore(1) gpu_semaphore = asyncio.Semaphore(1)
def _load_llm() -> Llama:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
return Llama(model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False)
def _unload_llm():
llm = state.pop("llm", None)
del llm
if cuda.is_available():
cuda.empty_cache()
logger.info("LLM unloaded due to inactivity")
async def _inactivity_watcher():
while True:
await asyncio.sleep(60)
llm = state.get("llm")
last_used = state.get("llm_last_used")
if llm is not None and last_used is not None:
if time.monotonic() - last_used > LLM_IDLE_TIMEOUT:
async with gpu_semaphore:
_unload_llm()
def _touch_llm():
state["llm_last_used"] = time.monotonic()
async def _ensure_llm() -> Llama:
llm = state.get("llm")
if llm is None:
if not os.path.exists(LLM_MODEL_PATH):
raise HTTPException(status_code=503, detail="LLM model file not found.")
loop = asyncio.get_event_loop()
state["llm"] = await loop.run_in_executor(None, _load_llm)
_touch_llm()
return state["llm"]
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
device = "cuda" if cuda.is_available() else "cpu" device = "cuda" if cuda.is_available() else "cpu"
@ -45,24 +82,21 @@ async def lifespan(app: FastAPI):
) )
if not os.path.exists(LLM_MODEL_PATH): if not os.path.exists(LLM_MODEL_PATH):
logger.error(f"LLM File not found at {LLM_MODEL_PATH}") logger.warning(f"LLM file not found at {LLM_MODEL_PATH} — will load on first request")
else: else:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}") state["llm"] = _load_llm()
state["llm"] = Llama( _touch_llm()
model_path=LLM_MODEL_PATH,
n_gpu_layers=-1,
n_ctx=8192,
n_batch=512,
verbose=False
)
logger.info("--- GPU Node Ready ---") logger.info(f"--- GPU Node Ready (LLM idle timeout: {LLM_IDLE_TIMEOUT}s) ---")
except Exception as e: except Exception as e:
logger.error(f"Failed to load models: {e}") logger.error(f"Failed to load models: {e}")
raise e raise e
watcher = asyncio.create_task(_inactivity_watcher())
yield yield
watcher.cancel()
state.clear() state.clear()
if cuda.is_available(): if cuda.is_available():
cuda.empty_cache() cuda.empty_cache()
@ -235,9 +269,7 @@ async def chat_completions(request: Request):
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}") logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
llm = state.get("llm") llm = await _ensure_llm()
if not llm:
raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
temperature = data.get("temperature", 0.7) temperature = data.get("temperature", 0.7)
@ -254,7 +286,6 @@ async def chat_completions(request: Request):
try: try:
if stream: if stream:
# For streaming, run inference in executor and stream results back
def _infer_stream(): def _infer_stream():
return llm.create_chat_completion( return llm.create_chat_completion(
messages=messages, messages=messages,