@@ -156,8 +156,11 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
156156#%%
157157
158158#%%
159+ from tqdm import tqdm
160+ from torch .utils .data import DataLoader , TensorDataset
159161from torchvision .datasets import CelebA
160162from torchvision .transforms import ToTensor , CenterCrop , Resize , Compose , Normalize
163+
161164tfm = Compose ([
162165 Resize (32 ),
163166 CenterCrop (32 ),
@@ -167,9 +170,7 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
167170dataset_rsz = CelebA ("/home/binxuwang/Datasets" , target_type = ["attr" ],
168171 transform = tfm , download = False ) # ,"identity"
169172#%%
170- from torch .utils .data import DataLoader , TensorDataset
171- from tqdm import tqdm
172-
173+ # def preprocess_dataset(dataset_rsz, ):
173174dataloader = DataLoader (dataset_rsz , batch_size = 64 , num_workers = 8 , shuffle = False )
174175x_col = []
175176y_col = []
@@ -181,12 +182,12 @@ def train_score_model(score_model, dataset, lr, n_epochs, batch_size, ckpt_name,
181182print (x_col .shape )
182183print (y_col .shape )
183184
184- maxlen = (y_col .sum (dim = 1 )).max ()
185185nantoken = 40
186- yseq_data = torch . ones (y_col .size ( 0 ), maxlen ,
187- dtype = int ).fill_ (nantoken )
186+ maxlen = (y_col .sum ( dim = 1 )). max ()
187+ yseq_data = torch . ones ( y_col . size ( 0 ), maxlen , dtype = int ).fill_ (nantoken )
188188
189189saved_dataset = TensorDataset (x_col , yseq_data )
190+ # return saved_dataset
190191#%%
191192import matplotlib .pyplot as plt
192193
0 commit comments