44 # Serial (single process, all stages)
55 python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml
66
7- # Initialize workflow only (for parallel launch)
7+ # Parallel (N workers on one machine)
8+ python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --parallel 4
9+
10+ # Initialize workflow only (for SLURM parallel launch)
811 python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --init-only
912
1013 # Run as a worker (claims tasks from shared workflow dir)
1720import argparse
1821import os
1922import sys
20- from pathlib import Path
2123
2224import yaml
2325
@@ -30,16 +32,14 @@ def main():
3032 parser .add_argument ("--wait" , action = "store_true" , help = "Wait for all tasks to complete" )
3133 parser .add_argument ("--assemble" , action = "store_true" , help = "Assemble final output volume" )
3234 parser .add_argument ("--parallel" , type = int , default = None ,
33- help = "Run N worker processes on this machine (e.g. --parallel 8) " )
35+ help = "Run N worker processes on this machine" )
3436 parser .add_argument ("--max-tasks" , type = int , default = None , help = "Max tasks per worker" )
3537 parser .add_argument ("--idle-timeout" , type = float , default = 60.0 , help = "Worker idle timeout (seconds)" )
3638 parser .add_argument ("--worker-id" , type = str , default = None , help = "Worker identifier" )
3739 parser .add_argument ("--job-id" , type = str , default = None , help = "SLURM job ID" )
38- # Allow CLI overrides in key=value format
3940 parser .add_argument ("overrides" , nargs = "*" , help = "Config overrides (key=value)" )
4041 args = parser .parse_args ()
4142
42- # Load config
4343 with open (args .config ) as f :
4444 cfg = yaml .safe_load (f )
4545
@@ -51,7 +51,6 @@ def main():
5151 print (f"Warning: skipping invalid override '{ override } ' (expected key=value)" )
5252 continue
5353 key , value = override .split ("=" , 1 )
54- # Try numeric conversion
5554 try :
5655 value = int (value )
5756 except ValueError :
@@ -70,44 +69,21 @@ def main():
7069
7170 os .environ .setdefault ("CCACHE_DISABLE" , "1" )
7271
73- from waterz import LargeDecodeRunner
74-
75- # Parse config
76- chunk_shape = large_cfg .get ("chunk_shape" , [256 , 512 , 512 ])
77- if isinstance (chunk_shape , list ):
78- chunk_shape = tuple (chunk_shape )
79-
80- thresholds = large_cfg .get ("thresholds" , [0.5 ])
81- if isinstance (thresholds , (int , float )):
82- thresholds = [thresholds ]
83-
84- runner = LargeDecodeRunner .create (
85- affinity_path = large_cfg ["affinity_path" ],
86- workflow_root = large_cfg ["workflow_root" ],
87- chunk_shape = chunk_shape ,
88- thresholds = thresholds ,
89- merge_function = large_cfg .get ("merge_function" , "aff85_his256" ),
90- aff_threshold_low = float (large_cfg .get ("aff_threshold_low" , 0.1 )),
91- aff_threshold_high = float (large_cfg .get ("aff_threshold_high" , 0.999 )),
92- channel_order = large_cfg .get ("channel_order" , "xyz" ),
93- write_output = bool (large_cfg .get ("write_output" , True )),
94- output_path = large_cfg .get ("output_path" ) or None ,
95- min_overlap = int (large_cfg .get ("min_overlap" , 1 )),
96- iou_threshold = float (large_cfg .get ("iou_threshold" , 0.0 )),
97- one_sided_threshold = float (large_cfg .get ("one_sided_threshold" , 0.9 )),
98- one_sided_min_size = int (large_cfg .get ("one_sided_min_size" , 0 )),
99- affinity_threshold = float (large_cfg .get ("affinity_threshold" , 0.0 )),
100- compression = large_cfg .get ("compression" , "gzip" ),
101- compression_level = int (large_cfg .get ("compression_level" , 4 )),
102- )
72+ from waterz import LargeDecodeConfig , LargeDecodeRunner
73+
74+ # Build config from yaml dict — from_dict handles all field mapping
75+ config = LargeDecodeConfig .from_dict (large_cfg )
76+ runner = LargeDecodeRunner (config )
77+ runner .initialize ()
10378
10479 chunks = runner .chunks
10580 borders = runner .borders
106- print (f"Volume shape: { runner .config .volume_shape } " )
107- print (f"Chunk shape: { runner .config .chunk_shape } " )
81+ print (f"Volume shape: { config .volume_shape } " )
82+ print (f"Chunk shape: { config .chunk_shape } " )
83+ print (f"Overlap: { config .overlap } " )
10884 print (f"Chunks: { len (chunks )} " )
10985 print (f"Borders: { len (borders )} " )
110- print (f"Workflow: { runner . config .workflow_root } " )
86+ print (f"Workflow: { config .workflow_root } " )
11187
11288 if args .init_only :
11389 print ("Workflow initialized. Launch workers to execute tasks." )
@@ -130,22 +106,20 @@ def main():
130106 print ("Waiting for all tasks to complete..." )
131107 runner .wait (timeout = None )
132108 print ("All tasks completed." )
133- if args .assemble and runner . config .write_output :
109+ if args .assemble and config .write_output :
134110 print ("Assembling output..." )
135111 runner .handle_assemble_output (None )
136- print (f"Output: { runner . config .resolved_output_path } " )
112+ print (f"Output: { config .resolved_output_path } " )
137113 return
138114
139115 if args .parallel and args .parallel > 1 :
140- # Multi-process on one machine
141116 import multiprocessing as mp
142117
143118 n_workers = args .parallel
144119 print (f"Running parallel decode with { n_workers } workers..." )
145120
146121 def _worker_fn (worker_idx ):
147122 os .environ ["CCACHE_DISABLE" ] = "1"
148- # Each process loads its own runner from disk
149123 from waterz import LargeDecodeRunner as _LDR
150124 w = _LDR .load (large_cfg ["workflow_root" ])
151125 return w .run_worker (
@@ -159,7 +133,6 @@ def _worker_fn(worker_idx):
159133 n = sum (counts )
160134 print (f"Completed { n } tasks across { n_workers } workers." )
161135 else :
162- # Default: run serial (all stages in one process)
163136 print ("Running serial decode..." )
164137 n = runner .run_serial ()
165138 print (f"Completed { n } tasks." )
0 commit comments