Skip to content

Commit f4c5184

Browse files
support longcat-image block offload with 2 mgr (#977)
1 parent 43687e3 commit f4c5184

7 files changed

Lines changed: 225 additions & 9 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"infer_steps": 50,
3+
"aspect_ratio": "16:9",
4+
"attn_type": "sage_attn2",
5+
"enable_cfg": true,
6+
"sample_guide_scale": 4.0,
7+
"enable_cfg_renorm": true,
8+
"cfg_renorm_min": 0.0,
9+
"axes_dims_rope": [16, 56, 56],
10+
"dit_quant_scheme": "Default",
11+
"rms_norm_type": "sgl-kernel",
12+
"cpu_offload": true,
13+
"offload_granularity": "block"
14+
}

lightx2v/models/networks/longcat_image/infer/offload/__init__.py

Whitespace-only changes.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
3+
from lightx2v.common.offload.manager import WeightAsyncStreamManager
4+
from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer
5+
from lightx2v_platform.base.global_var import AI_DEVICE
6+
7+
torch_device_module = getattr(torch, AI_DEVICE)
8+
9+
10+
class LongCatImageOffloadTransformerInfer(LongCatImageTransformerInfer):
11+
"""Offload transformer inference for LongCat Image model.
12+
13+
Supports block-level offload with double-buffer async prefetch for both
14+
double-stream blocks and single-stream blocks.
15+
"""
16+
17+
def __init__(self, config):
18+
super().__init__(config)
19+
if self.config.get("cpu_offload", False):
20+
offload_granularity = self.config.get("offload_granularity", "block")
21+
if offload_granularity == "block":
22+
self.infer_func = self.infer_with_blocks_offload
23+
if offload_granularity != "model":
24+
self.offload_manager_double = WeightAsyncStreamManager(offload_granularity=offload_granularity)
25+
self.offload_manager_single = WeightAsyncStreamManager(offload_granularity=offload_granularity)
26+
27+
def infer_with_blocks_offload(self, blocks, pre_infer_out):
28+
"""Run transformer inference with block-level offload.
29+
30+
Two-phase approach: first process all double blocks, then all single blocks,
31+
each with their own offload manager and cuda buffers.
32+
"""
33+
hidden_states = pre_infer_out.hidden_states
34+
encoder_hidden_states = pre_infer_out.encoder_hidden_states
35+
temb = pre_infer_out.temb
36+
image_rotary_emb = pre_infer_out.image_rotary_emb
37+
38+
# For I2I task: concatenate output latents with input image latents
39+
output_seq_len = None
40+
if pre_infer_out.input_image_latents is not None:
41+
output_seq_len = pre_infer_out.output_seq_len
42+
hidden_states = torch.cat([hidden_states, pre_infer_out.input_image_latents], dim=0)
43+
44+
# Stage 1: double blocks offload
45+
# wait for default stream
46+
current_stream = torch_device_module.current_stream()
47+
self.offload_manager_double.compute_stream.wait_stream(current_stream)
48+
for block_idx in range(len(blocks.double_blocks)):
49+
self.block_idx = block_idx
50+
51+
if self.offload_manager_double.need_init_first_buffer:
52+
self.offload_manager_double.init_first_buffer(blocks.double_blocks)
53+
54+
self.offload_manager_double.prefetch_weights((block_idx + 1) % len(blocks.double_blocks), blocks.double_blocks)
55+
56+
with torch_device_module.stream(self.offload_manager_double.compute_stream):
57+
encoder_hidden_states, hidden_states = self.infer_double_stream_block(
58+
self.offload_manager_double.cuda_buffers[0],
59+
hidden_states,
60+
encoder_hidden_states,
61+
temb,
62+
image_rotary_emb,
63+
)
64+
65+
self.offload_manager_double.swap_blocks()
66+
67+
# Stage 2: single blocks offload
68+
# wait for double stream
69+
self.offload_manager_single.compute_stream.wait_stream(self.offload_manager_double.compute_stream)
70+
for block_idx in range(len(blocks.single_blocks)):
71+
self.block_idx = block_idx
72+
73+
if self.offload_manager_single.need_init_first_buffer:
74+
self.offload_manager_single.init_first_buffer(blocks.single_blocks)
75+
76+
self.offload_manager_single.prefetch_weights((block_idx + 1) % len(blocks.single_blocks), blocks.single_blocks)
77+
78+
with torch_device_module.stream(self.offload_manager_single.compute_stream):
79+
encoder_hidden_states, hidden_states = self.infer_single_stream_block(
80+
self.offload_manager_single.cuda_buffers[0],
81+
hidden_states,
82+
encoder_hidden_states,
83+
temb,
84+
image_rotary_emb,
85+
)
86+
87+
self.offload_manager_single.swap_blocks()
88+
89+
# For I2I task: only return output image latents
90+
if output_seq_len is not None:
91+
hidden_states = hidden_states[:output_seq_len]
92+
93+
return hidden_states

lightx2v/models/networks/longcat_image/infer/transformer_infer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, config):
1616
self.config = config
1717
self.infer_conditional = True
1818
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
19-
19+
self.infer_func = self.infer_without_offload
2020
# Sequence parallel settings
2121
if self.config.get("seq_parallel", False):
2222
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
@@ -279,7 +279,7 @@ def infer_single_stream_block(
279279

280280
return encoder_hidden_states, hidden_states
281281

282-
def infer(self, block_weights, pre_infer_out):
282+
def infer_without_offload(self, block_weights, pre_infer_out):
283283
"""Run transformer inference through all blocks.
284284
285285
Args:
@@ -325,3 +325,7 @@ def infer(self, block_weights, pre_infer_out):
325325
hidden_states = hidden_states[:output_seq_len]
326326

327327
return hidden_states
328+
329+
def infer(self, block_weights, pre_infer_out):
330+
hidden_states = self.infer_func(block_weights, pre_infer_out)
331+
return hidden_states

lightx2v/models/networks/longcat_image/model.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.distributed as dist
33

44
from lightx2v.models.networks.base_model import BaseTransformerModel
5+
from lightx2v.models.networks.longcat_image.infer.offload.transformer_infer import LongCatImageOffloadTransformerInfer
56
from lightx2v.models.networks.longcat_image.infer.post_infer import LongCatImagePostInfer
67
from lightx2v.models.networks.longcat_image.infer.pre_infer import LongCatImagePreInfer
78
from lightx2v.models.networks.longcat_image.infer.transformer_infer import LongCatImageTransformerInfer
@@ -23,7 +24,7 @@ class LongCatImageTransformerModel(BaseTransformerModel):
2324
transformer_weight_class = LongCatImageTransformerWeights
2425
post_weight_class = LongCatImagePostWeights
2526

26-
def __init__(self, config, model_path, device):
27+
def __init__(self, model_path, config, device):
2728
super().__init__(model_path, config, device)
2829
# Use transformer_in_channels to avoid conflict with VAE's in_channels
2930
self.in_channels = self.config.get("transformer_in_channels", self.config.get("in_channels", 64))
@@ -35,17 +36,25 @@ def __init__(self, config, model_path, device):
3536
self._init_infer()
3637

3738
def _init_infer_class(self):
38-
self.transformer_infer_class = LongCatImageTransformerInfer
39+
if self.cpu_offload and self.offload_granularity == "block":
40+
self.transformer_infer_class = LongCatImageOffloadTransformerInfer
41+
else:
42+
self.transformer_infer_class = LongCatImageTransformerInfer
3943
self.pre_infer_class = LongCatImagePreInfer
4044
self.post_infer_class = LongCatImagePostInfer
4145

4246
def _init_infer(self):
4347
self.transformer_infer = self.transformer_infer_class(self.config)
4448
self.pre_infer = self.pre_infer_class(self.config)
4549
self.post_infer = self.post_infer_class(self.config)
46-
if hasattr(self.transformer_infer, "offload_manager"):
50+
if hasattr(self.transformer_infer, "offload_manager_double") and hasattr(self.transformer_infer, "offload_manager_single"):
4751
self._init_offload_manager()
4852

53+
def _init_offload_manager(self):
54+
"""Initialize offload managers for double and single block buffers."""
55+
self.transformer_infer.offload_manager_double.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_double_block_cuda_buffers)
56+
self.transformer_infer.offload_manager_single.init_cuda_buffer(blocks_cuda_buffer=self.transformer_weights.offload_single_block_cuda_buffers)
57+
4958
@torch.no_grad()
5059
def _infer_cond_uncond(self, latents_input, prompt_embeds, infer_condition=True):
5160
self.scheduler.infer_condition = infer_condition
@@ -77,7 +86,11 @@ def _seq_parallel_post_process(self, x):
7786
@torch.no_grad()
7887
def infer(self, inputs):
7988
if self.cpu_offload:
80-
self.to_cuda()
89+
if self.offload_granularity == "model":
90+
self.to_cuda()
91+
elif self.offload_granularity == "block":
92+
self.pre_weight.to_cuda()
93+
self.post_weight.to_cuda()
8194

8295
latents = self.scheduler.latents
8396

@@ -129,3 +142,10 @@ def infer(self, inputs):
129142
# ==================== No CFG Processing ====================
130143
noise_pred = self._infer_cond_uncond(latents, inputs["text_encoder_output"]["prompt_embeds"], infer_condition=True)
131144
self.scheduler.noise_pred = noise_pred
145+
146+
if self.cpu_offload:
147+
if self.offload_granularity == "model":
148+
self.to_cpu()
149+
elif self.offload_granularity == "block":
150+
self.pre_weight.to_cpu()
151+
self.post_weight.to_cpu()

0 commit comments

Comments
 (0)