Skip to content

Commit 66bedda

Browse files
Donglai Weiclaude
andcommitted
Update large decode config, simplify launcher, add justfile targets
- waterz_decoding_large.yaml: rename border params to match face_merge API - decode_large.py: simplify param passing to LargeDecodeRunner - main.py: minor decode-only mode fixes - justfile: add decode-large targets Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3e94085 commit 66bedda

4 files changed

Lines changed: 64 additions & 48 deletions

File tree

justfile

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,45 @@ visualize-files image='' label='' port='9999' *ARGS='':
273273
python -i scripts/visualize_neuroglancer.py $args {{ARGS}}
274274

275275
# Visualize multiple volumes with custom names (e.g., just visualize-volumes image:path/img.tif label:path/lbl.h5)
276+
# Override the default port with `port=8080` or `--port 8080`.
276277
visualize-volumes +volumes:
277-
python -i scripts/visualize_neuroglancer.py --volumes {{volumes}}
278+
#!/usr/bin/env bash
279+
set -euo pipefail
280+
port=9999
281+
args=({{volumes}})
282+
volume_args=()
283+
284+
i=0
285+
while [ $i -lt ${#args[@]} ]; do
286+
arg="${args[$i]}"
287+
case "$arg" in
288+
port=*)
289+
port="${arg#port=}"
290+
;;
291+
--port=*)
292+
port="${arg#--port=}"
293+
;;
294+
--port)
295+
i=$((i + 1))
296+
if [ $i -ge ${#args[@]} ]; then
297+
echo "ERROR: --port requires a value" >&2
298+
exit 1
299+
fi
300+
port="${args[$i]}"
301+
;;
302+
*)
303+
volume_args+=("$arg")
304+
;;
305+
esac
306+
i=$((i + 1))
307+
done
308+
309+
if [ ${#volume_args[@]} -eq 0 ]; then
310+
echo "ERROR: At least one volume spec is required" >&2
311+
exit 1
312+
fi
313+
314+
python -i scripts/visualize_neuroglancer.py --port "$port" --volumes "${volume_args[@]}"
278315

279316
# Visualize with remote access (use 0.0.0.0 for public IP, e.g., just visualize-remote 8080 tutorials/monai_lucchi.yaml)
280317
visualize-remote port config *ARGS='':

scripts/decode_large.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# Serial (single process, all stages)
55
python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml
66
7-
# Initialize workflow only (for parallel launch)
7+
# Parallel (N workers on one machine)
8+
python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --parallel 4
9+
10+
# Initialize workflow only (for SLURM parallel launch)
811
python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --init-only
912
1013
# Run as a worker (claims tasks from shared workflow dir)
@@ -17,7 +20,6 @@
1720
import argparse
1821
import os
1922
import sys
20-
from pathlib import Path
2123

2224
import yaml
2325

@@ -30,16 +32,14 @@ def main():
3032
parser.add_argument("--wait", action="store_true", help="Wait for all tasks to complete")
3133
parser.add_argument("--assemble", action="store_true", help="Assemble final output volume")
3234
parser.add_argument("--parallel", type=int, default=None,
33-
help="Run N worker processes on this machine (e.g. --parallel 8)")
35+
help="Run N worker processes on this machine")
3436
parser.add_argument("--max-tasks", type=int, default=None, help="Max tasks per worker")
3537
parser.add_argument("--idle-timeout", type=float, default=60.0, help="Worker idle timeout (seconds)")
3638
parser.add_argument("--worker-id", type=str, default=None, help="Worker identifier")
3739
parser.add_argument("--job-id", type=str, default=None, help="SLURM job ID")
38-
# Allow CLI overrides in key=value format
3940
parser.add_argument("overrides", nargs="*", help="Config overrides (key=value)")
4041
args = parser.parse_args()
4142

42-
# Load config
4343
with open(args.config) as f:
4444
cfg = yaml.safe_load(f)
4545

@@ -51,7 +51,6 @@ def main():
5151
print(f"Warning: skipping invalid override '{override}' (expected key=value)")
5252
continue
5353
key, value = override.split("=", 1)
54-
# Try numeric conversion
5554
try:
5655
value = int(value)
5756
except ValueError:
@@ -70,44 +69,21 @@ def main():
7069

7170
os.environ.setdefault("CCACHE_DISABLE", "1")
7271

73-
from waterz import LargeDecodeRunner
74-
75-
# Parse config
76-
chunk_shape = large_cfg.get("chunk_shape", [256, 512, 512])
77-
if isinstance(chunk_shape, list):
78-
chunk_shape = tuple(chunk_shape)
79-
80-
thresholds = large_cfg.get("thresholds", [0.5])
81-
if isinstance(thresholds, (int, float)):
82-
thresholds = [thresholds]
83-
84-
runner = LargeDecodeRunner.create(
85-
affinity_path=large_cfg["affinity_path"],
86-
workflow_root=large_cfg["workflow_root"],
87-
chunk_shape=chunk_shape,
88-
thresholds=thresholds,
89-
merge_function=large_cfg.get("merge_function", "aff85_his256"),
90-
aff_threshold_low=float(large_cfg.get("aff_threshold_low", 0.1)),
91-
aff_threshold_high=float(large_cfg.get("aff_threshold_high", 0.999)),
92-
channel_order=large_cfg.get("channel_order", "xyz"),
93-
write_output=bool(large_cfg.get("write_output", True)),
94-
output_path=large_cfg.get("output_path") or None,
95-
min_overlap=int(large_cfg.get("min_overlap", 1)),
96-
iou_threshold=float(large_cfg.get("iou_threshold", 0.0)),
97-
one_sided_threshold=float(large_cfg.get("one_sided_threshold", 0.9)),
98-
one_sided_min_size=int(large_cfg.get("one_sided_min_size", 0)),
99-
affinity_threshold=float(large_cfg.get("affinity_threshold", 0.0)),
100-
compression=large_cfg.get("compression", "gzip"),
101-
compression_level=int(large_cfg.get("compression_level", 4)),
102-
)
72+
from waterz import LargeDecodeConfig, LargeDecodeRunner
73+
74+
# Build config from yaml dict — from_dict handles all field mapping
75+
config = LargeDecodeConfig.from_dict(large_cfg)
76+
runner = LargeDecodeRunner(config)
77+
runner.initialize()
10378

10479
chunks = runner.chunks
10580
borders = runner.borders
106-
print(f"Volume shape: {runner.config.volume_shape}")
107-
print(f"Chunk shape: {runner.config.chunk_shape}")
81+
print(f"Volume shape: {config.volume_shape}")
82+
print(f"Chunk shape: {config.chunk_shape}")
83+
print(f"Overlap: {config.overlap}")
10884
print(f"Chunks: {len(chunks)}")
10985
print(f"Borders: {len(borders)}")
110-
print(f"Workflow: {runner.config.workflow_root}")
86+
print(f"Workflow: {config.workflow_root}")
11187

11288
if args.init_only:
11389
print("Workflow initialized. Launch workers to execute tasks.")
@@ -130,22 +106,20 @@ def main():
130106
print("Waiting for all tasks to complete...")
131107
runner.wait(timeout=None)
132108
print("All tasks completed.")
133-
if args.assemble and runner.config.write_output:
109+
if args.assemble and config.write_output:
134110
print("Assembling output...")
135111
runner.handle_assemble_output(None)
136-
print(f"Output: {runner.config.resolved_output_path}")
112+
print(f"Output: {config.resolved_output_path}")
137113
return
138114

139115
if args.parallel and args.parallel > 1:
140-
# Multi-process on one machine
141116
import multiprocessing as mp
142117

143118
n_workers = args.parallel
144119
print(f"Running parallel decode with {n_workers} workers...")
145120

146121
def _worker_fn(worker_idx):
147122
os.environ["CCACHE_DISABLE"] = "1"
148-
# Each process loads its own runner from disk
149123
from waterz import LargeDecodeRunner as _LDR
150124
w = _LDR.load(large_cfg["workflow_root"])
151125
return w.run_worker(
@@ -159,7 +133,6 @@ def _worker_fn(worker_idx):
159133
n = sum(counts)
160134
print(f"Completed {n} tasks across {n_workers} workers.")
161135
else:
162-
# Default: run serial (all stages in one process)
163136
print("Running serial decode...")
164137
n = runner.run_serial()
165138
print(f"Completed {n} tasks.")

scripts/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,14 +620,19 @@ def _invert_save_prediction_transform(cfg: Config, data):
620620
save_pred_cfg = inference_cfg.save_prediction
621621
intensity_scale = getattr(save_pred_cfg, "intensity_scale", None)
622622

623-
data = data.astype(np.float32)
624623
if intensity_scale is not None and intensity_scale > 0 and intensity_scale != 1.0:
624+
data = data.astype(np.float32, copy=False)
625625
data = data / float(intensity_scale)
626626
print(f" Inverted intensity scaling by {intensity_scale}")
627-
elif intensity_scale is not None and intensity_scale < 0:
627+
return data
628+
629+
if intensity_scale is not None and intensity_scale < 0:
628630
print(
629-
f" INFO: Intensity scaling was disabled (scale={intensity_scale}), no inversion needed"
631+
f" INFO: Intensity scaling was disabled (scale={intensity_scale}), keeping dtype "
632+
f"{data.dtype}"
630633
)
634+
else:
635+
print(f" INFO: No intensity inversion needed, keeping dtype {data.dtype}")
631636

632637
return data
633638

tutorials/waterz_decoding_large.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ large_decode:
3636
merge_function: aff85_his256
3737
aff_threshold_low: 0
3838
aff_threshold_high: 1
39+
border_threshold: 0.3
3940
channel_order: xyz
4041
#use_aff_uint8: true # uint8 affinities (4x less aff memory)
4142
#use_seg_uint32: true # uint32 segment IDs (2x less seg memory)

0 commit comments

Comments
 (0)