@@ -28,6 +28,7 @@ def _load_and_calc_params(args):
2828 # load image + mask #
2929 image_fn = args [0 ]
3030 mask_fn = args [1 ]
31+ is_3d = args [2 ]
3132 image = imread (image_fn )
3233 mask = imread (mask_fn )
3334
@@ -37,28 +38,34 @@ def _load_and_calc_params(args):
3738
3839 # calculate median nuclear dims #
3940 prop = regionprops (mask )
40- dims = np .array ([(prop [k ].bbox [3 ]- prop [k ].bbox [0 ], # Z
41- prop [k ].bbox [4 ]- prop [k ].bbox [1 ], # Y
42- prop [k ].bbox [5 ]- prop [k ].bbox [2 ]) # X
43- for k in range (len (prop ))])
41+ if is_3d :
42+ dims = np .array ([(prop [k ].bbox [3 ]- prop [k ].bbox [0 ], # Z
43+ prop [k ].bbox [4 ]- prop [k ].bbox [1 ], # Y
44+ prop [k ].bbox [5 ]- prop [k ].bbox [2 ]) # X
45+ for k in range (len (prop ))])
46+ else :
47+ dims = np .array ([(prop [k ].bbox [2 ]- prop [k ].bbox [0 ], # Y
48+ prop [k ].bbox [3 ]- prop [k ].bbox [1 ]) # X
49+ for k in range (len (prop ))])
4450 median_dims = np .median (dims , axis = 0 ).astype (int )
4551
4652 return (median_pixel , median_dims )
4753
4854
49- def _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores ):
55+ def _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores , is_3d ):
5056 '''
5157 Uses DAPI images + corresponding masks to calculate median non-zero pixel value + median nuclear diameter across entire dataset.
5258 Args:
5359 - image_fns [list] : filenames for DAPI images (format as 'tile_n.tif')
5460 - mask_fns [list] : filenames for segmentation masks (format as 'mask_n.tif')
5561 - num_cores [int] : number of cores to use for parallel processing (default: None uses a single core)
62+ - is_3d [bool] : flag to determine if images are 3D or 2D
5663 Out:
5764 - norm_factor [float] : median non-zero pixel value across entire dataset
5865 - scale_factor [np.ndarray] : downsampling factors along each axis to apply to user's images such that they match dim. of pretraining images
5966 '''
6067 # package image and mask fns together for multiprocessing compatiability #
61- image_and_mask_fns = [(i_fn , m_fn ) for i_fn , m_fn in zip (image_fns , mask_fns )]
68+ image_and_mask_fns = [(i_fn , m_fn , is_3d ) for i_fn , m_fn in zip (image_fns , mask_fns )]
6269
6370 # compute median non-zero pixel value for each tile #
6471 if num_cores == None :
@@ -75,7 +82,10 @@ def _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores):
7582 norm_factor = np .median (median_pixels )
7683
7784 # scale factor is median of nuclear dims #
78- pretrained_dims = np .array ([16 , 37 , 37 ])
85+ if is_3d :
86+ pretrained_dims = np .array ([16 , 37 , 37 ])
87+ else :
88+ pretrained_dims = np .array ([37 , 37 ])
7989 user_dims = np .median (median_dims , axis = 0 )
8090 scale_factor = user_dims / pretrained_dims
8191
@@ -98,6 +108,7 @@ def _gen_SNI(args):
98108 output_dir = args [2 ]
99109 norm_factor = args [3 ]
100110 scale_factor = args [4 ]
111+ is_3d = args [5 ]
101112
102113 # for each object in the mask... #
103114 obj_nums = np .unique (mask )[1 :]
@@ -108,10 +119,15 @@ def _gen_SNI(args):
108119 obj = np .where (obj_mask , image , 0 )
109120
110121 # crop out empty rows/columns/planes #
111- full_rows = np .any (obj , axis = (1 ,2 ))
112- full_columns = np .any (obj , axis = (0 ,2 ))
113- full_stacks = np .any (obj , axis = (0 ,1 ))
114- obj_crop = obj [full_rows ][:, full_columns ][:, :, full_stacks ]
122+ if is_3d :
123+ full_rows = np .any (obj , axis = (1 ,2 ))
124+ full_columns = np .any (obj , axis = (0 ,2 ))
125+ full_stacks = np .any (obj , axis = (0 ,1 ))
126+ obj_crop = obj [full_rows ][:, full_columns ][:, :, full_stacks ]
127+ else :
128+ full_rows = np .any (obj , axis = 1 )
129+ full_columns = np .any (obj , axis = 0 )
130+ obj_crop = obj [full_rows ][:, full_columns ]
115131
116132 # normalize image based on norm factor #
117133 obj_norm = obj_crop / norm_factor
@@ -121,14 +137,21 @@ def _gen_SNI(args):
121137 obj_rescale = resize_local_mean (obj_norm , out_dims )
122138
123139 # SNIs will be padded to have 1/8 of their diameter on each side #
124- obj_z , obj_y , obj_x = obj_rescale .shape
125- z , y , x = [1.25 * dim for dim in [obj_z , obj_y , obj_x ]]
126-
127- # pad image + save #
140+ if is_3d :
141+ obj_z , obj_y , obj_x = obj_rescale .shape
142+ z , y , x = [1.25 * dim for dim in [obj_z , obj_y , obj_x ]]
143+ else :
144+ obj_y , obj_x = obj_rescale .shape
145+ y , x = [1.25 * dim for dim in [obj_y , obj_x ]]
146+
147+ # calculate padding #
128148 x_add = (int (floor ((x - obj_x ) / 2 )), int (ceil ((x - obj_x ) / 2 )))
129149 y_add = (int (floor ((y - obj_y ) / 2 )), int (ceil ((y - obj_y ) / 2 )))
130- z_add = (int (floor ((z - obj_z ) / 2 )), int (ceil ((z - obj_z ) / 2 )))
131- obj_pad = np .pad (obj_rescale , [z_add , y_add , x_add ])
150+ if is_3d : z_add = (int (floor ((z - obj_z ) / 2 )), int (ceil ((z - obj_z ) / 2 )))
151+
152+ # pad image + save #
153+ if is_3d : obj_pad = np .pad (obj_rescale , [z_add , y_add , x_add ])
154+ else : obj_pad = np .pad (obj_rescale , [y_add , x_add ])
132155 imwrite (f'{ output_dir } /{ os .path .splitext (os .path .basename (args [0 ]))[0 ]} _obj_{ obj_num } .tif' , obj_pad )
133156
134157
@@ -148,6 +171,7 @@ def _gen_SNI_label(args):
148171 output_dir = args [3 ]
149172 norm_factor = args [4 ]
150173 scale_factor = args [5 ]
174+ is_3d = args [6 ]
151175
152176 # for each object in the mask... #
153177 obj_nums = np .unique (mask )[1 :]
@@ -162,10 +186,15 @@ def _gen_SNI_label(args):
162186 label_str = 'G1' if label == 1 else 'S-G2'
163187
164188 # crop out empty rows/columns/planes #
165- full_rows = np .any (obj , axis = (1 ,2 ))
166- full_columns = np .any (obj , axis = (0 ,2 ))
167- full_stacks = np .any (obj , axis = (0 ,1 ))
168- obj_crop = obj [full_rows ][:, full_columns ][:, :, full_stacks ]
189+ if is_3d :
190+ full_rows = np .any (obj , axis = (1 ,2 ))
191+ full_columns = np .any (obj , axis = (0 ,2 ))
192+ full_stacks = np .any (obj , axis = (0 ,1 ))
193+ obj_crop = obj [full_rows ][:, full_columns ][:, :, full_stacks ]
194+ else :
195+ full_rows = np .any (obj , axis = 1 )
196+ full_columns = np .any (obj , axis = 0 )
197+ obj_crop = obj [full_rows ][:, full_columns ]
169198
170199 # normalize image based on norm factor #
171200 obj_norm = obj_crop / norm_factor
@@ -175,14 +204,21 @@ def _gen_SNI_label(args):
175204 obj_rescale = resize_local_mean (obj_norm , out_dims )
176205
177206 # SNIs will be padded to have 1/8 of their diameter on each side #
178- obj_z , obj_y , obj_x = obj_rescale .shape
179- z , y , x = [1.25 * dim for dim in [obj_z , obj_y , obj_x ]]
180-
181- # pad image + save #
207+ if is_3d :
208+ obj_z , obj_y , obj_x = obj_rescale .shape
209+ z , y , x = [1.25 * dim for dim in [obj_z , obj_y , obj_x ]]
210+ else :
211+ obj_y , obj_x = obj_rescale .shape
212+ y , x = [1.25 * dim for dim in [obj_y , obj_x ]]
213+
214+ # calculate padding #
182215 x_add = (int (floor ((x - obj_x ) / 2 )), int (ceil ((x - obj_x ) / 2 )))
183216 y_add = (int (floor ((y - obj_y ) / 2 )), int (ceil ((y - obj_y ) / 2 )))
184- z_add = (int (floor ((z - obj_z ) / 2 )), int (ceil ((z - obj_z ) / 2 )))
185- obj_pad = np .pad (obj_rescale , [z_add , y_add , x_add ])
217+ if is_3d : z_add = (int (floor ((z - obj_z ) / 2 )), int (ceil ((z - obj_z ) / 2 )))
218+
219+ # pad image + save #
220+ if is_3d : obj_pad = np .pad (obj_rescale , [z_add , y_add , x_add ])
221+ else : obj_pad = np .pad (obj_rescale , [y_add , x_add ])
186222 imwrite (f'{ output_dir } /{ os .path .splitext (os .path .basename (args [0 ]))[0 ]} _obj_{ obj_num } _class_{ label_str } .tif' , obj_pad )
187223
188224
@@ -241,7 +277,7 @@ def _validate_paths(dir):
241277
242278####################################################################################################
243279
244- def generate_images (image_dir , mask_dir , output_dir = None , return_df = False , num_cores = None ):
280+ def generate_images (image_dir , mask_dir , output_dir = None , return_df = False , num_cores = None , is_3d = True ):
245281 '''
246282 Generates unlabeled single-nucleus images from tiled data for input to model for prediction only.
247283 Args:
@@ -278,15 +314,15 @@ def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_c
278314 mask_fns = sorted ([os .path .join (mask_dir , fn ) for fn in os .listdir (mask_dir )])
279315
280316 # calculate normalization and scaling factors #
281- norm_factor , scale_factor = _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores )
317+ norm_factor , scale_factor = _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores , is_3d )
282318
283319 # generate SNIs #
284320 if num_cores == None :
285321 for image_fn , mask_fn in zip (image_fns , mask_fns ):
286- _gen_SNI ([image_fn , mask_fn , output_dir , norm_factor , scale_factor ])
322+ _gen_SNI ([image_fn , mask_fn , output_dir , norm_factor , scale_factor , is_3d ])
287323 else :
288324 with Pool (num_cores ) as pool :
289- pool .map (_gen_SNI , [(image_fn , mask_fn , output_dir , norm_factor , scale_factor ) for image_fn , mask_fn in zip (image_fns , mask_fns )])
325+ pool .map (_gen_SNI , [(image_fn , mask_fn , output_dir , norm_factor , scale_factor , is_3d ) for image_fn , mask_fn in zip (image_fns , mask_fns )])
290326
291327 # generate DF #
292328 SNI_fns = sorted ([os .path .join (output_dir , fn ) for fn in os .listdir (output_dir ) if os .path .isfile (os .path .join (output_dir , fn ))])
@@ -295,7 +331,7 @@ def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_c
295331
296332####################################################################################################
297333
298- def generate_images_labeled (image_dir , mask_dir , label_dir , output_dir = None , return_df = False , num_cores = None ):
334+ def generate_images_labeled (image_dir , mask_dir , label_dir , output_dir = None , return_df = False , num_cores = None , is_3d = True ):
299335 '''
300336 Generates labeled single-nucleus images from tiled data for input to model for prediction or fine-tuning.
301337 Args:
@@ -335,15 +371,15 @@ def generate_images_labeled(image_dir, mask_dir, label_dir, output_dir=None, ret
335371 label_fns = sorted ([os .path .join (label_dir , fn ) for fn in os .listdir (label_dir )])
336372
337373 # calculate normalization and scaling factors #
338- norm_factor , scale_factor = _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores )
374+ norm_factor , scale_factor = _calc_norm_and_scale_factor (image_fns , mask_fns , num_cores , is_3d )
339375
340376 # generate SNIs #
341377 if num_cores == None :
342378 for image_fn , mask_fn , label_fn in zip (image_fns , mask_fns , label_fns ):
343- _gen_SNI_label ([image_fn , mask_fn , label_fn , output_dir , norm_factor , scale_factor ])
379+ _gen_SNI_label ([image_fn , mask_fn , label_fn , output_dir , norm_factor , scale_factor , is_3d ])
344380 else :
345381 with Pool (num_cores ) as pool :
346- pool .map (_gen_SNI_label , [(image_fn , mask_fn , label_fn , output_dir , norm_factor , scale_factor ) for image_fn , mask_fn , label_fn in zip (image_fns , mask_fns , label_fns )])
382+ pool .map (_gen_SNI_label , [(image_fn , mask_fn , label_fn , output_dir , norm_factor , scale_factor , is_3d ) for image_fn , mask_fn , label_fn in zip (image_fns , mask_fns , label_fns )])
347383
348384 # generate DF #
349385 SNI_fns = sorted ([os .path .join (output_dir , fn ) for fn in os .listdir (output_dir ) if os .path .isfile (os .path .join (output_dir , fn ))])
0 commit comments