Skip to content

Commit df6d5a0

Browse files
Donglai Weiclaude
andcommitted
Revert extractRegionGraph change, use merge_dust with matched scoring
Removing the _mergedUntil filter from extractRegionGraph caused bad results because stale edges have inconsistent scores and non-root endpoint IDs after agglomeration. Instead, use merge_dust (rebuilds graph via buildRegionGraphOnly) with the OneMinus wrapper stripped so it uses the same scoring function as agglomeration (e.g. HistogramQuantileAffinity p85). This gives correct and consistent edge weights without reusing the problematic internal agglomeration graph. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4f4d64a commit df6d5a0

2 files changed

Lines changed: 64 additions & 61 deletions

File tree

connectomics/decoding/decoders/waterz.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ def _merge_function_to_scoring(shorthand: str) -> str:
7373
)
7474

7575

76+
def _strip_oneminus(scoring_function: str) -> str:
77+
"""Strip ``OneMinus<...>`` or ``One255Minus<...>`` wrapper.
78+
79+
``merge_dust`` / ``buildRegionGraphOnly`` expects the raw scoring
80+
function (high score = strong connection), not the inverted wrapper
81+
used by the agglomeration priority queue.
82+
"""
83+
for prefix in ("OneMinus<", "One255Minus<"):
84+
if scoring_function.startswith(prefix) and scoring_function.endswith(">"):
85+
return scoring_function[len(prefix):-1]
86+
return scoring_function
87+
7688

7789
def decode_waterz(
7890
predictions: np.ndarray,
@@ -147,11 +159,11 @@ def decode_waterz(
147159
min_instance_size: Minimum instance size in voxels. Instances smaller
148160
than this are removed (set to background). Set to 0 to disable.
149161
Default: 0
150-
dust_merge: Enable dust postprocessing. Reuses the agglomeration's
151-
full region graph (with accumulated scoring statistics) via
152-
``waterz.merge_segments`` — no graph rebuild needed.
153-
When False, the dust merge and dust removal thresholds below
154-
are ignored. Default: True
162+
dust_merge: Enable dust postprocessing. Rebuilds the region graph
163+
via ``waterz.merge_dust`` using the same scoring function as
164+
agglomeration (e.g. p85 histogram quantile), ensuring consistent
165+
edge weights. When False, the dust merge and dust removal
166+
thresholds below are ignored. Default: True
155167
dust_merge_size: Size+affinity dust merge (zwatershed-style).
156168
Segments with fewer voxels than this are merged into their
157169
highest-affinity neighbor. Unlike *min_instance_size* which
@@ -296,50 +308,29 @@ def decode_waterz(
296308
waterz_kwargs["fragments"] = fragments.astype(np.uint64, copy=False)
297309

298310
do_dust_merge = bool(dust_merge) and dust_merge_size > 0
299-
waterz_kwargs["return_region_graph"] = do_dust_merge
311+
312+
# For dust merge, strip OneMinus/One255Minus so buildRegionGraphOnly
313+
# uses the same scoring function as agglomeration (e.g. p85 histogram)
314+
# but returns raw affinities (high = strong) instead of inverted scores.
315+
dust_scoring = _strip_oneminus(scoring_function) if do_dust_merge else ""
300316

301317
# waterz.waterz() runs watershed + region-graph once, then incrementally
302318
# merges for each threshold. Returns all segmentations (copied).
303319
seg_list = waterz.waterz(affs, thresholds=thresholds_list, **waterz_kwargs)
304320

305321
# Post-process each result
306322
processed: List[np.ndarray] = []
307-
for waterz_result in seg_list:
308-
if do_dust_merge:
309-
seg, region_graph = waterz_result
310-
else:
311-
seg = waterz_result
312-
313-
# Size+affinity dust merge reusing the agglomeration's full region
314-
# graph (extractRegionGraph returns all non-deleted edges with
315-
# accumulated scores from the agglomeration process).
323+
for seg in seg_list:
324+
# Size+affinity dust merge via buildRegionGraphOnly with the same
325+
# scoring function as agglomeration (not MeanAffinity default).
316326
if do_dust_merge:
317327
seg = seg.astype(np.uint64, copy=False)
318-
n_edges = len(region_graph)
319-
rg_affs = np.empty(n_edges, dtype=np.float32)
320-
id1 = np.empty(n_edges, dtype=np.uint64)
321-
id2 = np.empty(n_edges, dtype=np.uint64)
322-
# Invert OneMinus/One255Minus scores to raw affinities.
323-
score_max = 255.0 if is_uint8 else 1.0
324-
for idx, edge in enumerate(region_graph):
325-
rg_affs[idx] = score_max - float(edge["score"])
326-
id1[idx] = int(edge["u"])
327-
id2[idx] = int(edge["v"])
328-
if n_edges:
329-
np.clip(rg_affs, 0.0, score_max, out=rg_affs)
330-
order = np.argsort(rg_affs)[::-1]
331-
rg_affs = np.ascontiguousarray(rg_affs[order])
332-
id1 = np.ascontiguousarray(id1[order])
333-
id2 = np.ascontiguousarray(id2[order])
334-
ids, cnts = np.unique(seg, return_counts=True)
335-
max_id = int(ids.max()) if len(ids) else 0
336-
counts = np.zeros(max_id + 1, dtype=np.uint64)
337-
counts[ids] = cnts
338-
waterz.merge_segments(
339-
seg, rg_affs, id1, id2, counts,
328+
waterz.merge_dust(
329+
seg, affs,
340330
size_th=dust_merge_size,
341331
weight_th=dust_merge_affinity,
342332
dust_th=dust_remove_size,
333+
scoring_function=dust_scoring,
343334
)
344335
# Branch merge: resolve false splits via z-slice IOU analysis
345336
if branch_merge:

tests/unit/test_decode_waterz.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,26 @@ class _FakeWaterzModule:
1111
"""Minimal waterz stub for testing wrapper behavior."""
1212

1313
def __init__(self):
14-
self.merge_segments_calls = []
14+
self.merge_dust_calls = []
1515
self.waterz_calls = []
1616

1717
def waterz(self, affs, thresholds, **kwargs):
1818
self.waterz_calls.append(kwargs.copy())
1919
seg = np.zeros(affs.shape[1:], dtype=np.uint64)
2020
seg[:, :, :2] = 1
2121
seg[:, :, 2:] = 2
22-
if kwargs.get("return_region_graph", False):
23-
# ScoredEdge dicts from extractRegionGraph.
24-
# OneMinus score 0.2 → affinity = 1.0 - 0.2 = 0.8
25-
rg = [{"u": 1, "v": 2, "score": 0.2}]
26-
return [(seg.copy(), list(rg)) for _ in thresholds]
2722
return [seg.copy() for _ in thresholds]
2823

29-
def merge_segments(self, seg, rg_affs, id1, id2, counts,
30-
size_th, weight_th, dust_th):
31-
self.merge_segments_calls.append(
24+
def merge_dust(self, seg, affs, size_th, weight_th, dust_th,
25+
scoring_function, channels="all"):
26+
self.merge_dust_calls.append(
3227
{
3328
"seg_shape": seg.shape,
34-
"rg_affs": rg_affs.tolist(),
35-
"id1": id1.tolist(),
36-
"id2": id2.tolist(),
37-
"counts": counts.tolist(),
29+
"aff_shape": affs.shape,
3830
"size_th": size_th,
3931
"weight_th": weight_th,
4032
"dust_th": dust_th,
33+
"scoring_function": scoring_function,
4134
}
4235
)
4336

@@ -64,14 +57,13 @@ def test_decode_waterz_skips_dust_postprocessing_when_disabled(monkeypatch):
6457
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>>",
6558
"aff_threshold_low": 0.0001,
6659
"aff_threshold_high": 0.9999,
67-
"return_region_graph": False,
6860
}
6961
]
70-
assert fake_waterz.merge_segments_calls == []
62+
assert fake_waterz.merge_dust_calls == []
7163

7264

73-
def test_decode_waterz_reuses_agglomeration_region_graph_for_dust(monkeypatch):
74-
"""Dust merge reuses agglomeration's region graph with inverted scores."""
65+
def test_decode_waterz_dust_merge_uses_same_scoring_function(monkeypatch):
66+
"""Dust merge rebuilds graph with same scoring as agglomeration (OneMinus stripped)."""
7567
fake_waterz = _FakeWaterzModule()
7668
monkeypatch.setattr(waterz_decoder, "waterz", fake_waterz)
7769
monkeypatch.setattr(waterz_decoder, "WATERZ_AVAILABLE", True)
@@ -81,6 +73,7 @@ def test_decode_waterz_reuses_agglomeration_region_graph_for_dust(monkeypatch):
8173
waterz_decoder.decode_waterz(
8274
predictions,
8375
thresholds=0.4,
76+
merge_function="aff85_his256",
8477
dust_merge=True,
8578
dust_merge_size=100,
8679
dust_merge_affinity=0.3,
@@ -89,21 +82,40 @@ def test_decode_waterz_reuses_agglomeration_region_graph_for_dust(monkeypatch):
8982

9083
assert fake_waterz.waterz_calls == [
9184
{
92-
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>>",
85+
"scoring_function": "OneMinus<HistogramQuantileAffinity<RegionGraphType, 85, ScoreValue, 256>>",
9386
"aff_threshold_low": 0.0001,
9487
"aff_threshold_high": 0.9999,
95-
"return_region_graph": True,
9688
}
9789
]
98-
assert fake_waterz.merge_segments_calls == [
90+
assert fake_waterz.merge_dust_calls == [
9991
{
10092
"seg_shape": (4, 4, 4),
101-
"rg_affs": [0.800000011920929],
102-
"id1": [1],
103-
"id2": [2],
104-
"counts": [0, 32, 32],
93+
"aff_shape": (3, 4, 4, 4),
10594
"size_th": 100,
10695
"weight_th": 0.3,
10796
"dust_th": 50,
97+
"scoring_function": "HistogramQuantileAffinity<RegionGraphType, 85, ScoreValue, 256>",
10898
}
10999
]
100+
101+
102+
def test_decode_waterz_dust_merge_strips_one255minus(monkeypatch):
103+
"""One255Minus wrapper is also stripped for dust merge scoring."""
104+
fake_waterz = _FakeWaterzModule()
105+
monkeypatch.setattr(waterz_decoder, "waterz", fake_waterz)
106+
monkeypatch.setattr(waterz_decoder, "WATERZ_AVAILABLE", True)
107+
108+
predictions = np.ones((3, 4, 4, 4), dtype=np.float32)
109+
110+
waterz_decoder.decode_waterz(
111+
predictions,
112+
thresholds=0.4,
113+
merge_function="aff50_his256_ran255",
114+
dust_merge=True,
115+
dust_merge_size=100,
116+
dust_merge_affinity=0.3,
117+
dust_remove_size=50,
118+
)
119+
120+
call = fake_waterz.merge_dust_calls[0]
121+
assert call["scoring_function"] == "HistogramQuantileAffinity<RegionGraphType, 50, ScoreValue, 256>"

0 commit comments

Comments
 (0)