|
| 1 | +#%% |
| 2 | +import torch |
| 3 | +import functools |
| 4 | +from tqdm import tqdm, trange |
| 5 | +import torch.multiprocessing |
| 6 | +from tqdm import tqdm |
| 7 | +import torch.nn as nn |
| 8 | +import torch.nn.functional as F |
| 9 | +torch.multiprocessing.set_sharing_strategy('file_system') |
| 10 | +#%% |
| 11 | +from torch.utils.data import DataLoader, TensorDataset |
| 12 | +from torchvision.datasets import CelebA |
| 13 | +from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize |
| 14 | + |
| 15 | + |
| 16 | +tfm = Compose([ |
| 17 | + Resize(32), |
| 18 | + CenterCrop(32), |
| 19 | + ToTensor(), |
| 20 | + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| 21 | +]) |
| 22 | +dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"], |
| 23 | + transform=tfm, download=False) # ,"identity" |
| 24 | +#%% |
| 25 | +dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False) |
| 26 | +x_col = [] |
| 27 | +y_col = [] |
| 28 | +for xs, ys in tqdm(dataloader): |
| 29 | + x_col.append(xs) |
| 30 | + y_col.append(ys) |
| 31 | +x_col = torch.concat(x_col, dim=0) |
| 32 | +y_col = torch.concat(y_col, dim=0) |
| 33 | +print(x_col.shape) |
| 34 | +print(y_col.shape) |
| 35 | + |
| 36 | +nantoken = 40 |
| 37 | +maxlen = (y_col.sum(dim=1)).max() |
| 38 | +yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken) |
| 39 | + |
| 40 | +saved_dataset = TensorDataset(x_col, yseq_data) |
| 41 | +#%% |
| 42 | +import math |
| 43 | +from torch.optim import Adam |
| 44 | +from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR |
| 45 | +device = 'cuda' |
| 46 | + |
| 47 | +def marginal_prob_std(t, sigma): |
| 48 | + t = torch.tensor(t, device=device) |
| 49 | + return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / math.log(sigma)) |
| 50 | + |
| 51 | + |
| 52 | +def diffusion_coeff(t, sigma): |
| 53 | + return torch.tensor(sigma ** t, device=device) |
| 54 | + |
| 55 | + |
| 56 | +sigma = 25.0 # @param {'type':'number'} |
| 57 | +marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) |
| 58 | +diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) |
| 59 | +#% |
| 60 | +#@title Training Loss function |
| 61 | +def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5): |
| 62 | + """The loss function for training score-based generative models. |
| 63 | +
|
| 64 | + Args: |
| 65 | + model: A PyTorch model instance that represents a |
| 66 | + time-dependent score-based model. |
| 67 | + x: A mini-batch of training data. |
| 68 | + marginal_prob_std: A function that gives the standard deviation of |
| 69 | + the perturbation kernel. |
| 70 | + eps: A tolerance value for numerical stability. |
| 71 | + """ |
| 72 | + random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps |
| 73 | + z = torch.randn_like(x) |
| 74 | + std = marginal_prob_std(random_t) |
| 75 | + perturbed_x = x + z * std[:, None, None, None] |
| 76 | + score = model(perturbed_x, random_t, cond=y, output_dict=False) |
| 77 | + loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3))) |
| 78 | + return loss |
| 79 | + |
| 80 | +#% |
| 81 | +def train_score_model(score_model, cond_embed, dataset, lr, n_epochs, batch_size, ckpt_name, |
| 82 | + marginal_prob_std_fn=marginal_prob_std_fn, |
| 83 | + lr_scheduler_fn=lambda epoch: max(0.2, 0.98 ** epoch), |
| 84 | + device="cuda", |
| 85 | + callback=None): # resume=False, |
| 86 | + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
| 87 | + optimizer = Adam([*score_model.parameters(), *cond_embed.parameters()], lr=lr) |
| 88 | + scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn) |
| 89 | + tqdm_epoch = trange(n_epochs) |
| 90 | + for epoch in tqdm_epoch: |
| 91 | + score_model.train() |
| 92 | + avg_loss = 0. |
| 93 | + num_items = 0 |
| 94 | + batch_tqdm = tqdm(data_loader) |
| 95 | + for x, y in batch_tqdm: |
| 96 | + x = x.to(device) |
| 97 | + y_emb = cond_embed(y.to(device)) |
| 98 | + loss = loss_fn_cond(score_model, x, y_emb, marginal_prob_std_fn) |
| 99 | + optimizer.zero_grad() |
| 100 | + loss.backward() |
| 101 | + optimizer.step() |
| 102 | + avg_loss += loss.item() * x.shape[0] |
| 103 | + num_items += x.shape[0] |
| 104 | + batch_tqdm.set_description("Epoch %d, loss %.4f" % (epoch, avg_loss / num_items)) |
| 105 | + scheduler.step() |
| 106 | + lr_current = scheduler.get_last_lr()[0] |
| 107 | + print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current)) |
| 108 | + # Print the averaged training loss so far. |
| 109 | + tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) |
| 110 | + # Update the checkpoint after each epoch of training. |
| 111 | + torch.save(score_model.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}.pth') |
| 112 | + torch.save(cond_embed.state_dict(), |
| 113 | + f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{ckpt_name}_cond_embed.pth') |
| 114 | + if callback is not None: |
| 115 | + score_model.eval() |
| 116 | + callback(score_model, epoch, ckpt_name) |
| 117 | +#%% |
| 118 | +def Euler_Maruyama_sampler(score_model, |
| 119 | + marginal_prob_std, |
| 120 | + diffusion_coeff, |
| 121 | + batch_size=64, |
| 122 | + x_shape=(1, 28, 28), |
| 123 | + num_steps=500, |
| 124 | + device='cuda', |
| 125 | + eps=1e-3, |
| 126 | + y=None): |
| 127 | + """Generate samples from score-based models with the Euler-Maruyama solver. |
| 128 | +
|
| 129 | + Args: |
| 130 | + score_model: A PyTorch model that represents the time-dependent score-based model. |
| 131 | + marginal_prob_std: A function that gives the standard deviation of |
| 132 | + the perturbation kernel. |
| 133 | + diffusion_coeff: A function that gives the diffusion coefficient of the SDE. |
| 134 | + batch_size: The number of samplers to generate by calling this function once. |
| 135 | + num_steps: The number of sampling steps. |
| 136 | + Equivalent to the number of discretized time steps. |
| 137 | + device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs. |
| 138 | + eps: The smallest time step for numerical stability. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + Samples. |
| 142 | + """ |
| 143 | + t = torch.ones(batch_size, device=device) |
| 144 | + init_x = torch.randn(batch_size, *x_shape, device=device) \ |
| 145 | + * marginal_prob_std(t)[:, None, None, None] |
| 146 | + time_steps = torch.linspace(1., eps, num_steps, device=device) |
| 147 | + step_size = time_steps[0] - time_steps[1] |
| 148 | + x = init_x |
| 149 | + with torch.no_grad(): |
| 150 | + for time_step in tqdm(time_steps): |
| 151 | + batch_time_step = torch.ones(batch_size, device=device) * time_step |
| 152 | + g = diffusion_coeff(batch_time_step) |
| 153 | + mean_x = x + (g ** 2)[:, None, None, None] * score_model(x, batch_time_step, cond=y, output_dict=False) * step_size |
| 154 | + x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x) |
| 155 | + # Do not include any noise in the last sampling step. |
| 156 | + return mean_x |
| 157 | +#%% |
| 158 | +import matplotlib.pyplot as plt |
| 159 | +from torchvision.utils import make_grid |
| 160 | +def save_sample_callback(score_model, epocs, ckpt_name): |
| 161 | + sample_batch_size = 64 |
| 162 | + num_steps = 250 |
| 163 | + y_samp = yseq_data[:sample_batch_size, :] |
| 164 | + y_emb = cond_embed(y_samp.cuda()) |
| 165 | + sampler = Euler_Maruyama_sampler |
| 166 | + samples = sampler(score_model, |
| 167 | + marginal_prob_std_fn, |
| 168 | + diffusion_coeff_fn, |
| 169 | + sample_batch_size, |
| 170 | + x_shape=(3, 32, 32), |
| 171 | + num_steps=num_steps, |
| 172 | + device=device, |
| 173 | + y=y_emb, ) |
| 174 | + denormalize = Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], |
| 175 | + [1/0.229, 1/0.224, 1/0.225]) |
| 176 | + samples = denormalize(samples).clamp(0.0, 1.0) |
| 177 | + sample_grid = make_grid(samples, nrow=int(math.sqrt(sample_batch_size))) |
| 178 | + |
| 179 | + plt.figure(figsize=(8, 8)) |
| 180 | + plt.axis('off') |
| 181 | + plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) |
| 182 | + plt.tight_layout() |
| 183 | + plt.savefig(f"/home/binxuwang/DL_Projects/SDfromScratch/samples_{ckpt_name}_{epocs}.png") |
| 184 | + plt.show() |
| 185 | +#%% |
| 186 | +from StableDiff_UNet_model import UNet_SD, load_pipe_into_our_UNet |
| 187 | +#%% UNet without latent space no VAE |
| 188 | +unet_face = UNet_SD(in_channels=3, |
| 189 | + base_channels=128, |
| 190 | + time_emb_dim=256, |
| 191 | + context_dim=256, |
| 192 | + multipliers=(1, 1, 2), |
| 193 | + attn_levels=(1, 2, ), |
| 194 | + nResAttn_block=1, |
| 195 | + ) |
| 196 | +cond_embed = nn.Embedding(40 + 1, 256, padding_idx=40).cuda() |
| 197 | +#%% |
| 198 | +torch.save(unet_face.state_dict(), "/home/binxuwang/DL_Projects/SDfromScratch/SD_unet_face.pt",) |
| 199 | +#%% |
| 200 | +unet_face(torch.randn(1, 3, 64, 64).cuda(), time_steps=torch.rand(1).cuda(), |
| 201 | + cond=torch.randn(1, 20, 256).cuda(), |
| 202 | + output_dict=False) |
| 203 | +#%% |
| 204 | +#%% |
| 205 | +train_score_model(unet_face, cond_embed, saved_dataset, |
| 206 | + lr=1.5e-4, n_epochs=100, batch_size=256, |
| 207 | + ckpt_name="unet_SD_face", device=device, |
| 208 | + callback=save_sample_callback) |
| 209 | + |
| 210 | +#%% |
| 211 | + |
| 212 | + |
| 213 | +save_sample_callback(unet_face, 0, "unet_SD_face") |
| 214 | +#%% |
| 215 | +torch.save(cond_embed.state_dict(), f'/home/binxuwang/DL_Projects/SDfromScratch/ckpt_{"unet_SD_face"}_cond_embed.pth') |
0 commit comments