Skip to content

Commit 8debabf

Browse files
committed
add personal version of extremely succint stable diffusion.
0 parents  commit 8debabf

5 files changed

Lines changed: 1402 additions & 0 deletions

File tree

Diffusion_training_demo.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import torch
2+
import functools
3+
from tqdm import tqdm, trange
4+
import torch.multiprocessing
5+
torch.multiprocessing.set_sharing_strategy('file_system')
6+
7+
# @title Diffusion constant and noise strength
8+
device = 'cuda' # @param ['cuda', 'cpu'] {'type':'string'}
9+
10+
def marginal_prob_std(t, sigma):
11+
"""Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
12+
13+
Args:
14+
t: A vector of time steps.
15+
sigma: The $\sigma$ in our SDE.
16+
17+
Returns:
18+
The standard deviation.
19+
"""
20+
t = torch.tensor(t, device=device)
21+
return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))
22+
23+
24+
def diffusion_coeff(t, sigma):
25+
"""Compute the diffusion coefficient of our SDE.
26+
27+
Args:
28+
t: A vector of time steps.
29+
sigma: The $\sigma$ in our SDE.
30+
31+
Returns:
32+
The vector of diffusion coefficients.
33+
"""
34+
return torch.tensor(sigma ** t, device=device)
35+
36+
37+
sigma = 25.0 # @param {'type':'number'}
38+
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
39+
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
40+
#%%
41+
#@title Training Loss function
42+
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):
43+
"""The loss function for training score-based generative models.
44+
45+
Args:
46+
model: A PyTorch model instance that represents a
47+
time-dependent score-based model.
48+
x: A mini-batch of training data.
49+
marginal_prob_std: A function that gives the standard deviation of
50+
the perturbation kernel.
51+
eps: A tolerance value for numerical stability.
52+
"""
53+
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
54+
z = torch.randn_like(x)
55+
std = marginal_prob_std(random_t)
56+
perturbed_x = x + z * std[:, None, None, None]
57+
score = model(perturbed_x, random_t, y=y)
58+
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
59+
return loss
60+
#%%
61+
# @title Diffusion Model Sampler
62+
def Euler_Maruyama_sampler(score_model,
63+
marginal_prob_std,
64+
diffusion_coeff,
65+
batch_size=64,
66+
x_shape=(1, 28, 28),
67+
num_steps=500,
68+
device='cuda',
69+
eps=1e-3,
70+
y=None):
71+
"""Generate samples from score-based models with the Euler-Maruyama solver.
72+
73+
Args:
74+
score_model: A PyTorch model that represents the time-dependent score-based model.
75+
marginal_prob_std: A function that gives the standard deviation of
76+
the perturbation kernel.
77+
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
78+
batch_size: The number of samplers to generate by calling this function once.
79+
num_steps: The number of sampling steps.
80+
Equivalent to the number of discretized time steps.
81+
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
82+
eps: The smallest time step for numerical stability.
83+
84+
Returns:
85+
Samples.
86+
"""
87+
t = torch.ones(batch_size, device=device)
88+
init_x = torch.randn(batch_size, *x_shape, device=device) \
89+
* marginal_prob_std(t)[:, None, None, None]
90+
time_steps = torch.linspace(1., eps, num_steps, device=device)
91+
step_size = time_steps[0] - time_steps[1]
92+
x = init_x
93+
with torch.no_grad():
94+
for time_step in tqdm(time_steps):
95+
batch_time_step = torch.ones(batch_size, device=device) * time_step
96+
g = diffusion_coeff(batch_time_step)
97+
mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
98+
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
99+
# Do not include any noise in the last sampling step.
100+
return mean_x
101+
102+
#%%
103+
import torch
104+
import torch.nn as nn
105+
import torch.nn.functional as F
106+
import numpy as np
107+
import functools
108+
import math
109+
from torch.optim import Adam
110+
from torch.utils.data import DataLoader
111+
import torchvision.transforms as transforms
112+
from torchvision.datasets import MNIST
113+
import tqdm
114+
from tqdm.notebook import trange, tqdm
115+
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
116+
117+
import matplotlib.pyplot as plt
118+
119+
from torchvision.utils import make_grid
120+
from einops import rearrange
121+
122+
#@title Diffusion Trainer
123+
def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
124+
marginal_prob_std_fn=marginal_prob_std_fn,
125+
lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch),
126+
device="cuda",
127+
callback=None): # resume=False,
128+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
129+
130+
optimizer = Adam(score_model.parameters(), lr=lr)
131+
scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn)
132+
tqdm_epoch = trange(n_epochs)
133+
for epoch in tqdm_epoch:
134+
score_model.train()
135+
avg_loss = 0.
136+
num_items = 0
137+
for x, y in tqdm(data_loader):
138+
x = x.to(device)
139+
loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn)
140+
optimizer.zero_grad()
141+
loss.backward()
142+
optimizer.step()
143+
avg_loss += loss.item() * x.shape[0]
144+
num_items += x.shape[0]
145+
scheduler.step()
146+
lr_current = scheduler.get_last_lr()[0]
147+
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
148+
# Print the averaged training loss so far.
149+
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
150+
# Update the checkpoint after each epoch of training.
151+
torch.save(score_model.state_dict(), f'ckpt_{ckpt_name}.pth')
152+
if callback is not None:
153+
score_model.eval()
154+
callback(score_model, epoch, ckpt_name)
155+
156+
#%%
157+
158+
#%%
159+
from torchvision.datasets import CelebA
160+
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize
161+
tfm = Compose([
162+
Resize(32),
163+
CenterCrop(32),
164+
ToTensor(),
165+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
166+
])
167+
dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"],
168+
transform=tfm, download=False) # ,"identity"
169+
#%%
170+
from torch.utils.data import DataLoader, TensorDataset
171+
from tqdm import tqdm
172+
173+
dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False)
174+
x_col = []
175+
y_col = []
176+
for xs, ys in tqdm(dataloader):
177+
x_col.append(xs)
178+
y_col.append(ys)
179+
x_col = torch.concat(x_col, dim=0)
180+
y_col = torch.concat(y_col, dim=0)
181+
print(x_col.shape)
182+
print(y_col.shape)
183+
184+
maxlen = (y_col.sum(dim=1)).max()
185+
nantoken = 40
186+
yseq_data = torch.ones(y_col.size(0), maxlen,
187+
dtype=int).fill_(nantoken)
188+
189+
saved_dataset = TensorDataset(x_col, yseq_data)
190+
#%%
191+
import matplotlib.pyplot as plt
192+
193+
def save_sample_callback(score_model, epocs, ckpt_name):
194+
sample_batch_size = 64
195+
num_steps = 250
196+
y_samp = yseq_data[:sample_batch_size, :]
197+
sampler = Euler_Maruyama_sampler
198+
samples = sampler(score_model,
199+
marginal_prob_std_fn,
200+
diffusion_coeff_fn,
201+
sample_batch_size,
202+
x_shape=(3, 32, 32),
203+
num_steps=num_steps,
204+
device=device,
205+
y=y_samp, )
206+
denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225],
207+
[1/0.229, 1/0.224, 1/0.225])
208+
samples = denormalize(samples).clamp(0.0, 1.0)
209+
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
210+
211+
plt.figure(figsize=(8, 8))
212+
plt.axis('off')
213+
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
214+
plt.tight_layout()
215+
plt.savefig(f"samples_{ckpt_name}_{epocs}.png")
216+
plt.show()
217+
#%%
218+
#@title Training model
219+
220+
# continue_training = False #@param {type:"boolean"}
221+
# if not continue_training:
222+
# print("initilize new score model...")
223+
score_model = torch.nn.DataParallel(
224+
UNet_Tranformer_attrb(marginal_prob_std=marginal_prob_std_fn))
225+
score_model = score_model.to(device)
226+
227+
n_epochs = 50
228+
batch_size = 1024
229+
lr = 10e-4
230+
train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size,
231+
"Unet-tfmer_pad", device="cuda", callback=save_sample_callback)
232+
#%%
233+
234+
#%%
235+
n_epochs = 50
236+
batch_size = 1048
237+
lr = 2e-4
238+
train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size,
239+
"Unet-tfmer", device="cuda", callback=save_sample_callback)
240+
#%%
241+
242+
sample_batch_size = 64 #@param {'type':'integer'}
243+
num_steps = 250 #@param {'type':'integer'}
244+
sampler = Euler_Maruyama_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}
245+
# score_model.eval()
246+
## Generate samples using the specified sampler.
247+
samples = sampler(score_model,
248+
marginal_prob_std_fn,
249+
diffusion_coeff_fn,
250+
sample_batch_size,
251+
x_shape=(3, 32, 32),
252+
num_steps=num_steps,
253+
device=device,
254+
y=yseq_data[:sample_batch_size,:], )
255+
256+
## Sample visualization.
257+
denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225],
258+
[1/0.229, 1/0.224, 1/0.225])
259+
260+
samples = denormalize(samples).clamp(0.0, 1.0)
261+
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
262+
263+
plt.figure(figsize=(8, 8))
264+
plt.axis('off')
265+
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
266+
plt.tight_layout()
267+
plt.show()
268+

0 commit comments

Comments
 (0)