Skip to content

Commit a520cb5

Browse files
author
Antigravity
committed
[fix]:flux2 i2i auto target shape
1 parent 5230943 commit a520cb5

1 file changed

Lines changed: 7 additions & 16 deletions

File tree

lightx2v/models/runners/flux2_klein/flux2_klein_runner.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _run_input_encoder_local_i2i(self):
110110
input_image = [input_image]
111111

112112
condition_images = []
113-
for img in input_image:
113+
for index, img in enumerate(input_image):
114114
image_processor.check_image_input(img)
115115
image_width, image_height = img.size
116116
if image_width * image_height > 1024 * 1024:
@@ -122,9 +122,8 @@ def _run_input_encoder_local_i2i(self):
122122
image_height = (image_height // multiple_of) * multiple_of
123123
img = image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
124124
condition_images.append(img.to(AI_DEVICE))
125-
if not hasattr(self.input_info, "auto_width"):
126-
self.input_info.auto_width = image_width
127-
self.input_info.auto_height = image_height
125+
if index == 0:
126+
self.input_info.target_shape = (image_height, image_width)
128127

129128
torch.cuda.empty_cache()
130129
gc.collect()
@@ -254,26 +253,18 @@ def get_custom_shape(self):
254253
return (width, height)
255254

256255
def set_target_shape(self):
257-
multiple_of = self.config.get("vae_scale_factor", 8) * 2
258-
259256
task = self.config.get("task", "t2i")
260-
if task == "i2i" and hasattr(self.input_info, "auto_width"):
261-
width = self.input_info.auto_width
262-
height = self.input_info.auto_height
263-
else:
257+
if task == "i2i": # for i2i task, the target shape is already set in _run_input_encoder_local_i2i
258+
height, width = self.input_info.target_shape
259+
else: # for t2i task, calculate the target shape based on the resolution
264260
custom_shape = self.get_custom_shape()
265261
if custom_shape is not None:
266262
width, height = custom_shape
267263
else:
268264
calculated_width, calculated_height, _ = calculate_dimensions(self.resolution * self.resolution, 16 / 9)
269265
width = calculated_width // multiple_of * multiple_of
270266
height = calculated_height // multiple_of * multiple_of
271-
272-
self.input_info.auto_width = width
273-
self.input_info.auto_height = height
274-
275-
self.input_info.target_shape = (height, width)
276-
logger.info(f"Flux2Klein Image Runner set target shape: {width}x{height}")
267+
self.input_info.target_shape = (height, width)
277268

278269
multiple_of = self.config.get("vae_scale_factor", 8) * 2
279270

0 commit comments

Comments
 (0)