Deign86 commited on
Commit
a8945eb
·
1 Parent(s): 1af7678

fix: enable DeepSeek direct streaming for chat/stream endpoint

Browse files
Files changed (2) hide show
  1. main.py +9 -6
  2. services/inference_client.py +4 -6
main.py CHANGED
@@ -1418,14 +1418,17 @@ async def call_hf_chat_stream_async(
1418
  yield str(chunk)
1419
 
1420
  client = get_inference_client()
1421
- async for chunk in client._call_deepseek_stream(
1422
- messages,
1423
- max_tokens=max_tokens,
1424
- temperature=temperature,
1425
- top_p=top_p,
1426
  model=model,
1427
  task_type=task_type,
1428
- ):
 
 
 
 
 
 
1429
  yield str(chunk)
1430
 
1431
 
 
1418
  yield str(chunk)
1419
 
1420
  client = get_inference_client()
1421
+ req = InferenceRequest(
1422
+ messages=messages,
 
 
 
1423
  model=model,
1424
  task_type=task_type,
1425
+ request_tag=f"{task_type}-async-{int(time.time() * 1000)}",
1426
+ max_new_tokens=max_tokens,
1427
+ temperature=temperature,
1428
+ top_p=top_p,
1429
+ timeout_sec=timeout,
1430
+ )
1431
+ async for chunk in client._call_deepseek_stream(req):
1432
  yield str(chunk)
1433
 
1434
 
services/inference_client.py CHANGED
@@ -913,9 +913,7 @@ class InferenceClient:
913
  "model": target_model,
914
  "messages": req.messages,
915
  "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
916
- "stream": True,
917
  }
918
-
919
  if target_model == REASONER_MODEL:
920
  params["max_tokens"] = req.max_new_tokens or 1024
921
  else:
@@ -932,10 +930,10 @@ class InferenceClient:
932
  )
933
  start = time.perf_counter()
934
  try:
935
- async with client.chat.completions.stream(**params, timeout=timeout) as stream:
936
- async for event in stream:
937
- if event.type == "content.delta" and event.content:
938
- yield event.content
939
 
940
  latency_ms = (time.perf_counter() - start) * 1000
941
  log_model_call(
 
913
  "model": target_model,
914
  "messages": req.messages,
915
  "max_tokens": req.max_new_tokens or self.default_max_new_tokens,
 
916
  }
 
917
  if target_model == REASONER_MODEL:
918
  params["max_tokens"] = req.max_new_tokens or 1024
919
  else:
 
930
  )
931
  start = time.perf_counter()
932
  try:
933
+ stream = client.chat.completions.stream(**params, timeout=timeout)
934
+ async for chunk in stream:
935
+ if chunk.choices[0].delta.content:
936
+ yield chunk.choices[0].delta.content
937
 
938
  latency_ms = (time.perf_counter() - start) * 1000
939
  log_model_call(