diff --git a/packages/sunagent-app/pyproject.toml b/packages/sunagent-app/pyproject.toml index 44c6d24..dc47f14 100644 --- a/packages/sunagent-app/pyproject.toml +++ b/packages/sunagent-app/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sunagent-app" -version = "0.0.27b8" +version = "0.0.27b9" license = { file = "LICENSE-CODE" } description = "sunagent app" readme = "README.md" diff --git a/packages/sunagent-app/src/sunagent_app/agents/__init__.py b/packages/sunagent-app/src/sunagent_app/agents/__init__.py index 05ef906..0887321 100644 --- a/packages/sunagent-app/src/sunagent_app/agents/__init__.py +++ b/packages/sunagent-app/src/sunagent_app/agents/__init__.py @@ -1,10 +1,10 @@ from ._context_builder_agent import ContextBuilderAgent, MentionStream from ._image_generate_agent import ImageGenerateAgent from ._steemit_context_builder_agent import SteemContextBuilder +from ._storm_agent import StormAgent, StormConfig from ._token_launch_agent import TokenLaunchAgent from ._tweet_analysis_agent import TweetAnalysisAgent from ._tweet_check_agent import TweetCheckAgent -from ._storm_agent import StormAgent, StormConfig __all__ = [ "ContextBuilderAgent", diff --git a/packages/sunagent-app/src/sunagent_app/agents/_image_generate_agent.py b/packages/sunagent-app/src/sunagent_app/agents/_image_generate_agent.py index 20fe247..5686543 100644 --- a/packages/sunagent-app/src/sunagent_app/agents/_image_generate_agent.py +++ b/packages/sunagent-app/src/sunagent_app/agents/_image_generate_agent.py @@ -1,13 +1,18 @@ +import base64 import json import logging import random import traceback from io import BytesIO from typing import ( + Any, Dict, List, + Literal, Optional, + Protocol, Sequence, + cast, ) from autogen_agentchat.agents import BaseChatAgent @@ -22,7 +27,7 @@ UserMessage, ) from autogen_ext.models.openai import AzureOpenAIChatCompletionClient -from google.genai import Client, types +from google.genai import types from PIL import Image as PILImage from sunagent_app._constants import LOGGER_NAME @@ -32,6 +37,34 @@ logger = logging.getLogger(LOGGER_NAME) +# Allowed sizes for Azure OpenAI images.generate +_OpenAIImageSize = Literal[ + "auto", + "1024x1024", + "1536x1024", + "1024x1536", + "256x256", + "512x512", + "1792x1024", + "1024x1792", +] + + +class _OpenAIImagesClient(Protocol): + def generate(self, model: str, prompt: str, *, size: _OpenAIImageSize | None = None) -> Any: ... + + +class _OpenAIClient(Protocol): + images: _OpenAIImagesClient + + +class _GoogleModelsClient(Protocol): + def generate_images(self, *, model: str, prompt: str, config: Any) -> Any: ... + + +class _GoogleClient(Protocol): + models: _GoogleModelsClient + class ImageGenerateAgent(BaseChatAgent): """An agent that generate an image based on the description in the tweet. @@ -42,7 +75,7 @@ def __init__( self, name: str, text_model_client: AzureOpenAIChatCompletionClient, - image_model_client: Client, + image_model_client: _GoogleClient | _OpenAIClient, *, description: str = """ An agent that extract image attachment of given tweet, or generate an image according to the image description in the tweet. @@ -52,6 +85,7 @@ def __init__( system_message: str, image_styles: List[str], image_model_name: str = "imagen-3.0-generate-002", + image_provider: Literal["google", "openai"] = "google", image_path: str = "generated_image.png", width: int = 400, height: int = 400, @@ -60,6 +94,7 @@ def __init__( self.system_message = system_message self.image_styles = image_styles self._image_model_name = image_model_name + self._image_provider = image_provider self._image_path = image_path self.image_model_client = image_model_client self.text_model_client = text_model_client @@ -138,21 +173,58 @@ async def _generate_image(self, image_prompt: str) -> Optional[PILImage.Image]: """Generate an image based on the image prompt.""" try: logger.info(f"Generating image with prompt: {image_prompt}") - response = self.image_model_client.models.generate_images( - model=self._image_model_name, - prompt=image_prompt, - config=types.GenerateImagesConfig(number_of_images=1), - ) - model_api_success_count.inc() - if ( - response.generated_images is not None - and len(response.generated_images) > 0 - and response.generated_images[0].image is not None - and response.generated_images[0].image.image_bytes is not None - ): + if self._image_provider == "openai": + image_client = cast(_OpenAIClient, self.image_model_client) + size_str = f"{self.width}x{self.height}" + allowed_sizes: set[_OpenAIImageSize] = { + "auto", + "1024x1024", + "1536x1024", + "1024x1536", + "256x256", + "512x512", + "1792x1024", + "1024x1792", + } + size_literal: _OpenAIImageSize = ( + cast(_OpenAIImageSize, size_str) if size_str in allowed_sizes else "1024x1024" + ) + response = image_client.images.generate( + model=self._image_model_name, + prompt=image_prompt, + size=size_literal, + ) + data = getattr(response, "data", None) + if not data or not isinstance(data, list): + logger.error("OpenAI image response missing data") + return None + first = data[0] + b64_json = getattr(first, "b64_json", None) + if not isinstance(b64_json, str): + logger.error("OpenAI image response missing b64_json") + return None + raw_image = base64.b64decode(b64_json) + image = PILImage.open(BytesIO(raw_image), formats=["PNG", "JPEG"]) + else: + google_client = cast(_GoogleClient, self.image_model_client) + response = google_client.models.generate_images( + model=self._image_model_name, + prompt=image_prompt, + config=types.GenerateImagesConfig(number_of_images=1), + ) + if ( + response.generated_images is None + or len(response.generated_images) == 0 + or response.generated_images[0].image is None + or response.generated_images[0].image.image_bytes is None + ): + logger.error("Failed to generate image") + return None raw_image = response.generated_images[0].image.image_bytes image = PILImage.open(BytesIO(raw_image), formats=["PNG"]) - return image.resize((self.width, self.height)) + + model_api_success_count.inc() + return image.resize((self.width, self.height)) except Exception as e: logger.error(f"Error generating image: {e}") model_api_failure_count.inc() diff --git a/packages/sunagent-app/src/sunagent_app/agents/_storm_agent.py b/packages/sunagent-app/src/sunagent_app/agents/_storm_agent.py index 139a1b8..8986c82 100644 --- a/packages/sunagent-app/src/sunagent_app/agents/_storm_agent.py +++ b/packages/sunagent-app/src/sunagent_app/agents/_storm_agent.py @@ -1,8 +1,8 @@ import asyncio import logging -import traceback import os import tempfile +import traceback from typing import Any, Optional, Sequence, Union from autogen_agentchat.agents import BaseChatAgent @@ -169,7 +169,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: def _create_error_response(self) -> Response: return Response( chat_message=TextMessage( - content=f"system internal error, EARLY_TERMINATE", + content="system internal error, EARLY_TERMINATE", source=self.name, ) )