1919from http import HTTPStatus
2020from PIL import Image
2121import multiprocessing as mp
22- from typing import Any , AsyncGenerator , Optional , Union , List , Dict
22+ from typing import Any , AsyncGenerator , Optional , Union , List , Dict , Tuple
2323from typing import Callable
2424from lightllm .server import TokenLoad
2525from fastapi import BackgroundTasks , FastAPI , Request , WebSocket , WebSocketDisconnect
@@ -371,16 +371,13 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
371371 )
372372 return resp
373373
374- if sampling_params .n != 1 :
375- return create_error_response (HTTPStatus .BAD_REQUEST , "stream api only support n = 1" )
376-
377374 parser_dict = {}
378375 reasoning_parser_dict = {}
379376
380377 # Streaming case
381378 async def stream_results () -> AsyncGenerator [bytes , None ]:
382- finish_reason = None
383- has_emitted_tool_calls = False
379+ has_emitted_tool_calls : Dict [ int , bool ] = collections . defaultdict ( bool )
380+ stream_tool_call_ids : Dict [ Tuple [ int , int ], str ] = {}
384381 from .req_id_generator import convert_sub_id_to_group_id
385382
386383 prompt_tokens = 0
@@ -389,18 +386,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
389386 prompt_tokens = metadata ["prompt_tokens" ]
390387 completion_tokens += 1
391388 group_request_id = convert_sub_id_to_group_id (sub_req_id )
392- index = sub_req_id
389+ choice_index = sub_req_id - group_request_id
390+
393391 delta = request_output
394- finish_reason = finish_status .get_finish_reason ()
392+ current_finish_reason = finish_status .get_finish_reason ()
395393
396394 # Handle reasoning content
397395 if get_env_start_args ().reasoning_parser and request .separate_reasoning :
398396 reasoning_text , delta = _process_reasoning_stream (
399- index , delta , reasoning_parser_dict , request_output , request
397+ choice_index , delta , reasoning_parser_dict , request_output , request
400398 )
401399 if reasoning_text :
402400 choice_data = ChatCompletionStreamResponseChoice (
403- index = 0 ,
401+ index = choice_index ,
404402 delta = DeltaMessage (role = "assistant" , reasoning_content = reasoning_text ),
405403 finish_reason = None ,
406404 )
@@ -415,13 +413,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
415413 if request .tool_choice != "none" and request .tools :
416414 # parse_increment => returns (normal_text, calls)
417415 normal_text , calls = _process_tools_stream (
418- index = index , delta = delta , parser_dict = parser_dict , request = request
416+ index = choice_index , delta = delta , parser_dict = parser_dict , request = request
419417 )
420418
421419 # 1) if there's normal_text, output it as normal content
422420 if normal_text :
423421 choice_data = ChatCompletionStreamResponseChoice (
424- index = 0 ,
422+ index = choice_index ,
425423 delta = DeltaMessage (role = "assistant" , content = normal_text ),
426424 finish_reason = None ,
427425 )
@@ -435,32 +433,39 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
435433
436434 # 2) if we found calls, we output them as separate chunk(s)
437435 history_tool_calls_cnt = _get_history_tool_calls_cnt (request )
436+ fc_parser = parser_dict [choice_index ]
438437 for call_item in calls :
439- has_emitted_tool_calls = True
438+ has_emitted_tool_calls [ sub_req_id ] = True
440439 # transform call_item -> FunctionResponse + ToolCall
441- if finish_reason == "stop" :
442- latest_delta_len = 0
443- if isinstance (call_item .parameters , str ):
444- latest_delta_len = len (call_item .parameters )
445-
446- expected_call = json .dumps (
447- parser .multi_format_parser .detectors [0 ].prev_tool_call_arr [index ].get ("arguments" , {}),
448- ensure_ascii = False ,
449- )
450- actual_call = parser .multi_format_parser .detectors [0 ].streamed_args_for_tool [index ]
451- if latest_delta_len > 0 :
452- actual_call = actual_call [:- latest_delta_len ]
453- remaining_call = expected_call .replace (actual_call , "" , 1 )
454- call_item .parameters = remaining_call
440+ if current_finish_reason == "stop" :
441+ det = fc_parser .detector
442+ ti = call_item .tool_index
443+ if ti >= 0 and ti < len (det .prev_tool_call_arr ) and ti < len (det .streamed_args_for_tool ):
444+ latest_delta_len = 0
445+ if isinstance (call_item .parameters , str ):
446+ latest_delta_len = len (call_item .parameters )
447+
448+ expected_call = json .dumps (
449+ det .prev_tool_call_arr [ti ].get ("arguments" , {}),
450+ ensure_ascii = False ,
451+ )
452+ actual_call = det .streamed_args_for_tool [ti ]
453+ if latest_delta_len > 0 :
454+ actual_call = actual_call [:- latest_delta_len ]
455+ remaining_call = expected_call .replace (actual_call , "" , 1 )
456+ call_item .parameters = remaining_call
455457
458+ tool_parser = getattr (g_objs .args , "tool_call_parser" , None ) or "llama3"
459+ id_key = (choice_index , call_item .tool_index )
456460 if call_item .name :
457- # First chunk: include ID and function name
458- tool_parser = getattr (g_objs .args , "tool_call_parser" , None ) or "llama3"
459- tool_call_id = _process_tool_call_id (tool_parser , call_item , history_tool_calls_cnt )
461+ if id_key not in stream_tool_call_ids :
462+ stream_tool_call_ids [id_key ] = _process_tool_call_id (
463+ tool_parser , call_item , history_tool_calls_cnt
464+ )
465+ tool_call_id = stream_tool_call_ids [id_key ]
460466 function_name = call_item .name
461467 else :
462- # Subsequent chunks: null ID and name for argument deltas
463- tool_call_id = None
468+ tool_call_id = stream_tool_call_ids .get (id_key )
464469 function_name = None
465470
466471 tool_call = ToolCall (
@@ -472,7 +477,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
472477 ),
473478 )
474479 choice_data = ChatCompletionStreamResponseChoice (
475- index = 0 ,
480+ index = choice_index ,
476481 delta = DeltaMessage (role = "assistant" , tool_calls = [tool_call ]),
477482 finish_reason = None ,
478483 )
@@ -485,7 +490,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
485490 yield f"data: { chunk .model_dump_json ()} \n \n "
486491 else :
487492 delta_message = DeltaMessage (role = "assistant" , content = delta )
488- stream_choice = ChatCompletionStreamResponseChoice (index = 0 , delta = delta_message , finish_reason = None )
493+ stream_choice = ChatCompletionStreamResponseChoice (
494+ index = choice_index , delta = delta_message , finish_reason = None
495+ )
489496 stream_resp = ChatCompletionStreamResponse (
490497 id = group_request_id ,
491498 created = created_time ,
@@ -494,24 +501,22 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
494501 )
495502 yield f"data: { stream_resp .model_dump_json ()} \n \n "
496503
497- # Determine final finish_reason: override to "tool_calls" if tool calls were emitted
498- if has_emitted_tool_calls and finish_reason == "stop" :
499- finish_reason = "tool_calls"
500-
501- # Final empty chunk containing only finish_reason (and role)
502- if finish_reason is not None :
503- final_choice = ChatCompletionStreamResponseChoice (
504- index = 0 ,
505- delta = DeltaMessage (),
506- finish_reason = finish_reason ,
507- )
508- final_chunk = ChatCompletionStreamResponse (
509- id = group_request_id ,
510- created = created_time ,
511- model = request .model ,
512- choices = [final_choice ],
513- )
514- yield f"data: { final_chunk .model_dump_json ()} \n \n "
504+ # Emit a per-choice final empty chunk with finish_reason.
505+ if current_finish_reason is not None :
506+ if has_emitted_tool_calls [sub_req_id ] and current_finish_reason == "stop" :
507+ current_finish_reason = "tool_calls"
508+ final_choice = ChatCompletionStreamResponseChoice (
509+ index = choice_index ,
510+ delta = DeltaMessage (),
511+ finish_reason = current_finish_reason ,
512+ )
513+ final_chunk = ChatCompletionStreamResponse (
514+ id = group_request_id ,
515+ created = created_time ,
516+ model = request .model ,
517+ choices = [final_choice ],
518+ )
519+ yield f"data: { final_chunk .model_dump_json ()} \n \n "
515520
516521 if request .stream_options and request .stream_options .include_usage :
517522 usage = UsageInfo (
@@ -634,9 +639,6 @@ async def _process_prompts_completion(
634639 "Streaming is not supported for batch requests" ,
635640 )
636641
637- if sampling_params .n != 1 :
638- return create_error_response (HTTPStatus .BAD_REQUEST , "stream api only support n = 1" )
639-
640642 return await _handle_streaming_completion (
641643 prompts [0 ], sampling_params , multimodal_params , raw_request , request , created_time
642644 )
@@ -690,6 +692,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
690692
691693 async for sub_req_id , request_output , metadata , finish_status in results_generator :
692694 group_request_id = convert_sub_id_to_group_id (sub_req_id )
695+ choice_index = sub_req_id - group_request_id
693696 prompt_tokens = metadata ["prompt_tokens" ]
694697 completion_tokens += 1
695698 current_finish_reason = None
@@ -704,7 +707,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
704707 output_text = prompt_str + output_text
705708
706709 stream_choice = CompletionStreamChoice (
707- index = 0 ,
710+ index = choice_index ,
708711 text = output_text ,
709712 finish_reason = current_finish_reason ,
710713 logprobs = None if request .logprobs is None else {},
0 commit comments