Skip to content

Commit a82ce56

Browse files
committed
Merge remote-tracking branch 'origin/vista3d' into vista3d-export
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
2 parents d854bc5 + 3aaa4e4 commit a82ce56

4 files changed

Lines changed: 28 additions & 30 deletions

File tree

data/README.md

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ train_files, _, dataset_specific_transforms, dataset_specific_transforms_val = \
4545

4646
The following steps are necessary for creating a multi-dataset data loader for model training.
4747
Step 1 and 2 generate persistent JSON files based on the original dataset (the `image` and `label` pairs; without the additional pseudo label or supervoxel-based label), and only need to be run once when the JSON files don't exist.
48-
Step 3 is optional for generating overall data analysis stats.
4948

5049
##### 1. Generate data list JSON file
5150
```
@@ -73,34 +72,16 @@ creates a JSON file in a format:
7372
```
7473

7574
This step includes a 5-fold cross validation splitting and
76-
some logic for 80-20 training/testing splitting.
75+
some logic for 80-20 training/testing splitting. User need to modify the code in make_datalists.py for their own dataset. Meanwhile, the "training_transform" should manually added for each dataset.
7776

7877
The `original_label_dict` corresponds to the original dataset label definitions.
7978
The `label_dict` modifies `original_label_dict` by simply rephrasing the terms.
8079
For example in Task06, `cancer` is renamed to `lung tumor`.
8180
The output of this step is multiple JSON files, each file corresponds
8281
to one dataset.
8382

84-
85-
##### 2. Verify data pairs and generate a global label dictionary
86-
```
87-
python -m data.datasets
88-
```
89-
90-
This script computes a super set of labels from all the dataset JSON files.
91-
The output of this step is a `jsons/label_dict.json` file,
92-
representing the global label dictionary mapping, from class names to globally unique class indices (integers).
93-
94-
95-
##### 3. Compute class frequencies, data transform utilities
96-
```
97-
python -m data.analyzer ...
98-
```
99-
100-
This file (`data/analyzer.py`) contains useful transforms for reading images
101-
and labels, converting labels from dataset-specific labels to the global labels
102-
according to `jsons/label_dict.json`.
103-
83+
##### 2. Add label_dict.json and label_mapping.json
84+
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.
10485

10586
## SupverVoxel Generation
10687
1. Download the segment anything repo and download the ViT-H weights

scripts/export.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from .sliding_window import point_based_window_inferer, sliding_window_inference
3232
from .train import CONFIG
33-
from .utils.trans_utils import VistaPostTransform
33+
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point
3434
from .utils.trt_utils import ExportWrapper, TRTWrapper
3535
import time
3636

@@ -62,6 +62,7 @@ def infer_wrapper(inputs, model, **kwargs):
6262
outputs = model(input_images=inputs, **kwargs)
6363
return outputs.transpose(1, 0)
6464

65+
6566
class InferClass:
6667
def __init__(self, config_file="./configs/infer.yaml", **override):
6768
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -73,7 +74,6 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
7374
parser.read_config(config_file_)
7475
parser.update(pairs=_args)
7576

76-
# We do not use AMP for export
7777
self.amp = parser.get_parsed_content("amp")
7878
input_channels = parser.get_parsed_content("input_channels")
7979
patch_size = parser.get_parsed_content("patch_size")
@@ -182,10 +182,14 @@ def infer(
182182
batch_data = self.batch_data
183183
else:
184184
batch_data = self.infer_transforms(image_file)
185-
batch_data["label_prompt"] = label_prompt
185+
if label_prompt is not None:
186+
batch_data["label_prompt"] = label_prompt
186187
batch_data = list_data_collate([batch_data])
187188
self.batch_data = batch_data
188189
if point is not None:
190+
if type(point) is list:
191+
point = np.array(point)[np.newaxis, ...]
192+
point_label = np.array(point_label)[np.newaxis, ...]
189193
point = self.transform_points(
190194
point,
191195
np.linalg.inv(batch_data["image"].affine[0])
@@ -245,6 +249,10 @@ def infer(
245249
meta=batch_data["image"].meta,
246250
)
247251
self.prev_mask = batch_data["pred"]
252+
if label_prompt is None and point is not None:
253+
batch_data["pred"] = get_largest_connected_component_point(
254+
batch_data["pred"], point_coords=point, point_labels=point_label
255+
)
248256
batch_data["image"] = batch_data["image"].to("cpu")
249257
batch_data["pred"] = batch_data["pred"].to("cpu")
250258
torch.cuda.empty_cache()

scripts/infer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from .sliding_window import point_based_window_inferer, sliding_window_inference
3232
from .train import CONFIG
33-
from .utils.trans_utils import VistaPostTransform
33+
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point
3434

3535
rearrange, _ = optional_import("einops", name="rearrange")
3636
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
@@ -168,10 +168,14 @@ def infer(
168168
batch_data = self.batch_data
169169
else:
170170
batch_data = self.infer_transforms(image_file)
171-
batch_data["label_prompt"] = label_prompt
171+
if label_prompt is not None:
172+
batch_data["label_prompt"] = label_prompt
172173
batch_data = list_data_collate([batch_data])
173174
self.batch_data = batch_data
174175
if point is not None:
176+
if type(point) is list:
177+
point = np.array(point)[np.newaxis, ...]
178+
point_label = np.array(point_label)[np.newaxis, ...]
175179
point = self.transform_points(
176180
point,
177181
np.linalg.inv(batch_data["image"].affine[0])
@@ -231,6 +235,10 @@ def infer(
231235
meta=batch_data["image"].meta,
232236
)
233237
self.prev_mask = batch_data["pred"]
238+
if label_prompt is None and point is not None:
239+
batch_data["pred"] = get_largest_connected_component_point(
240+
batch_data["pred"], point_coords=point, point_labels=point_label
241+
)
234242
batch_data["image"] = batch_data["image"].to("cpu")
235243
batch_data["pred"] = batch_data["pred"].to("cpu")
236244
torch.cuda.empty_cache()

scripts/utils/trans_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def dilate3d(input_tensor, erosion=3):
195195

196196

197197
def get_largest_connected_component_point(
198-
img: NdarrayTensor, point_coords=None, point_labels=None, post_idx=3
198+
img: NdarrayTensor, point_coords=None, point_labels=None
199199
) -> NdarrayTensor:
200200
"""
201201
Gets the largest connected component mask of an image. img is before post process! And will include NaN values.
@@ -349,8 +349,9 @@ def __call__(
349349
pred += 0.5 # inplace mapping to avoid cloning pred
350350
for i in range(1, object_num + 1):
351351
frac = i + 0.5
352-
pred[pred == frac] = torch.tensor(data["label_prompt"][i - 1]).to(pred.dtype)
353-
# pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype)
352+
pred[pred == frac] = torch.tensor(
353+
data["label_prompt"][i - 1]
354+
).to(pred.dtype)
354355
pred[pred == 0.5] = 0.0
355356
data[keys] = pred
356357
return data

0 commit comments

Comments
 (0)