Skip to content

Commit 09f4267

Browse files
committed
added support for 2d images in utils functions
1 parent 1bd502f commit 09f4267

1 file changed

Lines changed: 71 additions & 35 deletions

File tree

cellcyclenet/utils.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _load_and_calc_params(args):
2828
# load image + mask #
2929
image_fn = args[0]
3030
mask_fn = args[1]
31+
is_3d = args[2]
3132
image = imread(image_fn)
3233
mask = imread(mask_fn)
3334

@@ -37,28 +38,34 @@ def _load_and_calc_params(args):
3738

3839
# calculate median nuclear dims #
3940
prop = regionprops(mask)
40-
dims = np.array([(prop[k].bbox[3]-prop[k].bbox[0], # Z
41-
prop[k].bbox[4]-prop[k].bbox[1], # Y
42-
prop[k].bbox[5]-prop[k].bbox[2]) # X
43-
for k in range(len(prop))])
41+
if is_3d:
42+
dims = np.array([(prop[k].bbox[3]-prop[k].bbox[0], # Z
43+
prop[k].bbox[4]-prop[k].bbox[1], # Y
44+
prop[k].bbox[5]-prop[k].bbox[2]) # X
45+
for k in range(len(prop))])
46+
else:
47+
dims = np.array([(prop[k].bbox[2]-prop[k].bbox[0], # Y
48+
prop[k].bbox[3]-prop[k].bbox[1]) # X
49+
for k in range(len(prop))])
4450
median_dims = np.median(dims, axis=0).astype(int)
4551

4652
return (median_pixel, median_dims)
4753

4854

49-
def _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores):
55+
def _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d):
5056
'''
5157
Uses DAPI images + corresponding masks to calculate median non-zero pixel value + median nuclear diameter across entire dataset.
5258
Args:
5359
- image_fns [list] : filenames for DAPI images (format as 'tile_n.tif')
5460
- mask_fns [list] : filenames for segmentation masks (format as 'mask_n.tif')
5561
- num_cores [int] : number of cores to use for parallel processing (default: None uses a single core)
62+
- is_3d [bool] : flag to determine if images are 3D or 2D
5663
Out:
5764
- norm_factor [float] : median non-zero pixel value across entire dataset
5865
- scale_factor [np.ndarray] : downsampling factors along each axis to apply to user's images such that they match dim. of pretraining images
5966
'''
6067
# package image and mask fns together for multiprocessing compatiability #
61-
image_and_mask_fns = [(i_fn, m_fn) for i_fn, m_fn in zip(image_fns, mask_fns)]
68+
image_and_mask_fns = [(i_fn, m_fn, is_3d) for i_fn, m_fn in zip(image_fns, mask_fns)]
6269

6370
# compute median non-zero pixel value for each tile #
6471
if num_cores == None:
@@ -75,7 +82,10 @@ def _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores):
7582
norm_factor = np.median(median_pixels)
7683

7784
# scale factor is median of nuclear dims #
78-
pretrained_dims = np.array([16, 37, 37])
85+
if is_3d:
86+
pretrained_dims = np.array([16, 37, 37])
87+
else:
88+
pretrained_dims = np.array([37, 37])
7989
user_dims = np.median(median_dims, axis=0)
8090
scale_factor = user_dims / pretrained_dims
8191

@@ -98,6 +108,7 @@ def _gen_SNI(args):
98108
output_dir = args[2]
99109
norm_factor = args[3]
100110
scale_factor = args[4]
111+
is_3d = args[5]
101112

102113
# for each object in the mask... #
103114
obj_nums = np.unique(mask)[1:]
@@ -108,10 +119,15 @@ def _gen_SNI(args):
108119
obj = np.where(obj_mask, image, 0)
109120

110121
# crop out empty rows/columns/planes #
111-
full_rows = np.any(obj, axis=(1,2))
112-
full_columns = np.any(obj, axis=(0,2))
113-
full_stacks = np.any(obj, axis=(0,1))
114-
obj_crop = obj[full_rows][:, full_columns][:, :, full_stacks]
122+
if is_3d:
123+
full_rows = np.any(obj, axis=(1,2))
124+
full_columns = np.any(obj, axis=(0,2))
125+
full_stacks = np.any(obj, axis=(0,1))
126+
obj_crop = obj[full_rows][:, full_columns][:, :, full_stacks]
127+
else:
128+
full_rows = np.any(obj, axis=1)
129+
full_columns = np.any(obj, axis=0)
130+
obj_crop = obj[full_rows][:, full_columns]
115131

116132
# normalize image based on norm factor #
117133
obj_norm = obj_crop / norm_factor
@@ -121,14 +137,21 @@ def _gen_SNI(args):
121137
obj_rescale = resize_local_mean(obj_norm, out_dims)
122138

123139
# SNIs will be padded to have 1/8 of their diameter on each side #
124-
obj_z, obj_y, obj_x = obj_rescale.shape
125-
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
126-
127-
# pad image + save #
140+
if is_3d:
141+
obj_z, obj_y, obj_x = obj_rescale.shape
142+
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
143+
else:
144+
obj_y, obj_x = obj_rescale.shape
145+
y, x = [1.25*dim for dim in [obj_y, obj_x]]
146+
147+
# calculate padding #
128148
x_add = (int(floor((x - obj_x) / 2)), int(ceil((x - obj_x) / 2)))
129149
y_add = (int(floor((y - obj_y) / 2)), int(ceil((y - obj_y) / 2)))
130-
z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
131-
obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
150+
if is_3d: z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
151+
152+
# pad image + save #
153+
if is_3d: obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
154+
else: obj_pad = np.pad(obj_rescale, [y_add, x_add])
132155
imwrite(f'{output_dir}/{os.path.splitext(os.path.basename(args[0]))[0]}_obj_{obj_num}.tif', obj_pad)
133156

134157

@@ -148,6 +171,7 @@ def _gen_SNI_label(args):
148171
output_dir = args[3]
149172
norm_factor = args[4]
150173
scale_factor = args[5]
174+
is_3d = args[6]
151175

152176
# for each object in the mask... #
153177
obj_nums = np.unique(mask)[1:]
@@ -162,10 +186,15 @@ def _gen_SNI_label(args):
162186
label_str = 'G1' if label == 1 else 'S-G2'
163187

164188
# crop out empty rows/columns/planes #
165-
full_rows = np.any(obj, axis=(1,2))
166-
full_columns = np.any(obj, axis=(0,2))
167-
full_stacks = np.any(obj, axis=(0,1))
168-
obj_crop = obj[full_rows][:, full_columns][:, :, full_stacks]
189+
if is_3d:
190+
full_rows = np.any(obj, axis=(1,2))
191+
full_columns = np.any(obj, axis=(0,2))
192+
full_stacks = np.any(obj, axis=(0,1))
193+
obj_crop = obj[full_rows][:, full_columns][:, :, full_stacks]
194+
else:
195+
full_rows = np.any(obj, axis=1)
196+
full_columns = np.any(obj, axis=0)
197+
obj_crop = obj[full_rows][:, full_columns]
169198

170199
# normalize image based on norm factor #
171200
obj_norm = obj_crop / norm_factor
@@ -175,14 +204,21 @@ def _gen_SNI_label(args):
175204
obj_rescale = resize_local_mean(obj_norm, out_dims)
176205

177206
# SNIs will be padded to have 1/8 of their diameter on each side #
178-
obj_z, obj_y, obj_x = obj_rescale.shape
179-
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
180-
181-
# pad image + save #
207+
if is_3d:
208+
obj_z, obj_y, obj_x = obj_rescale.shape
209+
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
210+
else:
211+
obj_y, obj_x = obj_rescale.shape
212+
y, x = [1.25*dim for dim in [obj_y, obj_x]]
213+
214+
# calculate padding #
182215
x_add = (int(floor((x - obj_x) / 2)), int(ceil((x - obj_x) / 2)))
183216
y_add = (int(floor((y - obj_y) / 2)), int(ceil((y - obj_y) / 2)))
184-
z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
185-
obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
217+
if is_3d: z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
218+
219+
# pad image + save #
220+
if is_3d: obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
221+
else: obj_pad = np.pad(obj_rescale, [y_add, x_add])
186222
imwrite(f'{output_dir}/{os.path.splitext(os.path.basename(args[0]))[0]}_obj_{obj_num}_class_{label_str}.tif', obj_pad)
187223

188224

@@ -241,7 +277,7 @@ def _validate_paths(dir):
241277

242278
####################################################################################################
243279

244-
def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_cores=None):
280+
def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_cores=None, is_3d=True):
245281
'''
246282
Generates unlabeled single-nucleus images from tiled data for input to model for prediction only.
247283
Args:
@@ -278,15 +314,15 @@ def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_c
278314
mask_fns = sorted([os.path.join(mask_dir, fn) for fn in os.listdir(mask_dir)])
279315

280316
# calculate normalization and scaling factors #
281-
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores)
317+
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
282318

283319
# generate SNIs #
284320
if num_cores == None:
285321
for image_fn, mask_fn in zip(image_fns, mask_fns):
286-
_gen_SNI([image_fn, mask_fn, output_dir, norm_factor, scale_factor])
322+
_gen_SNI([image_fn, mask_fn, output_dir, norm_factor, scale_factor, is_3d])
287323
else:
288324
with Pool(num_cores) as pool:
289-
pool.map(_gen_SNI, [(image_fn, mask_fn, output_dir, norm_factor, scale_factor) for image_fn, mask_fn in zip(image_fns, mask_fns)])
325+
pool.map(_gen_SNI, [(image_fn, mask_fn, output_dir, norm_factor, scale_factor, is_3d) for image_fn, mask_fn in zip(image_fns, mask_fns)])
290326

291327
# generate DF #
292328
SNI_fns = sorted([os.path.join(output_dir, fn) for fn in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, fn))])
@@ -295,7 +331,7 @@ def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_c
295331

296332
####################################################################################################
297333

298-
def generate_images_labeled(image_dir, mask_dir, label_dir, output_dir=None, return_df=False, num_cores=None):
334+
def generate_images_labeled(image_dir, mask_dir, label_dir, output_dir=None, return_df=False, num_cores=None, is_3d=True):
299335
'''
300336
Generates labeled single-nucleus images from tiled data for input to model for prediction or fine-tuning.
301337
Args:
@@ -335,15 +371,15 @@ def generate_images_labeled(image_dir, mask_dir, label_dir, output_dir=None, ret
335371
label_fns = sorted([os.path.join(label_dir, fn) for fn in os.listdir(label_dir)])
336372

337373
# calculate normalization and scaling factors #
338-
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores)
374+
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
339375

340376
# generate SNIs #
341377
if num_cores == None:
342378
for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns):
343-
_gen_SNI_label([image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor])
379+
_gen_SNI_label([image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, is_3d])
344380
else:
345381
with Pool(num_cores) as pool:
346-
pool.map(_gen_SNI_label, [(image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor) for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns)])
382+
pool.map(_gen_SNI_label, [(image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, is_3d) for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns)])
347383

348384
# generate DF #
349385
SNI_fns = sorted([os.path.join(output_dir, fn) for fn in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, fn))])

0 commit comments

Comments
 (0)