@@ -391,7 +391,7 @@ def read_image_input(self, img_path):
391391 latent_h = patched_h * self .config ["patch_size" ][1 ]
392392 latent_w = patched_w * self .config ["patch_size" ][2 ]
393393
394- if hasattr (self .input_info , "target_video_length" ):
394+ if hasattr (self .input_info , "target_video_length" ) and self . input_info . target_video_length is not None :
395395 target_video_length = self .input_info .target_video_length
396396 latent_shape = self .get_latent_shape_with_lat_hw (latent_h , latent_w , target_video_length )
397397 else :
@@ -510,7 +510,7 @@ def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_le
510510 """Prepare previous latents for conditioning"""
511511 dtype = GET_DTYPE ()
512512 tgt_h , tgt_w = self .input_info .target_shape [0 ], self .input_info .target_shape [1 ]
513- if hasattr (self .input_info , "target_video_length" ):
513+ if hasattr (self .input_info , "target_video_length" ) and self . input_info . target_video_length is not None :
514514 target_video_length = self .input_info .target_video_length
515515 else :
516516 target_video_length = self .config ["target_video_length" ]
@@ -836,7 +836,6 @@ def load_model(self):
836836
837837 def get_latent_shape_with_lat_hw (self , latent_h , latent_w , target_video_length = None ):
838838 target_video_length = target_video_length if target_video_length is not None else self .config ["target_video_length" ]
839- self .input_info .latent_freams = latent_h
840839 latent_shape = [
841840 self .config .get ("num_channels_latents" , 16 ),
842841 (target_video_length - 1 ) // self .config ["vae_stride" ][0 ] + 1 ,
0 commit comments