|
47 | 47 | "max_input_tokens", |
48 | 48 | "use_phoenix", |
49 | 49 | "via_streaming", |
| 50 | + "max_recursive_calls", |
50 | 51 | "min_delay", # minimum delay between queries for that model |
51 | 52 | ] |
52 | 53 |
|
@@ -649,7 +650,8 @@ def query( |
649 | 650 | litellm_debug=None, |
650 | 651 | stream=False, |
651 | 652 | via_streaming=False, |
652 | | - recursive_call_info: Optional[Dict[str, any]] = None, |
| 653 | + max_recursive_calls=99, |
| 654 | + recursive_call_info: Optional[Dict[str, any]] = None, |
653 | 655 | **kwargs, |
654 | 656 | ) -> Dict[str, any]: |
655 | 657 | """ |
@@ -715,7 +717,7 @@ def cleaned_args(args: dict): |
715 | 717 | logger.debug(f"Options: via_streaming: {via_streaming}, stream: {stream}") |
716 | 718 | logger.debug(f"Initial completion kwargs: {cleaned_args(completion_kwargs)}") |
717 | 719 | if recursive_call_info is None: |
718 | | - recursive_call_info = {} |
| 720 | + recursive_call_info = dict(n_calls=0) |
719 | 721 | if llm.get("api_key"): |
720 | 722 | completion_kwargs["api_key"] = llm["api_key"] |
721 | 723 | elif llm.get("api_key_env"): |
@@ -953,8 +955,18 @@ def chunk_generator(model_generator, retobj): |
953 | 955 | if len(tool_calls) > 0: # not an empty list |
954 | 956 | if debug: |
955 | 957 | logger.debug(f"Appending response message: {response_message}") |
| 958 | + skip_tools = False |
| 959 | + if recursive_call_info["n_calls"] > max_recursive_calls: |
| 960 | + skip_tools = True |
956 | 961 | messages.append(response_message) |
957 | 962 | for tool_call in tool_calls: |
| 963 | + if skip_tools: |
| 964 | + messages.append( |
| 965 | + dict( |
| 966 | + tool_call_id=tool_call.id, |
| 967 | + role="tool", name=function_name, |
| 968 | + content=f"ERROR: maximum number of tool calls ({max_recursive_calls}) exceeded!")) |
| 969 | + continue |
958 | 970 | function_name = tool_call.function.name |
959 | 971 | if debug: |
960 | 972 | logger.debug(f"Tool call {function_name}") |
|
0 commit comments