|
| 1 | +""" |
| 2 | +Run Command with a input text. |
| 3 | +""" |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import json |
| 7 | +import subprocess |
| 8 | +from typing import List |
| 9 | +import shlex |
| 10 | + |
| 11 | +import openai |
| 12 | + |
| 13 | +from devchat.utils import get_logger |
| 14 | +from . import Command |
| 15 | + |
| 16 | + |
| 17 | +logger = get_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +# Equivalent of CommandRun in Python\which executes subprocesses |
| 21 | +class CommandRunner: |
| 22 | + def __init__(self, model_name: str): |
| 23 | + self.process = None |
| 24 | + self._model_name = model_name |
| 25 | + |
| 26 | + def _call_function_by_llm(self, |
| 27 | + command_name: str, |
| 28 | + command: Command, |
| 29 | + history_messages: List[dict]): |
| 30 | + """ |
| 31 | + command needs multi parameters, so we need parse each |
| 32 | + parameter by LLM from input_text |
| 33 | + """ |
| 34 | + properties = {} |
| 35 | + required = [] |
| 36 | + for key, value in command.parameters.items(): |
| 37 | + properties[key] = {} |
| 38 | + for key1, value1 in value.dict().items(): |
| 39 | + if key1 not in ['type', 'description', 'enum'] or value1 is None: |
| 40 | + continue |
| 41 | + properties[key][key1] = value1 |
| 42 | + required.append(key) |
| 43 | + |
| 44 | + tools = [ |
| 45 | + { |
| 46 | + "type": "function", |
| 47 | + "function": { |
| 48 | + "name": command_name, |
| 49 | + "description": command.description, |
| 50 | + "parameters": { |
| 51 | + "type": "object", |
| 52 | + "properties": properties, |
| 53 | + "required": required, |
| 54 | + }, |
| 55 | + } |
| 56 | + } |
| 57 | + ] |
| 58 | + |
| 59 | + client = openai.OpenAI( |
| 60 | + api_key=os.environ.get("OPENAI_API_KEY", None), |
| 61 | + base_url=os.environ.get("OPENAI_API_BASE", None) |
| 62 | + ) |
| 63 | + |
| 64 | + connection_error = '' |
| 65 | + for _1 in range(3): |
| 66 | + try: |
| 67 | + response = client.chat.completions.create( |
| 68 | + messages=history_messages, |
| 69 | + model="gpt-3.5-turbo-16k", |
| 70 | + stream=False, |
| 71 | + tools=tools, |
| 72 | + tool_choice={"type": "function", "function": {"name": command_name}} |
| 73 | + ) |
| 74 | + |
| 75 | + respose_message = response.dict()["choices"][0]["message"] |
| 76 | + if not respose_message['tool_calls']: |
| 77 | + return None |
| 78 | + tool_call = respose_message['tool_calls'][0]['function'] |
| 79 | + if tool_call['name'] != command_name: |
| 80 | + return None |
| 81 | + parameters = json.loads(tool_call['arguments']) |
| 82 | + return parameters |
| 83 | + except (ConnectionError, openai.APIConnectionError) as err: |
| 84 | + connection_error = err |
| 85 | + continue |
| 86 | + except Exception as err: |
| 87 | + print("Exception:", err, file=sys.stderr, flush=True) |
| 88 | + logger.exception("Call command by LLM error: %s", err) |
| 89 | + return None |
| 90 | + print("Connect Error:", connection_error, file=sys.stderr, flush=True) |
| 91 | + return None |
| 92 | + |
| 93 | + |
| 94 | + def run_command(self, |
| 95 | + command_name: str, |
| 96 | + command: Command, |
| 97 | + history_messages: List[dict], |
| 98 | + input_text: str, |
| 99 | + parent_hash: str, |
| 100 | + context_contents: List[str]): |
| 101 | + """ |
| 102 | + if command has parameters, then generate command parameters from input by LLM |
| 103 | + if command.input is "required", and input is null, then return error |
| 104 | + """ |
| 105 | + if command.parameters and len(command.parameters) > 0: |
| 106 | + if not self._model_name.startswith("gpt-"): |
| 107 | + return None |
| 108 | + |
| 109 | + arguments = self._call_function_by_llm(command_name, command, history_messages) |
| 110 | + if not arguments: |
| 111 | + print("No valid parameters generated by LLM", file=sys.stderr, flush=True) |
| 112 | + return (-1, "") |
| 113 | + return self.run_command_with_parameters( |
| 114 | + command, |
| 115 | + { |
| 116 | + "input": input_text, |
| 117 | + **arguments |
| 118 | + }, |
| 119 | + parent_hash, |
| 120 | + context_contents) |
| 121 | + |
| 122 | + return self.run_command_with_parameters( |
| 123 | + command, |
| 124 | + { |
| 125 | + "input": input_text |
| 126 | + }, |
| 127 | + parent_hash, |
| 128 | + context_contents) |
| 129 | + |
| 130 | + |
| 131 | + def run_command_with_parameters(self, |
| 132 | + command: Command, |
| 133 | + parameters: dict[str, str], |
| 134 | + parent_hash: str, |
| 135 | + context_contents: List[str]): |
| 136 | + """ |
| 137 | + replace $xxx in command.steps[0].run with parameters[xxx] |
| 138 | + then run command.steps[0].run |
| 139 | + """ |
| 140 | + try: |
| 141 | + # add environment variables to parameters |
| 142 | + if parent_hash: |
| 143 | + os.environ['PARENT_HASH'] = parent_hash |
| 144 | + if context_contents: |
| 145 | + os.environ['CONTEXT_CONTENTS'] = json.dumps(context_contents) |
| 146 | + for env_var in os.environ: |
| 147 | + parameters[env_var] = os.environ[env_var] |
| 148 | + parameters["command_python"] = os.environ['command_python'] |
| 149 | + |
| 150 | + command_run = command.steps[0]["run"] |
| 151 | + # Replace parameters in command run |
| 152 | + for parameter in parameters: |
| 153 | + command_run = command_run.replace('$' + parameter, str(parameters[parameter])) |
| 154 | + |
| 155 | + # Run command_run |
| 156 | + env = os.environ.copy() |
| 157 | + if 'PYTHONPATH' in env: |
| 158 | + del env['PYTHONPATH'] |
| 159 | + # result = subprocess.run(command_run, shell=True, env=env) |
| 160 | + # return result |
| 161 | + process = subprocess.Popen( |
| 162 | + shlex.split(command_run), |
| 163 | + stdout=subprocess.PIPE, |
| 164 | + stderr=subprocess.STDOUT, |
| 165 | + text=True |
| 166 | + ) |
| 167 | + |
| 168 | + # 实时读取输出并打印 |
| 169 | + stdout = '' |
| 170 | + while True: |
| 171 | + output = process.stdout.readline() |
| 172 | + if output == '' and process.poll() is not None: |
| 173 | + break |
| 174 | + if output: |
| 175 | + stdout += output |
| 176 | + print(output, end='\n') |
| 177 | + rc = process.poll() |
| 178 | + return (rc, stdout) |
| 179 | + except Exception as err: |
| 180 | + print("Exception:", type(err), err, file=sys.stderr, flush=True) |
| 181 | + return (-1, "") |
0 commit comments