Skip to content

Commit d71a2d1

Browse files
Add point postprocessing (#27)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] In-line docstrings updated. --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5aa1472 commit d71a2d1

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

scripts/infer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from .sliding_window import point_based_window_inferer, sliding_window_inference
3232
from .train import CONFIG
33-
from .utils.trans_utils import VistaPostTransform
33+
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point
3434

3535
rearrange, _ = optional_import("einops", name="rearrange")
3636
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
@@ -168,7 +168,8 @@ def infer(
168168
batch_data = self.batch_data
169169
else:
170170
batch_data = self.infer_transforms(image_file)
171-
batch_data["label_prompt"] = label_prompt
171+
if label_prompt is not None:
172+
batch_data["label_prompt"] = label_prompt
172173
batch_data = list_data_collate([batch_data])
173174
self.batch_data = batch_data
174175
if point is not None:
@@ -231,6 +232,10 @@ def infer(
231232
meta=batch_data["image"].meta,
232233
)
233234
self.prev_mask = batch_data["pred"]
235+
if label_prompt is None and point is not None:
236+
batch_data["pred"] = get_largest_connected_component_point(
237+
batch_data["pred"], point_coords=point, point_labels=point_label
238+
)
234239
batch_data["image"] = batch_data["image"].to("cpu")
235240
batch_data["pred"] = batch_data["pred"].to("cpu")
236241
torch.cuda.empty_cache()

scripts/utils/trans_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def dilate3d(input_tensor, erosion=3):
195195

196196

197197
def get_largest_connected_component_point(
198-
img: NdarrayTensor, point_coords=None, point_labels=None, post_idx=3
198+
img: NdarrayTensor, point_coords=None, point_labels=None
199199
) -> NdarrayTensor:
200200
"""
201201
Gets the largest connected component mask of an image. img is before post process! And will include NaN values.

0 commit comments

Comments
 (0)