Skip to content

Commit b9bd34f

Browse files
committed
Added max_recursive_calls to limit tool calls
1 parent 9c03600 commit b9bd34f

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

llms_wrapper/llms.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"max_input_tokens",
4848
"use_phoenix",
4949
"via_streaming",
50+
"max_recursive_calls",
5051
"min_delay", # minimum delay between queries for that model
5152
]
5253

@@ -649,7 +650,8 @@ def query(
649650
litellm_debug=None,
650651
stream=False,
651652
via_streaming=False,
652-
recursive_call_info: Optional[Dict[str, any]] = None,
653+
max_recursive_calls=99,
654+
recursive_call_info: Optional[Dict[str, any]] = None,
653655
**kwargs,
654656
) -> Dict[str, any]:
655657
"""
@@ -715,7 +717,7 @@ def cleaned_args(args: dict):
715717
logger.debug(f"Options: via_streaming: {via_streaming}, stream: {stream}")
716718
logger.debug(f"Initial completion kwargs: {cleaned_args(completion_kwargs)}")
717719
if recursive_call_info is None:
718-
recursive_call_info = {}
720+
recursive_call_info = dict(n_calls=0)
719721
if llm.get("api_key"):
720722
completion_kwargs["api_key"] = llm["api_key"]
721723
elif llm.get("api_key_env"):
@@ -953,8 +955,18 @@ def chunk_generator(model_generator, retobj):
953955
if len(tool_calls) > 0: # not an empty list
954956
if debug:
955957
logger.debug(f"Appending response message: {response_message}")
958+
skip_tools = False
959+
if recursive_call_info["n_calls"] > max_recursive_calls:
960+
skip_tools = True
956961
messages.append(response_message)
957962
for tool_call in tool_calls:
963+
if skip_tools:
964+
messages.append(
965+
dict(
966+
tool_call_id=tool_call.id,
967+
role="tool", name=function_name,
968+
content=f"ERROR: maximum number of tool calls ({max_recursive_calls}) exceeded!"))
969+
continue
958970
function_name = tool_call.function.name
959971
if debug:
960972
logger.debug(f"Tool call {function_name}")

0 commit comments

Comments
 (0)