This repository was archived by the owner on Feb 22, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvae.py
More file actions
275 lines (233 loc) · 9.79 KB
/
vae.py
File metadata and controls
275 lines (233 loc) · 9.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import logging
from dataclasses import dataclass, field
from typing import Dict, Generic, List, Literal, Optional, Tuple, Type, TypeVar, Union
import torch
from diffusers import AutoencoderKL, AutoencoderKLHunyuanVideo
from torch import nn
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@dataclass
class VideoVAEArgs:
vae_type: str = "hunyuan" # can be "hunyuan" or "flux"
pretrained_model_name_or_path: str = "tencent/HunyuanVideo"
revision: Optional[str] = None
variant: Optional[str] = None
enable_tiling: bool = True
enable_slicing: bool = True
dtype: str = "bfloat16"
class BaseLatentVideoVAE:
def __init__(self, args: VideoVAEArgs):
self.cfg = args
self.dtype = dict(fp32=torch.float32, bfloat16=torch.bfloat16)[args.dtype]
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@torch.no_grad()
def decode(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def enable_vae_slicing(self):
logger.warning(
f"Useless func call, {self.cfg.vae_type} TVAE model doesn't support slicing !"
)
pass
def disable_vae_slicing(self):
logger.warning(
f"Useless func call, {self.cfg.vae_type} TVAE model doesn't support slicing !"
)
pass
def enable_vae_tiling(self):
logger.warning(
f"Useless func call, {self.cfg.vae_type} TVAE model doesn't support tiling !"
)
pass
def disable_vae_tiling(self):
logger.warning(
f"Useless func call, {self.cfg.vae_type} TVAE model doesn't support tiling !"
)
pass
def extract_latents(self, batch: dict[str:any], flops_meter= None) -> Union[torch.Tensor, List[torch.Tensor]]:
images = batch["image"]
if isinstance(images, torch.Tensor):
# Handle the case where input is already a batched tensor
if flops_meter is not None:
flops_meter.log_vae_flops(images.shape)
return self.encode(images)
elif isinstance(images, list):
# Group images by resolution (H, W) while preserving original index
grouped_images: Dict[Tuple[int, int], List[Tuple[int, torch.Tensor]]] = {}
for i, img in enumerate(images):
# Assuming BCHW or CHW format, get H and W
resolution = (img.shape[-2], img.shape[-1])
if resolution not in grouped_images:
grouped_images[resolution] = []
grouped_images[resolution].append((i, img))
# Encode batches for each resolution group
results = [None] * len(images) # Pre-allocate results list to maintain order
for resolution, indexed_tensors in grouped_images.items():
indices = [item[0] for item in indexed_tensors]
tensors = [item[1] for item in indexed_tensors]
# Stack tensors into a batch
input_batch = torch.stack(tensors, dim=0)
if flops_meter is not None:
flops_meter.log_vae_flops(input_batch.shape)
# Encode the batch
latent_batch = self.encode(input_batch)
# Place individual latents back into the results list at their original indices
for i, latent in enumerate(latent_batch):
original_index = indices[i]
results[original_index] = latent
return results
else:
raise TypeError(f"Unsupported type for batch['image']: {type(images)}")
class HunyuanVideoVAE(BaseLatentVideoVAE):
def __init__(self, args: VideoVAEArgs):
super().__init__(args)
vae = AutoencoderKLHunyuanVideo.from_pretrained(
self.cfg.pretrained_model_name_or_path,
revision=self.cfg.revision,
variant=self.cfg.variant,
torch_dtype=self.dtype,
device_map="auto"
).requires_grad_(False)
# Configure tiling and slicing
if self.cfg.enable_slicing:
vae.enable_slicing()
else:
vae.disable_slicing()
if self.cfg.enable_tiling:
vae.enable_tiling()
else:
vae.disable_tiling()
self.vae = vae
# TODO: jinjie: we are using video vae for BCHW image generation, so below code is tricky
# we need to refactor our dataloader once video gen training begins
# only feed vae with 5d tensor BCTHW
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.vae.dtype)
if x.ndim == 4: # Check if the input tensor is 4D, BCHW, image tensor
x = x.unsqueeze(2) # Add a temporal dimension (T=1) for video vae
x = self.vae.encode(x).latent_dist.sample()
if x.ndim == 5 and x.shape[2] == 1: # Check if T=1
x = x.squeeze(2) # Remove the temporal dimension at index 2
x = x * self.vae.config.scaling_factor
return x # return 4d image tensor now
@torch.no_grad()
def decode(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim == 4: # Check if the input tensor is 4D, BCHW, image tensor
x = x.unsqueeze(2) # Add a temporal dimension (T=1) for video vae
x = x.to(self.vae.dtype)
x = x / self.vae.config.scaling_factor
x = self.vae.decode(x).sample
if x.ndim == 5 and x.shape[2] == 1: # Check if T=1
x = x.squeeze(2) # Remove the temporal dimension at index 2
return x # return 4d image tensor now
@torch.no_grad()
def forward(self, x=torch.Tensor):
x = self.encode(x)
return self.decode(x)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
class FluxVAE(BaseLatentVideoVAE):
def __init__(self, args: VideoVAEArgs):
super().__init__(args)
self.vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=self.dtype,
).requires_grad_(False)
if torch.cuda.is_available():
self.vae = self.vae.to(torch.cuda.current_device())
self.scale = 0.3611
self.shift = 0.1159
@torch.compile
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(device=self.vae.device, dtype=self.vae.dtype)
x = (
self.vae.encode(x).latent_dist.mode() - self.shift
) * self.scale
return x
@torch.compile
@torch.no_grad()
def decode(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(device=self.vae.device, dtype=self.vae.dtype)
# Reverse the scaling and shifting from encode method
x = (x / self.scale + self.shift) # .to(self.vae.dtype) # dtype conversion happens inside decode
# Use the VAE's decode method and get the sample
decoded = self.vae.decode(x).sample
return decoded
@torch.compile
@torch.no_grad()
def forward(self, x=torch.Tensor):
x = self.encode(x)
return self.decode(x)
class COSMOSContinuousVAE(BaseLatentVideoVAE):
def __init__(self, args: VideoVAEArgs):
super().__init__(args)
from cosmos_tokenizer.image_lib import ImageTokenizer
"""
Initialize the encoder and decoder for Continuous VAE.
Checks model type and returns the initialized VAE instance.
"""
self.vae = ImageTokenizer(
checkpoint_enc=f"{self.cfg.pretrained_model_name_or_path}/encoder.jit",
checkpoint_dec=f"{self.cfg.pretrained_model_name_or_path}/decoder.jit",
device="cuda",
dtype=args.dtype
).requires_grad_(False)
@torch.no_grad()
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encodes the input frames into latent representations.
"""
(latent,) = self.vae.encode(x.to(self.vae._dtype))
return latent
@torch.no_grad()
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
Decodes the latent representations back into reconstructed frames.
"""
x = self.vae.decode(x.to(self.vae._dtype))
return x
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Full forward pass: encode and then decode.
"""
x = self.encode(x)
return self.decode(x)
def create_vae(args: VideoVAEArgs) -> BaseLatentVideoVAE:
"""
Create a VAE based on the type specified in the args.
"""
if args.vae_type.lower() == "hunyuan":
return HunyuanVideoVAE(args)
elif args.vae_type.lower() == "flux":
return FluxVAE(args)
elif args.vae_type.lower() == "cosmos":
return COSMOSContinuousVAE(args)
else:
raise ValueError(f"Unknown VAE type: {args.vae_type}")