|
| 1 | +from typing import Iterable |
| 2 | +import os |
| 3 | +import io |
| 4 | +import time |
| 5 | +from datetime import datetime |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | + |
| 9 | +from AgentSettings import AgentSettings |
| 10 | + |
| 11 | +from openai.types.beta.threads.message_content_image_file import MessageContentImageFile |
| 12 | +from openai.types.beta.threads.message_content_text import MessageContentText |
| 13 | +from openai.types.beta.threads.messages import MessageFile |
| 14 | +from openai.types import FileObject |
| 15 | +from PIL import Image |
| 16 | +from ArgumentException import ArgumentExceptionError |
| 17 | + |
| 18 | + |
| 19 | +class AssistantAgent: |
| 20 | + def __init__(self, settings, client, name, instructions, data_folder, tools_list, keep_state: bool = False, fn_calling_delegate=None): |
| 21 | + if name is None: |
| 22 | + raise ArgumentExceptionError("name parameter missing") |
| 23 | + if instructions is None: |
| 24 | + raise ArgumentExceptionError("instructions parameter missing") |
| 25 | + if tools_list is None: |
| 26 | + raise ArgumentExceptionError("tools_list parameter missing") |
| 27 | + |
| 28 | + self.assistant = None |
| 29 | + self.settings = settings |
| 30 | + self.client = client |
| 31 | + self.name = name |
| 32 | + self.instructions = instructions |
| 33 | + self.data_folder = data_folder |
| 34 | + self.tools_list = tools_list |
| 35 | + self.fn_calling_delegate = fn_calling_delegate |
| 36 | + self.keep_state = keep_state |
| 37 | + self.ai_threads = [] |
| 38 | + self.ai_files = [] |
| 39 | + self.file_ids = [] |
| 40 | + self.get_agent() |
| 41 | + |
| 42 | + def upload_file(self, path: str) -> FileObject: |
| 43 | + print(path) |
| 44 | + with Path(path).open("rb") as f: |
| 45 | + return self.client.files.create(file=f, purpose="assistants") |
| 46 | + |
| 47 | + def upload_all_files(self): |
| 48 | + files_in_folder = os.listdir(self.data_folder) |
| 49 | + local_file_list = [] |
| 50 | + for file in files_in_folder: |
| 51 | + filePath = self.data_folder + file |
| 52 | + assistant_file = self.upload_file(filePath) |
| 53 | + self.ai_files.append(assistant_file) |
| 54 | + local_file_list.append(assistant_file) |
| 55 | + self.file_ids = [file.id for file in local_file_list] |
| 56 | + |
| 57 | + def get_agent(self): |
| 58 | + if self.data_folder is not None: |
| 59 | + self.upload_all_files() |
| 60 | + self.assistant = self.client.beta.assistants.create( |
| 61 | + name=self.name, # "Sales Assistant", |
| 62 | + # "You are a sales assistant. You can answer questions related to customer orders.", |
| 63 | + instructions=self.instructions, |
| 64 | + tools=self.tools_list, |
| 65 | + model=self.settings.model_deployment, |
| 66 | + file_ids=self.file_ids |
| 67 | + ) |
| 68 | + else: |
| 69 | + self.assistant = self.client.beta.assistants.create( |
| 70 | + name=self.name, # "Sales Assistant", |
| 71 | + # "You are a sales assistant. You can answer questions related to customer orders.", |
| 72 | + instructions=self.instructions, |
| 73 | + tools=self.tools_list, |
| 74 | + model=self.settings.model_deployment |
| 75 | + ) |
| 76 | + |
| 77 | + def process_prompt(self, user_name: str, user_id: str, prompt: str) -> None: |
| 78 | + |
| 79 | + # if keep_state: |
| 80 | + # thread_id = check_if_thread_exists(user_id) |
| 81 | + |
| 82 | + # # If a thread doesn't exist, create one and store it |
| 83 | + # if thread_id is None: |
| 84 | + # print(f"Creating new thread for {name} with user_id {user_id}") |
| 85 | + # thread = self.client.beta.threads.create() |
| 86 | + # store_thread(user_id, thread) |
| 87 | + # thread_id = thread.id |
| 88 | + # # Otherwise, retrieve the existing thread |
| 89 | + # else: |
| 90 | + # print( |
| 91 | + # f"Retrieving existing thread for {name} with user_id {user_id}") |
| 92 | + # thread = self.client.beta.threads.retrieve(thread_id) |
| 93 | + # add_thread(thread) |
| 94 | + # else: |
| 95 | + thread = self.client.beta.threads.create() |
| 96 | + |
| 97 | + self.client.beta.threads.messages.create( |
| 98 | + thread_id=thread.id, role="user", content=prompt) |
| 99 | + |
| 100 | + run = self.client.beta.threads.runs.create( |
| 101 | + thread_id=thread.id, |
| 102 | + assistant_id=self.assistant.id, |
| 103 | + instructions="Please address the user as Jane Doe. The user has a premium account. Be assertive, accurate, and polite. Ask if the user has further questions. Do not provide explanations for the answers." |
| 104 | + + "The current date and time is: " |
| 105 | + + datetime.now().strftime("%x %X") |
| 106 | + + ". ", |
| 107 | + ) |
| 108 | + |
| 109 | + print("processing ...") |
| 110 | + while True: |
| 111 | + run = self.client.beta.threads.runs.retrieve( |
| 112 | + thread_id=thread.id, run_id=run.id) |
| 113 | + if run.status == "completed": |
| 114 | + # Handle completed |
| 115 | + messages = self.client.beta.threads.messages.list( |
| 116 | + thread_id=thread.id) |
| 117 | + self.print_messages(user_name, messages) |
| 118 | + break |
| 119 | + if run.status == "failed": |
| 120 | + messages = self.client.beta.threads.messages.list( |
| 121 | + thread_id=thread.id) |
| 122 | + self.print_messages(user_name, messages) |
| 123 | + # Handle failed |
| 124 | + break |
| 125 | + if run.status == "expired": |
| 126 | + # Handle expired |
| 127 | + break |
| 128 | + if run.status == "cancelled": |
| 129 | + # Handle cancelled |
| 130 | + break |
| 131 | + if run.status == "requires_action": |
| 132 | + if self.fn_calling_delegate: |
| 133 | + self.fn_calling_delegate(self.client, thread, run) |
| 134 | + else: |
| 135 | + time.sleep(5) |
| 136 | + if not self.keep_state: |
| 137 | + self.client.beta.threads.delete(thread.id) |
| 138 | + print("Deleted thread: ", thread.id) |
| 139 | + |
| 140 | + def read_assistant_file(self, file_id: str): |
| 141 | + response_content = self.client.files.content(file_id) |
| 142 | + return response_content.read() |
| 143 | + |
| 144 | + def print_messages(self, name: str, messages: Iterable[MessageFile]) -> None: |
| 145 | + message_list = [] |
| 146 | + |
| 147 | + # Get all the messages till the last user message |
| 148 | + for message in messages: |
| 149 | + message_list.append(message) |
| 150 | + if message.role == "user": |
| 151 | + break |
| 152 | + |
| 153 | + # Reverse the messages to show the last user message first |
| 154 | + message_list.reverse() |
| 155 | + |
| 156 | + # Print the user or Assistant messages or images |
| 157 | + for message in message_list: |
| 158 | + for item in message.content: |
| 159 | + # Determine the content type |
| 160 | + if isinstance(item, MessageContentText): |
| 161 | + if message.role == "user": |
| 162 | + print(f"user: {name}:\n{item.text.value}\n") |
| 163 | + else: |
| 164 | + print(f"{message.role}:\n{item.text.value}\n") |
| 165 | + file_annotations = item.text.annotations |
| 166 | + if file_annotations: |
| 167 | + for annotation in file_annotations: |
| 168 | + file_id = annotation.file_path.file_id |
| 169 | + content = self.read_assistant_file(file_id) |
| 170 | + print(f"Annotation Content:\n{str(content)}\n") |
| 171 | + elif isinstance(item, MessageContentImageFile): |
| 172 | + # Retrieve image from file id |
| 173 | + data_in_bytes = self.read_assistant_file( |
| 174 | + item.image_file.file_id) |
| 175 | + # Convert bytes to image |
| 176 | + readable_buffer = io.BytesIO(data_in_bytes) |
| 177 | + image = Image.open(readable_buffer) |
| 178 | + # Resize image to fit in terminal |
| 179 | + width, height = image.size |
| 180 | + image = image.resize( |
| 181 | + (width // 2, height // 2), Image.LANCZOS) |
| 182 | + # Display image |
| 183 | + image.show() |
| 184 | + |
| 185 | + def cleanup(self): |
| 186 | + print(self.client.beta.assistants.delete(self.assistant.id)) |
| 187 | + print("Deleting: ", len(self.ai_threads), " threads.") |
| 188 | + for thread in self.ai_threads: |
| 189 | + print(self.client.beta.threads.delete(thread.id)) |
| 190 | + print("Deleting: ", len(self.ai_files), " files.") |
| 191 | + for file in self.ai_files: |
| 192 | + print(self.client.files.delete(file.id)) |
0 commit comments