Skip to content

Commit 714556d

Browse files
committed
add stable diffusion model (ckpt loadable) and toy model trained on celebA
And extra experiments that you can play with
1 parent 8debabf commit 714556d

3 files changed

Lines changed: 267 additions & 31 deletions

File tree

StableDiff_UNet_model.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,38 @@
1212

1313

1414
class UNet_SD(nn.Module):
15-
def __init__(self, cat_unet=True):
15+
16+
def __init__(self, in_channels=4,
17+
base_channels=320,
18+
time_emb_dim=1280,
19+
context_dim=768,
20+
multipliers=(1, 2, 4, 4),
21+
attn_levels=(0, 1, 2),
22+
nResAttn_block=2,
23+
cat_unet=True):
1624
super().__init__()
17-
self.in_channels = 4
18-
self.out_channels = 4
1925
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20-
base_channels = 320
21-
time_proj_dim = 320
22-
time_emb_dim = 1280
23-
context_dim = 768
24-
nlevel = 4
26+
self.in_channels = in_channels
27+
self.out_channels = in_channels
28+
base_channels = base_channels
29+
time_emb_dim = time_emb_dim
30+
context_dim = context_dim
31+
multipliers = multipliers
32+
nlevel = len(multipliers)
2533
self.base_channels = base_channels
26-
attn_levels = [0, 1, 2]
27-
level_channels = [base_channels * mult for mult in [1, 2, 4, 4]]
34+
# attn_levels = [0, 1, 2]
35+
level_channels = [base_channels * mult for mult in multipliers]
2836
# Transform time into embedding
2937
self.time_embedding = nn.Sequential(OrderedDict({
30-
"linear_1": nn.Linear(time_proj_dim, time_emb_dim, bias=True),
38+
"linear_1": nn.Linear(base_channels, time_emb_dim, bias=True),
3139
"act": nn.SiLU(),
3240
"linear_2": nn.Linear(time_emb_dim, time_emb_dim, bias=True),
3341
})
3442
) # 2 layer MLP
3543
self.conv_in = nn.Conv2d(self.in_channels, base_channels, 3, stride=1, padding=1)
3644

3745
# Tensor Downsample blocks
38-
nResAttn_block = 2
46+
nResAttn_block = nResAttn_block
3947
self.down_blocks = TimeModulatedSequential() # nn.ModuleList()
4048
self.down_blocks_channels = [base_channels]
4149
cur_chan = base_channels
@@ -81,14 +89,13 @@ def __init__(self, cat_unet=True):
8189
nn.Conv2d(base_channels, self.out_channels, 3, padding=1),
8290
)
8391
self.to(self.device)
84-
8592
def time_proj(self, time_steps, max_period: int = 10000):
8693
if time_steps.ndim == 0:
8794
time_steps = time_steps.unsqueeze(0)
8895
half = self.base_channels // 2
8996
frequencies = torch.exp(- math.log(max_period)
90-
* torch.arange(start=0, end=half, dtype=torch.float32) / half
91-
).to(device=time_steps.device)
97+
* torch.arange(start=0, end=half, dtype=torch.float32) / half
98+
).to(device=time_steps.device)
9299
angles = time_steps[:, None].float() * frequencies[None, :]
93100
return torch.cat([torch.cos(angles), torch.sin(angles)], dim=-1)
94101

StableDiff_toy_celebA.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#%%
2+
import torch
3+
import functools
4+
from tqdm import tqdm, trange
5+
import torch.multiprocessing
6+
from tqdm import tqdm
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
torch.multiprocessing.set_sharing_strategy('file_system')
10+
#%%
11+
from torch.utils.data import DataLoader, TensorDataset
12+
from torchvision.datasets import CelebA
13+
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize
14+
15+
16+
tfm = Compose([
17+
Resize(32),
18+
CenterCrop(32),
19+
ToTensor(),
20+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21+
])
22+
dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"],
23+
transform=tfm, download=False) # ,"identity"
24+
#%%
25+
dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False)
26+
x_col = []
27+
y_col = []
28+
for xs, ys in tqdm(dataloader):
29+
x_col.append(xs)
30+
y_col.append(ys)
31+
x_col = torch.concat(x_col, dim=0)
32+
y_col = torch.concat(y_col, dim=0)
33+
print(x_col.shape)
34+
print(y_col.shape)
35+
36+
nantoken = 40
37+
maxlen = (y_col.sum(dim=1)).max()
38+
yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken)
39+
40+
saved_dataset = TensorDataset(x_col, yseq_data)
41+
#%%
42+
import math
43+
from torch.optim import Adam
44+
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
45+
device = 'cuda'
46+
47+
def marginal_prob_std(t, sigma):
48+
t = torch.tensor(t, device=device)
49+
return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma))
50+
51+
52+
def diffusion_coeff(t, sigma):
53+
return torch.tensor(sigma ** t, device=device)
54+
55+
56+
sigma = 25.0 # @param {'type':'number'}
57+
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
58+
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
59+
#%
60+
#@title Training Loss function
61+
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):
62+
"""The loss function for training score-based generative models.
63+
64+
Args:
65+
model: A PyTorch model instance that represents a
66+
time-dependent score-based model.
67+
x: A mini-batch of training data.
68+
marginal_prob_std: A function that gives the standard deviation of
69+
the perturbation kernel.
70+
eps: A tolerance value for numerical stability.
71+
"""
72+
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
73+
z = torch.randn_like(x)
74+
std = marginal_prob_std(random_t)
75+
perturbed_x = x + z * std[:, None, None, None]
76+
score = model(perturbed_x, random_t, cond=y, output_dict=False)
77+
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
78+
return loss
79+
80+
#%
81+
def train_score_model(score_model, cond_embed, dataset, lr, n_epochs, batch_size, ckpt_name,
82+
marginal_prob_std_fn=marginal_prob_std_fn,
83+
lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch),
84+
device="cuda",
85+
callback=None): # resume=False,
86+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
87+
optimizer = Adam([*score_model.parameters(), *cond_embed.parameters()], lr=lr)
88+
scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn)
89+
tqdm_epoch = trange(n_epochs)
90+
for epoch in tqdm_epoch:
91+
score_model.train()
92+
avg_loss = 0.
93+
num_items = 0
94+
batch_tqdm = tqdm(data_loader)
95+
for x, y in batch_tqdm:
96+
x = x.to(device)
97+
y_emb = cond_embed(y.to(device))
98+
loss = loss_fn_cond(score_model, x, y_emb, marginal_prob_std_fn)
99+
optimizer.zero_grad()
100+
loss.backward()
101+
optimizer.step()
102+
avg_loss += loss.item() * x.shape[0]
103+
num_items += x.shape[0]
104+
batch_tqdm.set_description("Epoch %d, loss %.4f" % (epoch, avg_loss / num_items))
105+
scheduler.step()
106+
lr_current = scheduler.get_last_lr()[0]
107+
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
108+
# Print the averaged training loss so far.
109+
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
110+
# Update the checkpoint after each epoch of training.
111+
torch.save(score_model.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}.pth')
112+
torch.save(cond_embed.state_dict(),
113+
f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}_cond_embed.pth')
114+
if callback is not None:
115+
score_model.eval()
116+
callback(score_model, epoch, ckpt_name)
117+
#%%
118+
def Euler_Maruyama_sampler(score_model,
119+
marginal_prob_std,
120+
diffusion_coeff,
121+
batch_size=64,
122+
x_shape=(1, 28, 28),
123+
num_steps=500,
124+
device='cuda',
125+
eps=1e-3,
126+
y=None):
127+
"""Generate samples from score-based models with the Euler-Maruyama solver.
128+
129+
Args:
130+
score_model: A PyTorch model that represents the time-dependent score-based model.
131+
marginal_prob_std: A function that gives the standard deviation of
132+
the perturbation kernel.
133+
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
134+
batch_size: The number of samplers to generate by calling this function once.
135+
num_steps: The number of sampling steps.
136+
Equivalent to the number of discretized time steps.
137+
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
138+
eps: The smallest time step for numerical stability.
139+
140+
Returns:
141+
Samples.
142+
"""
143+
t = torch.ones(batch_size, device=device)
144+
init_x = torch.randn(batch_size, *x_shape, device=device) \
145+
* marginal_prob_std(t)[:, None, None, None]
146+
time_steps = torch.linspace(1., eps, num_steps, device=device)
147+
step_size = time_steps[0] - time_steps[1]
148+
x = init_x
149+
with torch.no_grad():
150+
for time_step in tqdm(time_steps):
151+
batch_time_step = torch.ones(batch_size, device=device) * time_step
152+
g = diffusion_coeff(batch_time_step)
153+
mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size
154+
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
155+
# Do not include any noise in the last sampling step.
156+
return mean_x
157+
#%%
158+
import matplotlib.pyplot as plt
159+
from torchvision.utils import make_grid
160+
def save_sample_callback(score_model, epocs, ckpt_name):
161+
sample_batch_size = 64
162+
num_steps = 250
163+
y_samp = yseq_data[:sample_batch_size, :]
164+
y_emb = cond_embed(y_samp.cuda())
165+
sampler = Euler_Maruyama_sampler
166+
samples = sampler(score_model,
167+
marginal_prob_std_fn,
168+
diffusion_coeff_fn,
169+
sample_batch_size,
170+
x_shape=(3, 32, 32),
171+
num_steps=num_steps,
172+
device=device,
173+
y=y_emb, )
174+
denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225],
175+
[1/0.229, 1/0.224, 1/0.225])
176+
samples = denormalize(samples).clamp(0.0, 1.0)
177+
sample_grid = make_grid(samples, nrow=int(math.sqrt(sample_batch_size)))
178+
179+
plt.figure(figsize=(8, 8))
180+
plt.axis('off')
181+
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
182+
plt.tight_layout()
183+
plt.savefig(f"/home/binxuwang/DL_Projects/SDfromScratch/samples_{ckpt_name}_{epocs}.png")
184+
plt.show()
185+
#%%
186+
from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet
187+
#%% UNet without latent space no VAE
188+
unet_face = UNet_SD(in_channels=3,
189+
base_channels=128,
190+
time_emb_dim=256,
191+
context_dim=256,
192+
multipliers=(1, 1, 2),
193+
attn_levels=(1, 2, ),
194+
nResAttn_block=1,
195+
)
196+
cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).cuda()
197+
#%%
198+
torch.save(unet_face.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/SD_unet_face.pt",)
199+
#%%
200+
unet_face(torch.randn(1, 3, 64, 64).cuda(), time_steps=torch.rand(1).cuda(),
201+
cond=torch.randn(1, 20, 256).cuda(),
202+
output_dict=False)
203+
#%%
204+
#%%
205+
train_score_model(unet_face, cond_embed, saved_dataset,
206+
lr=1.5e-4, n_epochs=100, batch_size=256,
207+
ckpt_name="unet_SD_face", device=device,
208+
callback=save_sample_callback)
209+
210+
#%%
211+
212+
213+
save_sample_callback(unet_face, 0, "unet_SD_face")
214+
#%%
215+
torch.save(cond_embed.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{"unet_SD_face"}_cond_embed.pth')

StableDiffusion_exps.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,25 @@ def plt_show_image(image):
1414
plt.show()
1515

1616

17+
def recursive_print(module, prefix="", depth=0, deepest=3):
18+
"""Simulating print(module) for torch.nn.Modules
19+
but with depth control. Print to the `deepest` level. `deepest=0` means no print
20+
"""
21+
if depth >= deepest:
22+
return
23+
for name, child in module.named_children():
24+
if len([*child.named_children()]) == 0:
25+
print(f"{prefix}({name}): {child}")
26+
else:
27+
print(f"{prefix}({name}): {type(child).__name__}")
28+
recursive_print(child, prefix + " ", depth + 1, deepest)
29+
30+
#%%
31+
1732
pipe = StableDiffusionPipeline.from_pretrained(
1833
"CompVis/stable-diffusion-v1-4",
1934
use_auth_token=True
2035
).to("cuda")
21-
#%%
2236
def dummy_checker(images, **kwargs): return images, False
2337
pipe.safety_checker = dummy_checker
2438
#%% Text to
@@ -40,36 +54,36 @@ def dummy_checker(images, **kwargs): return images, False
4054

4155

4256
#%% Saving images during diffusion process using callback
57+
58+
latents_reservoir = []
4359
@torch.no_grad()
4460
def plot_show_callback(i, t, latents):
61+
latents_reservoir.append(latents.detach().cpu())
4562
latents = 1 / 0.18215 * latents
4663
image = pipe.vae.decode(latents).sample
4764
image = (image / 2 + 0.5).clamp(0, 1)
4865
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
4966
plt_show_image(image[0])
5067
plt.imsave(f"/home/binxuwang/DL_Projects/SDfromScratch/diffproc/sample_{i:02d}.png", image[0])
5168

69+
latents_reservoir = []
70+
@torch.no_grad()
71+
def save_latents(i, t, latents):
72+
latents_reservoir.append(latents.detach().cpu())
5273
#%%
5374
# prompt = "A ballerina dancing on a high ground in the starry night"
54-
prompt = "A cute cat running on the grass in the style of Monet"
75+
# prompt = "A cute cat running on the grass in the style of Monet"
76+
prompt = "A ballerina chasing her cat running on the grass in the style of Monet"
77+
prompt = "A kitty cat dressed like Lincoln, old timey style"
5578
with autocast("cuda"):
56-
image = pipe(prompt, callback=plot_show_callback)["sample"][0]
79+
image = pipe(prompt, callback=None)["sample"][0] # plot_show_callback
5780

58-
image.save("cat_Monet.png")
81+
image.save("cat_Lincoln.png")
5982
plt_show_image(image)
83+
#%%
84+
len(latents_reservoir)
85+
plt_show_image(latents_reservoir[-10][0, [0, 1, 2,], :].permute(1, 2, 0).cpu().numpy() / 1.6 + 0.4)
6086
#%% Visualize architecture
61-
def recursive_print(module, prefix="", depth=0, deepest=3):
62-
"""Simulating print(module) for torch.nn.Modules
63-
but with depth control. Print to the `deepest` level. `deepest=0` means no print
64-
"""
65-
if depth >= deepest:
66-
return
67-
for name, child in module.named_children():
68-
if len([*child.named_children()]) == 0:
69-
print(f"{prefix}({name}): {child}")
70-
else:
71-
print(f"{prefix}({name}): {type(child).__name__}")
72-
recursive_print(child, prefix + " ", depth + 1, deepest)
7387

7488
#%% Full unets
7589
recursive_print(pipe.unet, deepest=3)

0 commit comments

Comments
 (0)