Skip to content

Commit ef94957

Browse files
committed
ADD
1 parent 714556d commit ef94957

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

Diffusion_training_demo.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
156156
#%%
157157

158158
#%%
159+
from tqdm import tqdm
160+
from torch.utils.data import DataLoader, TensorDataset
159161
from torchvision.datasets import CelebA
160162
from torchvision.transforms import ToTensor, CenterCrop, Resize, Compose, Normalize
163+
161164
tfm = Compose([
162165
Resize(32),
163166
CenterCrop(32),
@@ -167,9 +170,7 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
167170
dataset_rsz = CelebA("/home/binxuwang/Datasets", target_type=["attr"],
168171
transform=tfm, download=False) # ,"identity"
169172
#%%
170-
from torch.utils.data import DataLoader, TensorDataset
171-
from tqdm import tqdm
172-
173+
# def preprocess_dataset(dataset_rsz, ):
173174
dataloader = DataLoader(dataset_rsz, batch_size=64, num_workers=8, shuffle=False)
174175
x_col = []
175176
y_col = []
@@ -181,12 +182,12 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
181182
print(x_col.shape)
182183
print(y_col.shape)
183184

184-
maxlen = (y_col.sum(dim=1)).max()
185185
nantoken = 40
186-
yseq_data = torch.ones(y_col.size(0), maxlen,
187-
dtype=int).fill_(nantoken)
186+
maxlen = (y_col.sum(dim=1)).max()
187+
yseq_data = torch.ones(y_col.size(0), maxlen, dtype=int).fill_(nantoken)
188188

189189
saved_dataset = TensorDataset(x_col, yseq_data)
190+
# return saved_dataset
190191
#%%
191192
import matplotlib.pyplot as plt
192193

StableDiffusion_exps.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def recursive_print(module, prefix="", depth=0, deepest=3):
3535
).to("cuda")
3636
def dummy_checker(images, **kwargs): return images, False
3737
pipe.safety_checker = dummy_checker
38+
#%%
39+
recursive_print(pipe.unet, deepest=2)
3840
#%% Text to
3941
# prompt = "a photo of an ballerina riding a horse on mars"
4042
prompt = "A ballerina riding a Harley Motorcycle, CG Art"
@@ -76,13 +78,13 @@ def save_latents(i, t, latents):
7678
prompt = "A ballerina chasing her cat running on the grass in the style of Monet"
7779
prompt = "A kitty cat dressed like Lincoln, old timey style"
7880
with autocast("cuda"):
79-
image = pipe(prompt, callback=None)["sample"][0] # plot_show_callback
81+
image = pipe(prompt, callback=plot_show_callback)["sample"][0] # plot_show_callback
8082

8183
image.save("cat_Lincoln.png")
8284
plt_show_image(image)
8385
#%%
8486
len(latents_reservoir)
85-
plt_show_image(latents_reservoir[-10][0, [0, 1, 2,], :].permute(1, 2, 0).cpu().numpy() / 1.6 + 0.4)
87+
plt_show_image(latents_reservoir[-1][0, [0, 1, 2,], :].permute(1, 2, 0).cpu().numpy() / 1.6 + 0.4)
8688
#%% Visualize architecture
8789

8890
#%% Full unets

0 commit comments

Comments
 (0)