Skip to content

Commit 09f01d3

Browse files
update zoe qwen-image (#1012)
1 parent 6db002f commit 09f01d3

5 files changed

Lines changed: 137 additions & 25 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"infer_steps": 4,
3+
"max_custom_size": 4096,
4+
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
5+
"prompt_template_encode_start_idx": 34,
6+
"attn_type": "sage_attn2",
7+
"enable_cfg": false,
8+
"dit_original_ckpt": "/data/nvme1/yongyang/ccc/models/distill_zoe_diff_qwen_image_data_680w_neo_prompt_res2k_3kiter_multi_large_char_200iter_step4.safetensors",
9+
"sample_shift": 5.0,
10+
"zoe_style_noise": true
11+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"infer_steps": 4,
3+
"max_custom_size": 4096,
4+
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
5+
"prompt_template_encode_start_idx": 34,
6+
"attn_type": "sage_attn2",
7+
"enable_cfg": false,
8+
"dit_quantized": true,
9+
"dit_quantized_ckpt": "/data/nvme1/yongyang/ccc/models/distill_zoe_diff_qwen_image_data_680w_neo_prompt_res2k_3kiter_multi_large_char_200iter_step4_fp8_mix.safetensors",
10+
"dit_quant_scheme": "fp8-sgl",
11+
"sample_shift": 5.0,
12+
"zoe_style_noise": true
13+
}

lightx2v/models/schedulers/qwen_image/scheduler.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def calculate_shift(
3737
return mu
3838

3939

40+
def time_shift_linear(mu: float, t: torch.Tensor) -> torch.Tensor:
41+
"""Linear time shift: mu / (mu + (1/t - 1)), matching zoe-diffusion's implementation."""
42+
return mu / (mu + (1.0 / t - 1.0))
43+
44+
4045
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
4146
def retrieve_timesteps(
4247
scheduler,
@@ -428,7 +433,7 @@ def __init__(self, config):
428433
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
429434
self.scheduler_config = json.load(f)
430435
self.dtype = torch.bfloat16
431-
self.sample_guide_scale = self.config["sample_guide_scale"]
436+
self.sample_guide_scale = self.config.get("sample_guide_scale", None)
432437
self.zero_cond_t = config.get("zero_cond_t", False)
433438
if self.config["seq_parallel"]:
434439
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
@@ -480,43 +485,84 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
480485

481486
return latent_image_ids.to(device=device, dtype=dtype)
482487

488+
def _prepare_latents_lightx2v(self, shape, height, width, num_channels_latents):
489+
"""Original LightX2V latent generation: noise in [B, T, C, H, W] then pack."""
490+
latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype)
491+
if self.is_layered:
492+
latents = self._pack_latents(latents, 1, num_channels_latents, height, width, self.layers + 1)
493+
else:
494+
latents = self._pack_latents(latents, 1, num_channels_latents, height, width)
495+
return latents
496+
497+
def _prepare_latents_zoe(self, shape, height, width, num_channels_latents):
498+
"""Zoe-aligned latent generation: noise in packed format [B, C*4, T, H//2, W//2].
499+
Ensures the same random sampling order as Zoe for bit-exact alignment.
500+
"""
501+
b, t = shape[0], shape[1]
502+
zoe_shape = (b, num_channels_latents * 4, t, height // 2, width // 2)
503+
latents = randn_tensor(zoe_shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype)
504+
# Convert to LightX2V sequence format: [B, (H//2)*(W//2), C*4]
505+
latents = latents.squeeze(2) # [B, C*4, H//2, W//2]
506+
latents = latents.permute(0, 2, 3, 1) # [B, H//2, W//2, C*4]
507+
latents = latents.reshape(b, (height // 2) * (width // 2), num_channels_latents * 4)
508+
return latents
509+
483510
def prepare_latents(self, input_info):
484511
self.input_info = input_info
485512
shape = input_info.target_shape
513+
# shape: [B, T, C, H, W]
486514
width, height = shape[-1], shape[-2]
487-
latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype)
488-
if self.is_layered:
489-
latents = self._pack_latents(latents, 1, self.config.get("num_channels_latents", 16), height, width, self.layers + 1)
515+
num_channels_latents = self.config.get("num_channels_latents", 16)
516+
517+
if self.config.get("zoe_style_noise", False) and not self.is_layered:
518+
latents = self._prepare_latents_zoe(shape, height, width, num_channels_latents)
490519
else:
491-
latents = self._pack_latents(latents, 1, self.config.get("num_channels_latents", 16), height, width)
520+
latents = self._prepare_latents_lightx2v(shape, height, width, num_channels_latents)
521+
492522
latent_image_ids = self._prepare_latent_image_ids(1, height // 2, width // 2, AI_DEVICE, self.dtype)
493523
self.latents = latents
494524
self.latent_image_ids = latent_image_ids
495525
self.noise_pred = None
496526

497527
def set_timesteps(self):
498-
sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"])
499-
image_seq_len = self.latents.shape[1]
500-
if self.is_layered:
501-
base_seqlen = 256 * 256 / 16 / 16
502-
image_seq_len = self.latents.shape[1] // 5
503-
mu = (image_seq_len / base_seqlen) ** 0.5
528+
num_inference_steps = self.config["infer_steps"]
529+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
530+
531+
sample_shift = self.config.get("sample_shift", None)
532+
if sample_shift is not None:
533+
# Zoe-style: linear time shift with a fixed mu, resolution-independent.
534+
# Formula: t_shifted = mu / (mu + (1/t - 1))
535+
sigmas_tensor = torch.from_numpy(sigmas).float().to(AI_DEVICE)
536+
sigmas_shifted = time_shift_linear(mu=sample_shift, t=sigmas_tensor)
537+
sigmas_shifted = torch.cat([sigmas_shifted, torch.zeros(1, device=AI_DEVICE)])
538+
self.scheduler.sigmas = sigmas_shifted.to(dtype=torch.float32, device=AI_DEVICE)
539+
self.scheduler.timesteps = sigmas_shifted[:-1] * self.scheduler_config["num_train_timesteps"]
540+
self.scheduler.timesteps = self.scheduler.timesteps.to(AI_DEVICE)
541+
self.scheduler._step_index = None
542+
self.scheduler._begin_index = None
543+
timesteps = self.scheduler.timesteps
504544
else:
505-
mu = calculate_shift(
506-
image_seq_len,
507-
self.scheduler_config.get("base_image_seq_len", 256),
508-
self.scheduler_config.get("max_image_seq_len", 4096),
509-
self.scheduler_config.get("base_shift", 0.5),
510-
self.scheduler_config.get("max_shift", 1.15),
545+
# Original: resolution-adaptive exponential shift via diffusers.
546+
image_seq_len = self.latents.shape[1]
547+
if self.is_layered:
548+
base_seqlen = 256 * 256 / 16 / 16
549+
image_seq_len = self.latents.shape[1] // 5
550+
mu = (image_seq_len / base_seqlen) ** 0.5
551+
else:
552+
mu = calculate_shift(
553+
image_seq_len,
554+
self.scheduler_config.get("base_image_seq_len", 256),
555+
self.scheduler_config.get("max_image_seq_len", 4096),
556+
self.scheduler_config.get("base_shift", 0.5),
557+
self.scheduler_config.get("max_shift", 1.15),
558+
)
559+
timesteps, num_inference_steps = retrieve_timesteps(
560+
self.scheduler,
561+
num_inference_steps,
562+
AI_DEVICE,
563+
sigmas=sigmas,
564+
mu=mu,
511565
)
512-
num_inference_steps = self.config["infer_steps"]
513-
timesteps, num_inference_steps = retrieve_timesteps(
514-
self.scheduler,
515-
num_inference_steps,
516-
AI_DEVICE,
517-
sigmas=sigmas,
518-
mu=mu,
519-
)
520566

521567
self.timesteps = timesteps
522568
self.infer_steps = num_inference_steps
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
# set path firstly
4+
lightx2v_path=/data/nvme1/yongyang/ccc/LightX2V
5+
model_path=/data/nvme1/models/Qwen/Qwen-Image-2512
6+
7+
export CUDA_VISIBLE_DEVICES=0
8+
9+
# set environment variables
10+
source ${lightx2v_path}/scripts/base/base.sh
11+
12+
python -m lightx2v.infer \
13+
--model_cls qwen_image \
14+
--task t2i \
15+
--model_path $model_path \
16+
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i_2512_distill_zoe.json \
17+
--prompt '2K超高清画质,16:9宽屏比例,电影级渲染。一个精致的咖啡店门口场景,温馨的街道氛围。门口摆放着一个复古风格的木质黑板,黑板上用粉笔字体写着"日日新咖啡,2美元一杯",笔触温馨可爱。旁边有一个闪烁的霓虹灯招牌,红色霓虹灯管拼出"商汤科技"字样,现代科技感。旁边立着一幅精美的海报,海报上是一位优雅的中国美女模特,海报下方用时尚字体写着"SenseNova newbee"。整体氛围是东西方文化交融的现代咖啡馆,暖色调灯光,傍晚时分,细节精致,高质量渲染' \
18+
--negative_prompt " " \
19+
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i_2512_distill_zoe.png \
20+
--seed 42 \
21+
--target_shape 1536 2752
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
# set path firstly
4+
lightx2v_path=/data/nvme1/yongyang/ccc/LightX2V
5+
model_path=/data/nvme1/models/Qwen/Qwen-Image-2512
6+
7+
export CUDA_VISIBLE_DEVICES=0
8+
9+
# set environment variables
10+
source ${lightx2v_path}/scripts/base/base.sh
11+
12+
python -m lightx2v.infer \
13+
--model_cls qwen_image \
14+
--task t2i \
15+
--model_path $model_path \
16+
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i_2512_distill_zoe_fp8.json \
17+
--prompt '2K超高清画质,16:9宽屏比例,电影级渲染。一个精致的咖啡店门口场景,温馨的街道氛围。门口摆放着一个复古风格的木质黑板,黑板上用粉笔字体写着"日日新咖啡,2美元一杯",笔触温馨可爱。旁边有一个闪烁的霓虹灯招牌,红色霓虹灯管拼出"商汤科技"字样,现代科技感。旁边立着一幅精美的海报,海报上是一位优雅的中国美女模特,海报下方用时尚字体写着"SenseNova newbee"。整体氛围是东西方文化交融的现代咖啡馆,暖色调灯光,傍晚时分,细节精致,高质量渲染' \
18+
--negative_prompt " " \
19+
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i_2512_distill_zoe_fp83.png \
20+
--seed 42 \
21+
--target_shape 1536 2752

0 commit comments

Comments
 (0)