@@ -110,7 +110,7 @@ def _run_input_encoder_local_i2i(self):
110110 input_image = [input_image ]
111111
112112 condition_images = []
113- for img in input_image :
113+ for index , img in enumerate ( input_image ) :
114114 image_processor .check_image_input (img )
115115 image_width , image_height = img .size
116116 if image_width * image_height > 1024 * 1024 :
@@ -122,9 +122,8 @@ def _run_input_encoder_local_i2i(self):
122122 image_height = (image_height // multiple_of ) * multiple_of
123123 img = image_processor .preprocess (img , height = image_height , width = image_width , resize_mode = "crop" )
124124 condition_images .append (img .to (AI_DEVICE ))
125- if not hasattr (self .input_info , "auto_width" ):
126- self .input_info .auto_width = image_width
127- self .input_info .auto_height = image_height
125+ if index == 0 :
126+ self .input_info .target_shape = (image_height , image_width )
128127
129128 torch .cuda .empty_cache ()
130129 gc .collect ()
@@ -254,26 +253,18 @@ def get_custom_shape(self):
254253 return (width , height )
255254
256255 def set_target_shape (self ):
257- multiple_of = self .config .get ("vae_scale_factor" , 8 ) * 2
258-
259256 task = self .config .get ("task" , "t2i" )
260- if task == "i2i" and hasattr (self .input_info , "auto_width" ):
261- width = self .input_info .auto_width
262- height = self .input_info .auto_height
263- else :
257+ if task == "i2i" : # for i2i task, the target shape is already set in _run_input_encoder_local_i2i
258+ height , width = self .input_info .target_shape
259+ else : # for t2i task, calculate the target shape based on the resolution
264260 custom_shape = self .get_custom_shape ()
265261 if custom_shape is not None :
266262 width , height = custom_shape
267263 else :
268264 calculated_width , calculated_height , _ = calculate_dimensions (self .resolution * self .resolution , 16 / 9 )
269265 width = calculated_width // multiple_of * multiple_of
270266 height = calculated_height // multiple_of * multiple_of
271-
272- self .input_info .auto_width = width
273- self .input_info .auto_height = height
274-
275- self .input_info .target_shape = (height , width )
276- logger .info (f"Flux2Klein Image Runner set target shape: { width } x{ height } " )
267+ self .input_info .target_shape = (height , width )
277268
278269 multiple_of = self .config .get ("vae_scale_factor" , 8 ) * 2
279270
0 commit comments