|
26 | 26 | from opentelemetry import trace |
27 | 27 | from opentelemetry.trace import Span |
28 | 28 | from volcenginesdkarkruntime import Ark |
29 | | -from volcenginesdkarkruntime.types.images.images import SequentialImageGenerationOptions |
| 29 | +from volcenginesdkarkruntime.types.images.images import ( |
| 30 | + SequentialImageGenerationOptions, |
| 31 | +) |
30 | 32 |
|
31 | 33 | from veadk.config import getenv, settings |
32 | 34 | from veadk.consts import ( |
|
41 | 43 |
|
42 | 44 | client = Ark( |
43 | 45 | api_key=getenv( |
44 | | - "MODEL_IMAGE_API_KEY", getenv("MODEL_AGENT_API_KEY", settings.model.api_key) |
| 46 | + "MODEL_IMAGE_API_KEY", |
| 47 | + getenv("MODEL_AGENT_API_KEY", settings.model.api_key), |
45 | 48 | ), |
46 | 49 | base_url=getenv("MODEL_IMAGE_API_BASE", DEFAULT_IMAGE_GENERATE_MODEL_API_BASE), |
47 | 50 | ) |
@@ -119,11 +122,24 @@ def handle_single_task_sync( |
119 | 122 | and sequential_image_generation == "auto" |
120 | 123 | and max_images |
121 | 124 | ): |
122 | | - response = client.images.generate( |
123 | | - model=getenv("MODEL_IMAGE_NAME", DEFAULT_IMAGE_GENERATE_MODEL_NAME), |
124 | | - **inputs, |
125 | | - sequential_image_generation_options=SequentialImageGenerationOptions( |
126 | | - max_images=max_images |
| 125 | + response = ( |
| 126 | + client.images.generate( |
| 127 | + model=getenv( |
| 128 | + "MODEL_IMAGE_NAME", |
| 129 | + DEFAULT_IMAGE_GENERATE_MODEL_NAME, |
| 130 | + ), |
| 131 | + **inputs, |
| 132 | + sequential_image_generation_options=SequentialImageGenerationOptions( |
| 133 | + max_images=max_images |
| 134 | + ), |
| 135 | + extra_headers={ |
| 136 | + "veadk-source": "veadk", |
| 137 | + "veadk-version": VERSION, |
| 138 | + "User-Agent": f"VeADK/{VERSION}", |
| 139 | + "X-Client-Request-Id": getenv( |
| 140 | + "MODEL_AGENT_CLIENT_REQ_ID", f"veadk/{VERSION}" |
| 141 | + ), |
| 142 | + }, |
127 | 143 | ), |
128 | 144 | ) |
129 | 145 | else: |
@@ -157,7 +173,8 @@ def handle_single_task_sync( |
157 | 173 | continue |
158 | 174 | image_bytes = base64.b64decode(b64) |
159 | 175 | image_url = _upload_image_to_tos( |
160 | | - image_bytes=image_bytes, object_key=f"{image_name}.png" |
| 176 | + image_bytes=image_bytes, |
| 177 | + object_key=f"{image_name}.png", |
161 | 178 | ) |
162 | 179 | if not image_url: |
163 | 180 | logger.error(f"Upload image to TOS failed: {image_name}") |
@@ -367,7 +384,11 @@ def make_task(idx, item): |
367 | 384 | logger.debug( |
368 | 385 | f"image_generate success_list: {success_list}\nerror_list: {error_list}" |
369 | 386 | ) |
370 | | - return {"status": "success", "success_list": success_list, "error_list": error_list} |
| 387 | + return { |
| 388 | + "status": "success", |
| 389 | + "success_list": success_list, |
| 390 | + "error_list": error_list, |
| 391 | + } |
371 | 392 |
|
372 | 393 |
|
373 | 394 | def add_span_attributes( |
|
0 commit comments