Skip to content

Commit fbc6551

Browse files
committed
Fix Flax/Torch namespace collision in wan_vae_test.py
1 parent 5b62056 commit fbc6551

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
flax.config.update("flax_always_shard_variable", False)
5858

5959

60-
class TorchWanRMS_norm(nn.Module):
60+
class TorchWanRMS_norm(torch.nn.Module):
6161
r"""
6262
A custom RMS normalization layer.
6363
@@ -76,14 +76,14 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
7676

7777
self.channel_first = channel_first
7878
self.scale = dim**0.5
79-
self.gamma = nn.Parameter(torch.ones(shape))
80-
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
79+
self.gamma = torch.nn.Parameter(torch.ones(shape))
80+
self.bias = torch.nn.Parameter(torch.zeros(shape)) if bias else 0.0
8181

8282
def forward(self, x):
8383
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
8484

8585

86-
class TorchWanResample(nn.Module):
86+
class TorchWanResample(torch.nn.Module):
8787
r"""
8888
A custom resampling module for 2D and 3D data.
8989
@@ -104,18 +104,18 @@ def __init__(self, dim: int, mode: str) -> None:
104104

105105
# layers
106106
if mode == "upsample2d":
107-
self.resample = nn.Sequential(
108-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
107+
self.resample = torch.nn.Sequential(
108+
torch.nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest"), torch.nn.Conv2d(dim, dim // 2, 3, padding=1)
109109
)
110110
elif mode == "upsample3d":
111111
raise Exception("downsample3d not supported")
112112

113113
elif mode == "downsample2d":
114-
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
114+
self.resample = torch.nn.Sequential(torch.nn.ZeroPad2d((0, 1, 0, 1)), torch.nn.Conv2d(dim, dim, 3, stride=(2, 2)))
115115
elif mode == "downsample3d":
116116
raise Exception("downsample3d not supported")
117117
else:
118-
self.resample = nn.Identity()
118+
self.resample = torch.nn.Identity()
119119

120120
def forward(self, x, feat_cache=None, feat_idx=[0]):
121121
b, c, t, h, w = x.size()
@@ -218,7 +218,7 @@ def test_zero_padded_conv(self):
218218
dim = 96
219219
kernel_size = 3
220220
stride = (2, 2)
221-
resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride))
221+
resample = torch.nn.Sequential(torch.nn.ZeroPad2d((0, 1, 0, 1)), torch.nn.Conv2d(dim, dim, kernel_size, stride=stride))
222222
input_shape = (1, 96, 480, 720)
223223
input = torch.ones(input_shape)
224224
output_torch = resample(input)

0 commit comments

Comments
 (0)