Skip to content

Commit 3ea1a27

Browse files
committed
fix
1 parent 00eb78e commit 3ea1a27

2 files changed

Lines changed: 709 additions & 26 deletions

File tree

lightllm/server/function_call_parser.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,38 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
239239

240240
try:
241241
try:
242-
tool_call_pos = current_text.find(self.bot_token)
243-
if tool_call_pos != -1:
244-
start_idx = tool_call_pos + len(self.bot_token)
245-
elif self.current_tool_id > 0 and current_text.startswith(self.tool_call_separator):
242+
# Priority check: if we're processing a subsequent tool (current_tool_id > 0),
243+
# first check if text starts with the tool separator. This is critical for
244+
# parallel tool calls because the bot_token (e.g., '[') can also
245+
# appear inside array parameters of the current tool, and we must not
246+
# mistakenly identify that as the start of a new tool.
247+
used_separator_branch = False
248+
if self.current_tool_id > 0 and current_text.startswith(self.tool_call_separator):
246249
start_idx = len(self.tool_call_separator)
250+
used_separator_branch = True
247251
else:
248-
start_idx = 0
252+
tool_call_pos = current_text.find(self.bot_token)
253+
if tool_call_pos != -1:
254+
start_idx = tool_call_pos + len(self.bot_token)
255+
else:
256+
start_idx = 0
249257

250258
if start_idx >= len(current_text):
251259
return StreamingParseResult()
252260

253-
obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
261+
try:
262+
obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
263+
except (MalformedJSON, json.JSONDecodeError):
264+
# Separator landed on non-JSON markup; fall back to
265+
# bot_token which skips past all inter-object markup.
266+
# e.g. Qwen25: separator "\n" matches between eot/bot tags.
267+
if used_separator_branch and self.bot_token in current_text:
268+
start_idx = current_text.find(self.bot_token) + len(self.bot_token)
269+
if start_idx >= len(current_text):
270+
return StreamingParseResult()
271+
obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
272+
else:
273+
raise
254274

255275
is_current_complete = _is_complete_json(current_text[start_idx : start_idx + end_idx])
256276

@@ -272,7 +292,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
272292

273293
current_tool_call = obj
274294

275-
except MalformedJSON:
295+
except (MalformedJSON, json.JSONDecodeError):
276296
return StreamingParseResult()
277297

278298
if not current_tool_call:
@@ -331,19 +351,24 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
331351
# Only remove the processed portion, keep unprocessed content
332352
self._buffer = current_text[start_idx + end_idx :]
333353

334-
if self.current_tool_id < len(self.prev_tool_call_arr):
335-
self.prev_tool_call_arr[self.current_tool_id].clear()
336-
self.current_tool_name_sent = False
337-
self.streamed_args_for_tool[self.current_tool_id] = ""
338-
self.current_tool_id += 1
339-
340354
# If the tool is still being parsed, send incremental changes
341355
elif prev_arguments:
342356
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
343357
if cur_args_json != prev_args_json:
344358
prefix = _find_common_prefix(prev_args_json, cur_args_json)
345359
argument_diff = prefix[sent:]
346360

361+
# Update prev_tool_call_arr BEFORE advancing current_tool_id
362+
if self.current_tool_id >= 0:
363+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
364+
self.prev_tool_call_arr.append({})
365+
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
366+
367+
# Advance to next tool if complete
368+
if is_current_complete:
369+
self.current_tool_name_sent = False
370+
self.current_tool_id += 1
371+
347372
# Send the argument diff if there's something new
348373
if argument_diff is not None:
349374
# Use the correct tool_index: completing_tool_id for completed tools,
@@ -357,15 +382,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
357382
)
358383
],
359384
)
360-
if not is_current_complete:
361-
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
362-
363-
# Update prev_tool_call_arr with current state
364-
if self.current_tool_id >= 0:
365-
# Ensure prev_tool_call_arr is large enough
366-
while len(self.prev_tool_call_arr) <= self.current_tool_id:
367-
self.prev_tool_call_arr.append({})
368-
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
385+
self.streamed_args_for_tool[tool_index_to_use] += argument_diff
369386

370387
return res
371388

@@ -396,8 +413,8 @@ def __init__(self):
396413
Initializes the detector with necessary state variables.
397414
"""
398415
super().__init__()
399-
self.bot_token = "<tool_call>"
400-
self.eot_token = "</tool_call>"
416+
self.bot_token = "<tool_call>\n"
417+
self.eot_token = "\n</tool_call>"
401418
self.tool_call_separator = "\n"
402419
self._normal_text_buffer = "" # Buffer for handling partial end tokens
403420

@@ -443,7 +460,7 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
443460
self._normal_text_buffer += result.normal_text
444461

445462
# Check if buffer contains complete end token (without leading newline)
446-
end_token_without_newline = self.eot_token # "</tool_call>"
463+
end_token_without_newline = self.eot_token[1:] # "</tool_call>" (strip leading \n)
447464
if end_token_without_newline in self._normal_text_buffer:
448465
cleaned_text = self._normal_text_buffer.replace(end_token_without_newline, "")
449466
self._normal_text_buffer = ""
@@ -1890,7 +1907,15 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
18901907
self._buffer = current_text[last_end + len(self.eot_token) :].lstrip()
18911908
else:
18921909
self._buffer = ""
1893-
self.current_tool_id = -1
1910+
1911+
# Reassign tool_index sequentially so parallel calls using the same
1912+
# tool get distinct indices (detect_and_parse uses definition-position
1913+
# indices which collide when the same tool is called twice).
1914+
if self.current_tool_id == -1:
1915+
self.current_tool_id = 0
1916+
for call in result.calls:
1917+
call.tool_index = self.current_tool_id
1918+
self.current_tool_id += 1
18941919
self.current_tool_name_sent = False
18951920
return result
18961921

0 commit comments

Comments
 (0)