@@ -197,3 +197,90 @@ def infer(self, inputs):
197197 elif self .offload_granularity != "model" :
198198 self .pre_weight .to_cpu ()
199199 self .transformer_weights .non_block_weights_to_cpu ()
200+
201+ @torch .no_grad ()
202+ def infer_tensor_once (self , latents , timestep , context , context_null = None ):
203+ """
204+ Run one WAN forward pass from explicit tensors.
205+
206+ Args:
207+ latents: noisy latents, shape [C,F,H,W] or [1,F,C,H,W].
208+ timestep: timestep tensor (scalar / [1] / [1,F]); first value is used.
209+ context: conditional text embeddings, shape [L,D] or [1,L,D].
210+ context_null: optional unconditional text embeddings, same shape as context.
211+ Returns:
212+ noise prediction tensor with shape [C,F,H,W].
213+ """
214+ if self .cpu_offload :
215+ if self .offload_granularity == "model" and "wan2.2_moe" not in self .config ["model_cls" ]:
216+ self .to_cuda ()
217+ elif self .offload_granularity != "model" :
218+ self .pre_weight .to_cuda ()
219+ self .transformer_weights .non_block_weights_to_cuda ()
220+
221+ if latents .ndim == 5 :
222+ # [B,F,C,H,W] -> [C,F,H,W], only batch size 1 is supported.
223+ if latents .shape [0 ] != 1 :
224+ raise ValueError (f"Expected batch size 1 for 5D latents, got shape { tuple (latents .shape )} " )
225+ latents = latents .squeeze (0 ).permute (1 , 0 , 2 , 3 ).contiguous ()
226+ elif latents .ndim != 4 :
227+ raise ValueError (f"Expected latents ndim in [4,5], got { latents .ndim } " )
228+
229+ if context .ndim == 2 :
230+ context = context .unsqueeze (0 )
231+ if context .ndim != 3 :
232+ raise ValueError (f"Expected context ndim in [2,3], got { context .ndim } " )
233+
234+ if context_null is None :
235+ context_null = context
236+ elif context_null .ndim == 2 :
237+ context_null = context_null .unsqueeze (0 )
238+
239+ timestep = timestep .flatten ()
240+ if timestep .numel () == 0 :
241+ raise ValueError ("Empty timestep tensor" )
242+ timestep = timestep [:1 ].to (torch .int64 ).contiguous ()
243+
244+ self .scheduler .prepare (seed = 0 , latent_shape = [1 , 1 , 1 , 1 ], image_encoder_output = {})
245+ self .scheduler .latents = latents .to (AI_DEVICE )
246+ self .scheduler .timestep_input = timestep .to (AI_DEVICE )
247+
248+ inputs = {
249+ "text_encoder_output" : {
250+ "context" : context .to (AI_DEVICE ),
251+ "context_null" : context_null .to (AI_DEVICE ),
252+ },
253+ "image_encoder_output" : {},
254+ }
255+
256+ def _convert_flow_pred_to_x0 (flow_pred , xt , timestep_tensor ):
257+ original_dtype = flow_pred .dtype
258+ flow_pred , xt , sigmas , timesteps = map (
259+ lambda x : x .double ().to (flow_pred .device ),
260+ [flow_pred , xt , self .scheduler .sigmas , self .scheduler .timesteps ],
261+ )
262+ timestep_id = torch .argmin ((timesteps .unsqueeze (0 ) - timestep_tensor .unsqueeze (1 )).abs (), dim = 1 )
263+ sigma_t = sigmas [timestep_id ].reshape (- 1 , 1 , 1 , 1 )
264+ x0_pred = xt - sigma_t * flow_pred
265+ return x0_pred .to (original_dtype )
266+
267+ timestep_for_x0 = timestep .flatten ()[:1 ]
268+ if self .config .get ("enable_cfg" , False ):
269+ noise_pred_cond = self ._infer_cond_uncond (inputs , infer_condition = True )
270+ noise_pred_uncond = self ._infer_cond_uncond (inputs , infer_condition = False )
271+ pred_x0_cond = _convert_flow_pred_to_x0 (noise_pred_cond , self .scheduler .latents , timestep_for_x0 )
272+ pred_x0_uncond = _convert_flow_pred_to_x0 (noise_pred_uncond , self .scheduler .latents , timestep_for_x0 )
273+ noise_pred = noise_pred_uncond + self .scheduler .sample_guide_scale * (noise_pred_cond - noise_pred_uncond )
274+ pred_x0 = pred_x0_uncond + self .scheduler .sample_guide_scale * (pred_x0_cond - pred_x0_uncond )
275+ else :
276+ noise_pred = self ._infer_cond_uncond (inputs , infer_condition = True )
277+ pred_x0 = _convert_flow_pred_to_x0 (noise_pred , self .scheduler .latents , timestep_for_x0 )
278+
279+ if self .cpu_offload :
280+ if self .offload_granularity == "model" and "wan2.2_moe" not in self .config ["model_cls" ]:
281+ self .to_cpu ()
282+ elif self .offload_granularity != "model" :
283+ self .pre_weight .to_cpu ()
284+ self .transformer_weights .non_block_weights_to_cpu ()
285+
286+ return noise_pred , pred_x0
0 commit comments