Skip to content

Commit 1322a9a

Browse files
Donglai Weiclaude
andcommitted
Remove _mergedUntil filter from extractRegionGraph, reuse full agglomeration graph
extractRegionGraph was filtering edges with score < _mergedUntil (the agglomeration threshold), discarding high-affinity edges that dust merge needs. Self-edges are already handled by the _deleted filter. Remove the threshold filter so extractRegionGraph returns the full graph with accumulated scoring statistics. This lets dust merge reuse the agglomeration's region graph directly — no rebuild needed, and results match the original behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fe60f0b commit 1322a9a

2 files changed

Lines changed: 61 additions & 63 deletions

File tree

connectomics/decoding/decoders/waterz.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,6 @@ def _merge_function_to_scoring(shorthand: str) -> str:
7373
)
7474

7575

76-
def _strip_oneminus(scoring_function: str) -> str:
77-
"""Strip ``OneMinus<...>`` or ``One255Minus<...>`` wrapper.
78-
79-
``merge_dust`` / ``buildRegionGraphOnly`` expects the raw scoring
80-
function (high score = strong connection), not the inverted wrapper
81-
used by the agglomeration priority queue.
82-
"""
83-
for prefix in ("OneMinus<", "One255Minus<"):
84-
if scoring_function.startswith(prefix) and scoring_function.endswith(">"):
85-
return scoring_function[len(prefix):-1]
86-
return scoring_function
87-
8876

8977
def decode_waterz(
9078
predictions: np.ndarray,
@@ -159,11 +147,11 @@ def decode_waterz(
159147
min_instance_size: Minimum instance size in voxels. Instances smaller
160148
than this are removed (set to background). Set to 0 to disable.
161149
Default: 0
162-
dust_merge: Enable dust postprocessing. Rebuilds the region graph
163-
via ``waterz.merge_dust`` using the same scoring function as
164-
agglomeration (e.g. p85 histogram quantile), ensuring consistent
165-
edge weights. When False, the dust merge and dust removal
166-
thresholds below are ignored. Default: True
150+
dust_merge: Enable dust postprocessing. Reuses the agglomeration's
151+
full region graph (with accumulated scoring statistics) via
152+
``waterz.merge_segments`` — no graph rebuild needed.
153+
When False, the dust merge and dust removal thresholds below
154+
are ignored. Default: True
167155
dust_merge_size: Size+affinity dust merge (zwatershed-style).
168156
Segments with fewer voxels than this are merged into their
169157
highest-affinity neighbor. Unlike *min_instance_size* which
@@ -308,29 +296,50 @@ def decode_waterz(
308296
waterz_kwargs["fragments"] = fragments.astype(np.uint64, copy=False)
309297

310298
do_dust_merge = bool(dust_merge) and dust_merge_size > 0
311-
312-
# For dust merge, strip OneMinus/One255Minus so buildRegionGraphOnly
313-
# uses the same scoring function as agglomeration (e.g. p85 histogram)
314-
# but returns raw affinities (high = strong) instead of inverted scores.
315-
dust_scoring = _strip_oneminus(scoring_function) if do_dust_merge else ""
299+
waterz_kwargs["return_region_graph"] = do_dust_merge
316300

317301
# waterz.waterz() runs watershed + region-graph once, then incrementally
318302
# merges for each threshold. Returns all segmentations (copied).
319303
seg_list = waterz.waterz(affs, thresholds=thresholds_list, **waterz_kwargs)
320304

321305
# Post-process each result
322306
processed: List[np.ndarray] = []
323-
for seg in seg_list:
324-
# Size+affinity dust merge via buildRegionGraphOnly with the same
325-
# scoring function as agglomeration (not MeanAffinity default).
307+
for waterz_result in seg_list:
308+
if do_dust_merge:
309+
seg, region_graph = waterz_result
310+
else:
311+
seg = waterz_result
312+
313+
# Size+affinity dust merge reusing the agglomeration's full region
314+
# graph (extractRegionGraph returns all non-deleted edges with
315+
# accumulated scores from the agglomeration process).
326316
if do_dust_merge:
327317
seg = seg.astype(np.uint64, copy=False)
328-
waterz.merge_dust(
329-
seg, affs,
318+
n_edges = len(region_graph)
319+
rg_affs = np.empty(n_edges, dtype=np.float32)
320+
id1 = np.empty(n_edges, dtype=np.uint64)
321+
id2 = np.empty(n_edges, dtype=np.uint64)
322+
# Invert OneMinus/One255Minus scores to raw affinities.
323+
score_max = 255.0 if is_uint8 else 1.0
324+
for idx, edge in enumerate(region_graph):
325+
rg_affs[idx] = score_max - float(edge["score"])
326+
id1[idx] = int(edge["u"])
327+
id2[idx] = int(edge["v"])
328+
if n_edges:
329+
np.clip(rg_affs, 0.0, score_max, out=rg_affs)
330+
order = np.argsort(rg_affs)[::-1]
331+
rg_affs = np.ascontiguousarray(rg_affs[order])
332+
id1 = np.ascontiguousarray(id1[order])
333+
id2 = np.ascontiguousarray(id2[order])
334+
ids, cnts = np.unique(seg, return_counts=True)
335+
max_id = int(ids.max()) if len(ids) else 0
336+
counts = np.zeros(max_id + 1, dtype=np.uint64)
337+
counts[ids] = cnts
338+
waterz.merge_segments(
339+
seg, rg_affs, id1, id2, counts,
330340
size_th=dust_merge_size,
331341
weight_th=dust_merge_affinity,
332342
dust_th=dust_remove_size,
333-
scoring_function=dust_scoring,
334343
)
335344
# Branch merge: resolve false splits via z-slice IOU analysis
336345
if branch_merge:

tests/unit/test_decode_waterz.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,33 @@ class _FakeWaterzModule:
1111
"""Minimal waterz stub for testing wrapper behavior."""
1212

1313
def __init__(self):
14-
self.merge_dust_calls = []
14+
self.merge_segments_calls = []
1515
self.waterz_calls = []
1616

1717
def waterz(self, affs, thresholds, **kwargs):
1818
self.waterz_calls.append(kwargs.copy())
1919
seg = np.zeros(affs.shape[1:], dtype=np.uint64)
2020
seg[:, :, :2] = 1
2121
seg[:, :, 2:] = 2
22+
if kwargs.get("return_region_graph", False):
23+
# ScoredEdge dicts from extractRegionGraph.
24+
# OneMinus score 0.2 → affinity = 1.0 - 0.2 = 0.8
25+
rg = [{"u": 1, "v": 2, "score": 0.2}]
26+
return [(seg.copy(), list(rg)) for _ in thresholds]
2227
return [seg.copy() for _ in thresholds]
2328

24-
def merge_dust(self, seg, affs, size_th, weight_th, dust_th, scoring_function, channels="all"):
25-
self.merge_dust_calls.append(
29+
def merge_segments(self, seg, rg_affs, id1, id2, counts,
30+
size_th, weight_th, dust_th):
31+
self.merge_segments_calls.append(
2632
{
2733
"seg_shape": seg.shape,
28-
"aff_shape": affs.shape,
34+
"rg_affs": rg_affs.tolist(),
35+
"id1": id1.tolist(),
36+
"id2": id2.tolist(),
37+
"counts": counts.tolist(),
2938
"size_th": size_th,
3039
"weight_th": weight_th,
3140
"dust_th": dust_th,
32-
"scoring_function": scoring_function,
3341
}
3442
)
3543

@@ -56,13 +64,14 @@ def test_decode_waterz_skips_dust_postprocessing_when_disabled(monkeypatch):
5664
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>>",
5765
"aff_threshold_low": 0.0001,
5866
"aff_threshold_high": 0.9999,
67+
"return_region_graph": False,
5968
}
6069
]
61-
assert fake_waterz.merge_dust_calls == []
70+
assert fake_waterz.merge_segments_calls == []
6271

6372

64-
def test_decode_waterz_dust_merge_uses_same_scoring_function(monkeypatch):
65-
"""Dust merge rebuilds graph with same scoring as agglomeration (OneMinus stripped)."""
73+
def test_decode_waterz_reuses_agglomeration_region_graph_for_dust(monkeypatch):
74+
"""Dust merge reuses agglomeration's region graph with inverted scores."""
6675
fake_waterz = _FakeWaterzModule()
6776
monkeypatch.setattr(waterz_decoder, "waterz", fake_waterz)
6877
monkeypatch.setattr(waterz_decoder, "WATERZ_AVAILABLE", True)
@@ -72,7 +81,6 @@ def test_decode_waterz_dust_merge_uses_same_scoring_function(monkeypatch):
7281
waterz_decoder.decode_waterz(
7382
predictions,
7483
thresholds=0.4,
75-
merge_function="aff85_his256",
7684
dust_merge=True,
7785
dust_merge_size=100,
7886
dust_merge_affinity=0.3,
@@ -81,40 +89,21 @@ def test_decode_waterz_dust_merge_uses_same_scoring_function(monkeypatch):
8189

8290
assert fake_waterz.waterz_calls == [
8391
{
84-
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 85, ScoreValue, 256>>",
92+
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>>",
8593
"aff_threshold_low": 0.0001,
8694
"aff_threshold_high": 0.9999,
95+
"return_region_graph": True,
8796
}
8897
]
89-
assert fake_waterz.merge_dust_calls == [
98+
assert fake_waterz.merge_segments_calls == [
9099
{
91100
"seg_shape": (4, 4, 4),
92-
"aff_shape": (3, 4, 4, 4),
101+
"rg_affs": [0.800000011920929],
102+
"id1": [1],
103+
"id2": [2],
104+
"counts": [0, 32, 32],
93105
"size_th": 100,
94106
"weight_th": 0.3,
95107
"dust_th": 50,
96-
"scoring_function": "HistogramQuantileAffinity<RegionGraphType, 85, ScoreValue, 256>",
97108
}
98109
]
99-
100-
101-
def test_decode_waterz_dust_merge_strips_one255minus(monkeypatch):
102-
"""One255Minus wrapper is also stripped for dust merge scoring."""
103-
fake_waterz = _FakeWaterzModule()
104-
monkeypatch.setattr(waterz_decoder, "waterz", fake_waterz)
105-
monkeypatch.setattr(waterz_decoder, "WATERZ_AVAILABLE", True)
106-
107-
predictions = np.ones((3, 4, 4, 4), dtype=np.float32)
108-
109-
waterz_decoder.decode_waterz(
110-
predictions,
111-
thresholds=0.4,
112-
merge_function="aff50_his256_ran255",
113-
dust_merge=True,
114-
dust_merge_size=100,
115-
dust_merge_affinity=0.3,
116-
dust_remove_size=50,
117-
)
118-
119-
call = fake_waterz.merge_dust_calls[0]
120-
assert call["scoring_function"] == "HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>"

0 commit comments

Comments
 (0)