Tweaking and reframing yields and streaming with extra garbage collection
This commit is contained in:
parent
1eada257b9
commit
bf9eb6efb5
1 changed files with 35 additions and 18 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
|
|
@ -38,6 +39,7 @@ def _load_llm() -> Llama:
|
|||
def _unload_llm():
|
||||
llm = state.pop("llm", None)
|
||||
del llm
|
||||
gc.collect()
|
||||
if cuda.is_available():
|
||||
cuda.empty_cache()
|
||||
logger.info("LLM unloaded due to inactivity")
|
||||
|
|
@ -56,8 +58,9 @@ def _touch_llm():
|
|||
state["llm_last_used"] = time.monotonic()
|
||||
|
||||
async def _ensure_llm() -> Llama:
|
||||
llm = state.get("llm")
|
||||
if llm is None:
|
||||
if state.get("llm") is None:
|
||||
async with gpu_semaphore:
|
||||
if state.get("llm") is None:
|
||||
if not os.path.exists(LLM_MODEL_PATH):
|
||||
raise HTTPException(status_code=503, detail="LLM model file not found.")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -225,6 +228,7 @@ async def semantic_chunk(request: Request):
|
|||
single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS)
|
||||
return {"chunks": [raw_text], "embeddings": single.cpu().tolist()}
|
||||
|
||||
with no_grad():
|
||||
s_embeddings = model.encode(sentences, convert_to_tensor=True)
|
||||
distances = [
|
||||
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
|
||||
|
|
@ -286,20 +290,33 @@ async def chat_completions(request: Request):
|
|||
|
||||
try:
|
||||
if stream:
|
||||
def _infer_stream():
|
||||
return llm.create_chat_completion(
|
||||
sentinel = object()
|
||||
|
||||
async def _stream_response():
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
_loop = asyncio.get_event_loop()
|
||||
|
||||
def _produce():
|
||||
try:
|
||||
for chunk in llm.create_chat_completion(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["<|eot_id|>", "<|end_of_text|>"],
|
||||
)
|
||||
):
|
||||
_loop.call_soon_threadsafe(queue.put_nowait, chunk)
|
||||
finally:
|
||||
_loop.call_soon_threadsafe(queue.put_nowait, sentinel)
|
||||
|
||||
async def _stream_response():
|
||||
async with gpu_semaphore:
|
||||
chunks = await loop.run_in_executor(None, lambda: list(_infer_stream()))
|
||||
for chunk in chunks:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
fut = _loop.run_in_executor(None, _produce)
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is sentinel:
|
||||
break
|
||||
yield f"data: {json.dumps(item)}\n\n"
|
||||
await fut
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(_stream_response(), media_type="text/event-stream")
|
||||
|
|
|
|||
Loading…
Reference in a new issue