|
| 1 | +import torch |
| 2 | +import functools |
| 3 | +from tqdm import tqdm, trange |
| 4 | +import torch.multiprocessing |
| 5 | +torch.multiprocessing.set_sharing_strategy('file_system') |
| 6 | + |
| 7 | +# @title Diffusion constant and noise strength |
| 8 | +device = 'cuda' # @param ['cuda', 'cpu'] {'type':'string'} |
| 9 | + |
| 10 | +def marginal_prob_std(t, sigma): |
| 11 | + """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. |
| 12 | +
|
| 13 | + Args: |
| 14 | + t: A vector of time steps. |
| 15 | + sigma: The $\sigma$ in our SDE. |
| 16 | +
|
| 17 | + Returns: |
| 18 | + The standard deviation. |
| 19 | + """ |
| 20 | + t = torch.tensor(t, device=device) |
| 21 | + return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma)) |
| 22 | + |
| 23 | + |
| 24 | +def diffusion_coeff(t, sigma): |
| 25 | + """Compute the diffusion coefficient of our SDE. |
| 26 | +
|
| 27 | + Args: |
| 28 | + t: A vector of time steps. |
| 29 | + sigma: The $\sigma$ in our SDE. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + The vector of diffusion coefficients. |
| 33 | + """ |
| 34 | + return torch.tensor(sigma ** t, device=device) |
| 35 | + |
| 36 | + |
| 37 | +sigma = 25.0 # @param {'type':'number'} |
| 38 | +marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) |
| 39 | +diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) |
| 40 | +#%% |
| 41 | +#@title Training Loss function |
| 42 | +def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): |
| 43 | + """The loss function for training score-based generative models. |
| 44 | +
|
| 45 | + Args: |
| 46 | + model: A PyTorch model instance that represents a |
| 47 | + time-dependent score-based model. |
| 48 | + x: A mini-batch of training data. |
| 49 | + marginal_prob_std: A function that gives the standard deviation of |
| 50 | + the perturbation kernel. |
| 51 | + eps: A tolerance value for numerical stability. |
| 52 | + """ |
| 53 | + random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps |
| 54 | + z = torch.randn_like(x) |
| 55 | + std = marginal_prob_std(random_t) |
| 56 | + perturbed_x = x + z * std[:, None, None, None] |
| 57 | + score = model(perturbed_x, random_t, y=y) |
| 58 | + loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) |
| 59 | + return loss |
| 60 | +#%% |
| 61 | +# @title Diffusion Model Sampler |
| 62 | +def Euler_Maruyama_sampler(score_model, |
| 63 | + marginal_prob_std, |
| 64 | + diffusion_coeff, |
| 65 | + batch_size=64, |
| 66 | + x_shape=(1, 28, 28), |
| 67 | + num_steps=500, |
| 68 | + device='cuda', |
| 69 | + eps=1e-3, |
| 70 | + y=None): |
| 71 | + """Generate samples from score-based models with the Euler-Maruyama solver. |
| 72 | +
|
| 73 | + Args: |
| 74 | + score_model: A PyTorch model that represents the time-dependent score-based model. |
| 75 | + marginal_prob_std: A function that gives the standard deviation of |
| 76 | + the perturbation kernel. |
| 77 | + diffusion_coeff: A function that gives the diffusion coefficient of the SDE. |
| 78 | + batch_size: The number of samplers to generate by calling this function once. |
| 79 | + num_steps: The number of sampling steps. |
| 80 | + Equivalent to the number of discretized time steps. |
| 81 | + device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs. |
| 82 | + eps: The smallest time step for numerical stability. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + Samples. |
| 86 | + """ |
| 87 | + t = torch.ones(batch_size, device=device) |
| 88 | + init_x = torch.randn(batch_size, *x_shape, device=device) \ |
| 89 | + * marginal_prob_std(t)[:, None, None, None] |
| 90 | + time_steps = torch.linspace(1., eps, num_steps, device=device) |
| 91 | + step_size = time_steps[0] - time_steps[1] |
| 92 | + x = init_x |
| 93 | + with torch.no_grad(): |
| 94 | + for time_step in tqdm(time_steps): |
| 95 | + batch_time_step = torch.ones(batch_size, device=device) * time_step |
| 96 | + g = diffusion_coeff(batch_time_step) |
| 97 | + mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size |
| 98 | + x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x) |
| 99 | + # Do not include any noise in the last sampling step. |
| 100 | + return mean_x |
| 101 | + |
| 102 | +#%% |
| 103 | +import torch |
| 104 | +import torch.nn as nn |
| 105 | +import torch.nn.functional as F |
| 106 | +import numpy as np |
| 107 | +import functools |
| 108 | +import math |
| 109 | +from torch.optim import Adam |
| 110 | +from torch.utils.data import DataLoader |
| 111 | +import torchvision.transforms as transforms |
| 112 | +from torchvision.datasets import MNIST |
| 113 | +import tqdm |
| 114 | +from tqdm.notebook import trange, tqdm |
| 115 | +from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR |
| 116 | + |
| 117 | +import matplotlib.pyplot as plt |
| 118 | + |
| 119 | +from torchvision.utils import make_grid |
| 120 | +from einops import rearrange |
| 121 | + |
| 122 | +#@title Diffusion Trainer |
| 123 | +def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name, |
| 124 | + marginal_prob_std_fn=marginal_prob_std_fn, |
| 125 | + lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch), |
| 126 | + device="cuda", |
| 127 | + callback=None): # resume=False, |
| 128 | + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
| 129 | + |
| 130 | + optimizer = Adam(score_model.parameters(), lr=lr) |
| 131 | + scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn) |
| 132 | + tqdm_epoch = trange(n_epochs) |
| 133 | + for epoch in tqdm_epoch: |
| 134 | + score_model.train() |
| 135 | + avg_loss = 0. |
| 136 | + num_items = 0 |
| 137 | + for x, y in tqdm(data_loader): |
| 138 | + x = x.to(device) |
| 139 | + loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn) |
| 140 | + optimizer.zero_grad() |
| 141 | + loss.backward() |
| 142 | + optimizer.step() |
| 143 | + avg_loss += loss.item() * x.shape[0] |
| 144 | + num_items += x.shape[0] |
| 145 | + scheduler.step() |
| 146 | + lr_current = scheduler.get_last_lr()[0] |
| 147 | + print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current)) |
| 148 | + # Print the averaged training loss so far. |
| 149 | + tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) |
| 150 | + # Update the checkpoint after each epoch of training. |
| 151 | + torch.save(score_model.state_dict(), f'ckpt_{ckpt_name}.pth') |
| 152 | + if callback is not None: |
| 153 | + score_model.eval() |
| 154 | + callback(score_model, epoch, ckpt_name) |
| 155 | + |
| 156 | +#%% |
| 157 | + |
| 158 | +#%% |
| 159 | +from torchvision.datasets import CelebA |
| 160 | +from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize |
| 161 | +tfm = Compose([ |
| 162 | + Resize(32), |
| 163 | + CenterCrop(32), |
| 164 | + ToTensor(), |
| 165 | + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| 166 | +]) |
| 167 | +dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"], |
| 168 | + transform=tfm, download=False) # ,"identity" |
| 169 | +#%% |
| 170 | +from torch.utils.data import DataLoader, TensorDataset |
| 171 | +from tqdm import tqdm |
| 172 | + |
| 173 | +dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False) |
| 174 | +x_col = [] |
| 175 | +y_col = [] |
| 176 | +for xs, ys in tqdm(dataloader): |
| 177 | + x_col.append(xs) |
| 178 | + y_col.append(ys) |
| 179 | +x_col = torch.concat(x_col, dim=0) |
| 180 | +y_col = torch.concat(y_col, dim=0) |
| 181 | +print(x_col.shape) |
| 182 | +print(y_col.shape) |
| 183 | + |
| 184 | +maxlen = (y_col.sum(dim=1)).max() |
| 185 | +nantoken = 40 |
| 186 | +yseq_data = torch.ones(y_col.size(0), maxlen, |
| 187 | + dtype=int).fill_(nantoken) |
| 188 | + |
| 189 | +saved_dataset = TensorDataset(x_col, yseq_data) |
| 190 | +#%% |
| 191 | +import matplotlib.pyplot as plt |
| 192 | + |
| 193 | +def save_sample_callback(score_model, epocs, ckpt_name): |
| 194 | + sample_batch_size = 64 |
| 195 | + num_steps = 250 |
| 196 | + y_samp = yseq_data[:sample_batch_size, :] |
| 197 | + sampler = Euler_Maruyama_sampler |
| 198 | + samples = sampler(score_model, |
| 199 | + marginal_prob_std_fn, |
| 200 | + diffusion_coeff_fn, |
| 201 | + sample_batch_size, |
| 202 | + x_shape=(3, 32, 32), |
| 203 | + num_steps=num_steps, |
| 204 | + device=device, |
| 205 | + y=y_samp, ) |
| 206 | + denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], |
| 207 | + [1/0.229, 1/0.224, 1/0.225]) |
| 208 | + samples = denormalize(samples).clamp(0.0, 1.0) |
| 209 | + sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size))) |
| 210 | + |
| 211 | + plt.figure(figsize=(8, 8)) |
| 212 | + plt.axis('off') |
| 213 | + plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) |
| 214 | + plt.tight_layout() |
| 215 | + plt.savefig(f"samples_{ckpt_name}_{epocs}.png") |
| 216 | + plt.show() |
| 217 | +#%% |
| 218 | +#@title Training model |
| 219 | + |
| 220 | +# continue_training = False #@param {type:"boolean"} |
| 221 | +# if not continue_training: |
| 222 | +# print("initilize new score model...") |
| 223 | +score_model = torch.nn.DataParallel( |
| 224 | + UNet_Tranformer_attrb(marginal_prob_std=marginal_prob_std_fn)) |
| 225 | +score_model = score_model.to(device) |
| 226 | + |
| 227 | +n_epochs = 50 |
| 228 | +batch_size = 1024 |
| 229 | +lr = 10e-4 |
| 230 | +train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size, |
| 231 | + "Unet-tfmer_pad", device="cuda", callback=save_sample_callback) |
| 232 | +#%% |
| 233 | + |
| 234 | +#%% |
| 235 | +n_epochs = 50 |
| 236 | +batch_size = 1048 |
| 237 | +lr = 2e-4 |
| 238 | +train_score_model(score_model, saved_dataset, lr, n_epochs, batch_size, |
| 239 | + "Unet-tfmer", device="cuda", callback=save_sample_callback) |
| 240 | +#%% |
| 241 | + |
| 242 | +sample_batch_size = 64 #@param {'type':'integer'} |
| 243 | +num_steps = 250 #@param {'type':'integer'} |
| 244 | +sampler = Euler_Maruyama_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'} |
| 245 | +# score_model.eval() |
| 246 | +## Generate samples using the specified sampler. |
| 247 | +samples = sampler(score_model, |
| 248 | + marginal_prob_std_fn, |
| 249 | + diffusion_coeff_fn, |
| 250 | + sample_batch_size, |
| 251 | + x_shape=(3, 32, 32), |
| 252 | + num_steps=num_steps, |
| 253 | + device=device, |
| 254 | + y=yseq_data[:sample_batch_size,:], ) |
| 255 | + |
| 256 | +## Sample visualization. |
| 257 | +denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], |
| 258 | + [1/0.229, 1/0.224, 1/0.225]) |
| 259 | + |
| 260 | +samples = denormalize(samples).clamp(0.0, 1.0) |
| 261 | +sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size))) |
| 262 | + |
| 263 | +plt.figure(figsize=(8, 8)) |
| 264 | +plt.axis('off') |
| 265 | +plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) |
| 266 | +plt.tight_layout() |
| 267 | +plt.show() |
| 268 | + |
0 commit comments