Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/sunagent-app/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "sunagent-app"
version = "0.0.27b8"
version = "0.0.27b9"
license = { file = "LICENSE-CODE" }
description = "sunagent app"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion packages/sunagent-app/src/sunagent_app/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ._context_builder_agent import ContextBuilderAgent, MentionStream
from ._image_generate_agent import ImageGenerateAgent
from ._steemit_context_builder_agent import SteemContextBuilder
from ._storm_agent import StormAgent, StormConfig
from ._token_launch_agent import TokenLaunchAgent
from ._tweet_analysis_agent import TweetAnalysisAgent
from ._tweet_check_agent import TweetCheckAgent
from ._storm_agent import StormAgent, StormConfig

__all__ = [
"ContextBuilderAgent",
Expand Down
102 changes: 87 additions & 15 deletions packages/sunagent-app/src/sunagent_app/agents/_image_generate_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import base64
import json
import logging
import random
import traceback
from io import BytesIO
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
cast,
)

from autogen_agentchat.agents import BaseChatAgent
Expand All @@ -22,7 +27,7 @@
UserMessage,
)
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
from google.genai import Client, types
from google.genai import types
from PIL import Image as PILImage

from sunagent_app._constants import LOGGER_NAME
Expand All @@ -32,6 +37,34 @@

logger = logging.getLogger(LOGGER_NAME)

# Allowed sizes for Azure OpenAI images.generate
_OpenAIImageSize = Literal[
"auto",
"1024x1024",
"1536x1024",
"1024x1536",
"256x256",
"512x512",
"1792x1024",
"1024x1792",
]


class _OpenAIImagesClient(Protocol):
def generate(self, model: str, prompt: str, *, size: _OpenAIImageSize | None = None) -> Any: ...


class _OpenAIClient(Protocol):
images: _OpenAIImagesClient


class _GoogleModelsClient(Protocol):
def generate_images(self, *, model: str, prompt: str, config: Any) -> Any: ...


class _GoogleClient(Protocol):
models: _GoogleModelsClient


class ImageGenerateAgent(BaseChatAgent):
"""An agent that generate an image based on the description in the tweet.
Expand All @@ -42,7 +75,7 @@ def __init__(
self,
name: str,
text_model_client: AzureOpenAIChatCompletionClient,
image_model_client: Client,
image_model_client: _GoogleClient | _OpenAIClient,
*,
description: str = """
An agent that extract image attachment of given tweet, or generate an image according to the image description in the tweet.
Expand All @@ -52,6 +85,7 @@ def __init__(
system_message: str,
image_styles: List[str],
image_model_name: str = "imagen-3.0-generate-002",
image_provider: Literal["google", "openai"] = "google",
image_path: str = "generated_image.png",
width: int = 400,
height: int = 400,
Expand All @@ -60,6 +94,7 @@ def __init__(
self.system_message = system_message
self.image_styles = image_styles
self._image_model_name = image_model_name
self._image_provider = image_provider
self._image_path = image_path
self.image_model_client = image_model_client
self.text_model_client = text_model_client
Expand Down Expand Up @@ -138,21 +173,58 @@ async def _generate_image(self, image_prompt: str) -> Optional[PILImage.Image]:
"""Generate an image based on the image prompt."""
try:
logger.info(f"Generating image with prompt: {image_prompt}")
response = self.image_model_client.models.generate_images(
model=self._image_model_name,
prompt=image_prompt,
config=types.GenerateImagesConfig(number_of_images=1),
)
model_api_success_count.inc()
if (
response.generated_images is not None
and len(response.generated_images) > 0
and response.generated_images[0].image is not None
and response.generated_images[0].image.image_bytes is not None
):
if self._image_provider == "openai":
image_client = cast(_OpenAIClient, self.image_model_client)
size_str = f"{self.width}x{self.height}"
allowed_sizes: set[_OpenAIImageSize] = {
"auto",
"1024x1024",
"1536x1024",
"1024x1536",
"256x256",
"512x512",
"1792x1024",
"1024x1792",
}
size_literal: _OpenAIImageSize = (
cast(_OpenAIImageSize, size_str) if size_str in allowed_sizes else "1024x1024"
)
response = image_client.images.generate(
model=self._image_model_name,
prompt=image_prompt,
size=size_literal,
)
data = getattr(response, "data", None)
if not data or not isinstance(data, list):
logger.error("OpenAI image response missing data")
return None
first = data[0]
b64_json = getattr(first, "b64_json", None)
if not isinstance(b64_json, str):
logger.error("OpenAI image response missing b64_json")
return None
raw_image = base64.b64decode(b64_json)
image = PILImage.open(BytesIO(raw_image), formats=["PNG", "JPEG"])
else:
google_client = cast(_GoogleClient, self.image_model_client)
response = google_client.models.generate_images(
model=self._image_model_name,
prompt=image_prompt,
config=types.GenerateImagesConfig(number_of_images=1),
)
if (
response.generated_images is None
or len(response.generated_images) == 0
or response.generated_images[0].image is None
or response.generated_images[0].image.image_bytes is None
):
logger.error("Failed to generate image")
return None
raw_image = response.generated_images[0].image.image_bytes
image = PILImage.open(BytesIO(raw_image), formats=["PNG"])
return image.resize((self.width, self.height))

model_api_success_count.inc()
return image.resize((self.width, self.height))
except Exception as e:
logger.error(f"Error generating image: {e}")
model_api_failure_count.inc()
Expand Down
4 changes: 2 additions & 2 deletions packages/sunagent-app/src/sunagent_app/agents/_storm_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
import traceback
import os
import tempfile
import traceback
from typing import Any, Optional, Sequence, Union

from autogen_agentchat.agents import BaseChatAgent
Expand Down Expand Up @@ -169,7 +169,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
def _create_error_response(self) -> Response:
return Response(
chat_message=TextMessage(
content=f"system internal error, EARLY_TERMINATE",
content="system internal error, EARLY_TERMINATE",
source=self.name,
)
)
Expand Down
Loading