2929from glob import glob
3030from cellcyclenet .unet3d .model import UNet3D
3131from cellcyclenet import models
32+ from cellcyclenet .unet2d import UNet2D
3233from torch .utils .data import Dataset , DataLoader
3334import torchvision .transforms .v2 as transforms
3435from skimage .transform import downscale_local_mean
4243
4344class CCN_Dataset (Dataset ):
4445 '''Create a class to hold the PyTorch Dataset, input to PyTorch Dataloader.'''
45- def __init__ (self , X , y , norm_factor , scale_factors , transform , lazy_load ):
46+ def __init__ (self , X , y , transform , lazy_load ):
4647 self .X = X
4748 self .y = y
4849 self .transform = transform
49- self .norm_factor = norm_factor
50- self .scale_factors = scale_factors
5150 self .lazy_load = lazy_load
5251
5352 def __len__ (self ):
@@ -66,8 +65,6 @@ def __getitem__(self, index):
6665 # if initialized with image fns (lazy loading), load, normalize, and scale image #
6766 else :
6867 image = imread (self .X [index ])
69- image = downscale_local_mean (image , self .scale_factors )
70- image = image / self .norm_factor
7168
7269 # convert image to float 32 #
7370 X = np .asarray (image , dtype = np .float32 )
@@ -92,40 +89,47 @@ def __getitem__(self, index):
9289
9390class CellCycleNet :
9491
95- def __init__ (self , state_dict_path = None ):
96- # Initialize device and model architecture, load model weights #
92+ def __init__ (self , state_dict_path = None , is_3d = True ):
93+ # initialize device as GPU if available, otherwise CPU #
9794 self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
98- self .model = UNet3D (in_channels = 1 , out_channels = 1 , is_segmentation = False , f_maps = 32 )
95+ self .is_3d = is_3d
96+
97+ # initialize model architecture #
98+ if self .is_3d :
99+ self .model = UNet3D (in_channels = 1 , out_channels = 1 , is_segmentation = False , f_maps = 32 )
100+ else :
101+ self .model = UNet2D (in_channels = 1 , init_features = 32 )
99102 self .model = torch .nn .DataParallel (self .model )
103+
104+ # if user does not provide a path to weights, load pretrained weights #
100105 if state_dict_path is None :
101- state_dict_path = pkg_resources .files (models ).joinpath ('pretrained-model.pt' )
102- if torch .cuda .is_available ():
106+ if self .is_3d :
107+ state_dict_path = pkg_resources .files (models ).joinpath ('pretrained-model_3D.pt' )
108+ else :
109+ state_dict_path = pkg_resources .files (models ).joinpath ('pretrained-model_2D.pt' )
110+
111+ # load model weights #
112+ if torch .cuda .is_available (): # load to GPU if available #
103113 state_dict = torch .load (state_dict_path )
104- else :
114+ else : # otherwise, load to CPU #
105115 state_dict = torch .load (state_dict_path , map_location = torch .device ('cpu' ))
106116 self .model .load_state_dict (state_dict )
107117 self .model .to (self .device )
108118
109119 ################################################################################################
110120
111- def create_dataset (self , dataframe , norm_factor , scale_factors , split_data = False , seed = 15 ):
121+ def create_dataset (self , dataframe , split_data = False , seed = 15 ):
112122 '''
113123 Creates training, validation, and testing dataframes containing GT labels and image filenames for each nucleus.
114124 (NOTE: it is assumed that the order of labels in the dataframe is the same as order of images in image_dir.)
115125 Arguments:
116126 - image_dir [str] : path to directory of single nucleus images
117127 - dataframe [pd.DataFrame] : dataframe containing GT labels; if None, it is assumed that user wants to proceed without using labels (use case 1)
118- - norm_factor [int] : images are divided pixelwise by this value during loading (calculated as median value of the medians of each input image)
119- - scale_factor [tuple] : Z,Y,X scale factors
120128 - split_data [bool] : flag to determine if input data is split into train/val/test sets (use case 2) or just a single dataset (use case 1)
121129 - seed [int] : random seed to use for train/val/test split
122130 Outputs:
123131 - train, val, test [pd.DataFrame] : dataframe containing GT label (if inputted) + image filename for each nucleus
124132 '''
125- # set norm / scale factors as attributes (to be called by .train() and .predict() when creating dataloaders)
126- self .norm_factor = norm_factor
127- self .scale_factors = scale_factors
128-
129133 ### FIXME small DF for debugging ###
130134 # dataframe = dataframe.iloc[::15]
131135
@@ -166,9 +170,10 @@ def run_epoch(self, dataloader, is_train):
166170 images .to (self .device )
167171 labels .to (self .device )
168172
169- # reshape images to [batch, channel, Z, Y, X] #
170- images = torch .swapaxes (images , 1 , 2 )
171- images = torch .unsqueeze (images , 1 )
173+ # for 3D images, reshape images to [batch, channel, Z, Y, X] #
174+ if self .is_3d :
175+ images = torch .swapaxes (images , 1 , 2 )
176+ images = torch .unsqueeze (images , 1 )
172177
173178 ### FORWARD PASS ###
174179 outputs = self .model (images )
@@ -234,13 +239,11 @@ def train(self, train_df, val_df, n_epochs, batch_size=4, initial_LR=1e-5, trans
234239 train_X_fn = train_df ['filename' ].values
235240 train_y = np .where (train_df ['label' ].values == 'G1' , 0 , 1 )
236241 if lazy_load :
237- train_dataloader = DataLoader (CCN_Dataset (train_X_fn , train_y , self . norm_factor , self . scale_factors , transform = transform , lazy_load = lazy_load ),
242+ train_dataloader = DataLoader (CCN_Dataset (train_X_fn , train_y , transform = transform , lazy_load = lazy_load ),
238243 batch_size = batch_size , shuffle = False )
239244 else :
240245 train_X = np .asarray ([imread (fn ) for fn in train_X_fn ])
241- train_X_ds = np .asarray ([downscale_local_mean (image , self .scale_factors ) for image in train_X ])
242- train_X_norm = np .asarray ([image / self .norm_factor for image in train_X_ds ])
243- train_dataloader = DataLoader (CCN_Dataset (train_X_norm , train_y , self .norm_factor , self .scale_factors , transform = transform , lazy_load = lazy_load ),
246+ train_dataloader = DataLoader (CCN_Dataset (train_X , train_y , transform = transform , lazy_load = lazy_load ),
244247 batch_size = batch_size , shuffle = False )
245248
246249 # create dataloader for validation data #
@@ -249,13 +252,11 @@ def train(self, train_df, val_df, n_epochs, batch_size=4, initial_LR=1e-5, trans
249252 val_y = np .where (val_df ['label' ].values == 'G1' , 0 , 1 )
250253
251254 if lazy_load :
252- val_dataloader = DataLoader (CCN_Dataset (val_X_fn , val_y , self . norm_factor , self . scale_factors , transform = None , lazy_load = lazy_load ),
255+ val_dataloader = DataLoader (CCN_Dataset (val_X_fn , val_y , transform = None , lazy_load = lazy_load ),
253256 batch_size = batch_size , shuffle = False )
254257 else :
255258 val_X = np .asarray ([imread (fn ) for fn in val_X_fn ])
256- val_X_ds = np .asarray ([downscale_local_mean (image , self .scale_factors ) for image in val_X ])
257- val_X_norm = np .asarray ([image / self .norm_factor for image in val_X_ds ])
258- val_dataloader = DataLoader (CCN_Dataset (val_X_norm , val_y , self .norm_factor , self .scale_factors , transform = None , lazy_load = lazy_load ),
259+ val_dataloader = DataLoader (CCN_Dataset (val_X , val_y , transform = None , lazy_load = lazy_load ),
259260 batch_size = batch_size , shuffle = False )
260261
261262 # track val acc for each epoch to check for improvement #
@@ -313,7 +314,7 @@ def predict(self, dataframe, with_labels, decision_threshold=0.5):
313314
314315 X_fn = dataframe ['filename' ].values
315316 y = np .where (dataframe ['label' ].values == 'G1' , 0 , 1 ) if with_labels else np .zeros (len (X_fn ))
316- dataloader = DataLoader (CCN_Dataset (X_fn , y , self . norm_factor , self . scale_factors , transform = None , lazy_load = True ), batch_size = 4 , shuffle = False )
317+ dataloader = DataLoader (CCN_Dataset (X_fn , y , transform = None , lazy_load = True ), batch_size = 4 , shuffle = False )
317318
318319 # Run through network #
319320 with torch .no_grad ():
@@ -322,9 +323,10 @@ def predict(self, dataframe, with_labels, decision_threshold=0.5):
322323 images = images .to (self .device )
323324 labels = labels .to (self .device )
324325
325- # Reshape to [batch, channels, Z, Y, X] #
326- images = torch .swapaxes (images , 1 , 2 )
327- images = torch .unsqueeze (images , 1 )
326+ # for 3D images, reshape to [batch, channels, Z, Y, X] #
327+ if self .is_3d :
328+ images = torch .swapaxes (images , 1 , 2 )
329+ images = torch .unsqueeze (images , 1 )
328330
329331 # Run inference #
330332 outputs = self .model (images )
@@ -422,7 +424,10 @@ def show_image(self, dataframe, index=None, with_preds=True, hide_plot=False, fi
422424 if not hide_plot :
423425 plt .figure (figsize = figsize )
424426 plt .axis ('off' )
425- plt .imshow (np .max (image , axis = 0 ))
427+ if self .is_3d :
428+ plt .imshow (np .max (image , axis = 0 ))
429+ else :
430+ plt .imshow (image )
426431 plt .title (f'Index: { index } / Label: { label } / Pred: { pred } / Prob: { prob :.3f} ' , fontsize = 10 )
427432 plt .show ()
428433
0 commit comments