From 4c75be063ba69bfeb80bbce716e8f3a36da34038 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167@qq.com> Date: Thu, 23 Apr 2026 03:15:40 +0000 Subject: [PATCH 1/8] support kv quant/offload --- configs/lingbot_fast/lingbot_fast_i2v.json | 15 +- .../lingbot_fast_i2v_kv_quant_offload.json | 31 ++ lightx2v/common/kvcache/__init__.py | 23 + lightx2v/common/kvcache/base.py | 61 +++ lightx2v/common/kvcache/calibrate.py | 85 ++++ lightx2v/common/kvcache/kernel.py | 208 ++++++++ lightx2v/common/kvcache/manager.py | 168 +++++++ lightx2v/common/kvcache/offload.py | 399 ++++++++++++++++ lightx2v/common/kvcache/quant.py | 449 ++++++++++++++++++ lightx2v/common/kvcache/rolling.py | 58 +++ lightx2v/common/ops/attn/__init__.py | 2 +- lightx2v/common/ops/attn/sage_attn.py | 134 +++++- lightx2v/models/networks/base_model.py | 0 .../wan/infer/lingbot/transformer_infer.py | 4 +- .../wan/infer/lingbot_fast/pre_infer.py | 8 +- .../infer/lingbot_fast/transformer_infer.py | 392 +++++++-------- .../models/networks/wan/lingbot_fast_model.py | 19 +- .../models/runners/wan/wan_audio_runner.py | 3 +- .../runners/wan/wan_lingbot_fast_runner.py | 54 +-- .../schedulers/wan/lingbot_fast/scheduler.py | 17 +- .../video_encoders/hf/ltx2/audio_vae/ops.py | 3 +- scripts/lingbot/run_lingbot_fast_i2v.sh | 4 +- 22 files changed, 1854 insertions(+), 283 deletions(-) create mode 100755 configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json create mode 100644 lightx2v/common/kvcache/__init__.py create mode 100644 lightx2v/common/kvcache/base.py create mode 100644 lightx2v/common/kvcache/calibrate.py create mode 100644 lightx2v/common/kvcache/kernel.py create mode 100644 lightx2v/common/kvcache/manager.py create mode 100644 lightx2v/common/kvcache/offload.py create mode 100644 lightx2v/common/kvcache/quant.py create mode 100644 lightx2v/common/kvcache/rolling.py mode change 100755 => 100644 lightx2v/models/networks/base_model.py mode change 100644 => 100755 lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py mode change 100644 => 100755 lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py mode change 100755 => 100644 lightx2v/models/runners/wan/wan_lingbot_fast_runner.py mode change 100644 => 100755 lightx2v/models/schedulers/wan/lingbot_fast/scheduler.py diff --git a/configs/lingbot_fast/lingbot_fast_i2v.json b/configs/lingbot_fast/lingbot_fast_i2v.json index 728e43546..2e5e107a9 100755 --- a/configs/lingbot_fast/lingbot_fast_i2v.json +++ b/configs/lingbot_fast/lingbot_fast_i2v.json @@ -15,13 +15,12 @@ "vae_cpu_offload": true, "use_image_encoder": false, "dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/", - "sf_config": { - "local_attn_size": -1, - "num_frame_per_block": 3, - "timesteps_index": [0, 179, 358, 679] + "ar_config": { + "local_attn_size": 21, + "num_frame_per_chunk": 3, + "timesteps_index": [0, 179, 358, 679], + "sink_size": 3, + "kv_offload": false }, - "parallel": { - "seq_p_size": 4, - "seq_p_attn_type": "ulysses" - } + "rms_norm_type": "self_forcing" } diff --git a/configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json b/configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json new file mode 100755 index 000000000..b17380d14 --- /dev/null +++ b/configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json @@ -0,0 +1,31 @@ +{ + "infer_steps": 4, + "target_video_length": 161, + "text_len": 512, + "target_height": 720, + "target_width": 1280, + "self_attn_1_type": "sage_attn2_k_int8_v_fp8", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 1.0, + "sample_shift": 10.0, + "enable_cfg": false, + "cpu_offload": false, + "t5_cpu_offload": true, + "vae_cpu_offload": true, + "use_image_encoder": false, + "dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/", + "ar_config": { + "local_attn_size": 21, + "num_frame_per_chunk": 3, + "timesteps_index": [0, 179, 358, 679], + "sink_size": 3, + "kv_quant": { + "calibrate": false, + "smooth_k": true, + "calib_path": "calib_kv.pt" + }, + "kv_offload": true + }, + "rms_norm_type": "self_forcing" +} diff --git a/lightx2v/common/kvcache/__init__.py b/lightx2v/common/kvcache/__init__.py new file mode 100644 index 000000000..597b3b7a2 --- /dev/null +++ b/lightx2v/common/kvcache/__init__.py @@ -0,0 +1,23 @@ +""" +KV cache for autoregressive transformer inference. + +- ``base``: cross-attention pool +- ``rolling``: ``RollingKVCachePool`` (bf16 rolling-window cache) +- ``quant``: ``CalibRollingKVCachePool`` / ``QuantRollingKVCachePool`` +- ``offload``: ``OffloadRollingKVCachePool`` / ``OffloadQuantRollingKVCachePool`` +- ``manager``: ``KVCacheManager`` +""" + +from .manager import KVCacheManager +from .offload import OffloadQuantRollingKVCachePool, OffloadRollingKVCachePool +from .quant import CalibRollingKVCachePool, QuantRollingKVCachePool +from .rolling import RollingKVCachePool + +__all__ = [ + "KVCacheManager", + "RollingKVCachePool", + "CalibRollingKVCachePool", + "QuantRollingKVCachePool", + "OffloadRollingKVCachePool", + "OffloadQuantRollingKVCachePool", +] diff --git a/lightx2v/common/kvcache/base.py b/lightx2v/common/kvcache/base.py new file mode 100644 index 000000000..79287e8d6 --- /dev/null +++ b/lightx2v/common/kvcache/base.py @@ -0,0 +1,61 @@ +import torch + + +class BaseKVCachePool: + def __init__( + self, + num_layers: int, + cache_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + self._num_layers = num_layers + self._cache_size = cache_size + self._num_heads = num_heads + self._head_dim = head_dim + self._device = device + self._dtype = dtype + + def _init_kv_buffer(self): + self._k_buffer = torch.zeros( + (self._num_layers, self._cache_size, self._num_heads, self._head_dim), + dtype=self._dtype, + device=self._device, + ) + self._v_buffer = torch.zeros( + (self._num_layers, self._cache_size, self._num_heads, self._head_dim), + dtype=self._dtype, + device=self._device, + ) + + def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor: + return self._k_buffer[layer_id][attn_start:local_end] + + def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor: + return self._v_buffer[layer_id][attn_start:local_end] + + def store_kv(self, k: torch.Tensor, v: torch.Tensor, layer_id: int) -> None: + self._k_buffer[layer_id, : k.shape[0]] = k + self._v_buffer[layer_id, : v.shape[0]] = v + + def reset(self) -> None: + self._k_buffer.zero_() + self._v_buffer.zero_() + + @property + def device(self) -> torch.device: + return self._device + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def num_layers(self) -> int: + return self._num_layers + + @property + def cache_size(self) -> int: + return self._cache_size diff --git a/lightx2v/common/kvcache/calibrate.py b/lightx2v/common/kvcache/calibrate.py new file mode 100644 index 000000000..04c61e4e6 --- /dev/null +++ b/lightx2v/common/kvcache/calibrate.py @@ -0,0 +1,85 @@ +""" +KV-cache quantisation calibration. + +Step 1 — Calibration run +~~~~~~~~~~~~~~~~~~~~~~~~ +Use a config with ``"calibrate": true`` and ``self_attn_1_type`` set to +the **non-quant** attention (e.g. ``"sage_attn2"``). This creates a +``CalibRollingKVCachePool`` that stores bf16 KV normally while +collecting K-mean and V per-channel abs-max. + +Config example (calibration):: + + { + "self_attn_1_type": "sage_attn2", + "ar_config": { + ... + "sage_quant_kv": { + "calibrate": true, + "smooth_k": true + } + } + } + +After inference, call ``save_calibration`` to export the stats:: + + from lightx2v.common.kvcache.calibrate import save_calibration + runner.run_main() + save_calibration(runner.model.kv_cache_manager, "calib_kv.pt") + +Step 2 — Quantised inference +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Switch to the quant attention and provide the calibration file:: + + { + "self_attn_1_type": "sage_attn2_kvquant", + "ar_config": { + ... + "sage_quant_kv": { + "smooth_k": true, + "calib_path": "calib_kv.pt" + } + } + } +""" + +from __future__ import annotations + +import torch +from loguru import logger + +from .quant import CalibRollingKVCachePool + + +def save_calibration( + kv_cache_manager, + output_path: str, +) -> dict[str, torch.Tensor]: + """Export and save KV cache calibration from a completed run. + + Parameters + ---------- + kv_cache_manager : KVCacheManager + The manager whose ``self_attn_kv_cache`` is a + ``CalibRollingKVCachePool`` that has been used for at least one + full inference pass. + output_path : str + File path to save the calibration dict (``torch.save`` format). + + Returns + ------- + dict with keys ``'km'`` and ``'v_scale'``. + """ + pool = kv_cache_manager.self_attn_kv_cache + if not isinstance(pool, CalibRollingKVCachePool): + raise TypeError(f"Expected CalibRollingKVCachePool, got {type(pool).__name__}. Make sure the config has sage_quant_kv.calibrate=true and self_attn_1_type is NOT sage_attn2_kvquant.") + + calib = pool.export_calibration() + torch.save(calib, output_path) + logger.info( + "KV calibration saved to {} — km {}, v_scale {}", + output_path, + list(calib["km"].shape), + list(calib["v_scale"].shape), + ) + return calib diff --git a/lightx2v/common/kvcache/kernel.py b/lightx2v/common/kvcache/kernel.py new file mode 100644 index 000000000..b221ca0d2 --- /dev/null +++ b/lightx2v/common/kvcache/kernel.py @@ -0,0 +1,208 @@ +import torch +import triton +import triton.language as tl +from triton import next_power_of_2 + + +@triton.jit +def quant_key_per_thread_int8_static_scale_kernel( + Input, # [chunk_len, H, D] bf16/fp16 + Output, # [chunk_len, H, D] int8 + Scale, # [num_blk, H, 4] fp32 (preset) + L, # chunk_len + StartIdx, # absolute position in buffer where chunk starts + stride_iz, + stride_ih, + stride_in, + stride_oz, + stride_oh, + stride_on, + stride_sb, + stride_sh, # stride per-block, per-head; per-thread stride is 1 + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # Translate block-relative token offsets into chunk-local indices. + # When StartIdx % BLK != 0, the first chunk block begins at a + # negative chunk-local index — those positions are masked off. + block_local_base = off_blk * BLK - (StartIdx % BLK) + offs_in_blk = tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n0 = block_local_base + offs_in_blk + offs_n1 = offs_n0 + 1 + offs_k = tl.arange(0, C) + + mask_n0 = (offs_n0 >= 0) & (offs_n0 < L) + mask_n1 = (offs_n1 >= 0) & (offs_n1 < L) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + + # Scale layout [num_blk, H, 4] — per-thread stride is 1. + scale = tl.load(Scale + off_blk * stride_sb + off_h * stride_sh + off_tld) + + x0 = tl.load(input_ptrs0, mask=mask_n0[:, None]).to(tl.float32) + x1 = tl.load(input_ptrs1, mask=mask_n1[:, None]).to(tl.float32) + + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + + # Saturate before int8 cast — preset scale doesn't bound |x/scale|. + x0_int8 = tl.minimum(tl.maximum(x0_int8, -127.0), 127.0).to(tl.int8) + x1_int8 = tl.minimum(tl.maximum(x1_int8, -127.0), 127.0).to(tl.int8) + + tl.store(output_ptrs0, x0_int8, mask=mask_n0[:, None]) + tl.store(output_ptrs1, x1_int8, mask=mask_n1[:, None]) + + +@triton.jit +def fp8_v_quantize_nhd_prescale_kernel( + X, + OUT, + S, # [H, D] fp32 (per-channel v_scale = amax / 448, shared across L) + n_tok: tl.int32, + n_heads: tl.int32, + D: tl.int32, + BLOCK_D: tl.constexpr, + FP8_MAX_VAL: tl.constexpr, + SCALE_EPS: tl.constexpr, +): + """Quantise V ``[L, H, D]`` contiguous to fp32 staging, ``y = x / S[h,d]``.""" + row = tl.program_id(0) + h = row % n_heads + t = row // n_heads + d_off = tl.arange(0, BLOCK_D) + m = d_off < D + base_v = t * n_heads * D + h * D + base_s = h * D + x = tl.load(X + base_v + d_off, mask=m, other=0.0).to(tl.float32) + s = tl.load(S + base_s + d_off, mask=m, other=0.0).to(tl.float32) + s = tl.maximum(s, SCALE_EPS) + y = x / s + y = tl.clamp(y, -FP8_MAX_VAL, FP8_MAX_VAL) + tl.store(OUT + base_v + d_off, y, mask=m) + + +# --------------------------------------------------------------------------- # +# K int8 rescale on rolling: new_int8 ≈ round( old * src_scale / dst_scale ) +# (per-token, per-head ratio; D channels share the same ratio for that t,h) +# --------------------------------------------------------------------------- # + + +@triton.jit +def k_int8_roll_rescale_nhd_kernel( + X, # int8 [T, H, D] + OUT, # int8 [T, H, D] + S_SRC, # f32 [T, H] row-major + S_DST, # f32 [T, H] + T: tl.int32, + H: tl.int32, + D: tl.int32, + BLOCK_D: tl.constexpr, + SCALE_EPS: tl.constexpr, +): + """Re-quant int8 so dequant with ``dst`` position's scale recovers + the value encoded with ``src`` position's scale (Sage k_block_scale). + Rounding matches the repo's int8 path: add 0.5*sign, clamp, to int8. + """ + row = tl.program_id(0) + h = row % H + t = row // H + offs = tl.arange(0, BLOCK_D) + m = offs < D + base = t * H * D + h * D + s_src = tl.load(S_SRC + t * H + h).to(tl.float32) + s_dst = tl.load(S_DST + t * H + h).to(tl.float32) + s_dst = tl.maximum(s_dst, SCALE_EPS) + ratio = s_src / s_dst + x = tl.load(X + base + offs, mask=m, other=0.0).to(tl.float32) + y = x * ratio + y = y + 0.5 * tl.where(y >= 0, 1, -1) + y = tl.minimum(tl.maximum(y, -127.0), 127.0) + y = y.to(tl.int8) + tl.store(OUT + base + offs, y, mask=m) + + +def k_int8_roll_rescale_triton( + x: torch.Tensor, + out: torch.Tensor, + src_scale: torch.Tensor, + dst_scale: torch.Tensor, + *, + scale_eps: float = 1e-5, +) -> None: + """In-place: ``out[t,h,d] = sat(round_half_away( x * src_s/dst_s ))``. + + Shapes: ``x``, ``out`` are ``[T, H, D]`` int8, ``src_scale``/``dst_scale`` are + ``[T, H]`` fp32 (Sage K scale one value per token per head for the + current thread group). + """ + if x.shape != out.shape: + raise ValueError(f"x and out must match, got {x.shape} vs {out.shape}") + T, h_, d_ = x.shape + if src_scale.shape != (T, h_) or dst_scale.shape != (T, h_): + raise ValueError("src_scale and dst_scale must be [T, H]") + if not out.is_contiguous(): + raise ValueError("out must be contiguous to write in-place to the K buffer") + if not x.is_contiguous(): + x = x.contiguous() + ss = src_scale.to(device=x.device, dtype=torch.float32, copy=False).contiguous() + ds = dst_scale.to(device=x.device, dtype=torch.float32, copy=False).contiguous() + block_d = next_power_of_2(d_) + grid = (T * h_,) + t_i, h_i, d_i = int(T), int(h_), int(d_) + k_int8_roll_rescale_nhd_kernel[grid]( + x, + out, + ss, + ds, + t_i, + h_i, + d_i, + block_d, + SCALE_EPS=scale_eps, + num_warps=4, + ) + + +def quant_value_per_channel_fp8_static_scale_kernel( + v: torch.Tensor, + v_scale: torch.Tensor, + *, + fp8_max: float = 448.0, + scale_eps: float = 1e-5, +) -> torch.Tensor: + """Sage-matched per-channel V quant: ``v`` ``[L,H,D]``, ``v_scale`` ``[H,D]`` (``amax/448``).""" + v = v.contiguous() + vs = v_scale.to(device=v.device, dtype=torch.float32, copy=False) + n_tok, n_h, d = v.shape + if vs.shape != (n_h, d): + raise ValueError(f"v_scale {tuple(vs.shape)} must be [H,D]={(n_h, d)} for v {tuple(v.shape)}") + vs = vs.contiguous() + out = torch.empty_like(v, dtype=torch.float32, device=v.device) + block_d = next_power_of_2(d) + grid = (n_tok * n_h,) + n_t = int(n_tok) + n_h_ = int(n_h) + d_ = int(d) + fp8_v_quantize_nhd_prescale_kernel[grid]( + v, + out, + vs, + n_t, + n_h_, + d_, + block_d, + FP8_MAX_VAL=fp8_max, + SCALE_EPS=scale_eps, + num_warps=8, + ) + return out.to(torch.float8_e4m3fn) diff --git a/lightx2v/common/kvcache/manager.py b/lightx2v/common/kvcache/manager.py new file mode 100644 index 000000000..ead62cef6 --- /dev/null +++ b/lightx2v/common/kvcache/manager.py @@ -0,0 +1,168 @@ +import torch +import torch.distributed as dist +from loguru import logger + +from lightx2v.utils.envs import GET_DTYPE + +from .base import BaseKVCachePool +from .quant import CalibRollingKVCachePool, QuantRollingKVCachePool +from .rolling import RollingKVCachePool + + +def _self_attn_pool_from_config(config, ar_config, kv_size, dtype, device): + kv_offload = ar_config.get("kv_offload", False) + sq = ar_config.get("kv_quant") + + if not sq: + if kv_offload: + from .offload import OffloadRollingKVCachePool + + return OffloadRollingKVCachePool( + num_layers=config["num_layers"], + cache_size=kv_size, + num_heads=config["num_heads"], + head_dim=config["dim"] // config["num_heads"], + dtype=dtype, + device=device, + ) + return RollingKVCachePool( + num_layers=config["num_layers"], + cache_size=kv_size, + num_heads=config["num_heads"], + head_dim=config["dim"] // config["num_heads"], + dtype=dtype, + device=device, + ) + else: + common = dict( + num_layers=config["num_layers"], + cache_size=kv_size, + num_heads=config["num_heads"], + head_dim=config["dim"] // config["num_heads"], + dtype=dtype, + device=device, + smooth_k=sq.get("smooth_k", True), + ) + + calibrate = sq.get("calibrate", False) + calib_path = sq.get("calib_path", None) + if not calibrate: + if kv_offload: + from .offload import OffloadQuantRollingKVCachePool + + return OffloadQuantRollingKVCachePool(**common, calib_path=calib_path) + return QuantRollingKVCachePool(**common, calib_path=calib_path) + else: + num_steps = config.get("infer_steps", 1) + return CalibRollingKVCachePool(**common, num_steps=num_steps) + + +class KVCacheManager: + def __init__( + self, + config={}, + device=torch.device("cuda"), + sp_group=None, + ): + self.config = config + self.ar_config = self.config.get("ar_config", {}) + self.dtype = GET_DTYPE() + self.device = device + self.sp_group = sp_group + + @property + def current_step(self) -> int: + return getattr(self.self_attn_kv_cache, "current_step", 0) + + @current_step.setter + def current_step(self, value: int) -> None: + pool = self.self_attn_kv_cache + if hasattr(pool, "current_step"): + pool.current_step = value + + def _create_self_attn_kv_cache(self): + return _self_attn_pool_from_config( + self.config, + self.ar_config, + self.kv_size, + self.dtype, + self.device, + ) + + def _create_cross_attn_kv_cache(self): + return BaseKVCachePool( + num_layers=self.config["num_layers"], + cache_size=self.config["text_len"], + num_heads=self.config["num_heads"], + head_dim=self.config["dim"] // self.config["num_heads"], + dtype=self.dtype, + device=self.device, + ) + + def _compute_frame_seq_length(self, latent_shape): + lat_f = latent_shape[1] + lat_h = latent_shape[2] + lat_w = latent_shape[3] + patch_size = self.config.get("patch_size", (1, 2, 2)) + frame_seq_length = (lat_h // patch_size[1]) * (lat_w // patch_size[2]) + num_output_frames = lat_f - (lat_f % self.ar_config.get("num_frame_per_chunk", 3)) + return frame_seq_length, num_output_frames + + def _create_kv_caches(self, latent_shape): + """Create (or recreate) cache pools with resolution-dependent sizes.""" + + self.frame_seq_length, self.num_output_frames = self._compute_frame_seq_length(latent_shape) + ws = dist.get_world_size(self.sp_group) if self.sp_group is not None else 1 + self.kv_size = self.frame_seq_length * self.num_output_frames + self.local_attn_size = self.ar_config.get("local_attn_size", -1) + self.sink_size = self.ar_config.get("sink_size", 0) + self.max_attention_size = self.ar_config.get("max_attention_size", None) + + if self.local_attn_size != -1: + self.kv_size = self.local_attn_size * self.frame_seq_length // ws + else: + self.kv_size = self.kv_size // ws + + if self.max_attention_size is not None: + self.max_attention_size = self.max_attention_size // ws + else: + self.max_attention_size = self.kv_size + + self.self_attn_kv_cache = self._create_self_attn_kv_cache() + self.cross_attn_kv_cache = self._create_cross_attn_kv_cache() + self.self_attn_kv_cache._init_kv_buffer() + self.cross_attn_kv_cache._init_kv_buffer() + + logger.info( + "[KVCacheManager] init: frame_seq_length={}, num_output_frames={}, kv_cache_size={}, max_attention_size={}, ws={}, local_attn_size={}, sink_size={}, kv_quant={}, kv_offload={}", + self.frame_seq_length, + self.num_output_frames, + self.kv_size, + self.max_attention_size, + ws, + self.local_attn_size, + self.sink_size, + bool(self.ar_config.get("kv_quant")), + bool(self.ar_config.get("kv_offload")), + ) + + def maybe_save_calibration(self) -> None: + """Auto-save calibration if running in calibrate mode with calib_path.""" + sq = self.ar_config.get("kv_quant") + if not sq or not isinstance(sq, dict): + return + if not sq.get("calibrate", False): + return + output_path = sq.get("calib_path", "calib_kv.pt") + pool = self.self_attn_kv_cache + if not isinstance(pool, CalibRollingKVCachePool): + return + calib = pool.export_calibration() + torch.save(calib, output_path) + logger.info( + "[KVCacheManager] calibration saved to {} — km {}, v_scale {}, k_block_scale {}", + output_path, + list(calib["km"].shape), + list(calib["v_scale"].shape), + list(calib["k_block_scale"].shape), + ) diff --git a/lightx2v/common/kvcache/offload.py b/lightx2v/common/kvcache/offload.py new file mode 100644 index 000000000..87057eaa9 --- /dev/null +++ b/lightx2v/common/kvcache/offload.py @@ -0,0 +1,399 @@ +"""KV cache pools with CPU offloading. + +Only 2 layers' worth of KV data resides on GPU at any time (double-buffered). +Async CPU↔GPU transfers via dedicated CUDA streams overlap with compute: + + Per-layer timeline (steady state): + ┌──────────────────────────────────────────────────────────────────────┐ + │ load_stream: [pf i+1.....] [pf i+2.................] │ + │ compute: [self-attn i] [cross+ffn i] [self-attn i+1] [cross+] │ + │ store_stream: [wb i.....] [wb i+1.....]│ + └──────────────────────────────────────────────────────────────────────┘ + + - prefetch (i+2) is queued at end_layer(i) so it starts during cross+ffn i + (PCIe Gen4 has dual DMA engines — H2D & D2H run concurrently) + - writeback (i) is also queued at end_layer(i) + +Usage: + - ``prefetch_initial([0, 1])`` once before the loop (preloads first 2 layers) + - ``begin_layer(layer_id)`` before self-attention (just waits via GPU event) + - ``end_layer(layer_id, next_prefetch=i+2)`` after self-attention + (writeback + queue prefetch — both overlap with cross-attn + FFN) + - ``sync_all()`` after the last layer +""" + +import torch +from loguru import logger + +from .kernel import quant_value_per_channel_fp8_static_scale_kernel +from .quant import _FP8_MAX, QuantRollingKVCachePool +from .rolling import RollingKVCachePool + +# ====================================================================== # +# Mixin: double-buffered CPU↔GPU transfer logic +# ====================================================================== # + + +class _KVCacheOffloadMixin: + """Double-buffered async CPU↔GPU KV cache transfer via CUDA streams. + + Pure GPU-side event handoff — no CPU-blocking sync inside the loop: + + load_stream: [load 0] [load 1] ─────── [load 2] ─────── ... + ↓ ↓ + compute: [self-attn 0] [self-attn 1] ... + ↓ ↓ + store_stream: [wb 0] ─────── [wb 1] ─────────── ... + + Subclasses must implement: + _offload_async_load(layer_id, buf) – CPU→GPU full copy + _offload_async_store(layer_id, buf, start, end) – GPU→CPU partial copy + """ + + def _init_offload(self): + self._load_stream = torch.cuda.Stream() + self._store_stream = torch.cuda.Stream() + # Per-buffer events for fine-grained dependency tracking + self._load_done = [torch.cuda.Event() for _ in range(2)] + self._writeback_done = [torch.cuda.Event() for _ in range(2)] + # Seed events as "completed" so first wait_event is a no-op + cur = torch.cuda.current_stream() + for e in self._load_done + self._writeback_done: + e.record(cur) + + self._cur_buf = 0 + self._gpu_layer = [-1, -1] + # Per-buffer dirty token range (start, end), inclusive-exclusive. + # None = clean (no need to writeback). Updated by store_kv / roll_window. + self._dirty: list[tuple[int, int] | None] = [None, None] + + def _mark_dirty(self, buf: int, start: int, end: int) -> None: + if self._dirty[buf] is None: + self._dirty[buf] = (start, end) + else: + s, e = self._dirty[buf] + self._dirty[buf] = (min(s, start), max(e, end)) + + # ------------------------------------------------------------------ # + + def _issue_prefetch(self, layer_id: int, buf: int) -> None: + """Queue an async H2D load of *layer_id* into *buf* on load_stream. + + Waits for any in-flight writeback of the same buf before overwriting. + """ + self._load_stream.wait_event(self._writeback_done[buf]) + with torch.cuda.stream(self._load_stream): + self._offload_async_load(layer_id, buf) + self._load_done[buf].record(self._load_stream) + self._gpu_layer[buf] = layer_id + self._dirty[buf] = None # fresh from CPU → in sync + + def prefetch_initial(self, layer_ids: list[int]) -> None: + """Pre-fill GPU buffers before the loop starts. + + Pass at most 2 layer ids — typically ``[0, 1]``. Subsequent prefetches + are issued automatically by ``end_layer(next_prefetch=...)``. + """ + assert len(layer_ids) <= 2 + self._cur_buf = 0 + for buf, lid in enumerate(layer_ids): + self._issue_prefetch(lid, buf) + + def begin_layer(self, layer_id: int): + """Wait (GPU-side) until *layer_id*'s KV is loaded into the active buffer. + + Falls back to issuing a load if *layer_id* wasn't prefetched (e.g. on + the very first call without ``prefetch_initial``). No CPU block. + """ + buf = self._cur_buf + + if self._gpu_layer[buf] != layer_id: + self._issue_prefetch(layer_id, buf) + + torch.cuda.current_stream().wait_event(self._load_done[buf]) + + def end_layer(self, layer_id: int, next_prefetch: int | None = None): + """Queue (a) writeback of the active buffer and (b) the *next* prefetch + — both run on dedicated streams in parallel with subsequent compute. + + Writeback only transfers the dirty token range (modified by store_kv / + roll_window). The next prefetch is queued *here* (not in the next + iteration's begin_layer) so it can start the moment writeback finishes + — letting it overlap with cross-attn + FFN of the current layer. + """ + buf = self._cur_buf + dirty = self._dirty[buf] + + if dirty is None: + # Nothing to write back — mark writeback as done immediately so a + # subsequent prefetch into this buf doesn't wait unnecessarily + self._writeback_done[buf].record(torch.cuda.current_stream()) + else: + start, end = dirty + done = torch.cuda.Event() + done.record() # captures compute stream's progress + self._store_stream.wait_event(done) + with torch.cuda.stream(self._store_stream): + self._offload_async_store(layer_id, buf, start, end) + self._writeback_done[buf].record(self._store_stream) + self._dirty[buf] = None + + # Queue the next prefetch into the buf we're about to free. + # PCIe Gen4 has independent H2D / D2H DMA engines, so the upcoming + # prefetch (H2D) can start in parallel with the writeback (D2H) of the + # OTHER buffer — and crucially before the next iteration's CPU-side + # kernel launches block load_stream from being scheduled. + if next_prefetch is not None: + self._issue_prefetch(next_prefetch, buf) + + self._cur_buf = 1 - self._cur_buf + + def sync_all(self): + """Block until all outstanding transfers complete (call after the loop).""" + self._store_stream.synchronize() + self._load_stream.synchronize() + + +# ====================================================================== # +# bf16 rolling KV cache with CPU offload +# ====================================================================== # + + +class OffloadRollingKVCachePool(_KVCacheOffloadMixin, RollingKVCachePool): + """RollingKVCachePool with CPU offload — only 2 layers on GPU.""" + + def __init__(self, num_layers, cache_size, num_heads, head_dim, dtype, device): + super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) + + def _init_kv_buffer(self): + L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim + + # CPU pinned memory — ground-truth storage for all layers + self._k_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() + self._v_cpu = torch.zeros(L, N, H, D, dtype=self._dtype, device="cpu").pin_memory() + + # GPU — fixed contiguous double buffers (2 layers only) + self._k_gpu_buf = torch.zeros(2, N, H, D, dtype=self._dtype, device=self._device) + self._v_gpu_buf = torch.zeros(2, N, H, D, dtype=self._dtype, device=self._device) + + self._global_end = torch.zeros(L, dtype=torch.long, device=self._device) + self._local_end = torch.zeros(L, dtype=torch.long, device=self._device) + + self._init_offload() + + gpu_mb = (self._k_gpu_buf.nbytes + self._v_gpu_buf.nbytes) / (1024 * 1024) + cpu_mb = (self._k_cpu.nbytes + self._v_cpu.nbytes) / (1024 * 1024) + logger.info( + "[OffloadRollingKVCachePool] GPU fixed buffer: {:.1f} MB, CPU pinned: {:.1f} MB (saved {:.1f} MB GPU)", + gpu_mb, + cpu_mb, + cpu_mb - gpu_mb, + ) + + # ------------------------------------------------------------------ # + # offload copy helpers + # ------------------------------------------------------------------ # + + def _offload_async_load(self, layer_id, buf): + self._k_gpu_buf[buf].copy_(self._k_cpu[layer_id], non_blocking=True) + self._v_gpu_buf[buf].copy_(self._v_cpu[layer_id], non_blocking=True) + + def _offload_async_store(self, layer_id, buf, start, end): + self._k_cpu[layer_id, start:end].copy_( + self._k_gpu_buf[buf, start:end], + non_blocking=True, + ) + self._v_cpu[layer_id, start:end].copy_( + self._v_gpu_buf[buf, start:end], + non_blocking=True, + ) + + # ------------------------------------------------------------------ # + # KV access (redirected to GPU double buffers) + # ------------------------------------------------------------------ # + + def store_kv(self, k, v, start_idx, end_idx, layer_id): + buf = self._cur_buf + self._k_gpu_buf[buf, start_idx:end_idx] = k + self._v_gpu_buf[buf, start_idx:end_idx] = v + self._mark_dirty(buf, start_idx, end_idx) + + def k_cache(self, layer_id, attn_start, local_end): + return self._k_gpu_buf[self._cur_buf, attn_start:local_end] + + def v_cache(self, layer_id, attn_start, local_end): + return self._v_gpu_buf[self._cur_buf, attn_start:local_end] + + def roll_window(self, layer_id, sink_tokens, num_evicted): + buf = self._cur_buf + num_kept = int(self._local_end[layer_id].item()) - num_evicted - sink_tokens + src_s = sink_tokens + num_evicted + dst_s = sink_tokens + + kb = self._k_gpu_buf[buf] + vb = self._v_gpu_buf[buf] + kb[dst_s : dst_s + num_kept].copy_(kb[src_s : src_s + num_kept].clone()) + vb[dst_s : dst_s + num_kept].copy_(vb[src_s : src_s + num_kept].clone()) + + # roll shifts data within the GPU buffer — CPU is now stale at [dst, dst+num_kept] + self._mark_dirty(buf, dst_s, dst_s + num_kept) + + def reset(self): + self._k_cpu.zero_() + self._v_cpu.zero_() + self._k_gpu_buf.zero_() + self._v_gpu_buf.zero_() + self._global_end.zero_() + self._local_end.zero_() + self._gpu_layer = [-1, -1] + self._dirty = [None, None] + self._cur_buf = 0 + + +# ====================================================================== # +# Quantized (K int8 + V fp8) rolling KV cache with CPU offload +# ====================================================================== # +class OffloadQuantRollingKVCachePool(_KVCacheOffloadMixin, QuantRollingKVCachePool): + """QuantRollingKVCachePool with CPU offload — only 2 layers on GPU. + + K (int8) and V (fp8) bulk data live on CPU in pinned memory; GPU + keeps a 2-buffer rolling window. Calibration data (km, v_scale, + k_block_scale) is small and stays on GPU permanently — both store_kv + and ``k_cache`` / ``v_cache`` (with window args) look up calibrated + scales by ``(step, layer)`` and + apply them on the active GPU buffer. + + Inherits ``_quant_key``, K/V pack layout, and ``_roll_window_on_k_v`` from + :class:`QuantRollingKVCachePool` while routing bulk data through + :meth:`_load_calib` and the double ``_[kv]_gpu_buf`` workspace. + + Because ``k_block_scale`` is loaded once into ``_calib_k_block_scale``, + we no longer need to ship per-layer K-scale buffers between CPU and + GPU (saves both memory and DMA bandwidth). + """ + + # ------------------------------------------------------------------ # + # buffer init + # ------------------------------------------------------------------ # + + def _init_kv_buffer(self) -> None: + L, N, H, D = self._num_layers, self._cache_size, self._num_heads, self._head_dim + self._load_calib() + + # CPU pinned memory — ground truth for K (int8) and V (fp8 viewed as uint8) + self._k_cpu = torch.zeros(L, N, H, D, dtype=torch.int8, device="cpu").pin_memory() + self._v_cpu = torch.zeros(L, N, H, D, dtype=torch.uint8, device="cpu").pin_memory() + + # GPU — fixed contiguous double buffers (2 layers only) + self._k_gpu_buf = torch.zeros(2, N, H, D, dtype=torch.int8, device=self._device) + self._v_gpu_buf = torch.zeros(2, N, H, D, dtype=torch.float8_e4m3fn, device=self._device) + + self._global_end = torch.zeros(L, dtype=torch.long, device=self._device) + self._local_end = torch.zeros(L, dtype=torch.long, device=self._device) + + self._init_offload() + + gpu_mb = (self._k_gpu_buf.nbytes + self._v_gpu_buf.nbytes) / (1024 * 1024) + cpu_mb = (self._k_cpu.nbytes + self._v_cpu.nbytes) / (1024 * 1024) + logger.info( + "[OffloadQuantRollingKVCachePool] GPU fixed buffer: {:.1f} MB, CPU pinned: {:.1f} MB (saved {:.1f} MB GPU)", + gpu_mb, + cpu_mb, + cpu_mb - gpu_mb, + ) + + def _offload_async_load(self, layer_id, buf): + self._k_gpu_buf[buf].copy_(self._k_cpu[layer_id], non_blocking=True) + self._v_gpu_buf[buf].view(torch.uint8).copy_(self._v_cpu[layer_id], non_blocking=True) + + def _offload_async_store(self, layer_id, buf, start, end): + self._k_cpu[layer_id, start:end].copy_( + self._k_gpu_buf[buf, start:end], + non_blocking=True, + ) + v_gpu_slice_u8 = self._v_gpu_buf[buf, start:end].view(torch.uint8) + self._v_cpu[layer_id, start:end].copy_(v_gpu_slice_u8, non_blocking=True) + + def store_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + start_idx: int, + end_idx: int, + layer_id: int, + ) -> None: + buf = self._cur_buf + km = self._lookup_km(layer_id) + if km is not None: + km_lowp = km.to(k.dtype).squeeze(0) + k_smoothed = k - km_lowp + else: + k_smoothed = k + + blk_start = start_idx // self._BLKK + last_blk = (end_idx - 1) // self._BLKK + num_blk = last_blk - blk_start + 1 + preset_scale = self._lookup_k_block_scale(layer_id, blk_start, num_blk) + k_int8 = self._quant_key(k_smoothed, preset_scale, start_idx, self._BLKK) + self._k_gpu_buf[buf, start_idx:end_idx] = k_int8 + + v_scale = self._lookup_v_scale(layer_id) + v_fp8 = quant_value_per_channel_fp8_static_scale_kernel(v, v_scale, fp8_max=_FP8_MAX) + self._v_gpu_buf[buf, start_idx:end_idx] = v_fp8 + + self._mark_dirty(buf, start_idx, end_idx) + + # ------------------------------------------------------------------ # + # read (sage_attn2_kvquant) — same tuple layout as QuantRollingKVCachePool + # ------------------------------------------------------------------ # + + def k_cache(self, layer_id: int, attn_start: int, local_end: int): + BLK = self._BLKK + buf = self._cur_buf + aligned_start = (attn_start // BLK) * BLK + k_int8 = self._k_gpu_buf[buf, aligned_start:local_end].unsqueeze(0).contiguous() + blk_s = aligned_start // BLK + blk_e = (local_end + BLK - 1) // BLK + k_scale = self._calib_k_block_scale[self.current_step, layer_id, blk_s:blk_e].permute(1, 0, 2).reshape(1, self._num_heads, -1).contiguous() + return k_int8, k_scale + + def v_cache(self, layer_id: int, attn_start: int, local_end: int): + BLK = self._BLKK + buf = self._cur_buf + aligned_start = (attn_start // BLK) * BLK + v_fp8 = self._v_gpu_buf[buf, aligned_start:local_end] + v_fp8 = self._transpose_permute_v(v_fp8) + v_scale = self._lookup_v_scale(layer_id).unsqueeze(0).contiguous() + return v_fp8, v_scale + + # ------------------------------------------------------------------ # + # roll + # ------------------------------------------------------------------ # + + def roll_window(self, layer_id: int, sink_tokens: int, num_evicted: int) -> None: + buf = self._cur_buf + self._roll_window_on_k_v( + self._k_gpu_buf[buf], + self._v_gpu_buf[buf], + layer_id, + sink_tokens, + num_evicted, + ) + num_kept = int(self._local_end[layer_id].item()) - num_evicted - sink_tokens + dst_s = sink_tokens + self._mark_dirty(buf, dst_s, dst_s + num_kept) + + # ------------------------------------------------------------------ # + # misc + # ------------------------------------------------------------------ # + + def reset(self): + self._k_cpu.zero_() + self._v_cpu.zero_() + self._k_gpu_buf.zero_() + self._v_gpu_buf.zero_() + self._global_end.zero_() + self._local_end.zero_() + self._gpu_layer = [-1, -1] + self._dirty = [None, None] + self._cur_buf = 0 diff --git a/lightx2v/common/kvcache/quant.py b/lightx2v/common/kvcache/quant.py new file mode 100644 index 000000000..2b6353f35 --- /dev/null +++ b/lightx2v/common/kvcache/quant.py @@ -0,0 +1,449 @@ +import torch + +try: + from sageattention.triton.quant_per_thread import quant_key_per_thread_int8_kernel +except ImportError: + quant_key_per_thread_int8_kernel = None + +from .kernel import ( + k_int8_roll_rescale_triton, + quant_key_per_thread_int8_static_scale_kernel, + quant_value_per_channel_fp8_static_scale_kernel, +) +from .rolling import RollingKVCachePool + +_FP8_MAX = 448.0 + + +class CalibRollingKVCachePool(RollingKVCachePool): + """Normal bf16 rolling cache that additionally captures the (km, + v_channel_max, k_block_scale) that sage_attn computes internally — + keyed by ``(step, layer)`` and **shared across all chunks**. + + Capture semantics + ----------------- + Each chunk's call to ``capture_attn`` runs sage's K-quant kernel on + the full attention window currently in the buffer and overwrites the + entries at ``[step, layer]`` — but only as long as the window keeps + growing. Once rolling kicks in (window stops growing), captures are + skipped: the rolled-state buffer no longer matches what early-chunk + inference will see at those positions, so freezing the pre-roll + snapshot gives consistent calibration. + + The scales are stored at *buffer-absolute* block positions so that + the quant cache can index them directly when storing later chunks. + + After inference, ``export_calibration()`` returns: + ``km`` shape [S, L, 1, H, D] fp32 + ``v_scale`` shape [S, L, H, D] fp32 + ``k_block_scale`` shape [S, L, max_blks, H, scales_per_blk] fp32 + + Set ``current_step`` before each denoising step so captures land in + the right slot. + + Implementation notes + -------------------- + - The K slice fed to the calibration kernel starts at ``aligned_start + = (attn_start // 128) * 128`` so the per-block scales line up with + the buffer's natural 128-token blocks. The same alignment is used + by ``QuantRollingKVCachePool.k_cache`` / ``v_cache`` with + ``attn_start`` and ``local_end``. + - km is captured in bf16 (matching sage's ``k.mean(...)`` dtype), + then cast to fp32 for storage. This avoids extra mantissa bits + that would otherwise diverge from sage at the bf16 ``k - km`` + subtraction step. + """ + + _BLKK = 128 + _SCALES_PER_BLK = 4 # WARPK=128 ⇒ 4 thread groups per block per head + + def __init__( + self, + num_layers: int, + cache_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + *, + smooth_k: bool = True, + num_steps: int = 1, + ) -> None: + self._smooth_k = smooth_k + self._num_steps = num_steps + self.current_step: int = 0 + super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) + + def _init_kv_buffer(self) -> None: + super()._init_kv_buffer() + S = self._num_steps + L, H, D = self._num_layers, self._num_heads, self._head_dim + BLK = self._BLKK + max_blks = (self._cache_size + BLK - 1) // BLK + self._km = torch.zeros(S, L, 1, H, D, dtype=torch.float32, device=self._device) + self._v_channel_max = torch.zeros(S, L, H, D, dtype=torch.float32, device=self._device) + self._k_block_scale_calib = torch.zeros( + S, + L, + max_blks, + H, + self._SCALES_PER_BLK, + dtype=torch.float32, + device=self._device, + ) + self._capture_flag = torch.zeros(S, L, dtype=torch.bool, device=self._device) + self._captured_window_size = torch.zeros(S, L, dtype=torch.long, device="cpu") + + def _quant_key(self, k: torch.Tensor, km: torch.Tensor | None = None, BLKK: int = 128, WARPK: int = 128): + """Run sage's per_thread int8 K-quantisation kernel on ``k``. + + Returns ``(k_int8, k_scale)`` where ``k`` is ``[B, kv_len, H, D]`` (NHD). + The km subtraction (if any) is done in ``k.dtype`` to match sage's + behaviour exactly — sage does ``k - km`` in bf16, NOT fp32. + + This is the source-of-truth quantisation used both at calibration time + (to capture the per-block scale we'll later replay) and as a reference + for the preset-scale quantisation path. + """ + if km is not None: + km_lowp = km.to(k.dtype) if km.dtype != k.dtype else km + k = k - km_lowp + + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + b, kv_len, h_kv, head_dim = k.shape + + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = ( + k_int8.stride(0), + k_int8.stride(2), + k_int8.stride(1), + ) + + num_blk = (kv_len + BLKK - 1) // BLKK + scales_per_blk = (BLKK // WARPK) * 4 + k_scale = torch.empty( + (b, h_kv, num_blk * scales_per_blk), + device=k.device, + dtype=torch.float32, + ) + + grid = (num_blk * scales_per_blk, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, + k_int8, + k_scale, + kv_len, + stride_bz_k, + stride_h_k, + stride_seq_k, + stride_bz_ko, + stride_h_ko, + stride_seq_ko, + k_scale.stride(0), + k_scale.stride(1), + C=head_dim, + BLK=WARPK, + ) + return k_int8, k_scale + + def capture_attn( + self, + layer_id: int, + attn_start: int, + local_end: int, + ) -> None: + """Capture (km, v_channel_max, k_block_scale) from the buffer's + current state — exactly what sage_attn would see at this call. + + Parameters + ---------- + attn_start : start position of the attention window in the buffer + (may not be 128-aligned). + local_end : end position (exclusive) — the buffer's current valid + length for this layer. + + The captured K slice is aligned down to the nearest 128 boundary + so per-block scales map cleanly to buffer block indices. + """ + BLK = self._BLKK + aligned_start = (attn_start // BLK) * BLK + step, layer = self.current_step, layer_id + + k_full = self._k_buffer[layer_id, aligned_start:local_end] # [kv_len_a, H, D] bf16 + v_full = self._v_buffer[layer_id, aligned_start:local_end] # [kv_len_a, H, D] bf16 + kv_len_a = k_full.size(0) + if kv_len_a == 0: + return + + prev_window = int(self._captured_window_size[step, layer].item()) + if 0 < prev_window >= kv_len_a: + return + self._captured_window_size[step, layer] = kv_len_a + + # ---- km (bf16 mean to match sage) ---- + if self._smooth_k: + km_lowp = k_full.mean(dim=0, keepdim=True) # bf16 [1, H, D] + self._km[step, layer] = km_lowp.to(torch.float32) + else: + km_lowp = None + + # ---- k_block_scale via sage's quant kernel on (k - km) ---- + k_batch = k_full.unsqueeze(0).contiguous() # [1, kv_len_a, H, D] + _, k_scale_raw = self._quant_key(k_batch, km_lowp) # [1, H, num_blk*4] + num_blk_local = (kv_len_a + BLK - 1) // BLK + k_scale_local = k_scale_raw[0].reshape(self._num_heads, num_blk_local, self._SCALES_PER_BLK).permute(1, 0, 2) # [num_blk_local, H, 4] + blk_offset = aligned_start // BLK + self._k_block_scale_calib[step, layer, blk_offset : blk_offset + num_blk_local] = k_scale_local + self._v_channel_max[step, layer] = v_full.float().abs().amax(dim=0) # [H, D] + self._capture_flag[step, layer] = True + + def export_calibration(self) -> dict[str, torch.Tensor]: + v_scale = self._v_channel_max.clamp(min=1e-5) / _FP8_MAX + return { + "km": self._km.clone(), + "v_scale": v_scale, + "k_block_scale": self._k_block_scale_calib.clone(), + } + + def reset(self) -> None: + super().reset() + self._km.zero_() + self._v_channel_max.zero_() + self._k_block_scale_calib.zero_() + self._capture_flag.zero_() + self._captured_window_size.zero_() + + +class QuantRollingKVCachePool(RollingKVCachePool): + _BLKK = 128 + _SCALES_PER_BLK = 4 # (BLKK // WARPK) * 4, WARPK=128 + _PERM_16 = torch.tensor([0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15], dtype=torch.long, device="cuda") + + def __init__( + self, + num_layers: int, + cache_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + *, + smooth_k: bool = True, + calib_path: str, + ) -> None: + self._smooth_k_sage = smooth_k + self._calib_path = calib_path + self.current_step: int = 0 + super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) + + def _quant_key( + self, + k_smoothed: torch.Tensor, + preset_scale: torch.Tensor, + start_idx: int, + BLKK: int = 128, + ) -> torch.Tensor: + chunk_len, H, D = k_smoothed.shape + num_blk = preset_scale.size(0) + + k_int8 = torch.empty_like(k_smoothed, dtype=torch.int8) + preset_scale_c = preset_scale.contiguous() + grid = (num_blk * 4, H, 1) + quant_key_per_thread_int8_static_scale_kernel[grid]( + k_smoothed, + k_int8, + preset_scale_c, + chunk_len, + start_idx, + 0, + k_smoothed.stride(1), + k_smoothed.stride(0), + 0, + k_int8.stride(1), + k_int8.stride(0), + preset_scale_c.stride(0), + preset_scale_c.stride(1), + C=D, + BLK=BLKK, + ) + return k_int8 + + def _lookup_km(self, layer_id: int) -> torch.Tensor | None: + """Return km of shape [1, 1, H, D] for the current (step, layer), + or None if K smoothing is disabled. + + Supported calibration file shapes (newest → legacy): + [S, L, 1, H, D] – per (step, layer) ← preferred + [ L, 1, H, D] – per (layer) ← legacy + """ + if not self._smooth_k_sage: + return None + km_cal = self._calib_km + if km_cal.dim() == 5: + return km_cal[self.current_step, layer_id].unsqueeze(0) + return km_cal[layer_id].unsqueeze(0) + + def _lookup_v_scale(self, layer_id: int) -> torch.Tensor: + """Return v_scale of shape [H, D] for the current (step, layer). + + Supported calibration file shapes (newest → legacy): + [S, L, H, D] – per (step, layer) ← preferred + [ L, H, D] – per (layer) ← legacy + """ + vs_cal = self._calib_v_scale + if vs_cal.dim() == 4: + return vs_cal[self.current_step, layer_id] + return vs_cal[layer_id] + + def _lookup_k_block_scale( + self, + layer_id: int, + blk_start: int, + num_blk: int, + ) -> torch.Tensor: + """Return ``[num_blk, H, scales_per_blk]`` slice of the calibrated + k-block scale at the given absolute buffer block range. + + Calibration file shape: ``[S, L, max_blks, H, scales_per_blk]``. + """ + return self._calib_k_block_scale[ + self.current_step, + layer_id, + blk_start : blk_start + num_blk, + ] + + def _load_calib(self) -> None: + calib = torch.load(self._calib_path, map_location=self._device, weights_only=True) + self._calib_km = calib["km"].to(device=self._device, dtype=torch.float32) + self._calib_v_scale = calib["v_scale"].to(device=self._device, dtype=torch.float32) + if "k_block_scale" not in calib: + raise RuntimeError(f"Calibration file {self._calib_path!r} is missing 'k_block_scale'. Re-run calibration with CalibRollingKVCachePool.") + self._calib_k_block_scale = calib["k_block_scale"].to( + device=self._device, + dtype=torch.float32, + ) + + def _init_kv_buffer(self) -> None: + L = self._num_layers + N = self._cache_size + H = self._num_heads + D = self._head_dim + self._load_calib() + self._k_buffer = torch.zeros(L, N, H, D, dtype=torch.int8, device=self._device) + self._v_buffer = torch.zeros(L, N, H, D, dtype=torch.float8_e4m3fn, device=self._device) + + self._global_end = torch.zeros(L, dtype=torch.long, device=self._device) + self._local_end = torch.zeros(L, dtype=torch.long, device=self._device) + + def store_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + start_idx: int, + end_idx: int, + layer_id: int, + ) -> None: + km = self._lookup_km(layer_id) + if km is not None: + km_lowp = km.to(k.dtype).squeeze(0) + k_smoothed = k - km_lowp + else: + k_smoothed = k + + blk_start = start_idx // self._BLKK + last_blk = (end_idx - 1) // self._BLKK + num_blk = last_blk - blk_start + 1 + + preset_scale = self._lookup_k_block_scale(layer_id, blk_start, num_blk) + k_int8 = self._quant_key(k_smoothed, preset_scale, start_idx, self._BLKK) + + self._k_buffer[layer_id, start_idx:end_idx] = k_int8 + + v_scale = self._lookup_v_scale(layer_id) + v_fp8 = quant_value_per_channel_fp8_static_scale_kernel(v, v_scale, fp8_max=_FP8_MAX) + self._v_buffer[layer_id, start_idx:end_idx] = v_fp8 + + def _gather_per_token_k_scale( + self, + layer_id: int, + start_pos: int, + num_tokens: int, + ) -> torch.Tensor: + positions = torch.arange( + start_pos, + start_pos + num_tokens, + device=self._device, + ) + blk_idx = positions // self._BLKK + thread = (positions % self._BLKK // 2) % 4 + return self._calib_k_block_scale[ + self.current_step, + layer_id, + blk_idx, + :, + thread, + ] + + def _transpose_permute_v(self, v: torch.Tensor) -> torch.Tensor: + kv_len, H, D = v.shape + padded_len = (kv_len + 127) // 128 * 128 + + if padded_len > kv_len: + v_t = v.new_zeros(D, H, padded_len) + v_t[:, :, :kv_len].copy_(v.permute(2, 1, 0)) + else: + v_t = v.permute(2, 1, 0).contiguous() + + v_t = v_t.view(D, H, -1, 16)[:, :, :, self._PERM_16].contiguous() + v_t = v_t.view(1, D, H, padded_len) + return v_t + + def _roll_window_on_k_v(self, kb: torch.Tensor, vb: torch.Tensor, layer_id: int, sink_tokens: int, num_evicted: int) -> None: + num_kept = int(self._local_end[layer_id].item()) - num_evicted - sink_tokens + src_start = sink_tokens + num_evicted + src_end = src_start + num_kept + dst_start = sink_tokens + dst_end = dst_start + num_kept + if num_kept > 0: + x = kb[src_start:src_end].contiguous() # [num_kept, H, D] + out = kb[dst_start:dst_end] + src_scale = self._gather_per_token_k_scale(layer_id, src_start, num_kept) + dst_scale = self._gather_per_token_k_scale(layer_id, dst_start, num_kept) + k_int8_roll_rescale_triton(x, out, src_scale, dst_scale, scale_eps=1e-5) + vb[dst_start:dst_end].copy_(vb[src_start:src_end].clone()) + + def roll_window(self, layer_id: int, sink_tokens: int, num_evicted: int) -> None: + self._roll_window_on_k_v( + self._k_buffer[layer_id], + self._v_buffer[layer_id], + layer_id, + sink_tokens, + num_evicted, + ) + + def k_cache( + self, + layer_id: int, + attn_start: int, + local_end: int, + ): + BLK = self._BLKK + aligned_start = (attn_start // BLK) * BLK + k_int8 = self._k_buffer[layer_id, aligned_start:local_end].unsqueeze(0).contiguous() + blk_s = aligned_start // BLK + blk_e = (local_end + BLK - 1) // BLK + k_scale = self._calib_k_block_scale[self.current_step, layer_id, blk_s:blk_e].permute(1, 0, 2).reshape(1, self._num_heads, -1).contiguous() + return k_int8, k_scale + + def v_cache( + self, + layer_id: int, + attn_start: int, + local_end: int, + ): + BLK = self._BLKK + aligned_start = (attn_start // BLK) * BLK + v_fp8 = self._v_buffer[layer_id, aligned_start:local_end] + v_fp8 = self._transpose_permute_v(v_fp8) + v_scale = self._lookup_v_scale(layer_id).unsqueeze(0).contiguous() + return v_fp8, v_scale diff --git a/lightx2v/common/kvcache/rolling.py b/lightx2v/common/kvcache/rolling.py new file mode 100644 index 000000000..30ab9dc35 --- /dev/null +++ b/lightx2v/common/kvcache/rolling.py @@ -0,0 +1,58 @@ +import torch + +from .base import BaseKVCachePool + + +class RollingKVCachePool(BaseKVCachePool): + def __init__( + self, + num_layers: int, + cache_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) + + def _init_kv_buffer(self): + super()._init_kv_buffer() + self._global_end = torch.zeros(self._num_layers, dtype=torch.long, device=self._device) + self._local_end = torch.zeros(self._num_layers, dtype=torch.long, device=self._device) + + def store_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + start_idx: int, + end_idx: int, + layer_id: int, + ) -> None: + self._k_buffer[layer_id][start_idx:end_idx] = k + self._v_buffer[layer_id][start_idx:end_idx] = v + + def get_global_end(self, layer_id: int) -> int: + return int(self._global_end[layer_id].item()) + + def get_local_end(self, layer_id: int) -> int: + return int(self._local_end[layer_id].item()) + + def set_ends(self, layer_id: int, global_end: int, local_end: int) -> None: + self._global_end[layer_id] = global_end + self._local_end[layer_id] = local_end + + def roll_window( + self, + layer_id: int, + sink_tokens: int, + num_evicted: int, + ) -> None: + num_kept = int(self._local_end[layer_id].item()) - num_evicted - sink_tokens + src_start = sink_tokens + num_evicted + src_end = src_start + num_kept + dst_start = sink_tokens + dst_end = dst_start + num_kept + + kb, vb = self._k_buffer[layer_id], self._v_buffer[layer_id] + kb[dst_start:dst_end].copy_(kb[src_start:src_end].clone()) + vb[dst_start:dst_end].copy_(vb[src_start:src_end].clone()) diff --git a/lightx2v/common/ops/attn/__init__.py b/lightx2v/common/ops/attn/__init__.py index aa7d11a43..aeea39507 100755 --- a/lightx2v/common/ops/attn/__init__.py +++ b/lightx2v/common/ops/attn/__init__.py @@ -4,7 +4,7 @@ from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer from .radial_attn import RadialAttnWeight from .ring_attn import RingAttnWeight -from .sage_attn import SageAttn2Weight, SageAttn3Weight, SparseSageAttn2Weight, SparseSageAttn3Weight +from .sage_attn import SageAttn2KInt8VFP8Weight, SageAttn2Weight, SageAttn3Weight, SparseSageAttn2Weight, SparseSageAttn3Weight from .sla_attn import SlaAttnWeight from .sparge_attn import SpargeAttnWeight from .sparse_mask_generator import NbhdMaskGenerator, SlaMaskGenerator, SpargeMaskGenerator, SvgMaskGenerator diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py index 0bda8a6be..3c58c768c 100755 --- a/lightx2v/common/ops/attn/sage_attn.py +++ b/lightx2v/common/ops/attn/sage_attn.py @@ -5,7 +5,12 @@ from .template import AttnWeightTemplate from .utils.sla_util import get_block_map, get_cuda_arch -from .utils.sparge_util import block_map_incremental_lut_triton, block_map_ordinal_lut_triton, get_block_map_meansim, sage2_block_sparse_attn +from .utils.sparge_util import ( + block_map_incremental_lut_triton, + block_map_ordinal_lut_triton, + get_block_map_meansim, + sage2_block_sparse_attn, +) try: from sageattn3_sparse import sage3_block_sparse_attn @@ -39,6 +44,14 @@ logger.info("sageattn3_sparse not found, please install sageattention sparse first") sparse_sageattn3 = None +try: + from sageattention._qattn_sm90 import qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf + from sageattention.triton.quant_per_thread import quant_query_per_thread_int8_kernel +except ImportError: + quant_query_per_thread_int8_kernel = None + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = None + logger.info("sageattention not found, please install sageattention first") + @ATTN_WEIGHT_REGISTER("sage_attn2") class SageAttn2Weight(AttnWeightTemplate): @@ -134,7 +147,15 @@ def apply( sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK) elif self.sparse_mode == "sparge_mode": smooth_k = k - k.mean(dim=-2, keepdim=True) - sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK) + sparse_map = get_block_map_meansim( + q, + smooth_k, + cdfthreshd=None, + topk=self.topk, + return_lut=False, + BLKQ=self.BLKQ, + BLKK=self.BLKK, + ) else: logger.info(f"spas_sage_attn2 sparse_mode only support sla_mode and sparge_mode now.") @@ -176,7 +197,15 @@ def apply( sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK) elif self.sparse_mode == "sparge_mode": smooth_k = k - k.mean(dim=-2, keepdim=True) - sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK) + sparse_map = get_block_map_meansim( + q, + smooth_k, + cdfthreshd=None, + topk=self.topk, + return_lut=False, + BLKQ=self.BLKQ, + BLKK=self.BLKK, + ) else: logger.info(f"spas_sage_attn3 sparse_mode only support sla_mode and sparge_mode now.") @@ -184,3 +213,102 @@ def apply( x = sage3_block_sparse_attn(q, k, v, lut, valid_block_num, per_block_mean=self.per_block_mean) x = x.transpose(1, 2).reshape(bs * max_seqlen_q, -1) return x + + +@ATTN_WEIGHT_REGISTER("sage_attn2_k_int8_v_fp8") +class SageAttn2KInt8VFP8Weight(AttnWeightTemplate): + def __init__(self): + self.config = {} + + def quant_query_per_thread_int8(self, q, BLKQ=128, WARPQ=32, sm_scale=None): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + b, qo_len, h_qo, head_dim = q.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = ( + q_int8.stride(0), + q_int8.stride(2), + q_int8.stride(1), + ) + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), + device=q.device, + dtype=torch.float32, + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, + q_int8, + q_scale, + qo_len, + stride_bz_q, + stride_h_q, + stride_seq_q, + stride_bz_qo, + stride_h_qo, + stride_seq_qo, + q_scale.stride(0), + q_scale.stride(1), + C=head_dim, + BLK=WARPQ, + ) + return q_int8, q_scale + + def apply( + self, + q, + k, + v, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + sm_scale=None, + **kwargs, + ): + k_int8, k_scale = k + v_fp8, v_scale = v + q, k_int8, v_fp8 = q.contiguous(), k_int8.contiguous(), v_fp8.contiguous() + assert capability == (9, 0) + assert q.dtype in [torch.float16, torch.bfloat16] + assert k_int8.dtype == torch.int8 + assert k_scale is not None + assert v_scale is not None + assert q.stride(-1) == 1 and k_int8.stride(-1) == 1 + + dtype = q.dtype + + if len(q.shape) == 3: + bs = 1 + q = q.unsqueeze(0) + if len(k_int8.shape) == 3: + k_int8 = k_int8.unsqueeze(0) + if len(v_fp8.shape) == 3: + v_fp8 = v_fp8.unsqueeze(0) + elif len(q.shape) == 4: + bs = q.shape[0] + + head_dim_og = q.size(-1) + if sm_scale is None: + sm_scale = float(head_dim_og**-0.5) + + q_int8, q_scale = self.quant_query_per_thread_int8(q, BLKQ=64, WARPQ=16) + o = torch.empty(q.size(), dtype=dtype, device=q.device) + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + 0, + 0, + 3, + sm_scale, + 0, + ) + o = o.view(bs * max_seqlen_q, -1) + return o diff --git a/lightx2v/models/networks/base_model.py b/lightx2v/models/networks/base_model.py old mode 100755 new mode 100644 diff --git a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py old mode 100644 new mode 100755 index dfed29f1c..44f4b3986 --- a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py @@ -1,12 +1,12 @@ import torch -from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer +from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) -class WanLingbotTransformerInfer(WanTransformerInfer): +class WanLingbotTransformerInfer(WanOffloadTransformerInfer): def infer_block(self, block, x, pre_infer_out): if hasattr(block.compute_phases[0], "before_proj") and block.compute_phases[0].before_proj.weight is not None: x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x diff --git a/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py b/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py old mode 100644 new mode 100755 index 5110de169..013c16fcc --- a/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot_fast/pre_infer.py @@ -51,9 +51,9 @@ def _build_lingbot_conditional_dict(self, weights, inputs, x_tokens: torch.Tenso if c2ws_plucker_emb.dim() != 5: return {} - seg_start = self.scheduler.seg_index * self.scheduler.num_frame_per_block + seg_start = self.scheduler.seg_index * self.scheduler.num_frame_per_chunk seg_end = min( - (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_block, + (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_chunk, self.scheduler.num_output_frames, ) sliced = c2ws_plucker_emb[:, :, seg_start:seg_end, :, :] @@ -78,9 +78,9 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0): vae_encoder_out = image_encoder_output.get("vae_encoder_out", None) if vae_encoder_out is not None: - seg_start = self.scheduler.seg_index * self.scheduler.num_frame_per_block + seg_start = self.scheduler.seg_index * self.scheduler.num_frame_per_chunk seg_end = min( - (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_block, + (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_chunk, self.scheduler.num_output_frames, ) vae_chunk = vae_encoder_out[:, seg_start:seg_end] diff --git a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py index 2d33679d2..ceb54fcd2 100755 --- a/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot_fast/transformer_infer.py @@ -2,111 +2,24 @@ import torch.distributed as dist import torch.nn.functional as F -from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import ( - WanSFTransformerInfer, - causal_rope_apply, -) -from lightx2v.utils.envs import GET_DTYPE +from lightx2v.common.kvcache.quant import CalibRollingKVCachePool +from lightx2v.models.networks.wan.infer.lingbot.transformer_infer import WanLingbotTransformerInfer +from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import causal_rope_apply from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) -class WanLingbotFastTransformerInfer(WanSFTransformerInfer): - """Fast (autoregressive) transformer infer with lingbot camera injection and KV cache. - - Fixes over base WanSFTransformerInfer: - - KV/cross-attn cache uses actual num_heads/head_dim (not hardcoded 12/128) - - causal_rope_apply works with sequence parallelism - - KV cache indexing uses actual token count (not hardcoded frame_seq_length) - """ - +class WanLingbotFastTransformerInfer(WanLingbotTransformerInfer): def __init__(self, config): super().__init__(config) - self._text_len = config.get("text_len", 512) - - def _sp_world_size(self): - if self.config.get("seq_parallel", False) and dist.is_initialized(): - return dist.get_world_size(self.seq_p_group) - return 1 - - def reinit_caches(self, frame_seq_length, num_output_frames, text_len=None): - self.frame_seq_length = frame_seq_length - self._kv_size = frame_seq_length * num_output_frames - if text_len is not None: - self._text_len = text_len - ws = self._sp_world_size() - cfg_max = self.config.get("sf_config", {}).get("max_attention_size", None) - if cfg_max is not None: - self.max_attention_size = cfg_max // ws - elif self.local_attn_size == -1: - self.max_attention_size = self._kv_size // ws + self.num_frame_per_chunk = config.get("ar_config", {}).get("num_frame_per_chunk", 3) + if config.get("ar_config", {}).get("kv_offload", False): + self.infer_func = self.infer_with_kvcache_offload else: - self.max_attention_size = self.local_attn_size * frame_seq_length // ws - - self._initialize_kv_cache(self.dtype, self.device) - self._initialize_crossattn_cache(self.dtype, self.device) - - def _initialize_kv_cache(self, dtype, device): - if not hasattr(self, "_kv_size"): - return - kv_cache1 = [] - ws = self._sp_world_size() - if self.local_attn_size != -1: - kv_cache_size = self.local_attn_size * self.frame_seq_length // ws - else: - kv_cache_size = self._kv_size // ws - self.kv_cache_size = kv_cache_size - - n, d = self.num_heads, self.head_dim - if self.kv_quant_config is not None: - k_bit = self.kv_quant_config["k_bit"] - v_bit = self.kv_quant_config["v_bit"] - self.k_cache_dtype = torch.float8_e4m3fn if k_bit == "e4m3" else torch.float8_e5m2 - self.v_cache_dtype = torch.float8_e4m3fn if v_bit == "e4m3" else torch.float8_e5m2 - else: - self.k_cache_dtype = None - self.v_cache_dtype = None - - for _ in range(self.config["num_layers"]): - if self.k_cache_dtype is not None: - entry = { - "k": torch.zeros((self.kv_cache_size, n, d), dtype=self.k_cache_dtype, device=self.device), - "v": torch.zeros((self.kv_cache_size, n, d), dtype=self.v_cache_dtype, device=self.device), - "k_scales": torch.zeros((self.kv_cache_size, n, 1), dtype=GET_DTYPE(), device=self.device), - "v_scales": torch.zeros((self.kv_cache_size, n, 1), dtype=GET_DTYPE(), device=self.device), - "global_end_index": torch.tensor([0], dtype=torch.long, device=self.device), - "local_end_index": torch.tensor([0], dtype=torch.long, device=self.device), - } - else: - entry = { - "k": torch.zeros((self.kv_cache_size, n, d)).to(dtype).to(device), - "v": torch.zeros((self.kv_cache_size, n, d)).to(dtype).to(device), - "global_end_index": torch.tensor([0], dtype=torch.long).to(device), - "local_end_index": torch.tensor([0], dtype=torch.long).to(device), - } - kv_cache1.append(entry) - - self.kv_cache1_default = kv_cache1 - - def _initialize_crossattn_cache(self, dtype, device): - if not hasattr(self, "_kv_size"): - return - crossattn_cache = [] - n, d = self.num_heads, self.head_dim - # Align with source: cross_kv_shape = [batch, max_sequence_length, num_heads, head_dim] - text_len = self._text_len - for _ in range(self.config["num_layers"]): - crossattn_cache.append( - { - "k": torch.zeros((text_len, n, d)).to(dtype).to(device), - "v": torch.zeros((text_len, n, d)).to(dtype).to(device), - } - ) - self.crossattn_cache_default = crossattn_cache + self.infer_func = self.infer_with_kvcache def _apply_rope_sp(self, q, k, grid_sizes, freqs, start_frame): - """Apply causal RoPE correctly when tokens are split across GPUs.""" f, h, w = grid_sizes[0].tolist() full_seq_len = f * h * w c = q.size(-1) // 2 @@ -127,14 +40,11 @@ def _apply_rope_sp(self, q, k, grid_sizes, freqs, start_frame): padding_size = (multiple - (full_seq_len % multiple)) % multiple if padding_size > 0: pos_freqs = F.pad(pos_freqs, (0, 0, 0, 0, 0, padding_size)) - pos_freqs = torch.chunk(pos_freqs, world_size, dim=0)[cur_rank] - - actual_len = q.size(0) - pos_freqs = pos_freqs[:actual_len] + pos_freqs = torch.chunk(pos_freqs, world_size, dim=0)[cur_rank][: q.size(0)] n = q.size(1) - q_c = torch.view_as_complex(q.to(torch.float64).reshape(actual_len, n, -1, 2)) - k_c = torch.view_as_complex(k.to(torch.float64).reshape(actual_len, n, -1, 2)) + q_c = torch.view_as_complex(q.to(torch.float64).reshape(q.size(0), n, -1, 2)) + k_c = torch.view_as_complex(k.to(torch.float64).reshape(k.size(0), n, -1, 2)) q = torch.view_as_real(q_c * pos_freqs).flatten(2).type_as(q) k = torch.view_as_real(k_c * pos_freqs).flatten(2).type_as(k) return q, k @@ -143,11 +53,9 @@ def _apply_rope_sp(self, q, k, grid_sizes, freqs, start_frame): def _a2a_seq_to_heads(x, world_size, shard_heads, group): """[local_seq, all_heads, dim] -> [full_seq, shard_heads, dim]""" local_seq, _, dim = x.shape - x = x.reshape(local_seq, world_size, shard_heads, dim) - x = x.permute(1, 0, 2, 3).contiguous() # [world_size, local_seq, shard_heads, dim] + x = x.reshape(local_seq, world_size, shard_heads, dim).permute(1, 0, 2, 3).contiguous() out = torch.empty_like(x) dist.all_to_all_single(out, x, group=group) - # out[i] = rank i's local tokens for this rank's head shard — keep contiguous return out.reshape(local_seq * world_size, shard_heads, dim) @staticmethod @@ -161,12 +69,6 @@ def _a2a_heads_to_seq(x, world_size, shard_heads, group): return out.permute(1, 0, 2, 3).reshape(local_seq, world_size * shard_heads, dim) def _sp_kvcache_attn(self, q, k_cache, v_cache, phase): - """Self-attention with KV cache under sequence parallelism. - - The standard Ulysses all-to-all assumes Q/K/V have the same "image" - length, but with KV cache K/V is longer than Q. We do separate - all-to-all for Q and K/V so the full history is properly assembled. - """ world_size = dist.get_world_size(self.seq_p_group) shard_heads = self.num_heads // world_size d = self.head_dim @@ -190,13 +92,44 @@ def _sp_kvcache_attn(self, q, k_cache, v_cache, phase): max_seqlen_kv=full_k.size(0), ) - # flash_attn returns 2D [seq, shard_heads*dim]; reshape to 3D for all-to-all attn_out = attn_out.view(full_q.size(0), shard_heads, d) attn_out = self._a2a_heads_to_seq(attn_out, world_size, shard_heads, self.seq_p_group) - # flatten back to 2D [local_seq, all_heads*dim] return attn_out.reshape(q.size(0), self.num_heads * d) - # ---- Override self-attention to fix RoPE and KV cache indexing ---- + def _calculate_q_k_len(self, q, k_lens): + q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) + cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) + cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) + return cu_seqlens_q, cu_seqlens_k + + def infer_with_kvcache(self, blocks, x, pre_infer_out): + mgr = self.kv_cache_manager + self.kv_cache_size = mgr.kv_size + self.max_attention_size = mgr.max_attention_size + self._kv_offload = False + for block_idx in range(len(blocks)): + self.block_idx = block_idx + x = self.infer_block_with_kvcache(blocks[block_idx], x, pre_infer_out) + return x + + def infer_with_kvcache_offload(self, blocks, x, pre_infer_out): + mgr = self.kv_cache_manager + self.kv_cache_size = mgr.kv_size + self.max_attention_size = mgr.max_attention_size + self._kv_offload = True + kv_cache = mgr.self_attn_kv_cache + num_blocks = len(blocks) + + kv_cache.prefetch_initial(list(range(min(2, num_blocks)))) + + for block_idx in range(num_blocks): + self.block_idx = block_idx + self._next_prefetch = block_idx + 2 if block_idx + 2 < num_blocks else None + kv_cache.begin_layer(block_idx) + x = self.infer_block_with_kvcache(blocks[block_idx], x, pre_infer_out) + + kv_cache.sync_all() + return x def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): if hasattr(phase, "smooth_norm1_weight"): @@ -207,27 +140,19 @@ def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, sh norm1_bias = shift_msa.squeeze() norm1_out = phase.norm1.apply(x) - if self.sensitive_layer_dtype != self.infer_dtype: norm1_out = norm1_out.to(self.sensitive_layer_dtype) - norm1_out.mul_(norm1_weight[0:1, :]).add_(norm1_bias[0:1, :]) - if self.sensitive_layer_dtype != self.infer_dtype: norm1_out = norm1_out.to(self.infer_dtype) s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim - - q0 = phase.self_attn_q.apply(norm1_out) - k0 = phase.self_attn_k.apply(norm1_out) - - q = phase.self_attn_norm_q.apply(q0).view(s, n, d) - k = phase.self_attn_norm_k.apply(k0).view(s, n, d) + q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) + k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) v = phase.self_attn_v.apply(norm1_out).view(s, n, d) seg_index = int(self.scheduler.seg_index) - frame_seqlen = grid_sizes[0][1:].prod().item() - current_start_frame = seg_index * self.num_frame_per_block + current_start_frame = seg_index * self.num_frame_per_chunk if self.config.get("seq_parallel", False): q, k = self._apply_rope_sp(q, k, grid_sizes, freqs, current_start_frame) @@ -235,74 +160,50 @@ def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, sh q = causal_rope_apply(q.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] k = causal_rope_apply(k.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0] - num_new_tokens = int(q.size(0)) - # Use num_new_tokens for KV cache positioning — it already adapts to SP - # (with SP each rank holds total_tokens/world_size per segment). - # Using frame_seqlen (full spatial) would leave gaps in per-rank caches. - current_start = seg_index * num_new_tokens - current_end = current_start + num_new_tokens - kv_cache = self.kv_cache1[self.block_idx] - local_per_frame = num_new_tokens // self.num_frame_per_block if self.num_frame_per_block > 0 else 0 - sink_tokens = self.sink_size * local_per_frame - - global_end = int(kv_cache["global_end_index"].item()) - local_end = int(kv_cache["local_end_index"].item()) - - if self.local_attn_size != -1 and (current_end > global_end) and (num_new_tokens + local_end > self.kv_cache_size): - num_evicted_tokens = num_new_tokens + local_end - self.kv_cache_size - num_rolled_tokens = local_end - num_evicted_tokens - sink_tokens - src_start = sink_tokens + num_evicted_tokens - src_end = src_start + num_rolled_tokens - dst_start = sink_tokens - dst_end = dst_start + num_rolled_tokens - kv_cache["k"][dst_start:dst_end] = kv_cache["k"][src_start:src_end].clone() - kv_cache["v"][dst_start:dst_end] = kv_cache["v"][src_start:src_end].clone() - if self.kv_quant_config is not None: - kv_cache["k_scales"][dst_start:dst_end] = kv_cache["k_scales"][src_start:src_end].clone() - kv_cache["v_scales"][dst_start:dst_end] = kv_cache["v_scales"][src_start:src_end].clone() - local_end_index = local_end + current_end - global_end - num_evicted_tokens - local_start_index = local_end_index - num_new_tokens - else: - local_end_index = local_end + current_end - global_end - local_start_index = local_end_index - num_new_tokens - - if self.kv_quant_config is not None: - s0, s1, s2 = k.shape - k_2d = k.view(s0 * s1, s2) - v_2d = v.view(s0 * s1, s2) - k_q, k_scales = self.quant_fp8_vllm(k_2d) - v_q, v_scales = self.quant_fp8_vllm(v_2d) - kv_cache["k"][local_start_index:local_end_index] = k_q.view(s0, s1, s2) - kv_cache["v"][local_start_index:local_end_index] = v_q.view(s0, s1, s2) - kv_cache["k_scales"][local_start_index:local_end_index] = k_scales.view(s0, s1, 1) - kv_cache["v_scales"][local_start_index:local_end_index] = v_scales.view(s0, s1, 1) - else: - kv_cache["k"][local_start_index:local_end_index] = k - kv_cache["v"][local_start_index:local_end_index] = v - - kv_cache["global_end_index"].fill_(current_end) - kv_cache["local_end_index"].fill_(local_end_index) - - attn_start = max(0, local_end_index - self.max_attention_size) - if self.kv_quant_config is not None: - k_fp8 = kv_cache["k"][attn_start:local_end_index] - v_fp8 = kv_cache["v"][attn_start:local_end_index] - k_sc = kv_cache["k_scales"][attn_start:local_end_index] - v_sc = kv_cache["v_scales"][attn_start:local_end_index] - attn_k = self.dequant_fp8_vllm(k_fp8, k_sc, self.dtype) - attn_v = self.dequant_fp8_vllm(v_fp8, v_sc, self.dtype) + kv_cache = self.kv_cache_manager.self_attn_kv_cache + + num_new = int(q.size(0)) + current_start = seg_index * num_new + current_end = current_start + num_new + global_end = kv_cache.get_global_end(self.block_idx) + local_end = kv_cache.get_local_end(self.block_idx) + local_per_frame = num_new // self.num_frame_per_chunk if self.num_frame_per_chunk > 0 else 0 + sink_tokens = self.kv_cache_manager.sink_size * local_per_frame + + if self.kv_cache_manager.local_attn_size != -1 and current_end > global_end and num_new + local_end > self.kv_cache_size: + num_evicted = num_new + local_end - self.kv_cache_size + kv_cache.roll_window(self.block_idx, sink_tokens, num_evicted) + local_end_idx = local_end + current_end - global_end - num_evicted else: - attn_k = kv_cache["k"][attn_start:local_end_index] - attn_v = kv_cache["v"][attn_start:local_end_index] + local_end_idx = local_end + current_end - global_end + local_start_idx = local_end_idx - num_new + + kv_cache.store_kv(k, v, local_start_idx, local_end_idx, self.block_idx) + kv_cache.set_ends(self.block_idx, current_end, local_end_idx) + attn_start = max(0, local_end_idx - self.max_attention_size) if self.clean_cuda_cache: del norm1_out, norm1_weight, norm1_bias torch_device_module.empty_cache() if self.config.get("seq_parallel", False): + attn_k = kv_cache.k_cache(self.block_idx, attn_start, local_end_idx) + attn_v = kv_cache.v_cache(self.block_idx, attn_start, local_end_idx) attn_out = self._sp_kvcache_attn(q, attn_k, attn_v, phase) else: - k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0)) + attn_k = kv_cache.k_cache(self.block_idx, attn_start, local_end_idx) + attn_v = kv_cache.v_cache(self.block_idx, attn_start, local_end_idx) + + if isinstance(kv_cache, CalibRollingKVCachePool) and not getattr( + self.scheduler, + "is_rerun", + False, + ): + kv_cache.capture_attn(self.block_idx, attn_start, local_end_idx) + if isinstance(attn_k, tuple): + k_lens = torch.empty_like(seq_lens).fill_(attn_k[0].size(0)) + else: + k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0)) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens) attn_out = phase.self_attn_1.apply( q=q, @@ -311,13 +212,13 @@ def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, sh cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=q.size(0), - max_seqlen_kv=attn_k.size(0), + max_seqlen_kv=attn_k.size(0) if not isinstance(attn_k, tuple) else attn_k[0].size(0), ) y = phase.self_attn_o.apply(attn_out) if self.clean_cuda_cache: - del q, k, v, attn_out, attn_k, attn_v + del q, k, v, attn_out torch_device_module.empty_cache() return y @@ -338,6 +239,12 @@ def infer_block_with_kvcache(self, block, x, pre_infer_out): scale_msa, ) + if self._kv_offload: + self.kv_cache_manager.self_attn_kv_cache.end_layer( + self.block_idx, + next_prefetch=self._next_prefetch, + ) + x, attn_out = self.infer_cross_attn_with_kvcache( block.compute_phases[1], x, @@ -353,13 +260,12 @@ def infer_block_with_kvcache(self, block, x, pre_infer_out): if self.has_post_adapter: x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out) - # print(x, x.shape) - # exit() + return x def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, block=None, conditional_dict=None): num_frames = gate_msa.shape[0] - frame_seqlen = x.shape[0] // gate_msa.shape[0] + frame_seqlen = x.shape[0] // num_frames seg_index = self.scheduler.seg_index x.add_((y_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) * gate_msa).flatten(0, 1)) @@ -369,15 +275,12 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc if cam.dim() == 3: cam = cam.squeeze(0) if cam.shape[0] < x.shape[0]: - cam = torch.nn.functional.pad(cam, (0, 0, 0, x.shape[0] - cam.shape[0])) + cam = F.pad(cam, (0, 0, 0, x.shape[0] - cam.shape[0])) elif cam.shape[0] > x.shape[0]: cam = cam[: x.shape[0]] cam = cam.to(dtype=x.dtype, device=x.device) - cam_hidden = block.cam_injector_layer2.apply(torch.nn.functional.silu(block.cam_injector_layer1.apply(cam))) - cam_hidden = cam_hidden + cam - cam_scale = block.cam_scale_layer.apply(cam_hidden) - cam_shift = block.cam_shift_layer.apply(cam_hidden) - x = (1.0 + cam_scale) * x + cam_shift + cam_hidden = block.cam_injector_layer2.apply(F.silu(block.cam_injector_layer1.apply(cam))) + cam + x = (1.0 + block.cam_scale_layer.apply(cam_hidden)) * x + block.cam_shift_layer.apply(cam_hidden) norm3_out = phase.norm3.apply(x) @@ -389,21 +292,23 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc if self.sensitive_layer_dtype != self.infer_dtype: context = context.to(self.infer_dtype) - if self.task in ["i2v", "flf2v", "animate", "s2v", "rs2v"] and self.config.get("use_image_encoder", True): + if context_img is not None: context_img = context_img.to(self.infer_dtype) n, d = self.num_heads, self.head_dim - q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d) + cross_kv_cache = self.kv_cache_manager.cross_attn_kv_cache + if seg_index == 0: k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d) v = phase.cross_attn_v.apply(context).view(-1, n, d) - self.crossattn_cache[self.block_idx]["k"] = k - self.crossattn_cache[self.block_idx]["v"] = v + cross_kv_cache.store_kv(k, v, self.block_idx) + self._cross_kv_len = k.size(0) else: - k = self.crossattn_cache[self.block_idx]["k"] - v = self.crossattn_cache[self.block_idx]["v"] + L = self._cross_kv_len + k = cross_kv_cache.k_cache(self.block_idx)[:L] + v = cross_kv_cache.v_cache(self.block_idx)[:L] cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, @@ -419,27 +324,27 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc max_seqlen_kv=k.size(0), ) - if self.task in ["i2v", "flf2v", "animate", "s2v", "rs2v"] and self.config.get("use_image_encoder", True) and context_img is not None: + if context_img is not None: k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d) v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) - cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len( q, k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), ) - img_attn_out = phase.cross_attn_2.apply( - q=q, - k=k_img, - v=v_img, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_k, - max_seqlen_q=q.size(0), - max_seqlen_kv=k_img.size(0), + attn_out.add_( + phase.cross_attn_2.apply( + q=q, + k=k_img, + v=v_img, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_k, + max_seqlen_q=q.size(0), + max_seqlen_kv=k_img.size(0), + ) ) - attn_out.add_(img_attn_out) if self.clean_cuda_cache: - del k_img, v_img, img_attn_out + del k_img, v_img torch_device_module.empty_cache() attn_out = phase.cross_attn_o.apply(attn_out) @@ -448,3 +353,66 @@ def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa, bloc del q, k, v, norm3_out, context, context_img torch_device_module.empty_cache() return x, attn_out + + def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa): + x.add_(attn_out) + + if self.clean_cuda_cache: + del attn_out + torch.cuda.empty_cache() + + num_frames = c_shift_msa.shape[0] + frame_seqlen = x.shape[0] // c_shift_msa.shape[0] + + norm2_weight = 1 + c_scale_msa + norm2_bias = c_shift_msa + + norm2_out = phase.norm2.apply(x) + norm2_out = norm2_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + norm2_out.mul_(norm2_weight).add_(norm2_bias) + norm2_out = norm2_out.flatten(0, 1) + + y = phase.ffn_0.apply(norm2_out) + if self.clean_cuda_cache: + del norm2_out, x, norm2_weight, norm2_bias + torch.cuda.empty_cache() + y = torch.nn.functional.gelu(y, approximate="tanh") + if self.clean_cuda_cache: + torch.cuda.empty_cache() + y = phase.ffn_2.apply(y) + + return y + + def post_process(self, x, y, c_gate_msa, pre_infer_out=None): + num_frames = c_gate_msa.shape[0] + frame_seqlen = x.shape[0] // c_gate_msa.shape[0] + y = y.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + x.add_(y * c_gate_msa) + x = x.flatten(0, 1) + + if self.clean_cuda_cache: + del y, c_gate_msa + torch.cuda.empty_cache() + return x + + def infer_non_blocks(self, weights, x, e): + num_frames = e.shape[0] + frame_seqlen = x.shape[0] // e.shape[0] + + x = weights.norm.apply(x) + x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) + + t = self.scheduler.timestep_input + e = e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) + modulation = weights.head_modulation.tensor + e = (modulation.unsqueeze(1) + e).chunk(2, dim=2) + + x.mul_(1 + e[1][0]).add_(e[0][0]) + x = x.flatten(0, 1) + x = weights.head.apply(x) + + if self.clean_cuda_cache: + del e + torch.cuda.empty_cache() + return x diff --git a/lightx2v/models/networks/wan/lingbot_fast_model.py b/lightx2v/models/networks/wan/lingbot_fast_model.py index f9797c813..4c513fd50 100644 --- a/lightx2v/models/networks/wan/lingbot_fast_model.py +++ b/lightx2v/models/networks/wan/lingbot_fast_model.py @@ -2,6 +2,9 @@ from lightx2v.models.networks.wan.infer.lingbot_fast.pre_infer import WanLingbotFastPreInfer from lightx2v.models.networks.wan.infer.lingbot_fast.transformer_infer import WanLingbotFastTransformerInfer +from lightx2v.models.networks.wan.infer.offload.transformer_infer import ( + WanOffloadTransformerInfer, +) from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.lingbot_model import WanLingbotModel @@ -9,14 +12,20 @@ class WanLingbotFastModel(WanLingbotModel): """Lingbot fast (autoregressive) model. - Inherits WanLingbotModel for lingbot weights and seq-parallel camera handling. - Adds SF checkpoint loading and SF inference loop. + MRO: WanLingbotFastModel -> AutoRegressiveBaseTransformerModel + -> WanLingbotModel -> WanModel -> BaseTransformerModel + + - AutoRegressiveBaseTransformerModel adds KVCacheManager lifecycle + - WanLingbotModel adds lingbot weights and seq-parallel camera handling """ + def __init__(self, model_path, config, device, model_type="wan2.1", lora_path=None, lora_strength=1.0): + super().__init__(model_path, config, device, model_type, lora_path, lora_strength) + def _init_infer_class(self): self.pre_infer_class = WanLingbotFastPreInfer self.post_infer_class = WanPostInfer - self.transformer_infer_class = WanLingbotFastTransformerInfer + self.transformer_infer_class = WanLingbotFastTransformerInfer if not self.cpu_offload else WanOffloadTransformerInfer @torch.no_grad() def infer(self, inputs): @@ -27,8 +36,8 @@ def infer(self, inputs): self.pre_weight.to_cuda() self.transformer_weights.non_block_weights_to_cuda() - current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block - current_end_frame = (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_block + current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_chunk + current_end_frame = (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_chunk noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) self.scheduler.noise_pred[:, current_start_frame:current_end_frame] = noise_pred diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 1bf294903..81d742e2d 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -9,7 +9,8 @@ import numpy as np import torch import torch.nn.functional as F -import torchaudio as ta + +# import torchaudio as ta import torchvision.transforms.functional as TF from PIL import Image, ImageCms, ImageOps from einops import rearrange diff --git a/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py b/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py old mode 100755 new mode 100644 index 5bdba466b..61c2fe650 --- a/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py +++ b/lightx2v/models/runners/wan/wan_lingbot_fast_runner.py @@ -1,6 +1,7 @@ import torch from loguru import logger +from lightx2v.common.kvcache import KVCacheManager from lightx2v.models.networks.wan.lingbot_fast_model import WanLingbotFastModel from lightx2v.models.runners.wan.wan_runner import LingbotRunner, WanRunner, build_wan_model_with_lora from lightx2v.models.schedulers.wan.lingbot_fast.scheduler import LingbotFastScheduler @@ -48,43 +49,19 @@ def load_transformer(self): model = build_wan_model_with_lora(WanLingbotFastModel, self.config, wan_model_kwargs, lora_configs, model_type="wan2.1") return model - # ---- SF scheduling ---- - def init_scheduler(self): self.scheduler = LingbotFastScheduler(self.config) - def set_target_shape(self): - num_frame_per_block = self.config["sf_config"].get("num_frame_per_block", 3) - latent_shape = self.input_info.latent_shape - lat_f = latent_shape[1] - lat_h, lat_w = latent_shape[2], latent_shape[3] - num_output_frames = lat_f - (lat_f % num_frame_per_block) - self.input_info.latent_shape = [latent_shape[0], num_output_frames, lat_h, lat_w] - self.config.target_shape = [self.config.get("num_channels_latents", 16), num_output_frames, lat_h, lat_w] - - self.scheduler.num_output_frames = num_output_frames - self.scheduler.num_blocks = num_output_frames // num_frame_per_block - - p = self.config.get("patch_size", [1, 2, 2]) - frame_seq_length = (lat_h // p[1]) * (lat_w // p[2]) - logger.info( - "[lingbot_fast] lat_f={}, num_output_frames={}, frame_seq_length={} (lat_h={}, lat_w={}, patch={})", - lat_f, - num_output_frames, - frame_seq_length, - lat_h, - lat_w, - p, - ) - - if hasattr(self, "model") and hasattr(self.model, "transformer_infer"): - self.model.transformer_infer.reinit_caches( - frame_seq_length, - num_output_frames, - ) + def init_kv_cache_manager(self): + self.model.kv_cache_manager = KVCacheManager(config=self.config, device=torch.device("cuda"), sp_group=self.model.seq_p_group) + self.model.kv_cache_manager._create_kv_caches(self.input_info.latent_shape) + self.model.transformer_infer.kv_cache_manager = self.model.kv_cache_manager + self.input_info.latent_shape = [self.input_info.latent_shape[0], self.model.kv_cache_manager.num_output_frames, self.input_info.latent_shape[2], self.input_info.latent_shape[3]] + self.scheduler.num_output_frames = self.model.kv_cache_manager.num_output_frames + self.scheduler.num_chunks = self.model.kv_cache_manager.num_output_frames // self.config.get("ar_config", {}).get("num_frame_per_chunk", 3) def get_video_segment_num(self): - self.video_segment_num = self.scheduler.num_blocks + self.video_segment_num = self.scheduler.num_chunks def run_segment(self, segment_idx=0): infer_steps = self.model.scheduler.infer_steps @@ -92,6 +69,7 @@ def run_segment(self, segment_idx=0): if self.video_segment_num == 1: self.check_stop() logger.info(f"==> step_index: {step_index + 1} / {infer_steps}") + self.model.kv_cache_manager.current_step = step_index with ProfilingContext4DebugL1("step_pre"): self.model.scheduler.step_pre(seg_index=segment_idx, step_index=step_index, is_rerun=False) @@ -109,6 +87,14 @@ def run_segment(self, segment_idx=0): return self.model.scheduler.stream_output + def init_run(self): + self.init_kv_cache_manager() + super().init_run() + + def end_run(self): + self.model.kv_cache_manager.maybe_save_calibration() + super().end_run() + @ProfilingContext4DebugL2("Run DiT") def run_main(self, total_steps=None): """Collect all segment latents, then decode at once with normal VAE. @@ -117,7 +103,6 @@ def run_main(self, total_steps=None): pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) videos = self.vae.decode([pred_latent_chunks]) """ - self.set_target_shape() self.init_run() if self.config.get("compile", False): self.model.select_graph_for_compile(self.input_info) @@ -154,8 +139,6 @@ def run_main(self, total_steps=None): self.end_run() return gen_video_final - # ---- Live streaming mode (per-segment decode, kept for future use) ---- - def get_rank_and_world_size(self): rank = 0 world_size = 1 @@ -206,7 +189,6 @@ def run_main_live(self, total_steps=None): self.video_recorder.start(self.width, self.height) if world_size > 1 and dist is not None: dist.barrier() - self.set_target_shape() self.init_run() if self.config.get("compile", False): self.model.select_graph_for_compile(self.input_info) diff --git a/lightx2v/models/schedulers/wan/lingbot_fast/scheduler.py b/lightx2v/models/schedulers/wan/lingbot_fast/scheduler.py old mode 100644 new mode 100755 index dc340c3bd..0484dd36f --- a/lightx2v/models/schedulers/wan/lingbot_fast/scheduler.py +++ b/lightx2v/models/schedulers/wan/lingbot_fast/scheduler.py @@ -23,8 +23,8 @@ class LingbotFastScheduler(WanScheduler): def __init__(self, config): super().__init__(config) self.dtype = torch.bfloat16 - self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"] - self.timesteps_index = self.config["sf_config"]["timesteps_index"] + self.num_frame_per_chunk = self.config["ar_config"]["num_frame_per_chunk"] + self.timesteps_index = self.config["ar_config"]["timesteps_index"] self.infer_steps = len(self.timesteps_index) self.context_noise = 0 @@ -89,12 +89,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def step_pre(self, seg_index, step_index, is_rerun=False): self.step_index = step_index self.seg_index = seg_index + self.is_rerun = is_rerun if not GET_DTYPE() == GET_SENSITIVE_DTYPE(): self.latents = self.latents.to(GET_DTYPE()) - seg_start = self.seg_index * self.num_frame_per_block - seg_end = min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames) + seg_start = self.seg_index * self.num_frame_per_chunk + seg_end = min((self.seg_index + 1) * self.num_frame_per_chunk, self.num_output_frames) self.latents_input = self.latents[:, seg_start:seg_end] if not is_rerun: @@ -104,9 +105,9 @@ def step_pre(self, seg_index, step_index, is_rerun=False): # Align with: context_timestep = [timesteps[-1] * 0.0] t_val = self.context_noise - # Shape [1, num_frame_per_block] required by infer_non_blocks + # Shape [1, num_frame_per_chunk] required by infer_non_blocks # (unflatten uses t.shape to reshape embed for head modulation broadcast) - self.timestep_input = torch.full([1, self.num_frame_per_block], t_val, device=AI_DEVICE, dtype=torch.long) + self.timestep_input = torch.full([1, self.num_frame_per_chunk], t_val, device=AI_DEVICE, dtype=torch.long) def step_post(self): """Align with source denoising loop: @@ -116,8 +117,8 @@ def step_post(self): next_timestep = timesteps[timestep_idx + 1] current_latent = scheduler.add_noise(x0, noise, next_timestep) """ - seg_start = self.seg_index * self.num_frame_per_block - seg_end = min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames) + seg_start = self.seg_index * self.num_frame_per_chunk + seg_end = min((self.seg_index + 1) * self.num_frame_per_chunk, self.num_output_frames) flow_pred = self.noise_pred[:, seg_start:seg_end] xt = self.latents_input diff --git a/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py b/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py index 68fadcde1..658f227dc 100755 --- a/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py +++ b/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, replace import torch -import torchaudio + +# import torchaudio from torch import nn diff --git a/scripts/lingbot/run_lingbot_fast_i2v.sh b/scripts/lingbot/run_lingbot_fast_i2v.sh index d2468122c..823979816 100755 --- a/scripts/lingbot/run_lingbot_fast_i2v.sh +++ b/scripts/lingbot/run_lingbot_fast_i2v.sh @@ -4,12 +4,12 @@ lightx2v_path=/data/nvme4/gushiqiao/LightX2V model_path=/data/nvme4/models/lingbot-world-base-cam -export CUDA_VISIBLE_DEVICES=1,2,3,4 +export CUDA_VISIBLE_DEVICES=2 # set environment variables source ${lightx2v_path}/scripts/base/base.sh -torchrun --nproc_per_node=4 -m lightx2v.infer \ +python -m lightx2v.infer \ --model_cls lingbot_world_fast \ --task i2v \ --model_path $model_path \ From 36f378a02900fc787662699c0e132fa12dabecc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:19:18 +0800 Subject: [PATCH 2/8] Update transformer_infer.py --- .../models/networks/wan/infer/lingbot/transformer_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py index 44f4b3986..dfed29f1c 100755 --- a/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/lingbot/transformer_infer.py @@ -1,12 +1,12 @@ import torch -from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer +from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) -class WanLingbotTransformerInfer(WanOffloadTransformerInfer): +class WanLingbotTransformerInfer(WanTransformerInfer): def infer_block(self, block, x, pre_infer_out): if hasattr(block.compute_phases[0], "before_proj") and block.compute_phases[0].before_proj.weight is not None: x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x From 75ed4f89801b5985e09c7dff34c13ecaeb7a8042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:20:35 +0800 Subject: [PATCH 3/8] Update lingbot_fast_model.py --- lightx2v/models/networks/wan/lingbot_fast_model.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/lightx2v/models/networks/wan/lingbot_fast_model.py b/lightx2v/models/networks/wan/lingbot_fast_model.py index 4c513fd50..a085bb014 100644 --- a/lightx2v/models/networks/wan/lingbot_fast_model.py +++ b/lightx2v/models/networks/wan/lingbot_fast_model.py @@ -2,30 +2,18 @@ from lightx2v.models.networks.wan.infer.lingbot_fast.pre_infer import WanLingbotFastPreInfer from lightx2v.models.networks.wan.infer.lingbot_fast.transformer_infer import WanLingbotFastTransformerInfer -from lightx2v.models.networks.wan.infer.offload.transformer_infer import ( - WanOffloadTransformerInfer, -) from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.lingbot_model import WanLingbotModel class WanLingbotFastModel(WanLingbotModel): - """Lingbot fast (autoregressive) model. - - MRO: WanLingbotFastModel -> AutoRegressiveBaseTransformerModel - -> WanLingbotModel -> WanModel -> BaseTransformerModel - - - AutoRegressiveBaseTransformerModel adds KVCacheManager lifecycle - - WanLingbotModel adds lingbot weights and seq-parallel camera handling - """ - def __init__(self, model_path, config, device, model_type="wan2.1", lora_path=None, lora_strength=1.0): super().__init__(model_path, config, device, model_type, lora_path, lora_strength) def _init_infer_class(self): self.pre_infer_class = WanLingbotFastPreInfer self.post_infer_class = WanPostInfer - self.transformer_infer_class = WanLingbotFastTransformerInfer if not self.cpu_offload else WanOffloadTransformerInfer + self.transformer_infer_class = WanLingbotFastTransformerInfer @torch.no_grad() def infer(self, inputs): From dba299c98a145330f30f80e6a3f93b3011f0e447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:21:13 +0800 Subject: [PATCH 4/8] Update wan_audio_runner.py --- lightx2v/models/runners/wan/wan_audio_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 81d742e2d..1bf294903 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -9,8 +9,7 @@ import numpy as np import torch import torch.nn.functional as F - -# import torchaudio as ta +import torchaudio as ta import torchvision.transforms.functional as TF from PIL import Image, ImageCms, ImageOps from einops import rearrange From 5f076f3b87ced8dda9ff1dc45cfde3f479b6a601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:21:47 +0800 Subject: [PATCH 5/8] Uncomment torchaudio import in ops.py --- lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py b/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py index 658f227dc..68fadcde1 100755 --- a/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py +++ b/lightx2v/models/video_encoders/hf/ltx2/audio_vae/ops.py @@ -1,8 +1,7 @@ from dataclasses import dataclass, replace import torch - -# import torchaudio +import torchaudio from torch import nn From da4ffff33720fa1fb50410556b762fbb25945d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:17:18 +0800 Subject: [PATCH 6/8] Update lightx2v/common/kvcache/quant.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v/common/kvcache/quant.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/lightx2v/common/kvcache/quant.py b/lightx2v/common/kvcache/quant.py index 2b6353f35..26482088b 100644 --- a/lightx2v/common/kvcache/quant.py +++ b/lightx2v/common/kvcache/quant.py @@ -217,7 +217,25 @@ def reset(self) -> None: class QuantRollingKVCachePool(RollingKVCachePool): _BLKK = 128 _SCALES_PER_BLK = 4 # (BLKK // WARPK) * 4, WARPK=128 - _PERM_16 = torch.tensor([0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15], dtype=torch.long, device="cuda") + _PERM_16_VAL = [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] + + def __init__( + self, + num_layers: int, + cache_size: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + *, + smooth_k: bool = True, + calib_path: str, + ) -> None: + self._smooth_k_sage = smooth_k + self._calib_path = calib_path + self.current_step: int = 0 + self._PERM_16 = torch.tensor(self._PERM_16_VAL, dtype=torch.long, device=device) + super().__init__(num_layers, cache_size, num_heads, head_dim, dtype, device) def __init__( self, From a31578fe42a750d26ee4f1243a0c3f6a8b2b4255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:17:32 +0800 Subject: [PATCH 7/8] Update lightx2v/common/ops/attn/sage_attn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v/common/ops/attn/sage_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py index 3c58c768c..02377f327 100755 --- a/lightx2v/common/ops/attn/sage_attn.py +++ b/lightx2v/common/ops/attn/sage_attn.py @@ -271,7 +271,7 @@ def apply( k_int8, k_scale = k v_fp8, v_scale = v q, k_int8, v_fp8 = q.contiguous(), k_int8.contiguous(), v_fp8.contiguous() - assert capability == (9, 0) + assert torch.cuda.get_device_capability(q.device) == (9, 0) assert q.dtype in [torch.float16, torch.bfloat16] assert k_int8.dtype == torch.int8 assert k_scale is not None From 02d3fbcf57aa87bedfd09f218e940e4d1212830a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shiqiao=20Gu=20=28=E8=B0=B7=E7=9F=B3=E6=A1=A5=29?= <77222802+gushiqiao@users.noreply.github.com> Date: Thu, 23 Apr 2026 12:18:48 +0800 Subject: [PATCH 8/8] Update lightx2v/common/kvcache/offload.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v/common/kvcache/offload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightx2v/common/kvcache/offload.py b/lightx2v/common/kvcache/offload.py index 87057eaa9..a3f748253 100644 --- a/lightx2v/common/kvcache/offload.py +++ b/lightx2v/common/kvcache/offload.py @@ -51,8 +51,8 @@ class _KVCacheOffloadMixin: """ def _init_offload(self): - self._load_stream = torch.cuda.Stream() - self._store_stream = torch.cuda.Stream() + self._load_stream = torch.cuda.Stream(device=self._device) + self._store_stream = torch.cuda.Stream(device=self._device) # Per-buffer events for fine-grained dependency tracking self._load_done = [torch.cuda.Event() for _ in range(2)] self._writeback_done = [torch.cuda.Event() for _ in range(2)]