Skip to content

Commit 573596a

Browse files
authored
Fix inference with cpu only configs (#9)
### Description Exciting work here! I noticed inference has some hard coded `.cuda()` calls. The "everything" segmentation isn't really feasible on CPU, and there seem to be some issues with multiple-prompt infers on my end, but this at least allows trying out single prompt infers with CPU only configs. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality).
1 parent 7aca195 commit 573596a

3 files changed

Lines changed: 27 additions & 11 deletions

File tree

monailabel/monaivista/lib/infers/vista_point_2pt5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
6464
]
6565

6666
def inferer(self, data=None) -> Inferer:
67-
return VISTASliceInferer()
67+
return VISTASliceInferer(device=data.get("device") if data else None)
6868

6969
def inverse_transforms(self, data=None):
7070
return []

monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,21 @@ def update_slice(
238238
continue
239239

240240
inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1)
241+
if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"):
242+
inputs = inputs.cuda()
241243
data, unique_labels = prepare_sam_val_input(
242-
inputs.cuda(), class_prompts, point_prompts, start_idx, original_affine
244+
inputs, class_prompts, point_prompts, start_idx, original_affine, device=device
243245
)
244246

245247
predictor.eval()
246-
with torch.cuda.amp.autocast():
247-
outputs = predictor(data)
248-
logit = outputs[0]["high_res_logits"]
248+
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
249+
with torch.cuda.amp.autocast():
250+
outputs = predictor(data)
251+
logit = outputs[0]["high_res_logits"]
252+
else:
253+
with torch.cpu.amp.autocast():
254+
outputs = predictor(data)
255+
logit = outputs[0]["high_res_logits"]
249256

250257
out_list = torch.unbind(logit, dim=0)
251258
y_pred = torch.stack(post_pred_slice(out_list)).float()
@@ -290,11 +297,15 @@ def iterate_all(
290297
)
291298
for start_idx in start_range:
292299
inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1)
293-
data, unique_labels = prepare_sam_val_input(inputs.cuda(), class_prompts, point_prompts, start_idx)
300+
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
301+
inputs = inputs.cuda()
302+
data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device)
294303
predictor = predictor.eval()
295304
with autocast():
296305
if cachedEmbedding:
297-
curr_embedding = cachedEmbedding[start_idx].cuda()
306+
curr_embedding = cachedEmbedding[start_idx]
307+
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
308+
curr_embedding = curr_embedding.cuda()
298309
outputs = predictor.get_mask_prediction(data, curr_embedding)
299310
else:
300311
outputs = predictor(data)

monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@ def distributed_all_gather(
8080
return tensor_list_out
8181

8282

83-
def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None):
83+
def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None, device=None):
8484
# Don't exclude background in val but will ignore it in metric calculation
8585
H, W = inputs.shape[1:]
8686
foreground_all = point_prompts["foreground"]
8787
background_all = point_prompts["background"]
8888

8989
class_list = [[i + 1] for i in class_prompts]
90-
unique_labels = torch.tensor(class_list).long().cuda()
90+
unique_labels = torch.tensor(class_list).long()
91+
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
92+
unique_labels = unique_labels.cuda()
9193

9294
volume_point_coords = [cp for cp in foreground_all]
9395
volume_point_labels = [1] * len(foreground_all)
@@ -129,8 +131,11 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi
129131
prepared_input[0].update({"labels": unique_labels})
130132

131133
if point_coords:
132-
point_coords = torch.tensor(point_coords).long().cuda()
133-
point_labels = torch.tensor(point_labels).long().cuda()
134+
point_coords = torch.tensor(point_coords).long()
135+
point_labels = torch.tensor(point_labels).long()
136+
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
137+
point_coords = point_coords.cuda()
138+
point_labels = point_labels.cuda()
134139

135140
prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels})
136141

0 commit comments

Comments
 (0)