Skip to content

Commit 8c2b473

Browse files
committed
SNIs now generated with same dimensions
1 parent ed00093 commit 8c2b473

1 file changed

Lines changed: 38 additions & 31 deletions

File tree

cellcyclenet/utils.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d):
8787
else:
8888
pretrained_dims = np.array([37, 37])
8989
user_dims = np.median(median_dims, axis=0)
90+
max_dims = np.max(median_dims, axis=0)
9091
scale_factor = user_dims / pretrained_dims
9192

92-
return norm_factor, scale_factor
93+
return norm_factor, scale_factor, max_dims
9394

9495
####################################################################################################
9596

@@ -108,7 +109,12 @@ def _gen_SNI(args):
108109
output_dir = args[2]
109110
norm_factor = args[3]
110111
scale_factor = args[4]
111-
is_3d = args[5]
112+
max_dims = args[5]
113+
is_3d = args[6]
114+
115+
# set padding to 1/8 of max dims #
116+
if is_3d: z, y, x = [1.05*dim for dim in max_dims]
117+
else: y, x = [1.05*dim for dim in max_dims]
112118

113119
# for each object in the mask... #
114120
obj_nums = np.unique(mask)[1:]
@@ -136,22 +142,20 @@ def _gen_SNI(args):
136142
out_dims = np.array([int(in_dim / scale) for in_dim, scale in zip(obj_crop.shape, scale_factor)])
137143
obj_rescale = resize_local_mean(obj_norm, out_dims)
138144

139-
# SNIs will be padded to have 1/8 of their diameter on each side #
145+
# pad image to have 1/8 of its diameter on each side #
140146
if is_3d:
141147
obj_z, obj_y, obj_x = obj_rescale.shape
142-
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
148+
z_pad = (int((z - obj_z) / 2), int((z - obj_z + 1) / 2))
149+
y_pad = (int((y - obj_y) / 2), int((y - obj_y + 1) / 2))
150+
x_pad = (int((x - obj_x) / 2), int((x - obj_x + 1) / 2))
151+
obj_pad = np.pad(obj_rescale, [z_pad, y_pad, x_pad])
143152
else:
144153
obj_y, obj_x = obj_rescale.shape
145-
y, x = [1.25*dim for dim in [obj_y, obj_x]]
146-
147-
# calculate padding #
148-
x_add = (int(floor((x - obj_x) / 2)), int(ceil((x - obj_x) / 2)))
149-
y_add = (int(floor((y - obj_y) / 2)), int(ceil((y - obj_y) / 2)))
150-
if is_3d: z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
154+
y_pad = (int((y - obj_y) / 2), int((y - obj_y + 1) / 2))
155+
x_pad = (int((x - obj_x) / 2), int((x - obj_x + 1) / 2))
156+
obj_pad = np.pad(obj_rescale, [y_pad, x_pad])
151157

152-
# pad image + save #
153-
if is_3d: obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
154-
else: obj_pad = np.pad(obj_rescale, [y_add, x_add])
158+
# save #
155159
imwrite(f'{output_dir}/{os.path.splitext(os.path.basename(args[0]))[0]}_obj_{obj_num}.tif', obj_pad)
156160

157161

@@ -171,7 +175,12 @@ def _gen_SNI_label(args):
171175
output_dir = args[3]
172176
norm_factor = args[4]
173177
scale_factor = args[5]
174-
is_3d = args[6]
178+
max_dims = args[6]
179+
is_3d = args[7]
180+
181+
# set padding to 1/20th of max dims #
182+
if is_3d: z, y, x = [1.05*dim for dim in max_dims]
183+
else: y, x = [1.05*dim for dim in max_dims]
175184

176185
# for each object in the mask... #
177186
obj_nums = np.unique(mask)[1:]
@@ -203,22 +212,20 @@ def _gen_SNI_label(args):
203212
out_dims = np.array([int(in_dim / scale) for in_dim, scale in zip(obj_crop.shape, scale_factor)])
204213
obj_rescale = resize_local_mean(obj_norm, out_dims)
205214

206-
# SNIs will be padded to have 1/8 of their diameter on each side #
215+
# pad image to have 1/8 of its diameter on each side #
207216
if is_3d:
208217
obj_z, obj_y, obj_x = obj_rescale.shape
209-
z, y, x = [1.25*dim for dim in [obj_z, obj_y, obj_x]]
218+
z_pad = (int((z - obj_z) / 2), int((z - obj_z + 1) / 2))
219+
y_pad = (int((y - obj_y) / 2), int((y - obj_y + 1) / 2))
220+
x_pad = (int((x - obj_x) / 2), int((x - obj_x + 1) / 2))
221+
obj_pad = np.pad(obj_rescale, [z_pad, y_pad, x_pad])
210222
else:
211223
obj_y, obj_x = obj_rescale.shape
212-
y, x = [1.25*dim for dim in [obj_y, obj_x]]
213-
214-
# calculate padding #
215-
x_add = (int(floor((x - obj_x) / 2)), int(ceil((x - obj_x) / 2)))
216-
y_add = (int(floor((y - obj_y) / 2)), int(ceil((y - obj_y) / 2)))
217-
if is_3d: z_add = (int(floor((z - obj_z) / 2)), int(ceil((z - obj_z) / 2)))
224+
y_pad = (int((y - obj_y) / 2), int((y - obj_y + 1) / 2))
225+
x_pad = (int((x - obj_x) / 2), int((x - obj_x + 1) / 2))
226+
obj_pad = np.pad(obj_rescale, [y_pad, x_pad])
218227

219-
# pad image + save #
220-
if is_3d: obj_pad = np.pad(obj_rescale, [z_add, y_add, x_add])
221-
else: obj_pad = np.pad(obj_rescale, [y_add, x_add])
228+
# save #
222229
imwrite(f'{output_dir}/{os.path.splitext(os.path.basename(args[0]))[0]}_obj_{obj_num}_class_{label_str}.tif', obj_pad)
223230

224231

@@ -314,15 +321,15 @@ def generate_images(image_dir, mask_dir, output_dir=None, return_df=False, num_c
314321
mask_fns = sorted([os.path.join(mask_dir, fn) for fn in os.listdir(mask_dir)])
315322

316323
# calculate normalization and scaling factors #
317-
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
324+
norm_factor, scale_factor, max_dims = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
318325

319326
# generate SNIs #
320327
if num_cores == None:
321328
for image_fn, mask_fn in zip(image_fns, mask_fns):
322-
_gen_SNI([image_fn, mask_fn, output_dir, norm_factor, scale_factor, is_3d])
329+
_gen_SNI([image_fn, mask_fn, output_dir, norm_factor, scale_factor, max_dims, is_3d])
323330
else:
324331
with Pool(num_cores) as pool:
325-
pool.map(_gen_SNI, [(image_fn, mask_fn, output_dir, norm_factor, scale_factor, is_3d) for image_fn, mask_fn in zip(image_fns, mask_fns)])
332+
pool.map(_gen_SNI, [(image_fn, mask_fn, output_dir, norm_factor, scale_factor, max_dims, is_3d) for image_fn, mask_fn in zip(image_fns, mask_fns)])
326333

327334
# generate DF #
328335
SNI_fns = sorted([os.path.join(output_dir, fn) for fn in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, fn))])
@@ -371,15 +378,15 @@ def generate_images_labeled(image_dir, mask_dir, label_dir, output_dir=None, ret
371378
label_fns = sorted([os.path.join(label_dir, fn) for fn in os.listdir(label_dir)])
372379

373380
# calculate normalization and scaling factors #
374-
norm_factor, scale_factor = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
381+
norm_factor, scale_factor, max_dims = _calc_norm_and_scale_factor(image_fns, mask_fns, num_cores, is_3d)
375382

376383
# generate SNIs #
377384
if num_cores == None:
378385
for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns):
379-
_gen_SNI_label([image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, is_3d])
386+
_gen_SNI_label([image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, max_dims, is_3d])
380387
else:
381388
with Pool(num_cores) as pool:
382-
pool.map(_gen_SNI_label, [(image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, is_3d) for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns)])
389+
pool.map(_gen_SNI_label, [(image_fn, mask_fn, label_fn, output_dir, norm_factor, scale_factor, max_dims, is_3d) for image_fn, mask_fn, label_fn in zip(image_fns, mask_fns, label_fns)])
383390

384391
# generate DF #
385392
SNI_fns = sorted([os.path.join(output_dir, fn) for fn in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, fn))])

0 commit comments

Comments
 (0)