Skip to content

Commit 50b41f9

Browse files
committed
updated spine model and error checks in spine pipeline
1 parent 7009326 commit 50b41f9

2 files changed

Lines changed: 42 additions & 11 deletions

File tree

bin/C2C

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ from comp2comp.utils.process import process_3d
2929
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3030

3131
### AAA Pipeline
32-
3332
def AAAPipelineBuilder(path, args):
3433
pipeline = InferencePipeline(
3534
[
@@ -185,7 +184,7 @@ def argument_parser():
185184

186185
# Spine
187186
spine_parser = subparsers.add_parser("spine", parents=[base_parser])
188-
spine_parser.add_argument("--spine_model", default="ts_spine", type=str)
187+
spine_parser.add_argument("--spine_model", default="stanford_spine_v0.0.1", type=str)
189188

190189
# Spine + muscle + fat
191190
spine_muscle_adipose_tissue_parser = subparsers.add_parser(

comp2comp/spine/spine_utils.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_slices(seg: np.ndarray, centroids: Dict, spine_model_type):
9797
for level in centroids:
9898
label_idx = spine_model_type.categories[level]
9999
binary_seg = (seg[centroids[level], :, :] == label_idx).astype(int)
100+
# FIXME This should maybe throw an error if >200 is not met
100101
if (
101102
np.sum(binary_seg) > 200
102103
): # heuristic to make sure enough of the body is showing
@@ -223,49 +224,77 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray)
223224
roi_start_time = time.time()
224225

225226
mask = None
227+
updated_mask = None
226228
inferior_superior_line = seg[int(centroid[0]), int(centroid[1]), :]
227229
# get the center point
228230
updated_z_center = np.mean(np.where(inferior_superior_line == 1))
229231
lower_z_idx = updated_z_center - ((length_k * 1.5) // 2)
230232
upper_z_idx = updated_z_center + ((length_k * 1.5) // 2)
233+
231234
for idx in range(int(lower_z_idx), int(upper_z_idx) + 1):
235+
print(f'idx: {idx}')
232236
# take multiple to increase robustness
233237
posterior_anterior_lines = [
234238
slice[:, idx],
235239
slice[:, idx + 1],
236240
slice[:, idx - 1],
237241
]
238-
posterior_anterior_sums = [
242+
posterior_anterior_sums = np.array([
239243
np.sum(posterior_anterior_lines[0]),
240244
np.sum(posterior_anterior_lines[1]),
241245
np.sum(posterior_anterior_lines[2]),
242-
]
243-
min_idx = np.argmin(posterior_anterior_sums)
246+
])
247+
248+
if posterior_anterior_sums.sum() == 0:
249+
print(f'skipped idx {idx} since posterior_anterior_sums where 0')
250+
continue
251+
252+
# min_idx = np.argmin(posterior_anterior_sums)
253+
254+
# ensure the posterior_anterior_sums array is not zero
255+
nonzero_mask = posterior_anterior_sums != 0
256+
min_idx = np.argmin(posterior_anterior_sums[nonzero_mask])
257+
# Convert back to original indices
258+
min_idx = np.where(nonzero_mask)[0][min_idx]
244259

245260
posterior_anterior_line = posterior_anterior_lines[min_idx]
261+
246262
updated_posterior_anterior_center = (
247263
np.min(np.where(posterior_anterior_line == 1))
248264
+ np.sum(posterior_anterior_line) * 0.58
249265
)
250266
posterior_anterior_length = (posterior_anterior_sums[min_idx] * 0.5) // 2
267+
posterior_anterior_length = max(posterior_anterior_length, 1)
251268

252269
left_right_lines = [
253270
seg[:, int(updated_posterior_anterior_center), idx],
254271
seg[:, int(updated_posterior_anterior_center) + 1, idx],
255272
seg[:, int(updated_posterior_anterior_center) - 1, idx],
256273
]
257274

258-
left_right_sums = [
275+
left_right_sums = np.array([
259276
np.sum(left_right_lines[0]),
260277
np.sum(left_right_lines[1]),
261278
np.sum(left_right_lines[2]),
262-
]
279+
])
280+
281+
if left_right_sums.sum() == 0:
282+
print(f'skipped idx {idx} since left_right_sums where 0')
283+
continue
284+
285+
# min_idx = np.argmin(left_right_sums)
286+
287+
# ensure the posterior_anterior_sums array is not zero
288+
nonzero_mask = left_right_sums != 0
289+
min_idx = np.argmin(left_right_sums[nonzero_mask])
290+
# Convert back to original indices
291+
min_idx = np.where(nonzero_mask)[0][min_idx]
263292

264-
min_idx = np.argmin(left_right_sums)
265293
left_right_line = left_right_lines[min_idx]
266294

267295
updated_left_right_center = np.mean(np.where(left_right_line == 1))
268296
left_right_length = (left_right_sums[min_idx] * 0.65) // 2
297+
left_right_length = max(left_right_length, 1)
269298

270299
roi_2d = np.zeros((img_np.shape[0], img_np.shape[1]))
271300
h = updated_left_right_center
@@ -286,10 +315,11 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray)
286315
for y in range(y_min - 2, y_max + 2):
287316
if ((x - h) / a) ** 2 + ((y - k) / b) ** 2 <= 1:
288317
roi_2d[x, y] = 1
318+
289319
roi[:, :, idx] = roi_2d
290320
if idx == int(centroid[2]):
291321
mask = np.flip(np.flip(np.transpose(roi_2d), axis=0), axis=1)
292-
if idx == updated_z_center:
322+
if idx == int(updated_z_center):
293323
updated_mask = np.flip(np.flip(np.transpose(roi_2d), axis=0), axis=1)
294324
if mask is None:
295325
mask = updated_mask
@@ -309,10 +339,10 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray)
309339
end_time = time.time()
310340
roi_end_time = time.time()
311341
elapsed_time = end_time - start_time
312-
print(f"Elapsed time for erosion operation: {elapsed_time} seconds")
342+
print(f"Elapsed time for erosion operation: {elapsed_time:.1f} seconds")
313343

314344
elapsed_time = roi_end_time - roi_start_time
315-
print(f"Elapsed time for full ROI computation: {elapsed_time} seconds")
345+
print(f"Elapsed time for full ROI computation: {elapsed_time:.1f} seconds")
316346

317347
return roi, mask
318348

@@ -375,6 +405,8 @@ def compute_rois(seg, img, spine_model_type):
375405
slice = slices[level]
376406
center_of_mass = compute_center_of_mass(slice)
377407
centroid = np.array([centroids[level], center_of_mass[1], center_of_mass[0]])
408+
print(f'Processing i={i} at level={level}')
409+
378410
roi, mask_2d = roi_from_mask(
379411
img,
380412
centroid,

0 commit comments

Comments
 (0)