@@ -37,6 +37,11 @@ def calculate_shift(
3737 return mu
3838
3939
40+ def time_shift_linear (mu : float , t : torch .Tensor ) -> torch .Tensor :
41+ """Linear time shift: mu / (mu + (1/t - 1)), matching zoe-diffusion's implementation."""
42+ return mu / (mu + (1.0 / t - 1.0 ))
43+
44+
4045# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
4146def retrieve_timesteps (
4247 scheduler ,
@@ -428,7 +433,7 @@ def __init__(self, config):
428433 with open (os .path .join (config ["model_path" ], "scheduler" , "scheduler_config.json" ), "r" ) as f :
429434 self .scheduler_config = json .load (f )
430435 self .dtype = torch .bfloat16
431- self .sample_guide_scale = self .config [ "sample_guide_scale" ]
436+ self .sample_guide_scale = self .config . get ( "sample_guide_scale" , None )
432437 self .zero_cond_t = config .get ("zero_cond_t" , False )
433438 if self .config ["seq_parallel" ]:
434439 self .seq_p_group = self .config .get ("device_mesh" ).get_group (mesh_dim = "seq_p" )
@@ -480,43 +485,84 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
480485
481486 return latent_image_ids .to (device = device , dtype = dtype )
482487
488+ def _prepare_latents_lightx2v (self , shape , height , width , num_channels_latents ):
489+ """Original LightX2V latent generation: noise in [B, T, C, H, W] then pack."""
490+ latents = randn_tensor (shape , generator = self .generator , device = AI_DEVICE , dtype = self .dtype )
491+ if self .is_layered :
492+ latents = self ._pack_latents (latents , 1 , num_channels_latents , height , width , self .layers + 1 )
493+ else :
494+ latents = self ._pack_latents (latents , 1 , num_channels_latents , height , width )
495+ return latents
496+
497+ def _prepare_latents_zoe (self , shape , height , width , num_channels_latents ):
498+ """Zoe-aligned latent generation: noise in packed format [B, C*4, T, H//2, W//2].
499+ Ensures the same random sampling order as Zoe for bit-exact alignment.
500+ """
501+ b , t = shape [0 ], shape [1 ]
502+ zoe_shape = (b , num_channels_latents * 4 , t , height // 2 , width // 2 )
503+ latents = randn_tensor (zoe_shape , generator = self .generator , device = AI_DEVICE , dtype = self .dtype )
504+ # Convert to LightX2V sequence format: [B, (H//2)*(W//2), C*4]
505+ latents = latents .squeeze (2 ) # [B, C*4, H//2, W//2]
506+ latents = latents .permute (0 , 2 , 3 , 1 ) # [B, H//2, W//2, C*4]
507+ latents = latents .reshape (b , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
508+ return latents
509+
483510 def prepare_latents (self , input_info ):
484511 self .input_info = input_info
485512 shape = input_info .target_shape
513+ # shape: [B, T, C, H, W]
486514 width , height = shape [- 1 ], shape [- 2 ]
487- latents = randn_tensor (shape , generator = self .generator , device = AI_DEVICE , dtype = self .dtype )
488- if self .is_layered :
489- latents = self ._pack_latents (latents , 1 , self .config .get ("num_channels_latents" , 16 ), height , width , self .layers + 1 )
515+ num_channels_latents = self .config .get ("num_channels_latents" , 16 )
516+
517+ if self .config .get ("zoe_style_noise" , False ) and not self .is_layered :
518+ latents = self ._prepare_latents_zoe (shape , height , width , num_channels_latents )
490519 else :
491- latents = self ._pack_latents (latents , 1 , self .config .get ("num_channels_latents" , 16 ), height , width )
520+ latents = self ._prepare_latents_lightx2v (shape , height , width , num_channels_latents )
521+
492522 latent_image_ids = self ._prepare_latent_image_ids (1 , height // 2 , width // 2 , AI_DEVICE , self .dtype )
493523 self .latents = latents
494524 self .latent_image_ids = latent_image_ids
495525 self .noise_pred = None
496526
497527 def set_timesteps (self ):
498- sigmas = np .linspace (1.0 , 1 / self .config ["infer_steps" ], self .config ["infer_steps" ])
499- image_seq_len = self .latents .shape [1 ]
500- if self .is_layered :
501- base_seqlen = 256 * 256 / 16 / 16
502- image_seq_len = self .latents .shape [1 ] // 5
503- mu = (image_seq_len / base_seqlen ) ** 0.5
528+ num_inference_steps = self .config ["infer_steps" ]
529+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
530+
531+ sample_shift = self .config .get ("sample_shift" , None )
532+ if sample_shift is not None :
533+ # Zoe-style: linear time shift with a fixed mu, resolution-independent.
534+ # Formula: t_shifted = mu / (mu + (1/t - 1))
535+ sigmas_tensor = torch .from_numpy (sigmas ).float ().to (AI_DEVICE )
536+ sigmas_shifted = time_shift_linear (mu = sample_shift , t = sigmas_tensor )
537+ sigmas_shifted = torch .cat ([sigmas_shifted , torch .zeros (1 , device = AI_DEVICE )])
538+ self .scheduler .sigmas = sigmas_shifted .to (dtype = torch .float32 , device = AI_DEVICE )
539+ self .scheduler .timesteps = sigmas_shifted [:- 1 ] * self .scheduler_config ["num_train_timesteps" ]
540+ self .scheduler .timesteps = self .scheduler .timesteps .to (AI_DEVICE )
541+ self .scheduler ._step_index = None
542+ self .scheduler ._begin_index = None
543+ timesteps = self .scheduler .timesteps
504544 else :
505- mu = calculate_shift (
506- image_seq_len ,
507- self .scheduler_config .get ("base_image_seq_len" , 256 ),
508- self .scheduler_config .get ("max_image_seq_len" , 4096 ),
509- self .scheduler_config .get ("base_shift" , 0.5 ),
510- self .scheduler_config .get ("max_shift" , 1.15 ),
545+ # Original: resolution-adaptive exponential shift via diffusers.
546+ image_seq_len = self .latents .shape [1 ]
547+ if self .is_layered :
548+ base_seqlen = 256 * 256 / 16 / 16
549+ image_seq_len = self .latents .shape [1 ] // 5
550+ mu = (image_seq_len / base_seqlen ) ** 0.5
551+ else :
552+ mu = calculate_shift (
553+ image_seq_len ,
554+ self .scheduler_config .get ("base_image_seq_len" , 256 ),
555+ self .scheduler_config .get ("max_image_seq_len" , 4096 ),
556+ self .scheduler_config .get ("base_shift" , 0.5 ),
557+ self .scheduler_config .get ("max_shift" , 1.15 ),
558+ )
559+ timesteps , num_inference_steps = retrieve_timesteps (
560+ self .scheduler ,
561+ num_inference_steps ,
562+ AI_DEVICE ,
563+ sigmas = sigmas ,
564+ mu = mu ,
511565 )
512- num_inference_steps = self .config ["infer_steps" ]
513- timesteps , num_inference_steps = retrieve_timesteps (
514- self .scheduler ,
515- num_inference_steps ,
516- AI_DEVICE ,
517- sigmas = sigmas ,
518- mu = mu ,
519- )
520566
521567 self .timesteps = timesteps
522568 self .infer_steps = num_inference_steps
0 commit comments