Added VRAM release mechanism
This commit is contained in:
parent
43618ee3f4
commit
c6f7f8917a
1 changed files with 45 additions and 14 deletions
|
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
@ -23,11 +24,47 @@ logger = logging.getLogger("gpu-node")
|
||||||
|
|
||||||
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
|
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
|
||||||
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
|
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
|
||||||
|
LLM_IDLE_TIMEOUT = int(os.getenv("LLM_IDLE_TIMEOUT", "1800"))
|
||||||
TARGET_DIMENSIONS = 768
|
TARGET_DIMENSIONS = 768
|
||||||
|
|
||||||
state: Dict[str, Any] = {}
|
state: Dict[str, Any] = {}
|
||||||
gpu_semaphore = asyncio.Semaphore(1)
|
gpu_semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_llm() -> Llama:
|
||||||
|
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
|
||||||
|
return Llama(model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False)
|
||||||
|
|
||||||
|
def _unload_llm():
|
||||||
|
llm = state.pop("llm", None)
|
||||||
|
del llm
|
||||||
|
if cuda.is_available():
|
||||||
|
cuda.empty_cache()
|
||||||
|
logger.info("LLM unloaded due to inactivity")
|
||||||
|
|
||||||
|
async def _inactivity_watcher():
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
llm = state.get("llm")
|
||||||
|
last_used = state.get("llm_last_used")
|
||||||
|
if llm is not None and last_used is not None:
|
||||||
|
if time.monotonic() - last_used > LLM_IDLE_TIMEOUT:
|
||||||
|
async with gpu_semaphore:
|
||||||
|
_unload_llm()
|
||||||
|
|
||||||
|
def _touch_llm():
|
||||||
|
state["llm_last_used"] = time.monotonic()
|
||||||
|
|
||||||
|
async def _ensure_llm() -> Llama:
|
||||||
|
llm = state.get("llm")
|
||||||
|
if llm is None:
|
||||||
|
if not os.path.exists(LLM_MODEL_PATH):
|
||||||
|
raise HTTPException(status_code=503, detail="LLM model file not found.")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
state["llm"] = await loop.run_in_executor(None, _load_llm)
|
||||||
|
_touch_llm()
|
||||||
|
return state["llm"]
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
device = "cuda" if cuda.is_available() else "cpu"
|
device = "cuda" if cuda.is_available() else "cpu"
|
||||||
|
|
@ -45,24 +82,21 @@ async def lifespan(app: FastAPI):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.exists(LLM_MODEL_PATH):
|
if not os.path.exists(LLM_MODEL_PATH):
|
||||||
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
|
logger.warning(f"LLM file not found at {LLM_MODEL_PATH} — will load on first request")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
|
state["llm"] = _load_llm()
|
||||||
state["llm"] = Llama(
|
_touch_llm()
|
||||||
model_path=LLM_MODEL_PATH,
|
|
||||||
n_gpu_layers=-1,
|
|
||||||
n_ctx=8192,
|
|
||||||
n_batch=512,
|
|
||||||
verbose=False
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("--- GPU Node Ready ---")
|
logger.info(f"--- GPU Node Ready (LLM idle timeout: {LLM_IDLE_TIMEOUT}s) ---")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load models: {e}")
|
logger.error(f"Failed to load models: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
watcher = asyncio.create_task(_inactivity_watcher())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
watcher.cancel()
|
||||||
state.clear()
|
state.clear()
|
||||||
if cuda.is_available():
|
if cuda.is_available():
|
||||||
cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
@ -235,9 +269,7 @@ async def chat_completions(request: Request):
|
||||||
|
|
||||||
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
|
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
|
||||||
|
|
||||||
llm = state.get("llm")
|
llm = await _ensure_llm()
|
||||||
if not llm:
|
|
||||||
raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.")
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
temperature = data.get("temperature", 0.7)
|
temperature = data.get("temperature", 0.7)
|
||||||
|
|
@ -254,7 +286,6 @@ async def chat_completions(request: Request):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
# For streaming, run inference in executor and stream results back
|
|
||||||
def _infer_stream():
|
def _infer_stream():
|
||||||
return llm.create_chat_completion(
|
return llm.create_chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue