Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions configs/lingbot_fast/lingbot_fast_i2v.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"vae_cpu_offload": true,
"use_image_encoder": false,
"dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/",
"sf_config": {
"local_attn_size": -1,
"num_frame_per_block": 3,
"timesteps_index": [0, 179, 358, 679]
"ar_config": {
"local_attn_size": 21,
"num_frame_per_chunk": 3,
"timesteps_index": [0, 179, 358, 679],
"sink_size": 3,
"kv_offload": false
},
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
"rms_norm_type": "self_forcing"
}
31 changes: 31 additions & 0 deletions configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"infer_steps": 4,
"target_video_length": 161,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2_k_int8_v_fp8",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": 1.0,
"sample_shift": 10.0,
"enable_cfg": false,
"cpu_offload": false,
"t5_cpu_offload": true,
"vae_cpu_offload": true,
"use_image_encoder": false,
"dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/",
"ar_config": {
"local_attn_size": 21,
"num_frame_per_chunk": 3,
"timesteps_index": [0, 179, 358, 679],
"sink_size": 3,
"kv_quant": {
"calibrate": false,
"smooth_k": true,
"calib_path": "calib_kv.pt"
},
"kv_offload": true
},
"rms_norm_type": "self_forcing"
}
23 changes: 23 additions & 0 deletions lightx2v/common/kvcache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
KV cache for autoregressive transformer inference.

- ``base``: cross-attention pool
- ``rolling``: ``RollingKVCachePool`` (bf16 rolling-window cache)
- ``quant``: ``CalibRollingKVCachePool`` / ``QuantRollingKVCachePool``
- ``offload``: ``OffloadRollingKVCachePool`` / ``OffloadQuantRollingKVCachePool``
- ``manager``: ``KVCacheManager``
"""

from .manager import KVCacheManager
from .offload import OffloadQuantRollingKVCachePool, OffloadRollingKVCachePool
from .quant import CalibRollingKVCachePool, QuantRollingKVCachePool
from .rolling import RollingKVCachePool

__all__ = [
"KVCacheManager",
"RollingKVCachePool",
"CalibRollingKVCachePool",
"QuantRollingKVCachePool",
"OffloadRollingKVCachePool",
"OffloadQuantRollingKVCachePool",
]
61 changes: 61 additions & 0 deletions lightx2v/common/kvcache/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch


class BaseKVCachePool:
def __init__(
self,
num_layers: int,
cache_size: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
self._num_layers = num_layers
self._cache_size = cache_size
self._num_heads = num_heads
self._head_dim = head_dim
self._device = device
self._dtype = dtype

def _init_kv_buffer(self):
self._k_buffer = torch.zeros(
(self._num_layers, self._cache_size, self._num_heads, self._head_dim),
dtype=self._dtype,
device=self._device,
)
self._v_buffer = torch.zeros(
(self._num_layers, self._cache_size, self._num_heads, self._head_dim),
dtype=self._dtype,
device=self._device,
)

def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
return self._k_buffer[layer_id][attn_start:local_end]

def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
return self._v_buffer[layer_id][attn_start:local_end]

def store_kv(self, k: torch.Tensor, v: torch.Tensor, layer_id: int) -> None:
self._k_buffer[layer_id, : k.shape[0]] = k
self._v_buffer[layer_id, : v.shape[0]] = v

def reset(self) -> None:
self._k_buffer.zero_()
self._v_buffer.zero_()

@property
def device(self) -> torch.device:
return self._device

@property
def dtype(self) -> torch.dtype:
return self._dtype

@property
def num_layers(self) -> int:
return self._num_layers

@property
def cache_size(self) -> int:
return self._cache_size
85 changes: 85 additions & 0 deletions lightx2v/common/kvcache/calibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
KV-cache quantisation calibration.

Step 1 — Calibration run
~~~~~~~~~~~~~~~~~~~~~~~~
Use a config with ``"calibrate": true`` and ``self_attn_1_type`` set to
the **non-quant** attention (e.g. ``"sage_attn2"``). This creates a
``CalibRollingKVCachePool`` that stores bf16 KV normally while
collecting K-mean and V per-channel abs-max.

Config example (calibration)::

{
"self_attn_1_type": "sage_attn2",
"ar_config": {
...
"sage_quant_kv": {
"calibrate": true,
"smooth_k": true
}
}
}

After inference, call ``save_calibration`` to export the stats::

from lightx2v.common.kvcache.calibrate import save_calibration
runner.run_main()
save_calibration(runner.model.kv_cache_manager, "calib_kv.pt")

Step 2 — Quantised inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Switch to the quant attention and provide the calibration file::

{
"self_attn_1_type": "sage_attn2_kvquant",
"ar_config": {
...
"sage_quant_kv": {
"smooth_k": true,
"calib_path": "calib_kv.pt"
}
}
}
"""

from __future__ import annotations

import torch
from loguru import logger

from .quant import CalibRollingKVCachePool


def save_calibration(
kv_cache_manager,
output_path: str,
) -> dict[str, torch.Tensor]:
"""Export and save KV cache calibration from a completed run.

Parameters
----------
kv_cache_manager : KVCacheManager
The manager whose ``self_attn_kv_cache`` is a
``CalibRollingKVCachePool`` that has been used for at least one
full inference pass.
output_path : str
File path to save the calibration dict (``torch.save`` format).

Returns
-------
dict with keys ``'km'`` and ``'v_scale'``.
"""
pool = kv_cache_manager.self_attn_kv_cache
if not isinstance(pool, CalibRollingKVCachePool):
raise TypeError(f"Expected CalibRollingKVCachePool, got {type(pool).__name__}. Make sure the config has sage_quant_kv.calibrate=true and self_attn_1_type is NOT sage_attn2_kvquant.")

calib = pool.export_calibration()
torch.save(calib, output_path)
logger.info(
"KV calibration saved to {} — km {}, v_scale {}",
output_path,
list(calib["km"].shape),
list(calib["v_scale"].shape),
)
return calib
Loading
Loading