diff --git a/gpu_server.py b/gpu_server.py index 4e5b439..64602ce 100644 --- a/gpu_server.py +++ b/gpu_server.py @@ -1,4 +1,5 @@ import asyncio +import gc import logging import os import json @@ -38,6 +39,7 @@ def _load_llm() -> Llama: def _unload_llm(): llm = state.pop("llm", None) del llm + gc.collect() if cuda.is_available(): cuda.empty_cache() logger.info("LLM unloaded due to inactivity") @@ -56,12 +58,13 @@ def _touch_llm(): state["llm_last_used"] = time.monotonic() async def _ensure_llm() -> Llama: - llm = state.get("llm") - if llm is None: - if not os.path.exists(LLM_MODEL_PATH): - raise HTTPException(status_code=503, detail="LLM model file not found.") - loop = asyncio.get_event_loop() - state["llm"] = await loop.run_in_executor(None, _load_llm) + if state.get("llm") is None: + async with gpu_semaphore: + if state.get("llm") is None: + if not os.path.exists(LLM_MODEL_PATH): + raise HTTPException(status_code=503, detail="LLM model file not found.") + loop = asyncio.get_event_loop() + state["llm"] = await loop.run_in_executor(None, _load_llm) _touch_llm() return state["llm"] @@ -225,7 +228,8 @@ async def semantic_chunk(request: Request): single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS) return {"chunks": [raw_text], "embeddings": single.cpu().tolist()} - s_embeddings = model.encode(sentences, convert_to_tensor=True) + with no_grad(): + s_embeddings = model.encode(sentences, convert_to_tensor=True) distances = [ 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() for i in range(len(s_embeddings) - 1) @@ -286,20 +290,33 @@ async def chat_completions(request: Request): try: if stream: - def _infer_stream(): - return llm.create_chat_completion( - messages=messages, - stream=True, - temperature=temperature, - max_tokens=max_tokens, - stop=["<|eot_id|>", "<|end_of_text|>"], - ) + sentinel = object() async def _stream_response(): + queue: asyncio.Queue = asyncio.Queue() + _loop = asyncio.get_event_loop() + + def _produce(): + try: + for chunk in llm.create_chat_completion( + messages=messages, + stream=True, + temperature=temperature, + max_tokens=max_tokens, + stop=["<|eot_id|>", "<|end_of_text|>"], + ): + _loop.call_soon_threadsafe(queue.put_nowait, chunk) + finally: + _loop.call_soon_threadsafe(queue.put_nowait, sentinel) + async with gpu_semaphore: - chunks = await loop.run_in_executor(None, lambda: list(_infer_stream())) - for chunk in chunks: - yield f"data: {json.dumps(chunk)}\n\n" + fut = _loop.run_in_executor(None, _produce) + while True: + item = await queue.get() + if item is sentinel: + break + yield f"data: {json.dumps(item)}\n\n" + await fut yield "data: [DONE]\n\n" return StreamingResponse(_stream_response(), media_type="text/event-stream")