Skip to content

Commit 695eea5

Browse files
committed
updates
1 parent c61dfe2 commit 695eea5

9 files changed

Lines changed: 113 additions & 0 deletions
7.48 MB
Binary file not shown.
1.57 MB
Binary file not shown.
9.77 KB
Binary file not shown.
4.44 KB
Binary file not shown.
44.9 MB
Binary file not shown.
Binary file not shown.
58.6 KB
Binary file not shown.
Binary file not shown.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torchvision import datasets, transforms
5+
from torch.utils.data import DataLoader
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
from tqdm import tqdm
9+
10+
# === Hyperparameters ===
11+
T = 300 # Number of diffusion steps
12+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13+
14+
# === Beta schedule (linear) ===
15+
betas = torch.linspace(1e-4, 0.02, T).to(device)
16+
alphas = 1. - betas
17+
alphas_cumprod = torch.cumprod(alphas, dim=0)
18+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
19+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
20+
21+
# === Forward diffusion ===
22+
def forward_diffusion_sample(x_0, t, noise=None):
23+
if noise is None:
24+
noise = torch.randn_like(x_0)
25+
sqrt_alpha = sqrt_alphas_cumprod[t][:, None, None, None]
26+
sqrt_one_minus_alpha = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
27+
return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise, noise
28+
29+
# === Simple CNN for denoising ===
30+
class SimpleUNet(nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.net = nn.Sequential(
34+
nn.Conv2d(2, 32, 3, padding=1),
35+
nn.ReLU(),
36+
nn.Conv2d(32, 64, 3, padding=1),
37+
nn.ReLU(),
38+
nn.Conv2d(64, 32, 3, padding=1),
39+
nn.ReLU(),
40+
nn.Conv2d(32, 1, 3, padding=1),
41+
)
42+
43+
def forward(self, x, t):
44+
t_emb = t[:, None, None, None].float() / T # normalize timestep
45+
t_emb = t_emb.expand(-1, 1, 28, 28)
46+
x_input = torch.cat([x, t_emb], dim=1)
47+
return self.net(x_input)
48+
49+
# === Data ===
50+
transform = transforms.Compose([
51+
transforms.ToTensor(),
52+
transforms.Lambda(lambda x: (x - 0.5) * 2), # scale to [-1, 1]
53+
])
54+
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
55+
loader = DataLoader(dataset, batch_size=128, shuffle=True)
56+
57+
# === Model, optimizer ===
58+
model = SimpleUNet().to(device)
59+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
60+
61+
# === Training loop ===
62+
def train(epochs=10):
63+
model.train()
64+
for epoch in range(epochs):
65+
pbar = tqdm(loader)
66+
for batch, _ in pbar:
67+
batch = batch.to(device)
68+
t = torch.randint(0, T, (batch.size(0),), device=device).long()
69+
x_noisy, noise = forward_diffusion_sample(batch, t)
70+
noise_pred = model(x_noisy, t)
71+
loss = F.mse_loss(noise_pred, noise)
72+
73+
optimizer.zero_grad()
74+
loss.backward()
75+
optimizer.step()
76+
pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")
77+
78+
# === Sampling loop ===
79+
@torch.no_grad()
80+
def sample():
81+
model.eval()
82+
img = torch.randn((16, 1, 28, 28), device=device)
83+
for t in reversed(range(T)):
84+
t_batch = torch.full((img.shape[0],), t, device=device, dtype=torch.long)
85+
noise_pred = model(img, t_batch)
86+
beta = betas[t]
87+
alpha = alphas[t]
88+
alpha_cumprod = alphas_cumprod[t]
89+
coef1 = 1 / torch.sqrt(alpha)
90+
coef2 = (1 - alpha) / torch.sqrt(1 - alpha_cumprod)
91+
if t > 0:
92+
noise = torch.randn_like(img)
93+
else:
94+
noise = 0
95+
img = coef1 * (img - coef2 * noise_pred) + torch.sqrt(beta) * noise
96+
return img
97+
98+
# === Plotting generated samples ===
99+
def show_samples(imgs):
100+
imgs = imgs.cpu().clamp(-1, 1)
101+
imgs = (imgs + 1) / 2 # back to [0, 1]
102+
grid = torch.cat([img for img in imgs], dim=2).squeeze()
103+
plt.figure(figsize=(12, 2))
104+
plt.imshow(grid, cmap="gray")
105+
plt.axis('off')
106+
plt.title("Generated Samples")
107+
plt.show()
108+
109+
# === Run training and generate ===
110+
if __name__ == "__main__":
111+
train(epochs=10)
112+
samples = sample()
113+
show_samples(samples)

0 commit comments

Comments
 (0)