diff --git a/.dockerignore b/.dockerignore index 66349c8a4..09e319073 100644 --- a/.dockerignore +++ b/.dockerignore @@ -106,6 +106,7 @@ venv.bak/ # Visual Code .vscode/ +*.code-workspace # terraform .terraform/ diff --git a/.gitignore b/.gitignore index 044d3b64c..e39b39890 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,7 @@ venv.bak/ # local dev stuff +*.code-workspace .claude/ .devcontainer/ *.ipynb diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 5f28ab80b..434248a9b 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -5,19 +5,26 @@ steps: args: ["-c", "docker login --username=$$USERNAME --password=$$PASSWORD"] secretEnv: ["USERNAME", "PASSWORD"] + # Build + push in one BuildKit invocation using a docker-container + # builder (required for registry-type cache export). The builder + # `--use` setting is client-side and doesn't persist across cloudbuild + # steps, so create + use + build must happen in a single step. + # Registry cache at the fixed :buildcache tag lets unchanged stages + # (conda env, bigtable emulator, pip install) reuse the previous + # build's exact layer artifacts, so already-warm nodes only download + # what actually changed on pull. - name: "gcr.io/cloud-builders/docker" entrypoint: "bash" args: - "-c" - | - DOCKER_BUILDKIT=1 docker build -t $$USERNAME/pychunkedgraph:$TAG_NAME . - timeout: 600s - secretEnv: ["USERNAME"] - - # Push the final image to Dockerhub - - name: "gcr.io/cloud-builders/docker" - entrypoint: "bash" - args: ["-c", "docker push $$USERNAME/pychunkedgraph:$TAG_NAME"] + docker buildx create --use --name pcg-builder --driver docker-container + docker buildx build \ + --cache-from type=registry,ref=$$USERNAME/pychunkedgraph:buildcache \ + --cache-to type=registry,ref=$$USERNAME/pychunkedgraph:buildcache,mode=max \ + --push \ + -t $$USERNAME/pychunkedgraph:$TAG_NAME . + timeout: 1800s secretEnv: ["USERNAME"] availableSecrets: diff --git a/docs/precomputed_ocdbt_hybrid.md b/docs/precomputed_ocdbt_hybrid.md new file mode 100644 index 000000000..d1caafd5f --- /dev/null +++ b/docs/precomputed_ocdbt_hybrid.md @@ -0,0 +1,101 @@ +# Hybrid base: precomputed + OCDBT fork (proposal) + +Status: proposal, not implemented. Open question is whether storage and ingest-compute savings justify the read-path complexity. + +## Problem + +PCG ingest copies the entire watershed segmentation into `/ocdbt/base/` in OCDBT format before any CG edit can happen. Per-CG forks at `/ocdbt//` store only the deltas from SV splits. Two costs follow: + +- **Storage**: roughly 2× the segmentation footprint per dataset — original precomputed plus full OCDBT copy. +- **Ingest compute**: a per-chunk pass that reads the precomputed and writes it through the OCDBT driver. Hours of cluster time on TB-scale datasets. + +Both costs are paid up-front, before any user has done a single edit. The proposal here: skip the base copy and serve unedited chunks directly from the raw precomputed directory. Per-CG OCDBT forks remain as the delta store. + +## Why the current architecture has the base copy + +Today's per-CG read spec is: + +``` +neuroglancer_precomputed + └─ kvstore: ocdbt + ├─ base: kvstack [base_layer, fork_manifest, fork_data] + └─ config: { compression, max_inline_value_bytes, ... } +``` + +When a reader asks for chunk key `8_8_40/1024-..._0-128`: + +1. The `neuroglancer_precomputed` driver passes the chunk key to its kvstore (the OCDBT driver). +2. OCDBT looks up the key **in its B+tree**. The B+tree's leaves map chunk keys to values. +3. If the key isn't in the B+tree, OCDBT returns not-found. It does not consult the kvstack any further. + +The three kvstack layers serve OCDBT's *internal* storage (B+tree manifest + node blobs + leaf blobs) — they have no visibility into chunk-key lookups. So the OCDBT B+tree must contain every chunk key the reader will ever ask for, and that's why ingest copies the whole watershed: to populate the B+tree. + +## What tensorstore primitives provide + +Confirmed against tensorstore docs: + +- **`kvstack` routes by exact / prefix match, with no fallthrough on miss.** A layer that claims a key range absorbs misses — they return `state='missing'` and do not cascade to the next layer. So we can't put raw precomputed below an OCDBT layer in a kvstack and expect kvstack to fall through when OCDBT doesn't have a key. +- **No native overlay/fallback kvstore driver.** `kvstack` is the only composition primitive at the kvstore level; it's precedence-based, not fallthrough. +- **OCDBT has no external-blob references.** B+tree leaves either inline the value or point to a data file under the OCDBT directory. There's no way to make a leaf reference a raw GCS precomputed file. +- **Array-level `stack` / `ts.overlay`** layers arrays by spatial domain. In overlapping regions, the later layer takes absolute precedence — missing-in-later does not fall back to earlier. + +No single tensorstore primitive provides "try OCDBT delta first, fall through to raw precomputed on miss." + +## Architectural options + +### A — Two-stage read at the pcg layer + +PCG reads open two handles: the OCDBT fork for the delta, and a raw `neuroglancer_precomputed` reader for the watershed base. For any voxel region, issue both reads and merge with "delta wins where present, base fills the rest." + +- **Pros**: works inside pcg (`lookup_svs_from_seg`, sanity checks, debug tools) without any tensorstore changes. +- **Cons**: every pcg caller that uses `meta.ws_ocdbt` needs to route through a new merging reader. Neuroglancer doesn't benefit — it still gets a single kvstore spec from `dataset_info`. Either NG runs two layers itself (Option B) or we stand up a server-side proxy that does the merge before serving. + +### B — NG-side layer stack + +`dataset_info` publishes two precomputed layers: the raw watershed (read-only base) and the per-CG OCDBT fork (delta). NG composites them — visible segmentation is whichever has data at a given chunk. + +- **Pros**: no change to pcg's read path. Pushes the architecture complexity into the viewer. +- **Cons**: requires NG to treat "missing chunk in delta" as "fall through to base," not "render as background." Default NG behavior is the latter, so a viewer-side or proxy-side shim is likely needed. + +### C — Custom tensorstore kvstore driver + +A new "fallthrough" kvstore driver: read tries layer N, falls through on miss to layer N−1. Implement upstream in tensorstore or fork-and-maintain. + +- **Pros**: cleanest consumer-facing story — pcg and NG both keep using a single kvstore spec. +- **Cons**: tensorstore kvstore drivers are C++. Non-trivial maintenance surface; review/merge timeline if upstreaming. + +### D — Lazy base population (not a win on its own) + +Skip the ingest copy; copy a chunk from precomputed to OCDBT on first edit. Saves ingest compute. Does **not** save storage for reads — unedited chunks still 404 in OCDBT for a reader that doesn't have a fallback. Only useful in combination with A/B/C. + +## Recommendation + +Measure first. Confirm the actual storage and ingest-compute savings on a real dataset and weigh against the engineering cost of A/B/C. + +If the savings justify the work, **A + B together** is the most pragmatic path: +- A gives pcg a single merged-read API. Edits, sanity checks, debug tooling keep working. +- B avoids standing up a proxy service for the viewer by letting NG handle the overlay. + +Both require upstream verification: +- **For A**: confirm that `(x0:x1, y0:y1, z0:z1)` reads on an OCDBT with sparse keys surface missing-ness *per chunk* at the `neuroglancer_precomputed` array layer (not per-region, not silently fill-valued). +- **For B**: confirm NG's segmentation loader can be configured to fall through gaps in one layer to another. If it can't, build a small server-side merging shim — at which point Option A's reader becomes that shim and B reduces to "publish two specs." + +C is the cleanest design but carries the highest cost. Pursue only if A/B turn out to have unworkable semantics. + +## Open questions before any implementation + +1. Does OCDBT's `read_result.state == 'missing'` surface per-chunk at the `neuroglancer_precomputed` array layer, or does the array silently fill missing chunks with fill-value? Verifiable by opening an OCDBT with sparse keys and reading a region that spans present + missing chunks. +2. Does NG distinguish "chunk returned as missing" from "chunk is all fill-value"? If not, a viewer-side overlay needs a shim regardless. +3. What's the actual delta volume per CG over its lifetime? If SV splits eventually touch a significant fraction of chunks, the storage win shrinks toward zero — at which point the simpler architecture (today's full base copy) wins on engineering cost. + +## Files to start from when implementing + +- `pychunkedgraph/graph/ocdbt.py` — spec construction (`build_cg_ocdbt_spec`), base population (`create_base_ocdbt`), fork setup (`fork_base_manifest`). +- `pychunkedgraph/ingest/cli.py`, `pychunkedgraph/ingest/cluster.py` — current base-copy flow. +- `pychunkedgraph/graph/utils/generic.py::get_local_segmentation` — single pcg read entry point that would need the two-stage merge in Option A. + +## Verification (per chosen option) + +- **A**: unit test that simulates a partial-delta OCDBT + raw precomputed and confirms the pcg reader returns the correct labels for spans crossing both. +- **B**: configure an NG link with both layers against a test dataset; compare the rendered segmentation to a known-good reference at edited and unedited regions. +- **C**: a tensorstore build with the new driver passes a fallthrough test (missing key in upper layer resolves from lower layer). diff --git a/docs/sv_splitting.md b/docs/sv_splitting.md new file mode 100644 index 000000000..d6b56a94b --- /dev/null +++ b/docs/sv_splitting.md @@ -0,0 +1,140 @@ +# Supervoxel splitting + +## What it is + +A *supervoxel split* bisects one physical supervoxel — a connected region in the raw segmentation — along a user-seeded cut. The user supplies a source coordinate and a sink coordinate inside one supervoxel; the system finds a cut surface separating them and assigns new supervoxel IDs to each half, writing the updated segmentation and the corresponding graph hierarchy. + +This only runs on segmentations stored in OCDBT (a writable, append-only segmentation backend). With a read-only segmentation backend the split path is never entered; the multicut instead surfaces a precondition error asking the user to pick different source/sink points. + +## Why it's needed + +The graph is stored in **chunks**: the segmentation volume is partitioned into a regular 3D grid, and each chunk owns its own set of supervoxel IDs. When a physical supervoxel spans a chunk boundary, it is artificially cut into multiple graph-level supervoxel IDs — one per chunk — with infinite-affinity *cross-chunk edges* connecting the pieces so the graph still represents one physical object. + +The multicut algorithm runs on a local graph around source and sink. If it finds that source and sink sit inside the same cross-chunk-connected component — i.e., in the same physical supervoxel — a clean graph cut cannot separate them without first **splitting that physical supervoxel at the voxel level** and giving the resulting halves fresh IDs. That voxel-level cut is what the split flow does. The multicut runs again against the refreshed graph and produces the graph-level edges to remove. + +## End-to-end flow + +``` +Split request (source coord, sink coord) + │ + ▼ +Resolve coords → current supervoxel IDs at those pixels + │ + ▼ +┌───────────────────────────────────────────────────────────────────────┐ +│ ROOT LOCK (held across the whole operation) │ +│ │ +│ multicut: │ +│ build local subgraph around source/sink │ +│ stitch cross-chunk-connected SVs via inf-affinity edges │ +│ run mincut between source and sink │ +│ result ─► one of: │ +│ ● clean cut → edges to remove │ +│ ● SV split needed → cross-chunk-representative mapping │ +│ ● precondition → surface to user, abort │ +│ │ +│ if SV split needed: │ +│ ┌───────────────────────────────────────────────────────────────┐ │ +│ │ L2 CHUNK LOCK (spatial; sparse set + 1-chunk margin) │ │ +│ │ │ │ +│ │ for each cross-chunk rep linking source to sink: │ │ +│ │ bbs/bbe ◄ envelope of rep's pieces' chunk coords │ │ +│ │ read seg in [bbs-1, bbe+1] │ │ +│ │ (1-voxel shell → anchor voxels for edge routing) │ │ +│ │ compute voxel-level cut between seeds │ │ +│ │ allocate fresh SV IDs per chunk to each half │ │ +│ │ route existing cross-chunk edges onto the new fragments │ │ +│ │ write seg — only chunks that actually received new IDs │ │ +│ │ write hierarchy rows (lineage + new cross-chunk edges) │ │ +│ └───────────────────────────────────────────────────────────────┘ │ +│ │ +│ refresh source/sink IDs: │ +│ look up the new IDs in the in-memory split output │ +│ (bit-identical to what just landed on storage; no extra read) │ +│ │ +│ multicut (retry against post-split graph): │ +│ result ─► clean cut → edges to remove │ +│ │ still-split-needed → surface precondition error │ +│ │ +│ commit the cut: │ +│ remove graph-level edges │ +│ produce new roots │ +│ write hierarchy rows + operation log │ +└───────────────────────────────────────────────────────────────────────┘ + │ + ▼ +Release root lock — edit is durable + │ + ▼ +Publish pubsub message; when an SV split ran it carries the list of +base-resolution bounding boxes that were rewritten + │ + ▼ +┌───────────────────────────────────────────────────────────────┐ +│ Async downsample worker │ +│ │ +│ partition each published bbox into pyramid blocks │ +│ (cube regions aligned to the coarsest MIP's chunk grid; │ +│ two distinct blocks never share a storage chunk at │ +│ any MIP level) │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ PYRAMID BLOCK LOCK (separate lock family from L2) │ │ +│ │ │ │ +│ │ for each pyramid block: │ │ +│ │ read base resolution │ │ +│ │ downsample through every coarser MIP │ │ +│ │ write only tiles whose footprint intersects │ │ +│ │ a published bbox │ │ +│ └──────────────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────┘ +``` + +### Notes on the flow + +- **"SV split required" is a return value, not an exception.** The multicut returns one of several tagged outcomes so the caller dispatches with a straight branch. Nothing uses raise/catch for control flow, which is what allows the root lock to stay held across the detect-then-split-then-commit sequence without the exception unwinding the lock. + +- **The cross-chunk-representative mapping** comes out of the multicut for free: as part of building its local graph it stitches every cross-chunk-connected group of graph-level supervoxels into one node and records the mapping. That map tells the split step which supervoxels are artificially-cut pieces of one physical SV, and which of them sit on a source→sink bridge. + +- **The split is per-representative.** If two unrelated physical supervoxels both need splitting in one edit (rare but possible), each is handled in its own pass under the same L2 chunk lock. + +## Concurrency design + +Two races exist at the segmentation layer even with root locks in place: + +- **Same-root race.** Without care, the root lock could drop between "detect split needed" and "perform split", letting another edit on the same root slip in and race for the same supervoxel pieces. +- **Cross-root spatial race.** Two edits on entirely distinct roots can target supervoxels whose pieces live in overlapping chunks. Root locks don't serialize them; segmentation writes would clobber each other. + +The split flow closes both: + +- **Root lock scope covers the full operation.** Detection, supervoxel-level split, retry detection, commit — all under one root lock. Same-root interleaving is impossible; any other edit on the root waits for this one to finish. + +- **L2 chunk lock covers the supervoxel-level split only.** Inside the root lock, the split step additionally acquires a spatial lock on every L2 chunk it will read or write. Keyed by chunk, so edits on different roots but overlapping chunks serialize here. Released as soon as the split writes land; the graph-level commit afterwards runs under the root lock alone. + +### How the spatial lock set is computed + +For each cross-chunk representative being split, take the envelope of the chunks its cross-chunk-connected pieces live in — the exact chunks whose voxels will be rewritten. Expand by one voxel (= at most one L2 chunk of margin in each direction), because the edge-routing step reads a 1-voxel shell outside the rewritten region to see neighboring supervoxels' labels. Union the per-representative chunk sets, sort deterministically so workers with overlapping sets never acquire in opposite orders, lock once. + +The envelope comes from the supervoxels' own chunk coordinates — no coordinate padding, no resolution-axis assumption. The chunks locked are exactly the chunks the split will touch, plus the 1-chunk margin the shell read requires. + +### How the write scope is kept minimal + +Only chunks that actually receive new supervoxel IDs get written to storage. Gap chunks that happen to sit inside an envelope but contain no cross-chunk-connected pieces, and neighbor chunks read only for the edge-routing shell, are never written. The segmentation backend is append-only, so writing unchanged bytes would inflate the on-disk delta for no real change. Writing exactly the changed chunks keeps the delta proportional to the user's edit. + +### Why the post-split ID refresh is safe without an extra read + +After the split lands, the caller-supplied source and sink supervoxel IDs reference now-superseded supervoxels. The retry multicut needs the *current* IDs at the source and sink pixels — the subgraph fetch returns only live supervoxels, so a mincut asking about superseded ones would fail to find its endpoints. + +The in-memory segmentation block produced by the split is bitwise identical to what was just written to storage, and the storage write is synchronous (we wait for it) and happens under the L2 chunk lock (so nothing else can have mutated those voxels). Looking up source/sink coords in that block returns the same IDs a storage re-read would — no extra round-trip needed. + +### Worker crash mid-write + +A worker that dies — or raises from the persist block — inside the indefinite L2 chunk lock's scope leaves the lock cells set and the op-log row's `L2ChunkLockScope` populated with the exact chunks being written. Future ops on any of those chunks refuse to start — the crashed state is isolated, not amplified. An operator runs the recovery flow described in [sv_splitting_recovery.md](sv_splitting_recovery.md) to revert the partial writes and replay the op. + +## Invariants + +- A supervoxel split and its graph-level commit are one atomic operation. Either both land or neither does, under a single root lock. +- Within the supervoxel-split step, concurrent splits on overlapping L2 chunks serialize. No two operations write segmentation to the same chunk at the same time. +- Supervoxel-level writes touch only chunks whose voxels actually changed. Gap chunks between cross-chunk-connected pieces and neighbor chunks read for edge routing are untouched. +- After the commit, readers at the operation's timestamp see new supervoxel IDs in the cut region and new roots reflecting the cut. +- Coarser MIP levels are eventually consistent with the base scale, lagging at most until the async downsample worker processes the operation's pubsub message. diff --git a/docs/sv_splitting_edges.md b/docs/sv_splitting_edges.md new file mode 100644 index 000000000..d19bc08e8 --- /dev/null +++ b/docs/sv_splitting_edges.md @@ -0,0 +1,132 @@ +# Edge updates after a supervoxel split + +## Context + +A supervoxel split rewrites voxels inside a bbox: a single old SV is replaced by N new fragments (one per chunk × per side of the cut). Every atomic edge that referenced the old SV — to neighbors inside the same root, to neighbors in a different root, and to other pieces of the same physical supervoxel — must now reference an appropriate fragment instead, or the graph hierarchy diverges from the new segmentation. + +Edge update is the second half of `split_supervoxel`. The first half produced a labeled bbox, an `old_new_map` (`old_sv_id → set[new_sv_ids]`), and a `new_id_label_map` (`new_sv_id → cut-side label`). This document covers what happens from there. + +## Algorithm overview + +``` +inputs from voxel-level split + ├─ new_seg bbox volume with new SV IDs in place of old + ├─ old_new_map which old SVs got split, and into which new IDs + └─ new_id_label_map for each new ID, which side of the cut it's on + +update_edges (edges_sv.py): + 1. fetch atomic subgraph inside bbox, rooted at the rep's root + 2. dedupe edges, drop self-loops + 3. group by partner-root vs split-root → active / inactive + 4. for each old SV: + inactive partners → broadcast edge to every fragment + active partners → expand split partners, match by label/proximity + intra-fragment → low-affinity edges between every fragment pair + 5. validate (no cross-label inf bridges, no self-loops, completeness) + 6. return new (edges, affinities, areas) + +add_new_edges (edges_sv.py): + 1. duplicate bidirectional, group by L2 parent chunk + 2. per chunk: append to SplitEdges (history) and rewrite + CompactedSplitEdges (snapshot, with stale rows filtered) +``` + +## Inputs to `update_edges` + +- `cg, root_id, bbox` — the rep's root and the bbox the voxel-level cut acted on. +- `new_seg` — segmentation in the read window (bbox + 1-voxel shell). The shell is what makes anchor lookups work for unsplit pieces of the rep on the other side of a chunk boundary; without it, cross-chunk edges from those pieces would route to whatever happens to lie at the boundary face, not to the actual fragment the cut produced. +- `old_new_map` — drives which edges need re-routing. +- `new_id_label_map` — used to pair fragments with the same cut-side label across cross-chunk edges. + +`update_edges` calls `cg.get_subgraph(root_id, bbox, bbox_is_coordinate=True)`. This returns every atomic edge whose endpoint sits in the bbox under the rep's root. That set already includes both intra-cut edges (between split SVs) and the cross-chunk-shell edges to neighbors outside the rewritten region. + +After fetch, edges are sorted within each pair, deduped, and self-loops filtered. The remaining set is the input to classification. + +## Classification + +For each edge, the partner's root determines the routing path. `sv_root_map` is built from one batched `cg.get_roots(...)` over all unique partners. + +### Inactive partner (`partner_root != root_id`) + +The partner sits in a different agglomerated object. The split's cut-side has no semantic relationship to that neighbor — *any* fragment of the old SV that touched the neighbor's voxels still touches them after the split. **Broadcast**: for each old SV split into N fragments, copy the edge to every fragment, preserving affinity and area. + +This intentionally over-creates edges. They cost nothing if both endpoints stay in different roots forever; they collapse harmlessly into a single root-level edge if the two roots later merge. + +### Active partner (`partner_root == root_id`) + +The partner is inside the same agglomerated object as the rep — the partner is either: + +- another piece of the same physical SV (cross-chunk-connected), +- a different SV in the same root reachable via L2 hierarchy. + +For active partners, edges are routed based on affinity type: + +#### Inf-affinity, partner also split + +The partner SV is itself in `old_new_map` (e.g. it's another piece of the rep that was rewritten). We need each new fragment of the old SV to connect to the *matching-label* new fragment of the partner — the one on the same cut-side. `_match_by_label` does this lookup via `new_id_label_map`. If no fragment of the partner shares the source SV's label (rare, indicates a partial split), fallback to the closest fragment by distance. + +#### Inf-affinity, partner unsplit + +This is the cross-chunk edge to a piece of the rep that the bbox didn't include — by construction with the seed-driven bbox, these are the rep's far-away pieces that keep their old IDs. The unsplit partner has no `new_id_label_map` entry. + +**Critical**: do *not* broadcast this edge to all fragments. An unsplit partner connected via inf-affinity to fragments on both sides of the cut would form an uncuttable bridge — a future mincut on this object would route through `frag_a → unsplit_partner → frag_b` with infinite affinity and never separate them. So `_match_inf_unsplit` assigns the edge to exactly one fragment: the one closest to the partner. + +`validate_split_edges` enforces this with check (A): no inf-affinity edge from an unsplit partner to fragments with different cut-side labels. + +#### Finite-affinity (regular) + +Real adjacency edges between SVs based on per-pair affinity. `_match_by_proximity` assigns the edge to *every* fragment within `cg.meta.sv_split_threshold` voxels of the partner, fallback to closest if none qualify. Multiple fragments may legitimately neighbor the partner; the threshold preserves the original adjacency where it actually exists. + +### Intra-fragment edges + +For each old SV split into multiple new fragments, add a low-affinity (0.001) edge between every pair of fragments. These are cuttable by future mincut operations — they record that the fragments share a graph-level neighborhood (they came from the same SV) without forcing them to stay agglomerated. Without these edges, an entirely-disconnected fragment of an old SV would have no link to the rest of the object; with them, the standard mincut machinery handles the relationship correctly. + +## Distance computation + +Distances drive both proximity matching and the closest-fragment fallback. Each fragment gets a `cKDTree` over its voxels (built from `build_coords_by_label(new_seg)`). + +- **Partner inside bbox**: build a kdtree on the partner's voxels too. For each fragment, the smaller-tree-queries-larger heuristic minimizes work; result is the minimum voxel distance. +- **Active partner outside bbox**: the partner's voxels aren't in `new_seg`. `_compute_boundary_distances` uses the partner's chunk coordinate to determine which face of the source chunk the edge crosses, then measures each fragment's distance to that boundary plane. This is an over-estimate for non-boundary-aligned partners but it's the only signal available without extra reads. + +## Validation + +`validate_split_edges` checks four invariants and raises `PostconditionError` on any violation. Failures abort the operation cleanly under the indefinite L2 chunk lock; the recovery flow then handles cleanup. + +| Check | Why | +|-------|-----| +| (A) No inf-affinity bridges between cut-sides via an unsplit partner | Would be uncuttable by future mincuts | +| (B) No self-loops | Indicates a routing bug; would skew degree counts and break some traversal assumptions | +| (C) Every old SV has at least one replacement edge from its fragments | Catches old SVs that vanished from the edge set entirely (would orphan them in the hierarchy) | +| (D) All fragment pairs of each old SV are connected | Confirms the intra-fragment low-affinity edges were emitted | + +These run before any bigtable write, so the validation is the last line of defense before the writes commit under the lock. + +## Persisting: `add_new_edges` + +The new edges are batched into bigtable per L2 chunk. Two columns get written per chunk per op: + +### `SplitEdges` (history) + +An append-only column. Each split op writes its new edges as a fresh cell with the op's logical timestamp. Time-travel reads at any timestamp T walk all cells with `ts ≤ T`, then apply the stale-edge resolution path to filter out edges whose endpoints have been superseded by later ops. This is the authoritative store for historical reads. + +### `CompactedSplitEdges` (snapshot) + +A latest-only column for fast current-time reads. On each op: + +1. Read the previous compacted cell (if any) plus its matching `CompactedAffinity` and `CompactedArea`. +2. Filter out rows whose endpoints reference any old SV in `old_new_map.keys()` (these are the SVs that just got split — their edges are stale). +3. Concatenate the new rows. +4. Write the whole thing as one fresh cell. + +Current-time readers can take this single cell directly without history walks or stale-edge resolution. + +The chunk grouping uses each edge's first endpoint's L2 parent chunk: `cg.get_chunk_ids_from_node_ids(cg.get_parents(nodes))`. Parent chunks (not the SV's own L1 chunk) is the correct routing — the edge belongs to the chunk where its endpoint lives in the L2 hierarchy. Bidirectional duplication ensures every edge is owned by both endpoints' chunks; readers picking up either side find it. + +Both writes use `time_stamp=task.operation_ts`, so all rows from one op land at the same logical time. Concurrent SV-splits on disjoint chunks don't interfere because they write disjoint chunk rows. + +## Invariants + +- For every old SV in `old_new_map`, every atomic edge that referenced it in the pre-split graph has at least one corresponding edge among its fragments after the split. +- No inf-affinity edge crosses cut-sides through an unsplit partner. +- Every cross-chunk piece of the rep that the bbox didn't include keeps its old ID and its existing edges resolve unchanged (because no edge in those rows references the now-split SVs at endpoints — the routing only touches edges whose endpoints are in the bbox or its 1-voxel shell). +- `SplitEdges` and `CompactedSplitEdges` agree at the latest timestamp: the compacted snapshot is the result of replaying the history through the stale-edge filter. diff --git a/docs/sv_splitting_recovery.md b/docs/sv_splitting_recovery.md new file mode 100644 index 000000000..cc9efaa57 --- /dev/null +++ b/docs/sv_splitting_recovery.md @@ -0,0 +1,66 @@ +# Supervoxel split recovery + +## What it is + +A recovery path for supervoxel-split operations that crash mid-write, leaving partial state in segmentation and per-chunk locks held indefinitely. An operator runs a one-shot command that reverts the crashed op's partial segmentation writes and re-runs the op from scratch, producing a clean successful edit and freeing the affected chunks for future work. + +## When it applies + +Every supervoxel split writes two things atomically from the point of view of the operation: +- Segmentation chunks — the voxel-level split, where each L2 chunk touched by the split gets fresh supervoxel IDs at the voxels that moved to a new fragment. +- Graph hierarchy rows — lineage from the old supervoxels to the new ones plus the cross-chunk edges linking the new fragments. + +Both writes happen under an indefinite L2 chunk lock covering the exact chunks being rewritten. If the worker running the op dies before the lock's context exits — process kill, pod eviction, hardware failure, OOM — or raises a caught exception inside the persist block, the lock stays held and the op-log row's `L2ChunkLockScope` stays populated with the affected chunk IDs. From that moment on, any new op whose chunk set overlaps the stuck op's chunks refuses to start, blocking further corruption. + +The authoritative signal that an op is stuck is `L2ChunkLockScope` being non-empty — the clean `__exit__` path clears it on success. Either a crash (`Status=CREATED`, exit never ran) or a caught exception (`Status=FAILED`, held-cells-on-exception path) keeps the scope set. + +The operator runs recovery when the lock has been held long enough that the worker is definitively gone, not merely slow. A minimum-age threshold (10 minutes is a reasonable default) distinguishes stuck ops from ops still in flight. + +## Concurrent edits on other regions keep working + +The indefinite lock is per L2 chunk. While op X is stuck on chunks `{C1, C2, C3}`, another op Y on chunks `{C4, C5}` sees no indefinite cell on its chunks and proceeds normally. Its writes advance the latest OCDBT manifest. By the time the operator gets to recovery, the manifest has moved past the stuck op's `OperationTimeStamp` and other regions of segmentation reflect Y's (and any subsequent ops') work. + +This is important: the recovery must not undo Y's changes. It also cannot rely on a single "read pre-op segmentation" pin because that would return pre-Y state outside the stuck op's own chunks, and the replay would overwrite neighbor state with stale values. + +## Why pre-op pinning is not enough on its own + +A supervoxel split reads more than just its own chunks. To route existing cross-chunk edges onto the new fragments, the split reads a one-voxel shell around its chunk envelope — supervoxel IDs from neighboring chunks serve as anchors for the re-routed edges. + +If the replay opened segmentation with a pin at the stuck op's `OperationTimeStamp` and then read that shell, the neighbor voxels would show their pre-op state, not their current state. If Y had split a supervoxel in one of those neighbor chunks in the interim, the pinned read would see the old neighbor IDs, and the replay would route its cross-chunk edges to supervoxel IDs that no longer exist. Graph corruption. + +So the replay cannot read the world through a single pinned view. It must see latest state for the neighbor shell and clean pre-op state for its own chunks. + +## Cleanup-then-replay + +Recovery proceeds in two steps. + +**Cleanup.** For each chunk in the stuck op's durably-recorded scope, the operator reads the chunk's pre-op voxel values from a segmentation handle pinned at the op's `OperationTimeStamp`, and writes those values back to the latest (unpinned) handle. The result: those chunks, at the latest manifest, now show pre-op segmentation — as if the crashed op had never started. Neighbor chunks and every chunk outside the stuck op's scope are untouched, so any concurrent op's work is preserved. + +**Replay.** The operator then re-runs the op under the privileged-repair path. The run reads latest state, which is now consistent — clean pre-op values on the stuck op's own chunks, current state everywhere else — and goes through the normal edit flow. It allocates fresh supervoxel IDs, re-computes the split, writes new segmentation and hierarchy, and lands the op-log row at `SUCCESS`. + +When the replay's indefinite lock context exits, it issues value-matched releases on every chunk in the scope. Because the replay re-uses the crashed op's operation ID, the value match succeeds and the pre-existing indefinite cells are deleted. The chunks are free for new ops again. + +## Orphans in segmentation history + +OCDBT is append-only. The crashed op's partial segmentation writes still exist in the store's commit history — they are not deleted, only overshadowed. At the latest manifest, the cleanup step has overwritten them with pre-op values, so a normal (unpinned) read returns the pre-op state and the replay's fresh writes take effect on top of that. Readers that explicitly pin a historical version between the crash and the replay will still see the partial writes as a snapshot, but readers at latest never observe them. + +The orphan supervoxel IDs allocated by the crashed op are never referenced by any hierarchy row — the crashed op never wrote its hierarchy rows to completion, and the replay allocated a new set of IDs. From the graph's perspective those orphan IDs do not exist. + +## Operator workflow + +1. **List stuck ops.** The operator runs the list command with a minimum-age threshold. It returns op-log rows whose `L2ChunkLockScope` is still populated past that age (excluding any that have reached `SUCCESS`), along with each op's user ID, timestamp, age, status, and the number of chunks in its recorded scope. Ops too young to classify are skipped. + +2. **Inspect.** For each candidate, confirm from logs or monitoring that the worker that submitted the op is definitively dead — not, for example, paused on a long-running multicut. The minimum-age threshold exists to reduce false positives but the operator retains final judgment. + +3. **Replay.** The operator runs the replay command with the op ID. Before any destructive step the replay cross-checks the recorded scope against live lock state: for every chunk in `L2ChunkLockScope` it reads back the `Concurrency.IndefiniteLock` cell and verifies it's held by the op being replayed. Any discrepancy (cell missing, or held by a different op) aborts the replay loudly — a stale scope could otherwise have cleanup revert chunks another op legitimately owns. On clean verification, cleanup reverts the op's partial writes, then the privileged-repair path reruns the op. On success, the op-log row shows `SUCCESS` and the previously-held indefinite lock cells are released. + +4. **Verify.** A second list invocation should no longer include the op. Any new ops that were waiting on the affected chunks proceed. + +If the replay itself fails — for example, the operator's judgment about the worker's status was wrong and the original worker comes back — the replay surfaces the error and leaves the op-log row and lock state as it found them. The operator investigates, potentially clears the lock manually via direct bigtable tools, and tries again. + +## Invariants + +- A stuck op's durable scope record (written before any segmentation or hierarchy write begins) lets recovery locate every chunk that might have received a partial write, without a bigtable-wide scan. +- Cleanup only touches chunks in the stuck op's scope. Neighbor state and any concurrent ops' changes are preserved byte-for-byte. +- The replay sees a consistent world: pre-op values on the stuck op's own chunks (from the cleanup), current state on every other chunk (from the latest manifest). +- After successful replay, the op-log row is at `SUCCESS`, all indefinite cells previously held by that op are released, and the affected chunks are available to new ops. The op's original intent — the edit the user asked for — is realized with a fresh set of supervoxel IDs. diff --git a/pychunkedgraph/app/app_utils.py b/pychunkedgraph/app/app_utils.py index 9d69c3650..888b39ae4 100644 --- a/pychunkedgraph/app/app_utils.py +++ b/pychunkedgraph/app/app_utils.py @@ -229,20 +229,28 @@ def ccs(coordinates_nm_): return ccs coordinates = np.array(coordinates, dtype=int) - coordinates_nm = coordinates * cg.meta.resolution - max_dist_steps = np.array([4, 8, 14, 28], dtype=float) * np.mean(cg.meta.resolution) - - node_ids = np.array(node_ids, dtype=np.uint64) if len(coordinates.shape) != 2: raise cg_exceptions.BadRequest( f"Could not determine supervoxel ID for coordinates " f"{coordinates} - Validation stage." ) - # Fast path: all node_ids are L1 and OCDBT — single seg read for all coords - if cg.meta.ocdbt_seg and np.all(cg.get_chunk_layers(np.unique(node_ids)) == 1): + # OCDBT: always read the current segmentation at the click coords, + # regardless of node_ids layer. + # - 2D slice click: NG sends `node_id` = L1 SV from the slice view. + # That slice can be stale after an SV split; the seg read returns + # the current SV at that voxel (which may have a different root). + # - 3D mesh click: NG sends `node_id` = root; no L1 SV is attached, + # so we have to look it up against current seg anyway. + # `node_ids` are not used as a constraint here. Stale UI surfaces + # downstream as "different roots" with the sv_id->root diagnostic + # mapping added in operation.py / cutting.py. + if cg.meta.ocdbt_seg: return lookup_svs_from_seg(cg.meta, coordinates) + coordinates_nm = coordinates * cg.meta.resolution + max_dist_steps = np.array([4, 8, 14, 28], dtype=float) * np.mean(cg.meta.resolution) + node_ids = np.array(node_ids, dtype=np.uint64) atomic_ids = np.zeros(len(coordinates), dtype=np.uint64) for node_id in np.unique(node_ids): node_id_m = node_ids == node_id diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 0ff758c2d..33a232547 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -2,6 +2,7 @@ import json import os +import pickle import time from datetime import datetime, timezone from functools import reduce @@ -9,8 +10,8 @@ import numpy as np import pandas as pd -import fastremap from flask import current_app, g, jsonify, make_response, request +from messagingclient import MessagingClient from pytz import UTC from pychunkedgraph import __version__, get_logger @@ -25,7 +26,6 @@ exceptions as cg_exceptions, ) from pychunkedgraph.graph.analysis import pathing -from pychunkedgraph.graph.edits_sv import split_supervoxel from pychunkedgraph.graph.misc import get_contact_sites from pychunkedgraph.debug.sv_split import check_unsplit_sv_bridges from pychunkedgraph.graph.operation import GraphEditOperation @@ -322,15 +322,13 @@ def publish_edit( is_priority=True, remesh: bool = True, ): - import pickle - - from messagingclient import MessagingClient - + downsample = bool(result.seg_bbox) attributes = { "table_id": table_id, "user_id": user_id, "remesh_priority": "true" if is_priority else "false", "remesh": "true" if remesh else "false", + "downsample": "true" if downsample else "false", } payload = { "operation_id": int(result.operation_id), @@ -338,6 +336,13 @@ def publish_edit( "new_root_ids": result.new_root_ids.tolist(), "old_root_ids": result.old_root_ids.tolist(), } + if downsample: + # Each entry is the base-resolution bbox of one supervoxel split's + # writes. Kept as a list (not merged) so the worker only rewrites + # tiles whose base footprint actually changed. + payload["seg_bboxes"] = [ + [bbs.tolist(), bbe.tolist()] for bbs, bbe in result.seg_bbox + ] exchange = os.getenv("PYCHUNKEDGRAPH_EDITS_EXCHANGE", "pychunkedgraph") c = MessagingClient() @@ -434,84 +439,6 @@ def _get_sources_and_sinks(cg: ChunkedGraph, data): return (source_ids, sink_ids, source_coords, sink_coords) -def split_with_sv_splits(cg, data, user_id="test", mincut=True): - """Remove edges with automatic supervoxel splitting when needed. - - Attempts remove_edges. If source/sink SVs share a cross-chunk representative, - splits the overlapping SVs in the segmentation and retries. - """ - sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) - logger.note(f"pre-split: sources={sources}, sinks={sinks}") - t0 = time.time() - try: - ret = cg.remove_edges( - user_id=user_id, - source_ids=sources, - sink_ids=sinks, - source_coords=source_coords, - sink_coords=sink_coords, - mincut=mincut, - ) - logger.note(f"remove_edges ({time.time() - t0:.2f}s)") - except cg_exceptions.SupervoxelSplitRequiredError as e: - logger.note(f"sv split required ({time.time() - t0:.2f}s): {e}") - sources_remapped = fastremap.remap( - sources, - e.sv_remapping, - preserve_missing_labels=True, - in_place=False, - ) - sinks_remapped = fastremap.remap( - sinks, - e.sv_remapping, - preserve_missing_labels=True, - in_place=False, - ) - logger.note(f"remapped sources={sources_remapped}, sinks={sinks_remapped}") - overlap_mask = np.isin(sources_remapped, sinks_remapped) - logger.note(f"overlapping reps: {np.unique(sources_remapped[overlap_mask])}") - t1 = time.time() - for rep in np.unique(sources_remapped[overlap_mask]): - _mask0 = sources_remapped == rep - _mask1 = sinks_remapped == rep - split_supervoxel( - cg, - sources[_mask0][0], - source_coords[_mask0], - sink_coords[_mask1], - e.operation_id, - sv_remapping=e.sv_remapping, - ) - logger.note(f"sv splits done ({time.time() - t1:.2f}s)") - - sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) - logger.note(f"post-split: sources={sources}, sinks={sinks}") - t1 = time.time() - try: - ret = cg.remove_edges( - user_id=user_id, - source_ids=sources, - sink_ids=sinks, - source_coords=source_coords, - sink_coords=sink_coords, - mincut=mincut, - ) - except cg_exceptions.SupervoxelSplitRequiredError as e2: - # The cross-chunk representative group extends beyond the split - # bbox. Unsplit SVs inside the bbox still have inf edges to SVs - # outside, bridging source and sink through the broader component. - - logger.note(f"retry still requires sv split") - # check_unsplit_sv_bridges(cg, e2.sv_remapping, sources, sinks) - raise cg_exceptions.PreconditionError( - "Supervoxel split succeeded but the split region is too small " - "to fully separate source and sink. " - "Try placing source and sink points farther apart." - ) from e2 - logger.note(f"remove_edges after sv split ({time.time() - t1:.2f}s)") - return ret - - def handle_split(table_id): current_app.table_id = table_id user_id = str(g.auth_user.get("id", current_app.user_id)) @@ -523,8 +450,17 @@ def handle_split(table_id): cg = app_utils.get_cg(table_id, skip_cache=True) current_app.logger.debug(data) + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + logger.note(f"split inputs: sources={sources}, sinks={sinks}") try: - ret = split_with_sv_splits(cg, data, user_id, mincut) + ret = cg.remove_edges( + user_id=user_id, + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, + mincut=mincut, + ) except cg_exceptions.LockingError as e: raise cg_exceptions.InternalServerError(e) except cg_exceptions.PreconditionError as e: diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py index e49cc9ded..95031c6d9 100644 --- a/pychunkedgraph/graph/cutting.py +++ b/pychunkedgraph/graph/cutting.py @@ -5,7 +5,8 @@ import graph_tool import graph_tool.flow -from typing import Tuple +from dataclasses import dataclass +from typing import Tuple, Union from typing import Sequence from typing import Iterable @@ -19,6 +20,40 @@ DEBUG_MODE = False +@dataclass +class Cut: + """Multicut produced a clean partition — these SV-pair edges are to be cut.""" + + atomic_edges: np.ndarray # shape (N, 2) + + +@dataclass +class PreviewCut: + """Multicut in preview mode — connected components after the proposed cut. + + `illegal_split` flags cases where the cut isolates source or sink. + """ + + supervoxel_ccs: list + illegal_split: bool + + +@dataclass +class SvSplitRequired: + """Multicut could not partition without first splitting a supervoxel. + + Carries the cross-chunk-representative remapping the caller needs to + run the actual SV split. Returned (not raised) from run_multicut; the + SupervoxelSplitRequiredError that surfaces this condition is caught + inside run_multicut and never escapes as control flow. + """ + + sv_remapping: dict # old_sv_id -> rep_sv_id + + +MulticutResult = Union[Cut, PreviewCut, SvSplitRequired] + + class IsolatingCutException(Exception): """Raised when mincut would split off one of the labeled supervoxel exactly. This is used to trigger a PostconditionError with a custom message. @@ -668,21 +703,38 @@ def run_multicut( path_augment: bool = True, disallow_isolating_cut: bool = True, sv_split_supported: bool = False, -): - local_mincut_graph = LocalMincutGraph( - edges.get_pairs(), - edges.affinities, - source_ids, - sink_ids, - split_preview, - path_augment, - disallow_isolating_cut=disallow_isolating_cut, - sv_split_supported=sv_split_supported, - ) - atomic_edges = local_mincut_graph.compute_mincut() - if len(atomic_edges) == 0: +) -> MulticutResult: + """Run the multicut and return either the cut edges or an SV-split request. + + When `sv_split_supported=True`, the "source and sink share a cross-chunk + rep" condition is returned as `SvSplitRequired` rather than raised — + `SupervoxelSplitRequiredError` is an implementation detail of + `LocalMincutGraph` unwinding, caught at this boundary so it never + drives control flow in callers. + """ + try: + local_mincut_graph = LocalMincutGraph( + edges.get_pairs(), + edges.affinities, + source_ids, + sink_ids, + split_preview, + path_augment, + disallow_isolating_cut=disallow_isolating_cut, + sv_split_supported=sv_split_supported, + ) + mincut_output = local_mincut_graph.compute_mincut() + except SupervoxelSplitRequiredError as err: + return SvSplitRequired(err.sv_remapping) + + if split_preview: + # compute_mincut returns (ccs, illegal_split) in preview mode. + supervoxel_ccs, illegal_split = mincut_output + return PreviewCut(supervoxel_ccs, illegal_split) + + if len(mincut_output) == 0: raise PostconditionError(f"Mincut failed. Try with a different set of points.") - return atomic_edges + return Cut(mincut_output) def run_split_preview( @@ -695,11 +747,15 @@ def run_split_preview( path_augment: bool = True, disallow_isolating_cut: bool = True, ): - root_ids = set( - cg.get_roots(np.concatenate([source_ids, sink_ids]), assert_roots=True) - ) + sink_and_source_ids = np.concatenate([source_ids, sink_ids]) + roots = cg.get_roots(sink_and_source_ids, assert_roots=True) + root_ids = set(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sources={list(source_ids)} sinks={list(sink_ids)} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) bbox = get_bounding_box(source_coords, sink_coords, bb_offset) l2id_agglomeration_d, edges = cg.get_subgraph( @@ -713,7 +769,7 @@ def run_split_preview( mask0 = np.isin(edges.node_ids1, supervoxels) mask1 = np.isin(edges.node_ids2, supervoxels) edges = edges[mask0 & mask1] - edges_to_remove, illegal_split = run_multicut( + result = run_multicut( edges, source_ids, sink_ids, @@ -722,8 +778,14 @@ def run_split_preview( disallow_isolating_cut=disallow_isolating_cut, sv_split_supported=cg.meta.ocdbt_seg, ) + if isinstance(result, SvSplitRequired): + # Preview callers can't perform an SV split; surface as a precondition. + raise PreconditionError( + "Supervoxel split required to cut these source/sink points; " + "preview is not available until an edit is applied." + ) - if len(edges_to_remove) == 0: + assert isinstance(result, PreviewCut), f"unexpected preview result type: {result!r}" + if len(result.supervoxel_ccs) == 0: raise PostconditionError("Mincut could not find any edges to remove.") - - return edges_to_remove, illegal_split + return result.supervoxel_ccs, result.illegal_split diff --git a/pychunkedgraph/graph/downsample.py b/pychunkedgraph/graph/downsample.py new file mode 100644 index 000000000..7a28305c4 --- /dev/null +++ b/pychunkedgraph/graph/downsample.py @@ -0,0 +1,342 @@ +"""Async mip-pyramid downsample worker support. + +An SV split writes at base resolution only; coarser mips are produced +afterwards by a pubsub worker that consumes this module's primitives. + +Work is organized into `pyramid_block`s. A block is a cubic physical +region sized so that at the coarsest scale in the pyramid it equals +exactly one storage chunk. Because every finer scale's chunk grid is a +power-of-2 refinement of the coarsest, a block aligned at the coarsest +scale is automatically aligned at every finer scale — so two different +blocks never share a storage chunk at any mip. That is what makes a +single lock per block safe. + +Within a block we pick one of two code paths: + 1. Fast in-memory path: read the affected base region once, call + tinybrain with `num_mips=K` (all mips at once), write each mip's + output. Used when the base read fits a memory budget — the typical + case because the SV-split bbox is bounded by the /split endpoint + (source+sink coords + small padding). + 2. Per-mip fallback: read the previous mip, tinybrain one step, write. + K storage round-trips instead of 1. Kept for pathological inputs + whose base read would exceed the memory budget. + +Uniform downsample factor (e.g. 2x2x2) across all non-base scales is +assumed and asserted. +""" + +import numpy as np +import tinybrain + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) + +# Default memory budget for the in-memory path's base read. +# uint64 segmentation is 8 bytes/voxel; 1 GiB ≈ 512^3 voxels. Edits +# produced by the /split endpoint are bounded far below this. +DEFAULT_MEMORY_BUDGET_BYTES = 1 << 30 + + +def num_output_mips(meta) -> int: + """Count of non-base scales — what the worker actually writes.""" + return len(meta.ws_ocdbt_scales) - 1 + + +def uniform_factor(meta) -> tuple: + """Per-axis downsample factor between consecutive scales. + + tinybrain takes one factor tuple per call, so the factor must be + constant across the pyramid. Asserts rather than silently producing + wrong mips for a dataset with mixed factors. + """ + resolutions = [np.array(r, dtype=float) for r in meta.ws_ocdbt_resolutions] + factors = [ + tuple((resolutions[i] / resolutions[i - 1]).astype(int)) + for i in range(1, len(resolutions)) + ] + assert all( + f == factors[0] for f in factors + ), f"non-uniform downsample factors {factors}" + return factors[0] + + +def _chunk_size_at_scale(meta, scale_idx: int) -> np.ndarray: + """Storage chunk size at a given scale (excluding the channel dim).""" + return np.array( + meta.ws_ocdbt_scales[scale_idx].chunk_layout.read_chunk.shape[:3], dtype=int + ) + + +def block_shape(meta) -> np.ndarray: + """pyramid_block size in base-resolution voxels. + + Chosen so that at the coarsest scale K the block equals exactly one + storage chunk — which transitively aligns it to every finer scale's + chunk grid. + """ + K = num_output_mips(meta) + coarsest_chunk = _chunk_size_at_scale(meta, K) + factor = np.array(uniform_factor(meta), dtype=int) + return coarsest_chunk * factor**K + + +def blocks_for_bbox(meta, bbs, bbe) -> list: + """Block coords intersected by a base-resolution bbox. + + Bbox is rounded outward to the block grid — a tiny bbox inside one + block still yields that one block coord. Returns sorted list of + `(bx, by, bz)` ints for deadlock-free lock acquisition. + """ + shape = block_shape(meta) + lo = np.asarray(bbs, dtype=int) // shape + hi = -(-np.asarray(bbe, dtype=int) // shape) + coords = [ + (int(bx), int(by), int(bz)) + for bx in range(lo[0], hi[0]) + for by in range(lo[1], hi[1]) + for bz in range(lo[2], hi[2]) + ] + return sorted(coords) + + +def block_base_bbox(meta, block_coord) -> tuple: + """Inverse of `blocks_for_bbox` for a single coord — base-voxel bbox.""" + shape = block_shape(meta) + lo = np.asarray(block_coord, dtype=int) * shape + hi = lo + shape + return lo, hi + + +def _seg_bboxes_to_np(seg_bboxes): + return [ + (np.asarray(bbs, dtype=int), np.asarray(bbe, dtype=int)) + for bbs, bbe in seg_bboxes + ] + + +def _affected_region_base(meta, block_coord, seg_bboxes_np): + """Base-voxel region covering all tiles this block will write, at any mip. + + Starts from the union of (seg bbox ∩ block ∩ volume) then aligns + outward to the coarsest mip's base-voxel grid (= factor**K per axis). + That alignment both makes the region tinybrain-valid for num_mips=K + and guarantees clean chunk-aligned writes at every mip (coarsest + alignment refines down to every finer scale). + + Returns `(base_lo, base_hi)` or `None` if no overlap. + """ + K = num_output_mips(meta) + factor = np.array(uniform_factor(meta), dtype=int) + align = factor**K + + block_lo, block_hi = block_base_bbox(meta, block_coord) + vol_lo = meta.voxel_bounds[:, 0] + vol_hi = meta.voxel_bounds[:, 1] + clipped_lo = np.maximum(block_lo, vol_lo) + clipped_hi = np.minimum(block_hi, vol_hi) + if np.any(clipped_hi <= clipped_lo): + return None + + union_lo, union_hi = None, None + for sb, eb in seg_bboxes_np: + ilo = np.maximum(sb, clipped_lo) + ihi = np.minimum(eb, clipped_hi) + if np.any(ihi <= ilo): + continue + union_lo = ilo if union_lo is None else np.minimum(union_lo, ilo) + union_hi = ihi if union_hi is None else np.maximum(union_hi, ihi) + if union_lo is None: + return None + + base_lo = (union_lo // align) * align + base_hi = -(-union_hi // align) * align + # Keep within the clipped block. Block corners are factor**K-aligned + # (block_shape is a multiple of factor**K), so this clip preserves + # alignment. + base_lo = np.maximum(base_lo, clipped_lo) + base_hi = np.minimum(base_hi, clipped_hi) + if np.any(base_hi <= base_lo): + return None + return base_lo, base_hi + + +def _process_block_in_memory(meta, base_region, K, factor): + """Read base once, tinybrain all mips, write each output. + + Assumes the base region is factor**K-aligned in size (which is what + `_affected_region_base` returns) so tinybrain with num_mips=K emits + clean integer voxel counts at every mip. + """ + base_lo, base_hi = base_region + base = meta.ws_ocdbt_scales[0] + arr = ( + base[ + base_lo[0] : base_hi[0], + base_lo[1] : base_hi[1], + base_lo[2] : base_hi[2], + :, + ] + .read() + .result() + ) + mips = tinybrain.downsample_segmentation( + arr, factor=tuple(int(f) for f in factor), num_mips=K, sparse=False + ) + for m, out in enumerate(mips, start=1): + scale = factor**m + mip_lo = base_lo // scale + mip_hi = base_hi // scale + dst = meta.ws_ocdbt_scales[m] + dst[ + mip_lo[0] : mip_hi[0], + mip_lo[1] : mip_hi[1], + mip_lo[2] : mip_hi[2], + :, + ].write(out).result() + + +def _affected_region_at_mip( + block_lo_base, + block_hi_base, + vol_lo, + vol_hi, + seg_bboxes_base, + mip: int, + factor: np.ndarray, + mip_chunk: np.ndarray, +): + """Write region at this mip in mip-local voxel coords. + + Union of seg bboxes ∩ block ∩ volume, aligned outward to this mip's + storage-chunk grid. Returns `(mip_lo, mip_hi)` or None. + """ + scale = factor**mip + clipped_lo = np.maximum(block_lo_base, vol_lo) + clipped_hi = np.minimum(block_hi_base, vol_hi) + if np.any(clipped_hi <= clipped_lo): + return None + + union_lo, union_hi = None, None + for sb, eb in seg_bboxes_base: + ilo = np.maximum(sb, clipped_lo) + ihi = np.minimum(eb, clipped_hi) + if np.any(ihi <= ilo): + continue + union_lo = ilo if union_lo is None else np.minimum(union_lo, ilo) + union_hi = ihi if union_hi is None else np.maximum(union_hi, ihi) + if union_lo is None: + return None + + mip_lo = union_lo // scale + mip_hi = -(-union_hi // scale) + mip_lo = (mip_lo // mip_chunk) * mip_chunk + mip_hi = -(-mip_hi // mip_chunk) * mip_chunk + + vol_lo_mip = vol_lo // scale + vol_hi_mip = -(-vol_hi // scale) + mip_lo = np.maximum(mip_lo, vol_lo_mip) + mip_hi = np.minimum(mip_hi, vol_hi_mip) + if np.any(mip_hi <= mip_lo): + return None + return mip_lo, mip_hi + + +def _process_block_per_mip(meta, block_coord, seg_bboxes_np, K, factor): + """Fallback path: process one mip at a time. + + Used when the full in-memory base read would exceed the memory + budget. Each mip reads the prior mip from storage, does one + tinybrain step, writes. + + Safe across mip boundaries only because the caller holds the block + lock — no other task can write the storage chunks this block owns, + so reading mip N here always sees what we wrote at mip N in the + previous iteration. + """ + vol_lo = meta.voxel_bounds[:, 0] + vol_hi = meta.voxel_bounds[:, 1] + block_lo_base, block_hi_base = block_base_bbox(meta, block_coord) + + for mip in range(1, K + 1): + mip_chunk = _chunk_size_at_scale(meta, mip) + region = _affected_region_at_mip( + block_lo_base, + block_hi_base, + vol_lo, + vol_hi, + seg_bboxes_np, + mip, + factor, + mip_chunk, + ) + if region is None: + continue + mip_lo, mip_hi = region + src = meta.ws_ocdbt_scales[mip - 1] + src_lo = mip_lo * factor + src_hi = mip_hi * factor + arr = ( + src[ + src_lo[0] : src_hi[0], + src_lo[1] : src_hi[1], + src_lo[2] : src_hi[2], + :, + ] + .read() + .result() + ) + out = tinybrain.downsample_segmentation( + arr, factor=tuple(int(f) for f in factor), num_mips=1, sparse=False + )[0] + dst = meta.ws_ocdbt_scales[mip] + dst[ + mip_lo[0] : mip_hi[0], + mip_lo[1] : mip_hi[1], + mip_lo[2] : mip_hi[2], + :, + ].write(out).result() + + +def process_block( + meta, + block_coord, + seg_bboxes, + memory_budget_bytes: int = DEFAULT_MEMORY_BUDGET_BYTES, +): + """Downsample one pyramid_block through every non-base mip. + + Atomic within the block: caller must hold the block lock. Picks the + in-memory path when the base read fits the memory budget, falls + back to the per-mip path otherwise. + + Reads and writes only the aligned region covering `seg_bboxes` + inside the block; the rest of the block is untouched. Region + alignment rounds outward to the coarsest mip's grid so the aligned + region is always tinybrain-valid and chunk-aligned at every mip. + + Args: + meta: ChunkedGraphMeta with `ws_ocdbt_scales` / `ws_ocdbt_resolutions`. + block_coord: (bx, by, bz) block grid coord. + seg_bboxes: iterable of `(bbs, bbe)` base-voxel bbox pairs from + the SV splits that triggered this job. + """ + K = num_output_mips(meta) + factor = np.array(uniform_factor(meta), dtype=int) + seg_bboxes_np = _seg_bboxes_to_np(seg_bboxes) + + region = _affected_region_base(meta, block_coord, seg_bboxes_np) + if region is None: + return + base_lo, base_hi = region + + bytes_per_voxel = meta.ws_ocdbt_scales[0].dtype.numpy_dtype.itemsize + base_bytes = int(np.prod(base_hi - base_lo)) * bytes_per_voxel + if base_bytes <= memory_budget_bytes: + _process_block_in_memory(meta, region, K, factor) + else: + logger.info( + f"block {block_coord} base read {base_bytes / 1e9:.2f} GB exceeds " + f"budget {memory_budget_bytes / 1e9:.2f} GB; using per-mip path" + ) + _process_block_per_mip(meta, block_coord, seg_bboxes_np, K, factor) diff --git a/pychunkedgraph/graph/edges_sv.py b/pychunkedgraph/graph/edges_sv.py index ea9354990..19a2c2eb3 100644 --- a/pychunkedgraph/graph/edges_sv.py +++ b/pychunkedgraph/graph/edges_sv.py @@ -23,8 +23,8 @@ Distance computation: For partners within the segmentation bbox, distances are precomputed via kdtree pairwise distances. For active partners outside the bbox (e.g. - cross-chunk fragments excluded by _get_whole_sv's bbox clipping), distances - are computed from each new fragment's kdtree to the partner's chunk boundary. + cross-chunk fragments not in the rep's CC member set), distances are + computed from each new fragment's kdtree to the partner's chunk boundary. """ from __future__ import annotations diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py index a0ce5b98b..be605c6e4 100644 --- a/pychunkedgraph/graph/edits_sv.py +++ b/pychunkedgraph/graph/edits_sv.py @@ -3,8 +3,10 @@ """ import time +from dataclasses import dataclass from datetime import datetime -from collections import defaultdict, deque +from collections import defaultdict +from typing import TYPE_CHECKING, List, Tuple import fastremap import numpy as np @@ -12,7 +14,6 @@ from pychunkedgraph import get_logger from pychunkedgraph.graph import ( attributes, - ChunkedGraph, cache as cache_utils, basetypes, serializers, @@ -20,54 +21,301 @@ from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox from pychunkedgraph.graph.cutting_sv import split_supervoxel_helper from pychunkedgraph.graph.edges_sv import update_edges, add_new_edges -from pychunkedgraph.graph.ocdbt import write_seg from pychunkedgraph.graph.utils import get_local_segmentation -from pychunkedgraph.io.edges import get_chunk_edges + +if TYPE_CHECKING: + from pychunkedgraph.graph.chunkedgraph import ChunkedGraph logger = get_logger(__name__) -def _get_whole_sv( - cg: ChunkedGraph, node: basetypes.NODE_ID, min_coord, max_coord -) -> set: - all_chunks = [ - (x, y, z) - for x in range(min_coord[0], max_coord[0]) - for y in range(min_coord[1], max_coord[1]) - for z in range(min_coord[2], max_coord[2]) - ] - edges = get_chunk_edges(cg.meta.data_source.EDGES, all_chunks) - cx_edges = edges["cross"].get_pairs() - if len(cx_edges) == 0: - return {node} - - explored_nodes = set([node]) - queue = deque([node]) - while queue: - vertex = queue.popleft() - mask = cx_edges[:, 0] == vertex - neighbors = cx_edges[mask][:, 1] - - if len(neighbors) > 0: - neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) - min_mask = (neighbor_coords >= min_coord).all(axis=1) - max_mask = (neighbor_coords < max_coord).all(axis=1) - neighbors = neighbors[min_mask & max_mask] - - for neighbor in neighbors: - if neighbor not in explored_nodes: - explored_nodes.add(neighbor) - queue.append(neighbor) - return explored_nodes - - -def _update_chunks(cg, chunks_bbox_map, seg, result_seg, bb_start): +@dataclass +class SvSplitTask: + """One SV-split task per cross-chunk rep. + + Produced by `plan_sv_splits` (pure, no IO), consumed by + `split_supervoxel`. `src_mask`/`sink_mask` are positional masks + back into the caller's `source_ids`/`sink_ids` arrays so the + aggregator can splice the per-task fresh IDs in at the right + positions. + """ + + sv_id: int + src_coords: np.ndarray + sink_coords: np.ndarray + src_mask: np.ndarray + sink_mask: np.ndarray + bbs: np.ndarray + bbe: np.ndarray + + +@dataclass +class SvSplitOutcome: + """Output of `split_supervoxel` for one task. Aggregated into + `SplitResult` by `split_supervoxels`.""" + + seg_bbox: Tuple[np.ndarray, np.ndarray] + src_new_ids: np.ndarray + sink_new_ids: np.ndarray + # Per-chunk OCDBT write payloads for this task. + seg_write_pairs: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] + bigtable_rows: list + + +@dataclass +class SplitResult: + """Pure planner output of `split_supervoxels`. + + The caller (`MulticutOperation._apply`) performs the actual writes + under the L2 chunk locks: + - `seg_writes` is fed to `write_seg_chunks` as one flat parallel batch. + - `bigtable_rows` is written via `cg.client.write` in one batch. + """ + + seg_bboxes: List[Tuple[np.ndarray, np.ndarray]] + source_ids_fresh: np.ndarray + sink_ids_fresh: np.ndarray + # Flat list across all tasks: (voxel_slices, data_block) per OCDBT + # chunk write. `voxel_slices` is a 3-tuple of `slice` objects; the + # caller appends the channel slice and writes to `meta.ws_ocdbt`. + seg_writes: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] + bigtable_rows: list + + +def _coords_bbox( + cg: "ChunkedGraph", + src_coords_rep: np.ndarray, + sink_coords_rep: np.ndarray, +) -> tuple: + """Base-voxel bbox covering the user's source/sink seeds plus a margin. + + The cut surface lives between the user-placed source and sink + voxels; voxels of the rep that are far from those seeds never + contribute to the cut. So the read region is the seeds' envelope, + not the rep's full chunk envelope — for a physical SV cut into many + pieces across chunks, this can be orders of magnitude smaller. + + The margin is one CG chunk on each side. It matches the existing + L2 chunk lock margin and the 1-voxel shell read in + `split_supervoxel`, and gives `split_supervoxel_helper` headroom + around the seeds for the cut surface to travel along the SV. + + Pieces of the rep that fall outside the bbox keep their existing + IDs — they aren't read here and aren't rewritten. Cross-chunk-edge + routing for boundary-adjacent pieces is handled by the 1-voxel + shell at read time; cross-chunk edges entirely between unsplit + pieces don't change because their IDs don't change. + """ + coords = np.concatenate([src_coords_rep, sink_coords_rep], axis=0) + margin = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + bbs = np.clip(coords.min(axis=0) - margin, vol_start, vol_end) + bbe = np.clip(coords.max(axis=0) + margin, vol_start, vol_end) + return bbs, bbe + + +def _l2_chunks_for_splits(cg: "ChunkedGraph", per_rep_bboxes: list) -> list[int]: + """Layer-2 chunk IDs every rep's split will read or write. + + Reads extend 1 voxel past `[bbs, bbe]` so `update_edges` has anchor + voxels for cross-chunk neighbors; the lock must cover those neighbor + chunks too, hence the `bbs - 1` / `bbe + 1` expansion. Clipped to + volume bounds so a bbox on the volume edge doesn't enumerate phantom + negative-index chunks. Sorted for deterministic lock-acquire order + (L2ChunkLock relies on sorted input for deadlock avoidance). + """ + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + chunk_size = cg.meta.graph_config.CHUNK_SIZE + chunk_coords = set() + for bbs, bbe in per_rep_bboxes: + read_lo = np.clip(bbs - 1, vol_start, vol_end) + read_hi = np.clip(bbe + 1, vol_start, vol_end) + chunk_coords.update( + chunks_overlapping_bbox(read_lo, read_hi, chunk_size).keys() + ) + return sorted( + int(cg.get_chunk_id(layer=2, x=x, y=y, z=z)) for (x, y, z) in chunk_coords + ) + + +def _overlapping_reps( + *, + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, +): + """Yield per-rep data for every rep that links source and sink. + + A rep is a cross-chunk-representative SV shared by at least one + source and one sink in `sv_remapping`. These are the SVs that must + be split before the multicut can partition source from sink. + + Yields `(sv_id, src_coords_rep, sink_coords_rep, src_mask, sink_mask)`: + sv_id — one of the rep's source SV IDs, used as the + seed for `split_supervoxel`. + src_coords_rep — slice of source_coords whose SV maps to this rep. + sink_coords_rep — slice of sink_coords whose SV maps to this rep. + src_mask — positional boolean mask over source_ids; the + caller uses it to splice per-rep results back + into the full source arrays. + sink_mask — same, for sink_ids. + + Keyword-only signature — positional source/sink args of the same + shape are easy to swap without noticing. + """ + sources_remapped = fastremap.remap( + source_ids, sv_remapping, preserve_missing_labels=True, in_place=False + ) + sinks_remapped = fastremap.remap( + sink_ids, sv_remapping, preserve_missing_labels=True, in_place=False + ) + overlap_mask = np.isin(sources_remapped, sinks_remapped) + for rep in np.unique(sources_remapped[overlap_mask]): + src_mask = sources_remapped == rep + sink_mask = sinks_remapped == rep + yield ( + source_ids[src_mask][0], + source_coords[src_mask], + sink_coords[sink_mask], + src_mask, + sink_mask, + ) + + +def plan_sv_splits( + cg: "ChunkedGraph", + *, + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, +) -> Tuple[List[SvSplitTask], list]: + """Compute one `SvSplitTask` per rep and the L2 chunk set the splits + will touch. + + Pure function — no bigtable/OCDBT IO, no locks. Lets the caller + acquire the L2 chunk locks (both temporal and indefinite) around + `split_supervoxels` without recomputing the plan inside. + + Returns `(tasks, chunk_ids)` — `tasks` feeds `split_supervoxels`, + `chunk_ids` is the sorted union of read-expanded L2 chunks the full + operation touches. + """ + tasks: List[SvSplitTask] = [] + for ( + sv_id, + src_coords_rep, + sink_coords_rep, + src_mask, + sink_mask, + ) in _overlapping_reps( + sv_remapping=sv_remapping, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + ): + bbs, bbe = _coords_bbox(cg, src_coords_rep, sink_coords_rep) + tasks.append( + SvSplitTask( + sv_id=sv_id, + src_coords=src_coords_rep, + sink_coords=sink_coords_rep, + src_mask=src_mask, + sink_mask=sink_mask, + bbs=bbs, + bbe=bbe, + ) + ) + chunk_ids = _l2_chunks_for_splits(cg, [(t.bbs, t.bbe) for t in tasks]) + return tasks, chunk_ids + + +def split_supervoxels( + cg: "ChunkedGraph", + *, + tasks: List[SvSplitTask], + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + operation_id: int, + timestamp: datetime = None, +) -> SplitResult: + """Pure planner for the SV-split step. Returns a `SplitResult` with + all the data the caller needs to persist under locks. + + Does **not** write — the caller (`MulticutOperation._apply`) owns + the L2 chunk lock lifecycle and fires the OCDBT + bigtable writes + inside `IndefiniteL2ChunkLock`. + + Must be called inside the caller's `L2ChunkLock` for the + `plan.chunk_ids` set — the seg reads inside `split_supervoxel` need + to be consistent with concurrent writers. + + `timestamp` is the op's logical write time; threaded down to every + `mutate_row` in the persist block so all new-SV cells land at the + same logical time (atomic visibility for `parent_ts`-filtered + readers, and deterministic replay via `override_ts`). + + Fields on the returned `SplitResult`: + seg_bboxes: per-task base-resolution `(bbs, bbe)` — downsample + worker input. + source_ids_fresh / sink_ids_fresh: input `source_ids`/`sink_ids` + with positions touched by an overlap task replaced by the + new SV ID that now lives at that coord. Untouched positions + stay unchanged. Feeds the retry multicut. + seg_writes: flat list of `(voxel_slices, data)` pairs across all + tasks — one tensorstore write per pair, fired in parallel. + bigtable_rows: flattened rows from `copy_parents_and_add_lineage` + + `add_new_edges` across all tasks. + """ + source_ids_fresh = np.asarray(source_ids, dtype=basetypes.NODE_ID).copy() + sink_ids_fresh = np.asarray(sink_ids, dtype=basetypes.NODE_ID).copy() + + seg_bboxes = [] + seg_writes: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] = [] + bigtable_rows: list = [] + for task in tasks: + out = split_supervoxel( + cg, + task, + operation_id, + sv_remapping=sv_remapping, + time_stamp=timestamp, + ) + seg_bboxes.append(out.seg_bbox) + source_ids_fresh[task.src_mask] = out.src_new_ids + sink_ids_fresh[task.sink_mask] = out.sink_new_ids + seg_writes.extend(out.seg_write_pairs) + bigtable_rows.extend(out.bigtable_rows) + return SplitResult( + seg_bboxes=seg_bboxes, + source_ids_fresh=source_ids_fresh, + sink_ids_fresh=sink_ids_fresh, + seg_writes=seg_writes, + bigtable_rows=bigtable_rows, + ) + + +def _update_chunks(cg: "ChunkedGraph", chunks_bbox_map, seg, result_seg, bb_start): """Process all chunks in a single pass: assign new SV IDs to split fragments. - For each chunk overlapping the split bbox, finds split labels and - batch-allocates new IDs. No multiprocessing needed. + Returns `(results, change_chunks)`: + results: per-chunk (indices, old_values, new_values, label_id_map) + tuples; consumed by `_parse_results`. + change_chunks: `(chunk_coord, chunk_bbox)` for the chunks whose + voxels received new SV IDs. `write_seg_chunks` uses this to + rewrite only those chunks (skipping gap chunks that had no + split activity keeps the OCDBT delta proportional to actual + label changes). """ results = [] + change_chunks = [] for chunk_coord, chunk_bbox in chunks_bbox_map.items(): x, y, z = chunk_coord chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) @@ -102,7 +350,8 @@ def _update_chunks(cg, chunks_bbox_map, seg, result_seg, bb_start): _old_values = np.concatenate(_old_values) _new_values = np.concatenate(_new_values) results.append((_indices, _old_values, _new_values, _label_id_map)) - return results + change_chunks.append((chunk_coord, chunk_bbox)) + return results, change_chunks def _voxel_crop(bbs, bbe, bbs_, bbe_): @@ -136,55 +385,62 @@ def _parse_results(results, seg, bbs, bbe): def split_supervoxel( - cg: ChunkedGraph, - sv_id: basetypes.NODE_ID, - source_coords: np.ndarray, - sink_coords: np.ndarray, + cg: "ChunkedGraph", + task: SvSplitTask, operation_id: int, + *, sv_remapping: dict, - verbose: bool = False, time_stamp: datetime = None, -) -> dict[int, set]: - """ - Lookups coordinates of given supervoxel in segmentation. - Finds its counterparts split by chunk boundaries and splits them as a whole. - Updates the segmentation with new IDs. + verbose: bool = False, +) -> SvSplitOutcome: + """Split one cross-chunk-connected SV into connected components. + + `task.bbs` / `task.bbe` are the base-voxel bbox covering the user's + source and sink seeds plus a one-chunk margin — `plan_sv_splits` + pre-computed this via `_coords_bbox`. The bbox is driven by where + the user wants the cut, not by the rep's full chunk envelope; rep + pieces outside the bbox aren't read and keep their existing IDs. + + `time_stamp` is the op's logical write time; threaded through to + `copy_parents_and_add_lineage` + `add_new_edges` so every new-SV + mutation lands at the same timestamp. """ + sv_id = task.sv_id + source_coords = task.src_coords + sink_coords = task.sink_coords + bbs = task.bbs + bbe = task.bbe + vol_start = cg.meta.voxel_bounds[:, 0] vol_end = cg.meta.voxel_bounds[:, 1] - chunk_size = cg.meta.graph_config.CHUNK_SIZE - _coords = np.concatenate([source_coords, sink_coords]) - _padding = np.array([cg.meta.resolution[-1] * 2] * 3) / cg.meta.resolution - - bbs = np.clip((np.min(_coords, 0) - _padding).astype(int), vol_start, vol_end) - bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) - chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) - bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size logger.note(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}") - logger.note(f"chunk and padding {chunk_size}; {_padding}") - logger.note(f"bbox and chunk min max {(bbs, bbe)}; {(chunk_min, chunk_max)}") + logger.note(f"bbox: {(bbs, bbe)}") - t0 = time.time() rep = sv_remapping.get(sv_id, sv_id) - all_svs = np.array( - [sv for sv, r in sv_remapping.items() if r == rep], - dtype=basetypes.NODE_ID, - ) - coords = cg.get_chunk_coordinates_multiple(all_svs) - in_bbox = (coords >= chunk_min).all(axis=1) & (coords < chunk_max).all(axis=1) - cut_supervoxels = set(all_svs[in_bbox].tolist()) - supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) - logger.note( - f"whole sv {sv_id} -> {supervoxel_ids.tolist()} ({time.time() - t0:.2f}s)" - ) + rep_pieces = {int(sv) for sv, r in sv_remapping.items() if r == rep} - # one voxel overlap for neighbors + # one voxel overlap for neighbors — update_edges needs anchor voxels + # from neighboring SVs to route existing cross-chunk edges onto the + # new fragments. bbs_ = np.clip(bbs - 1, vol_start, vol_end) bbe_ = np.clip(bbe + 1, vol_start, vol_end) t0 = time.time() seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() logger.note(f"segmentation read {seg.shape} ({time.time() - t0:.2f}s)") + # Narrow the rep to pieces actually present in the bbox seg. Pieces + # of the rep whose voxels lie outside the seed-driven bbox don't + # appear in `seg` and so don't contribute to `binary_seg` anyway — + # carrying them in `cut_supervoxels` is just log noise plus inflated + # `unsplit` diff churn. + seg_ids = {int(x) for x in fastremap.unique(seg) if x != 0} + cut_supervoxels = rep_pieces & seg_ids + supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) + logger.note( + f"whole sv {sv_id} -> {supervoxel_ids.tolist()} " + f"({len(rep_pieces) - len(cut_supervoxels)} rep pieces outside bbox)" + ) + binary_seg = np.isin(seg, supervoxel_ids) voxel_overlap_crop = _voxel_crop(bbs, bbe, bbs_, bbe_) t0 = time.time() @@ -199,11 +455,12 @@ def split_supervoxel( chunks_bbox_map = chunks_overlapping_bbox(bbs, bbe, cg.meta.graph_config.CHUNK_SIZE) t0 = time.time() - results = _update_chunks( + results, change_chunks = _update_chunks( cg, chunks_bbox_map, seg[voxel_overlap_crop], split_result, bbs ) logger.note( - f"chunk updates {len(chunks_bbox_map)} chunks, {len(results)} with splits ({time.time() - t0:.2f}s)" + f"chunk updates {len(chunks_bbox_map)} chunks, " + f"{len(change_chunks)} with splits ({time.time() - t0:.2f}s)" ) seg_cropped = seg[voxel_overlap_crop].copy() @@ -238,26 +495,62 @@ def split_supervoxel( ) logger.note(f"edge update ({time.time() - t0:.2f}s)") - rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) + rows0 = copy_parents_and_add_lineage( + cg, operation_id, old_new_map, time_stamp=time_stamp + ) rows1 = add_new_edges(cg, edges_tuple, old_new_map, time_stamp=time_stamp) rows = rows0 + rows1 - t0 = time.time() - write_seg(cg.meta, bbs, bbe, new_seg) - cg.client.write(rows) - logger.note(f"write seg + {len(rows)} rows ({time.time() - t0:.2f}s)") - return old_new_map, edges_tuple + # Prepare per-chunk OCDBT write payloads. The caller batches these + # across all tasks into one parallel tensorstore write — no serial + # per-task loop. + seg_write_pairs: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] = [] + for _, chunk_bbox in change_chunks: + lo, hi = chunk_bbox[0], chunk_bbox[1] + local_lo = lo - bbs + local_hi = hi - bbs + data = new_seg[ + local_lo[0] : local_hi[0], + local_lo[1] : local_hi[1], + local_lo[2] : local_hi[2], + ] + voxel_slices = tuple(slice(int(s), int(e)) for s, e in zip(lo, hi)) + seg_write_pairs.append((voxel_slices, data)) + + # Per-coord fresh IDs: bit-identical to what a post-write seg read + # would return — new_seg is what the caller is about to write, and + # the caller holds the L2 chunk lock when it does, so the storage + # round-trip would see the same bytes. + local_src = (np.asarray(source_coords, dtype=int) - bbs).astype(int) + local_sink = (np.asarray(sink_coords, dtype=int) - bbs).astype(int) + src_new_ids = new_seg[tuple(local_src.T)] + sink_new_ids = new_seg[tuple(local_sink.T)] + return SvSplitOutcome( + seg_bbox=(bbs, bbe), + src_new_ids=src_new_ids, + sink_new_ids=sink_new_ids, + seg_write_pairs=seg_write_pairs, + bigtable_rows=rows, + ) def copy_parents_and_add_lineage( - cg: ChunkedGraph, + cg: "ChunkedGraph", operation_id: int, old_new_map: dict, + *, + time_stamp: datetime = None, ) -> list: - """ - Copy parents column from `old_id` to each of `new_ids`. - This makes it easy to get old hierarchy with `new_ids` using an older timestamp. - Link `old_id` and `new_ids` to create a lineage at supervoxel layer. + """Copy parent pointers from old SVs onto their new-ID fragments + and write the lineage (FormerIdentity / NewIdentity) + L2 Child + list updates. + + `time_stamp` is the op's logical write time — used for every new-SV + cell this function writes so a `parent_ts`-filtered reader sees the + op atomically. The Parent-copy and Child-list writes deliberately + preserve the old cell's timestamp (so pre-op readers still see the + old hierarchy via the old timestamp). + Returns a list of mutations to be persisted. """ result = [] @@ -275,7 +568,11 @@ def copy_parents_and_add_lineage( attributes.OperationLogs.OperationID: operation_id, } result.append( - cg.client.mutate_row(serializers.serialize_uint64(new_id), val_dict) + cg.client.mutate_row( + serializers.serialize_uint64(new_id), + val_dict, + time_stamp=time_stamp, + ) ) for cell in parent_cells_map[old_id]: cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) @@ -291,7 +588,11 @@ def copy_parents_and_add_lineage( attributes.Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID) } result.append( - cg.client.mutate_row(serializers.serialize_uint64(old_id), val_dict) + cg.client.mutate_row( + serializers.serialize_uint64(old_id), + val_dict, + time_stamp=time_stamp, + ) ) children_cells_map = cg.client.read_nodes( diff --git a/pychunkedgraph/graph/locks.py b/pychunkedgraph/graph/locks.py index 47a63dacf..bdff4f789 100644 --- a/pychunkedgraph/graph/locks.py +++ b/pychunkedgraph/graph/locks.py @@ -1,6 +1,7 @@ +import hashlib +import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Union -from typing import Sequence +from typing import Sequence, Union from collections import defaultdict import networkx as nx @@ -8,7 +9,7 @@ from pychunkedgraph import get_logger -from . import exceptions +from . import attributes, exceptions, serializers from .types import empty_1d from .lineage import lineage_graph @@ -165,6 +166,14 @@ def __enter__(self): return self def __exit__(self, exception_type, exception_value, traceback): + if exception_type is not None: + # Partial bigtable hierarchy writes may have landed before + # the exception propagated. Keep the indefinite cells held + # so subsequent ops on these roots refuse to acquire — + # forces operator recovery (`repair_operation(..., unlock= + # True)`) rather than letting a silent corruption slip into + # further edits. + return if self.acquired: max_workers = min(8, max(1, len(self.root_ids))) with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -181,3 +190,382 @@ def __exit__(self, exception_type, exception_value, traceback): future.result() except Exception as e: logger.warning(f"Failed to unlock root: {e}") + + +def _downsample_block_lock_row_key(block_coord) -> bytes: + """Row key for one pyramid_block's downsample lock cell. + + Hash-prefixed so spatially-clustered block coords — common when a + team edits the same region — scatter across bigtable tablets instead + of piling up in one lexicographic range, which would hot-spot a + single tablet under concurrent load. + + 26 bytes total: + - 2-byte blake2b hash of the packed coord (tablet distribution). + - 24 bytes of packed coord (big-endian uint64 per axis). + uint64 per axis tracks the existing node-id width and puts no cap on + the block grid. The full coord in the key guarantees uniqueness even + if two coords share the 2-byte hash prefix. + """ + bx, by, bz = (int(c) for c in block_coord) + packed = ( + bx.to_bytes(8, "big", signed=False) + + by.to_bytes(8, "big", signed=False) + + bz.to_bytes(8, "big", signed=False) + ) + return hashlib.blake2b(packed, digest_size=2).digest() + packed + + +class DownsampleBlockLock: + """Lock a set of pyramid_blocks for the lifetime of a downsample task. + + The downsample worker holds one across read → tinybrain → write for + every block it touches. All-or-nothing: on partial acquisition we + release what we got and retry with backoff; on repeated failure we + raise so the pubsub message ends up un-acked and redelivered. + + Uses `cg.client.lock_by_row_key` with hash-prefixed row keys — the + generic row-key lock primitive in kvdbclient — so these rows never + collide with node-id-keyed root locks even though both use the same + `Concurrency.Lock` column. + """ + + __slots__ = ["cg", "block_coords", "operation_id", "acquired_keys"] + + # Retry budget for partial-acquire failures. Each attempt releases + # anything it got in the previous pass, then re-acquires from scratch. + _MAX_ACQUIRE_ATTEMPTS = 7 + _ACQUIRE_BACKOFF_BASE_SEC = 0.5 + + def __init__( + self, + cg, + block_coords: Sequence, + operation_id: np.uint64, + ) -> None: + self.cg = cg + # Sort so every `__enter__` uses a consistent acquisition order + # across workers — reduces contention between workers whose block + # sets overlap. Sort is on the coord tuple (not the hashed row + # key) so the order is stable and debuggable. + self.block_coords = sorted( + (int(bx), int(by), int(bz)) for bx, by, bz in block_coords + ) + self.operation_id = np.uint64(operation_id) + self.acquired_keys: list = [] + + def __enter__(self): + for attempt in range(self._MAX_ACQUIRE_ATTEMPTS): + self.acquired_keys = [] + all_ok = True + for coord in self.block_coords: + row_key = _downsample_block_lock_row_key(coord) + if self.cg.client.lock_by_row_key(row_key, self.operation_id): + self.acquired_keys.append(row_key) + else: + all_ok = False + break + if all_ok: + return self + self._release_acquired() + time.sleep(self._ACQUIRE_BACKOFF_BASE_SEC * (2**attempt)) + raise exceptions.LockingError( + f"Could not acquire downsample block locks for coords " + f"{self.block_coords} after {self._MAX_ACQUIRE_ATTEMPTS} attempts" + ) + + def __exit__(self, exception_type, exception_value, traceback): + self._release_acquired() + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_by_row_key, key, self.operation_id + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock downsample block: {e}") + self.acquired_keys = [] + + def renew(self) -> bool: + """Extend expiry on every held lock. Returns False if any failed.""" + ok = True + for key in self.acquired_keys: + if not self.cg.client.renew_lock_by_row_key(key, self.operation_id): + logger.warning(f"Failed to renew downsample block lock {key!r}") + ok = False + return ok + + +def _l2_chunk_lock_row_key(chunk_id) -> bytes: + """Row key for one L2 chunk's spatial lock cell. + + Hash-prefixed so spatially-clustered chunk IDs scatter across + bigtable tablets instead of piling up in one lexicographic range, + which would hot-spot a single tablet under concurrent load. + + 10 bytes total: + - 2-byte blake2b hash of the chunk_id (tablet distribution). + - 8 bytes of big-endian uint64 chunk_id. + chunk_id already encodes layer+xyz in its bits, so the full key is + unique per L2 chunk. + """ + packed = int(chunk_id).to_bytes(8, "big", signed=False) + return hashlib.blake2b(packed, digest_size=2).digest() + packed + + +class L2ChunkLock: + """Lock a set of L2 chunks to serialize SV splits that touch them. + + Closes the cross-root spatial race: two SV splits on overlapping L2 + chunks but distinct roots acquire disjoint root-lock sets and would + otherwise race on seg state. This lock is held across the + `split_supervoxel` loop (seg write + SV-level hierarchy row write) + so the pair commits atomically. + + All-or-nothing: on partial acquisition we release what we got and + retry with backoff; on repeated failure we raise `LockingError`. + + Uses `cg.client.lock_by_row_key` — the generic row-key lock in + kvdbclient — with a row-key namespace distinct from root and + downsample block locks (all three share `attributes.Concurrency.Lock` + under the hood; the row key disambiguates). + """ + + __slots__ = [ + "cg", + "chunk_ids", + "operation_id", + "privileged_mode", + "acquired_keys", + ] + + # Retry budget for partial-acquire failures. Each attempt releases + # anything it got in the previous pass, then re-acquires from scratch. + _MAX_ACQUIRE_ATTEMPTS = 7 + _ACQUIRE_BACKOFF_BASE_SEC = 0.5 + + def __init__( + self, + cg, + chunk_ids: Sequence[int], + operation_id: np.uint64, + *, + privileged_mode: bool = False, + ) -> None: + self.cg = cg + # Sort so every `__enter__` uses a consistent acquisition order + # across workers — reduces contention when overlapping lock sets + # would otherwise race AB/BA. + self.chunk_ids = sorted(int(c) for c in chunk_ids) + self.operation_id = np.uint64(operation_id) + self.privileged_mode = privileged_mode + self.acquired_keys: list = [] + + def __enter__(self): + if self.privileged_mode: + # Replay path: the crashed op's `IndefiniteL2ChunkLock` cells + # are still set on these chunks (that's what's blocking new + # ops), and `lock_by_row_key_with_indefinite` would refuse. + # Mirror `RootLock`/`IndefiniteRootLock`'s privileged escape + # hatch — skip temporal acquire, the indefinite cells are + # our de-facto lock and they'll be released by the inner + # `IndefiniteL2ChunkLock(privileged_mode=True)` on exit. + return self + for attempt in range(self._MAX_ACQUIRE_ATTEMPTS): + self.acquired_keys = [] + all_ok = True + for chunk_id in self.chunk_ids: + row_key = _l2_chunk_lock_row_key(chunk_id) + # `_with_indefinite`: the temporal acquire must also + # refuse if the indefinite column is set. Closes the + # crash-recovery race — a worker that died holding + # `IndefiniteL2ChunkLock` leaves the indefinite cell + # set, and the next op must see it rather than silently + # racing into partial state. + if self.cg.client.lock_by_row_key_with_indefinite( + row_key, self.operation_id + ): + self.acquired_keys.append(row_key) + else: + all_ok = False + break + if all_ok: + return self + self._release_acquired() + time.sleep(self._ACQUIRE_BACKOFF_BASE_SEC * (2**attempt)) + raise exceptions.LockingError( + f"Could not acquire L2 chunk locks for chunks {self.chunk_ids} " + f"after {self._MAX_ACQUIRE_ATTEMPTS} attempts" + ) + + def __exit__(self, exception_type, exception_value, traceback): + self._release_acquired() + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_by_row_key, key, self.operation_id + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock L2 chunk: {e}") + self.acquired_keys = [] + + def renew(self) -> bool: + """Extend expiry on every held lock. Returns False if any failed.""" + ok = True + for key in self.acquired_keys: + if not self.cg.client.renew_lock_by_row_key(key, self.operation_id): + logger.warning(f"Failed to renew L2 chunk lock {key!r}") + ok = False + return ok + + +class IndefiniteL2ChunkLock: + """Upgrade held-temporal L2 chunk locks to indefinite. + + Structurally mirrors `IndefiniteRootLock`: acquired inside the + temporal lock (`L2ChunkLock`) context after preconditions are + established, and held across the write phase. Doesn't expire — the + cell persists on bigtable until explicitly released (or operator + recovery clears it), so a worker that dies with writes in flight + leaves the chunks marked indefinitely-held. + + The temporal `L2ChunkLock` must already be held by the same + `operation_id`; the acquire filter for temporal now rejects on + indefinite cells, so future temporal acquires on these chunks + refuse until this lock is released. + + Durable scope: `__enter__` writes `chunk_ids` to the op-log row's + `OperationLogs.L2ChunkLockScope` column. This persists through a + worker crash, giving `stuck_ops replay` the exact chunk set to + clean up without a bigtable-wide lock-row scan. + + `privileged_mode=True` is the operator recovery escape hatch: + skips the acquire step (the cells already exist, held by this same + op_id from the crashed attempt), pre-populates `acquired_keys` from + `chunk_ids` so `__exit__` still value-matches-releases those cells, + and does not re-write the op-log scope column. + """ + + __slots__ = ["cg", "chunk_ids", "operation_id", "privileged_mode", "acquired_keys"] + + def __init__( + self, + cg, + chunk_ids: Sequence[int], + operation_id: np.uint64, + *, + privileged_mode: bool = False, + ) -> None: + self.cg = cg + self.chunk_ids = sorted(int(c) for c in chunk_ids) + self.operation_id = np.uint64(operation_id) + self.privileged_mode = privileged_mode + self.acquired_keys: list = [] + + def __enter__(self): + if self.privileged_mode: + # Recovery path: crashed op's indefinite cells already exist + # under this op_id. Populate acquired_keys so __exit__'s + # value-matched release deletes them after the replay writes + # succeed. + self.acquired_keys = [_l2_chunk_lock_row_key(c) for c in self.chunk_ids] + return self + for chunk_id in self.chunk_ids: + row_key = _l2_chunk_lock_row_key(chunk_id) + if not self.cg.client.lock_by_row_key_indefinitely( + row_key, self.operation_id + ): + # Partial acquire: release what we got and fail. No + # retry — an indefinite cell belongs to a currently- + # running or crashed op and won't clear on its own. + self._release_acquired() + raise exceptions.LockingError( + f"Could not upgrade L2 chunk {chunk_id} to indefinite lock " + f"(another op holds it)" + ) + self.acquired_keys.append(row_key) + self._write_scope_to_op_log() + return self + + def __exit__(self, exception_type, exception_value, traceback): + if exception_type is not None: + # Partial OCDBT seg / bigtable SV-hierarchy writes may have + # landed before the exception propagated. Leave the + # indefinite cells held and the op-log scope intact so + # subsequent ops refuse at `L2ChunkLock` acquire — forces + # operator recovery (`stuck_ops replay`) rather than + # leaking orphan SV IDs into downstream reads. + return + self._release_acquired() + self._clear_scope_on_op_log() + + def _write_scope_to_op_log(self): + """Record the chunk scope on the op-log row before seg/bigtable + writes begin. A worker crash after this point leaves both the + per-chunk indefinite cells AND this field set, so recovery can + locate the partial-write region without a bigtable scan. + """ + row_key = serializers.serialize_uint64(self.operation_id) + scope = np.asarray(self.chunk_ids, dtype=np.uint64) + entry = self.cg.client.mutate_row( + row_key, + {attributes.OperationLogs.L2ChunkLockScope: scope}, + ) + self.cg.client.write([entry]) + + def _clear_scope_on_op_log(self): + """Clear the scope record on normal exit — op completed or was + cleanly rolled back, so no partial state needs recovery. Overwrites + with an empty array; a subsequent `read_log_entries` returns an + empty scope (recovery skips). Best-effort; failures here are + logged but don't propagate. + """ + try: + row_key = serializers.serialize_uint64(self.operation_id) + empty = np.array([], dtype=np.uint64) + entry = self.cg.client.mutate_row( + row_key, + {attributes.OperationLogs.L2ChunkLockScope: empty}, + ) + self.cg.client.write([entry]) + except Exception as e: + logger.warning(f"Failed to clear L2ChunkLockScope on op-log row: {e}") + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_indefinitely_locked_by_row_key, + key, + self.operation_id, + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock indefinite L2 chunk: {e}") + self.acquired_keys = [] diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 2d2d1d289..d4331f3aa 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -8,7 +8,14 @@ import numpy as np from cloudvolume import CloudVolume -from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt +from pychunkedgraph.graph.ocdbt import ( + OcdbtConfig, + build_cg_ocdbt_spec, + ensure_fork_synced, + fork_exists, + get_seg_source_and_destination_ocdbt, + read_populate_meta, +) from .utils.generic import compute_bitmasks from .chunks.utils import get_chunks_boundary @@ -50,6 +57,29 @@ ) +def _redis_cached_json(key: str, loader): + """Return JSON-decoded value at ``key`` in Redis, or call ``loader()`` and + write the result through. Spares distributed workers from re-fetching the + same GCS object on every CG instantiation. Silently bypasses Redis if it + is unreachable; returns ``loader()`` directly in that case. + """ + redis = None + try: + redis = get_redis_connection() + cached = redis.get(key) + if cached is not None: + return json.loads(cached) + except Exception: + redis = None + value = loader() + if value is not None and redis is not None: + try: + redis.set(key, json.dumps(value)) + except Exception: + ... + return value + + class ChunkedGraphMeta: def __init__( self, graph_config: GraphConfig, data_source: DataSource, custom_data: Dict = {} @@ -71,6 +101,7 @@ def __init__( self._layer_count = None self._bitmasks = None self._ocdbt_seg = None + self._ocdbt_config_cached = None @property def graph_id(self): @@ -93,28 +124,57 @@ def custom_data(self): def ws_cv(self): if self._ws_cv: return self._ws_cv + ws = self._data_source.WATERSHED + info = _redis_cached_json( + f"ws_cv_info_cached:{ws}", + lambda: CloudVolume(ws, progress=False).info, + ) + self._ws_cv = CloudVolume(ws, info=info, progress=False) + return self._ws_cv - cache_key = f"{self.graph_config.ID}:ws_cv_info_cached" - try: - # try reading a cached info file for distributed workers - # useful to avoid md5 errors on high gcs load - redis = get_redis_connection() - cached_info = json.loads(redis.get(cache_key)) - self._ws_cv = CloudVolume( - self._data_source.WATERSHED, info=cached_info, progress=False + @property + def ocdbt_config(self) -> OcdbtConfig: + """Per-CG OCDBT settings with precedence info-file > custom_data > defaults. + + The watershed's ``/ocdbt/.populated/meta.json`` is the authoritative + on-disk source for fields that affect the OCDBT format (compression, + max_inline_value_bytes, populate_layer). custom_data fills per-CG + fields (enabled, sv_split_threshold) and anything the info file + doesn't pin. Both layers fall through to dataclass defaults. + + The info-file fetch goes through a Redis cache (same pattern as + ``ws_cv``) so distributed workers don't re-read the same GCS + object on every CG instantiation. Result is also cached in + instance state after first access. Legacy ``custom_data["seg"]`` + shape is read when ``"ocdbt_config"`` is absent so pre-refactor + CGs still open. + """ + if self._ocdbt_config_cached is not None: + return self._ocdbt_config_cached + + meta_d = self._custom_data.get("ocdbt_config") + if meta_d is None: + seg = self._custom_data.get("seg", {}) + meta_d = { + "enabled": bool(seg.get("ocdbt", False)), + "sv_split_threshold": int(seg.get("sv_split_threshold", 10)), + } + + info_d = None + ws = self._data_source.WATERSHED + if ws: + info_d = _redis_cached_json( + f"ocdbt_info_cached:{ws}", + lambda: read_populate_meta(ws), ) - except Exception: - self._ws_cv = CloudVolume(self._data_source.WATERSHED, progress=False) - try: - redis.set(cache_key, json.dumps(self._ws_cv.info)) - except Exception: - ... - return self._ws_cv + + self._ocdbt_config_cached = OcdbtConfig.resolve(meta_d, info_d) + return self._ocdbt_config_cached @property def ocdbt_seg(self) -> bool: if self._ocdbt_seg is None: - self._ocdbt_seg = self._custom_data.get("seg", {}).get("ocdbt", False) + self._ocdbt_seg = self.ocdbt_config.enabled return self._ocdbt_seg @property @@ -131,9 +191,19 @@ def ws_ocdbt_scales(self): """ assert self.ocdbt_seg, "make sure this pcg has segmentation in ocdbt format" if self._ws_ocdbt_scales is None: + ws = self.data_source.WATERSHED + assert fork_exists(ws, self.graph_id), ( + f"ocdbt fork missing at {ws}/ocdbt/{self.graph_id}/ — " + "create it via fork_base_manifest or the seg_ocdbt notebook" + ) + # Refresh the fork manifest from base if it's stale and edit-free. + # See ensure_fork_synced docstring; without this, post-fork-creation + # populate writes to base are invisible through the kvstack view + # and reads return zeros. + ensure_fork_synced(ws, self.graph_id) _, self._ws_ocdbt_scales, self._ws_ocdbt_resolutions = ( get_seg_source_and_destination_ocdbt( - self.data_source.WATERSHED, self.graph_id + ws, self.graph_id, self.ocdbt_config ) ) return self._ws_ocdbt_scales @@ -272,7 +342,7 @@ def READ_ONLY(self): @property def sv_split_threshold(self) -> int: - return self._custom_data.get("seg", {}).get("sv_split_threshold", 10) + return self.ocdbt_config.sv_split_threshold @property def split_bounding_offset(self): @@ -296,12 +366,22 @@ def dataset_info(self) -> Dict: "n_layers": self.layer_count, "spatial_bit_masks": self.bitmasks, "ocdbt_seg": self.ocdbt_seg, - # Per-CG delta OCDBT path. Neuroglancer must open this - # via the kvstack spec from build_cg_ocdbt_spec() to see - # both base + delta data. Opening it as plain OCDBT only - # sees the delta. - "ocdbt_path": ( - f"ocdbt/{self.graph_id}" if self._graph_config.ID else None + # Full kvstore spec a reader hands to tensorstore's + # `neuroglancer_precomputed` driver. Server owns the + # contract — paths, data prefixes, and OCDBT config + # (e.g. `max_inline_value_bytes`) are all resolved + # here, so readers don't duplicate configuration and + # future schema changes are picked up on re-fetch. + # Readers pass this verbatim as `kvstore`; add a + # `version` field for time-travel reads. + "ocdbt_kvstore_spec": ( + build_cg_ocdbt_spec( + self._data_source.WATERSHED, + self.graph_id, + self.ocdbt_config, + ) + if self.ocdbt_seg and self._graph_config.ID + else None ), }, } diff --git a/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md b/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md new file mode 100644 index 000000000..b47f47c47 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md @@ -0,0 +1,134 @@ +# tensorstore OCDBT reference + +Every entry below was verified by probing tensorstore directly (intentional-bad-value + spec round-trip) against the binary in this workspace's venv. Re-verify if the tensorstore version changes. + +## OCDBT kvstore spec — top-level fields + +Sibling of `driver: "ocdbt"`: + +| Field | Type | Default | Notes | +|---|---|---|---| +| `base` | kvstore spec or URL | — | underlying kvstore (gcs/file/s3/…) | +| `manifest` | kvstore spec or URL | (under `base`) | the manifest *can* live in a separate kvstore from data | +| `config` | object | `{}` | see Config sub-fields below | +| `assume_config` | bool | `false` | skip reading the existing config from the manifest (use with care) | +| `coordinator` | ocdbt_coordinator resource | named ref `"ocdbt_coordinator"` | enables distributed mode when set | +| `cache_pool` | cache_pool resource | named ref `"cache_pool"` | | +| `data_copy_concurrency` | data_copy_concurrency resource | named ref | | +| `target_data_file_size` | uint64 | driver default | when a single commit's d/ writes exceed this, the writer rolls a new d/ file | +| `experimental_read_coalescing_threshold_bytes` | uint64 | — | | +| `experimental_read_coalescing_merged_bytes` | uint64 | — | | +| `experimental_read_coalescing_interval` | uint64 | — | | +| `btree_node_data_prefix` | string | `"d/"` | path prefix for btree-node files | +| `value_data_prefix` | string | `"d/"` | path prefix for value files | +| `version_tree_node_data_prefix` | string | `"d/"` | path prefix for version-tree files | +| `path` | string | `""` | sub-prefix in the kvstore | + +**Not fields**: `data_file_prefixes`, `version_spec`, `recheck_cached*`, `transaction`, `btree_writer_concurrency`, `manifest_kind` (lives under `config`). + +## OCDBT `config` sub-fields + +| Field | Type | tensorstore default | Notes | +|---|---|---|---| +| `compression` | object | `{}` (none) | `{"id": "zstd", "level": N}` — zstd level 1–22 | +| `max_inline_value_bytes` | uint64 | `100` | values ≤ this size live inline in the btree leaf bytes; larger values get written to a d/ file and the mutation carries only an `IndirectDataReference`. In distributed mode this **directly bounds cooperator-forwarded RPC size**: inline values are carried inside the `WriteRequest.mutations` field, so a leaf's batch blows past the 4 MiB gRPC max-receive whenever multiple inline values pile up on one node. Source: `distributed/btree_writer.cc` `StagePending`. Setting low (≤ a few KB) pushes chunk values out-of-line → small mutations → small RPCs. | +| `max_decoded_node_bytes` | uint64 | `8388608` (8 MiB) | btree node split threshold. Larger nodes → shallower tree → fewer per-commit node touches. Setting this *smaller* than the default INCREASES per-commit forwarded bytes — empirically went from ~8 MiB to ~23 MiB RPCs when set to 1 MiB. | +| `version_tree_arity_log2` | int | — | controls version tree branching; rarely tuned | +| `manifest_kind` | enum | `"single"` | `"single"` or `"numbered"` (manifest history retained — needed for time-travel reads) | +| `uuid` | string | (auto) | 32-hex per-base UUID assigned at create time | + +**Not fields**: `data_file_prefixes`, `data_file_prefix`, `btree_node_arity_log2`, `version_tree_node_arity`. + +## `ocdbt_coordinator` context resource + +| Field | Type | Default | Notes | +|---|---|---|---| +| `address` | string | — | `"host:port"` of the DistributedCoordinatorServer | +| `lease_duration` | duration string (`"1s"`, `"500ms"`, etc.) | — | how long a lease holder owns a btree node | +| `security` | object | `{method: "insecure"}` | requires `method` key. This build has **no** security methods registered (build flag) — all calls cleartext. | + +## `DistributedCoordinatorServer({...})` + +| Field | Type | Default | Notes | +|---|---|---|---| +| `bind_addresses` | list[string] | one ephemeral port | gRPC server bind address(es). `.port` after construction gives the ephemeral port. | +| `security` | object | insecure | same shape as the resource's security | + +**There is NO Python knob for the gRPC server's max-receive message size.** The 4 MiB default is set inside tensorstore's gRPC server builder. Confirmed by strings on the binary: no `TENSORSTORE_*` env var, no spec/resource field, no Context resource that maps to `grpc.max_receive_message_length`. + +## Distributed vs non-distributed write paths + +The OCDBT driver picks one of two compiled implementations at open time: + +- **non-distributed** (`btree_writer.cc`): coordinator absent. Each commit writes the manifest itself. Concurrent writers race the manifest CAS; losers retry; their pre-commit d/ writes become orphans. +- **distributed** (`distributed/btree_writer.cc`, `cooperator_*.cc`): coordinator present. One lease holder per btree node serializes commits. Other cooperators **forward their mutations over gRPC** to the lease holder. + +### Constraints unique to distributed mode + +1. **`ts.Transaction(atomic=True)` is incompatible.** "Cannot read/write … as single atomic transaction" — verified on (info + chunk) and on (cross-key). A plain `ts.Transaction()` still batches all writes into one OCDBT commit; only the *atomicity* across keys is lost. +2. **Cooperator-forwarded RPC ≤ ~4 MiB.** Carries (btree node delta) + (value bytes for keys committed into that node). +3. **Disjoint user-key writes still trigger forwarding.** Leases are per btree node, not per user-key range. Two workers writing distinct keys into the same node → one forwards to the other. + +## Cooperator batching + +`cooperator_submit_mutation_batch.cc` `SendToPeer` is the gRPC sender. The `WriteRequest` proto has `repeated bytes mutations` — each entry is one encoded `BtreeNodeWriteMutation` destined for the same leaf. The encoded mutation embeds the value_reference inline if it's an `absl::Cord`, or carries just an `IndirectDataReference` (small struct) otherwise. So **what's actually on the wire per RPC = (small request header) + Σ encoded mutations**, and each encoded mutation's size is dominated by its value bytes IF the value is inline. + +Threshold for inline-vs-ref is `max_inline_value_bytes` (see config table). That's the real lever for RPC size. + +What changes RPC size (verified by production dumps): +- `max_inline_value_bytes=1 MiB`, default node bytes → RPCs 5–8 MiB (inline chunks pile up in the batch) +- `max_inline_value_bytes=1 MiB` + `max_decoded_node_bytes=1 MiB` → RPCs up to 23 MiB (smaller nodes ≠ smaller RPCs) +- `max_inline_value_bytes=1 MiB` + dst `chunk_size` halved → RPCs grew to 12 MiB (more mutations per node → bigger batches) +- `max_inline_value_bytes=4 KiB` (chunks go out-of-line) → mutations carry only refs; RPC = small header + N×(key + ref + generation) → fits 4 MiB regardless of value sizes (this is the path our code takes) + +## Defaults visible from spec round-trip + +```json +{ + "assume_config": false, + "btree_node_data_prefix": "d/", + "config": {}, + "coordinator": "ocdbt_coordinator", + "cache_pool": "cache_pool", + "data_copy_concurrency": "data_copy_concurrency", + "value_data_prefix": "d/", + "version_tree_node_data_prefix": "d/" +} +``` + +## Env vars + +- `OCDBT_COORDINATOR_HOST`, `OCDBT_COORDINATOR_PORT`: **NO EFFECT**. Not referenced anywhere in the binary. Address must go in spec's `coordinator.address`. +- `TENSORSTORE_VERBOSE_LOGGING`: comma-separated tag list to stderr. Tags include `ocdbt`, `coordinator`. + +Other `TENSORSTORE_*` vars exist (CA paths, S3/GCS concurrency, etc.) — grep the binary. + +## On-disk layout + +- `manifest.ocdbt` at the base — root btree node + current data file refs. +- `d/` — directory of "data files" each holding concatenated values + (optionally) btree node bytes + version-tree node bytes. +- Each commit creates **at least one** d/ file holding all values + nodes for that commit, then a CAS-update of `manifest.ocdbt`. +- `target_data_file_size` controls when a single commit splits its d/ writes across files. + +## How this maps onto pychunkedgraph + +- `OcdbtConfig` (`pychunkedgraph/graph/ocdbt/meta.py`) → `compression: zstd 12`, `max_inline_value_bytes = 4 KiB`. The 4 KiB threshold keeps small metadata (info JSON, populate markers) inline while forcing every chunk value out-of-line into d/ files — this is what keeps cooperator RPCs under the 4 MiB gRPC ceiling. +- `create_base_ocdbt` / `open_base_ocdbt` pass `config.ts_config()` so the same OCDBT config persists across opens. +- `populate_chunk` (`pychunkedgraph/ingest/ocdbt.py`) opens the base with `coordinator_address` (distributed mode). +- `copy_ws_bbox_multiscale` uses **non-atomic** `ts.Transaction()` because of the distributed-mode constraint above. +- `_dump_failure_to_gcs` writes JSON failure forensics when `ERROR_DUMP` env is set. + +## Empirically tried and ruled out + +- `OCDBT_COORDINATOR_HOST/PORT` env vars — no effect. +- Bumping gRPC max-receive via env / channel arg / spec field — no such knob. +- Smaller `dst chunk_size` alone — RPC size grew (more mutations per node). +- Smaller `max_decoded_node_bytes` alone — RPC size grew (more per-commit node touches). +- `--ocdbt-edges` legacy path — decommissioned, removed. +- `ts.Transaction(atomic=True)` with distributed coordinator — incompatible. + +## Open observations (not verified at production scale) + +- `lease_duration` may reduce cross-cooperator forwarding if held long enough that a worker's whole task lands on its own nodes. +- `target_data_file_size` may affect manifest growth but not RPC size. +- Switching dst encoding from `compressed_segmentation` to `raw` would make per-value size predictable (`chunk_volume × bytes_per_voxel`), bypassing the dense-region pathological CS encoding (one observed key encoded to 23 MiB at 256×256×64). diff --git a/pychunkedgraph/graph/ocdbt/__init__.py b/pychunkedgraph/graph/ocdbt/__init__.py new file mode 100644 index 000000000..633964e1c --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/__init__.py @@ -0,0 +1,57 @@ +"""Public API for the OCDBT-backed segmentation store. + +See ``main.py`` for the architectural notes. This module just re-exports +the names that external callers (ingest, edits, runtime, tests) reach for. +""" + +from .meta import OcdbtConfig +from .utils import ( + _layer_bbox, + _read_source_scales, + base_exists, + fork_exists, + is_chunk_populated, + mark_chunk_populated, + read_populate_meta, + write_populate_meta, +) +from .main import ( + _mode_downsample, + build_cg_ocdbt_spec, + copy_ws_bbox_multiscale, + copy_ws_chunk, + copy_ws_chunk_multiscale, + create_base_ocdbt, + ensure_fork_synced, + fork_base_manifest, + get_seg_source_and_destination_ocdbt, + open_base_ocdbt, + propagate_to_coarser_scales, + wipe_base_ocdbt, + write_seg_chunks, +) + +__all__ = [ + "OcdbtConfig", + "_layer_bbox", + "_mode_downsample", + "_read_source_scales", + "base_exists", + "build_cg_ocdbt_spec", + "copy_ws_bbox_multiscale", + "copy_ws_chunk", + "copy_ws_chunk_multiscale", + "create_base_ocdbt", + "ensure_fork_synced", + "fork_base_manifest", + "fork_exists", + "get_seg_source_and_destination_ocdbt", + "is_chunk_populated", + "mark_chunk_populated", + "open_base_ocdbt", + "propagate_to_coarser_scales", + "read_populate_meta", + "wipe_base_ocdbt", + "write_populate_meta", + "write_seg_chunks", +] diff --git a/pychunkedgraph/graph/ocdbt/debug.py b/pychunkedgraph/graph/ocdbt/debug.py new file mode 100644 index 000000000..b64376695 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/debug.py @@ -0,0 +1,143 @@ +"""Diagnostic plumbing for OCDBT failures. + +Humanize-count for log lines, generic failure envelope (host/pod/versions/ +traceback/timestamp), bbox-failure payload builder, and a GCS dump helper +that writes per-task forensic JSON under ``$ERROR_DUMP/__.json``. +Kept out of ``main.py`` and ``utils.py`` so the core OCDBT code stays +free of import bloat that's only used on failure paths. +""" + +import json +import logging +import os +import socket +import sys +import traceback +from datetime import datetime, timezone +from os import environ +from typing import Optional + +import tensorstore as ts + +_logger = logging.getLogger(__name__) + + +def humanize_count(n: int) -> str: + """Compact count for log lines: 1234567 → '1.2M', 950 → '950'.""" + for unit, scale in (("G", 1_000_000_000), ("M", 1_000_000), ("K", 1_000)): + if n >= scale: + return f"{n / scale:.1f}{unit}" + return str(n) + + +def failure_envelope(exc: BaseException, dump_tag: Optional[str]) -> dict: + """Generic metadata for any failure dump — host, pod, versions, + timestamp, traceback, coordinator env. Caller merges with the + failure-specific fields to build the final payload. + """ + return { + "timestamp_utc": datetime.now(timezone.utc).isoformat(), + "dump_tag": dump_tag, + "host": { + "hostname": socket.gethostname(), + "pid": os.getpid(), + "pod_name": environ.get("MY_POD_NAME"), + "pod_ip": environ.get("MY_POD_IP"), + "node_name": environ.get("MY_NODE_NAME"), + }, + "versions": { + "tensorstore": getattr(ts, "__version__", None), + "python": sys.version, + }, + "ocdbt_coordinator_env": { + "OCDBT_COORDINATOR_HOST": environ.get("OCDBT_COORDINATOR_HOST"), + "OCDBT_COORDINATOR_PORT": environ.get("OCDBT_COORDINATOR_PORT"), + }, + "exception": { + "type": type(exc).__name__, + "module": type(exc).__module__, + "message": str(exc), + "traceback": traceback.format_exc(), + }, + } + + +def bbox_failure_payload( + exc: BaseException, + dump_tag: Optional[str], + bbox_lo, + bbox_hi, + resolutions, + per_scale, + dst_handle, + src_handle, +) -> dict: + """Build the full diagnostic dict for a ``copy_ws_bbox_multiscale`` + commit failure. + + Merges generic ``failure_envelope`` metadata with bbox-specific + fields (per-scale shape / chunk / key-count / raw-bytes, src+dst + kvstore specs). Spec dumps are wrapped in try/except so a malformed + handle doesn't shadow the original exception. + """ + try: + dst_spec = dst_handle.kvstore.spec().to_json() + except Exception as e: + dst_spec = f"" + try: + src_spec = src_handle.kvstore.spec().to_json() + except Exception as e: + src_spec = f"" + total_voxels = sum(p[2] for p in per_scale) + total_raw = sum(p[3] for p in per_scale) + total_keys = sum(p[5] for p in per_scale) + return { + **failure_envelope(exc, dump_tag), + "bbox_lo": [int(c) for c in bbox_lo], + "bbox_hi": [int(c) for c in bbox_hi], + "resolutions": [list(map(int, r)) for r in resolutions], + "n_scales": len(per_scale), + "total_voxels": total_voxels, + "total_raw_bytes": total_raw, + "total_keys": total_keys, + "per_scale": [ + { + "scale_index": i, + "dims": list(dims), + "voxels": nvox, + "raw_bytes": raw_bytes, + "chunk_shape": list(chunk_shape), + "n_keys": n_keys, + "max_raw_per_key_bytes": max_per_key, + } + for i, dims, nvox, raw_bytes, chunk_shape, n_keys, max_per_key in per_scale + ], + "dst_kvstore_spec": dst_spec, + "src_kvstore_spec": src_spec, + } + + +def dump_failure_to_gcs(payload: dict, dump_tag: str) -> Optional[str]: + """Write a per-task failure report to ``$ERROR_DUMP/__.json``. + + Returns the full path or None (env unset, dump_tag empty, or write + error). ``dump_tag`` carries the calling-context identifier (graph + id, layer, coords, …) so multiple experiments can share one + ``ERROR_DUMP`` bucket without collisions. + """ + root = environ.get("ERROR_DUMP", "").strip() + if not root or not dump_tag: + return None + if not root.endswith("/"): + root += "/" + utc = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ") + rel = f"{dump_tag}__{utc}.json" + full = root + rel + try: + ts.KvStore.open(root).result().write( + rel, json.dumps(payload, indent=2).encode("utf-8") + ).result() + return full + except Exception as e: + _logger.warning("failed to write ERROR_DUMP at %s: %r", full, e) + return None diff --git a/pychunkedgraph/graph/ocdbt.py b/pychunkedgraph/graph/ocdbt/main.py similarity index 50% rename from pychunkedgraph/graph/ocdbt.py rename to pychunkedgraph/graph/ocdbt/main.py index fb12cb7d0..e6f0baf91 100644 --- a/pychunkedgraph/graph/ocdbt.py +++ b/pychunkedgraph/graph/ocdbt/main.py @@ -1,114 +1,84 @@ -"""OCDBT-backed neuroglancer_precomputed segmentation store. +"""OCDBT-backed neuroglancer_precomputed segmentation store — public API. Architecture: one immutable base OCDBT per watershed + one delta OCDBT per ChunkedGraph. Reads merge base + delta via tensorstore's kvstack driver. -Writes land in the delta via OCDBT's *_data_prefix options. +Writes land in the delta via OCDBT's ``*_data_prefix`` options. Multi-scale (MIP pyramid) is supported: the source watershed's info JSON drives the scale layout. All scales share one OCDBT kvstore; the precomputed driver prefixes keys by scale key automatically. + +Versioned reads +--------------- +Every OCDBT commit gets a monotonically-increasing ``generation_number`` and +an ``absl::Now()``-stamped ``commit_time`` (nanoseconds since epoch). The +tensorstore OCDBT driver lets callers pin a read-only open to a prior version +via the ``version`` spec field; accepts either an integer generation number +or an ISO-8601 UTC timestamp string. The timestamp form requires a ``Z`` +suffix (not ``+00:00``) and is interpreted as ``commit_time <= T`` — the open +returns the latest version at or before the pinned time. + +The commit_time itself cannot be overridden by the caller: OCDBT stamps each +commit from the writer's local clock (``absl::Now()`` in +``btree_writer_commit_operation.cc``). This means we can't make OCDBT commits +align exactly with a caller-provided operation timestamp. What the L2 chunk +lock guarantees instead: no other writer can commit to our chunks while we +hold the lock, so any timestamp captured under the lock before our first +commit is a valid pin for "pre-op state of our chunks." + +Retention: the OCDBT spec exposes no pruning fields. All versions are +retained by default. """ -import json +from os import environ import numpy as np import tensorstore as ts from pychunkedgraph import get_logger -logger = get_logger(__name__) - -OCDBT_SEG_COMPRESSION_LEVEL = 12 - -OCDBT_CONFIG = { - "compression": {"id": "zstd", "level": OCDBT_SEG_COMPRESSION_LEVEL}, - # Inline chunk values into B+tree leaves so they share the leaf's zstd - # compression context. Default (100 bytes) puts every chunk in its own - # out-of-line blob with independent zstd framing → ~7x bloat on GCS. - # 512 KiB captures every compressed_segmentation chunk we've measured. - "max_inline_value_bytes": 524288, -} - - -def _read_source_scales(ws_path): - """Read the source precomputed `info` JSON to get scale count and resolutions. - - The leading '/' in '/info' is required for GCS — without it the read - returns empty. - """ - kvs = ts.KvStore.open(ws_path).result() - info = json.loads(kvs.read("/info").result().value) - return info["scales"] - - -def _open_precomputed_scale(kvstore, scale_index, create=False, **schema_kw): - """Open one neuroglancer_precomputed scale on top of a kvstore spec.""" - spec = { - "driver": "neuroglancer_precomputed", - "kvstore": kvstore, - "scale_index": scale_index, - } - return ts.open(spec, create=create, **schema_kw).result() - - -def _schema_from_src(src_handle): - """Extract the schema kwargs needed to open a matching destination.""" - s = src_handle.schema - return dict( - rank=s.rank, - dtype=s.dtype, - codec=s.codec, - domain=s.domain, - shape=s.shape, - chunk_layout=s.chunk_layout, - dimension_units=s.dimension_units, - ) - +from .debug import bbox_failure_payload, dump_failure_to_gcs +from .meta import OcdbtConfig +from .utils import ( + _base_ocdbt_path, + _ensure_trailing_slash, + _open_precomputed_scale, + _read_source_scales, + _schema_from_src, + base_exists, + fork_exists, +) -# --------------------------------------------------------------------------- -# Base OCDBT (shared, immutable after ingest) -# --------------------------------------------------------------------------- - - -def _ensure_trailing_slash(path): - """Ensure kvstore paths end with / so they're treated as directories.""" - return path if path.endswith("/") else path + "/" - - -def _base_ocdbt_path(ws_path): - return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/base") - - -def base_exists(ws_path: str) -> bool: - """Check if the base OCDBT has already been created for this watershed.""" - base = _base_ocdbt_path(ws_path) - kvs = ts.KvStore.open(base).result() - result = kvs.read("manifest.ocdbt").result() - return result.value is not None and len(result.value) > 0 +logger = get_logger(__name__) -def create_base_ocdbt(ws_path: str): - """One-time bootstrap: create the shared base OCDBT at /ocdbt/base/. +def create_base_ocdbt(ws_path: str, config: OcdbtConfig): + """One-time bootstrap: create the shared base OCDBT at ``/ocdbt/base/``. Wipes any existing base first, then opens each scale with create=True so the info JSON is built from the source. Populating the base with - actual chunk data happens separately via copy_ws_chunk_multiscale - during the per-chunk ingest tasks. + actual chunk data happens separately via ``copy_ws_chunk_multiscale`` + or ``copy_ws_bbox_multiscale`` during the per-chunk ingest tasks. Returns (src_list, dst_list, resolutions) for the caller to use with - copy_ws_chunk_multiscale. + the copy helpers. """ base = _base_ocdbt_path(ws_path) - # Wipe existing base for a clean slate. + # Wipe via the underlying GCS/file driver, NOT through the ocdbt + # driver. Opening as ocdbt on an empty dir creates a default-config + # `manifest.ocdbt` stub (max_inline_value_bytes=100); on a dir with + # an existing manifest it only clears the B+tree, leaving the + # manifest's config in place. Either way the subsequent open with a + # different config mismatches. try: - kvs = ts.KvStore.open({"driver": "ocdbt", "base": base}).result() + kvs = ts.KvStore.open(base).result() kvs.delete_range(ts.KvStore.KeyRange()).result() except Exception: pass scales = _read_source_scales(ws_path) resolutions = [s["resolution"] for s in scales] - base_kvstore = {"driver": "ocdbt", "base": base, "config": dict(OCDBT_CONFIG)} + base_kvstore = {"driver": "ocdbt", "base": base, "config": config.ts_config()} src_list, dst_list = [], [] for i in range(len(scales)): @@ -126,26 +96,39 @@ def create_base_ocdbt(ws_path: str): def wipe_base_ocdbt(ws_path: str): """Wipe the base OCDBT entirely (for --reset-ocdbt).""" base = _base_ocdbt_path(ws_path) + # Wipe via the underlying GCS/file driver so the manifest file is + # deleted too. Opening as ocdbt only clears the B+tree. try: - kvs = ts.KvStore.open({"driver": "ocdbt", "base": base}).result() + kvs = ts.KvStore.open(base).result() kvs.delete_range(ts.KvStore.KeyRange()).result() except Exception: pass -def open_base_ocdbt(ws_path: str): +def open_base_ocdbt( + ws_path: str, config: OcdbtConfig, coordinator_address: str | None = None +): """Open the existing base OCDBT (read/write) for populating during ingest. Used by per-chunk ingest tasks that copy precomputed data into the shared base. NOT used at runtime — runtime always goes through the per-CG fork - spec via get_seg_source_and_destination_ocdbt. + spec via ``get_seg_source_and_destination_ocdbt``. + + ``coordinator_address`` (``"host:port"``) routes every OCDBT commit + through a ``DistributedCoordinatorServer`` so parallel workers don't + race the shared manifest's CAS — the only thing that prevents the + orphan ``d/`` file explosion. Required for any concurrent writer; the + arg is optional so single-process callers (e.g. tests, notebooks) can + skip it. Returns (src_list, dst_list, resolutions). """ base = _base_ocdbt_path(ws_path) scales = _read_source_scales(ws_path) resolutions = [s["resolution"] for s in scales] - base_kvstore = {"driver": "ocdbt", "base": base, "config": dict(OCDBT_CONFIG)} + base_kvstore = {"driver": "ocdbt", "base": base, "config": config.ts_config()} + if coordinator_address: + base_kvstore["coordinator"] = {"address": coordinator_address} src_list, dst_list = [], [] for i in range(len(scales)): @@ -158,20 +141,31 @@ def open_base_ocdbt(ws_path: str): return src_list, dst_list, resolutions -# --------------------------------------------------------------------------- -# Per-CG delta (fork of the base) -# --------------------------------------------------------------------------- - - -def build_cg_ocdbt_spec(ws_path: str, graph_id: str) -> dict: +def build_cg_ocdbt_spec( + ws_path: str, + graph_id: str, + config: OcdbtConfig, + *, + pinned_at: "int | str | None" = None, +) -> dict: """Open-time kvstore spec for a CG's OCDBT, backed by a shared immutable base. - The fork directory and its manifest are created automatically by - `fork_base_manifest` as part of CG creation — no manual setup. + This function is a pure spec-constructor — it doesn't materialize + the fork. The fork's ``manifest.ocdbt`` must exist before ``ts.open`` + on this spec will succeed; it's created by ``fork_base_manifest`` + (invoked from the ingest CLI's OCDBT path or the ``seg_ocdbt`` + notebook). ``ChunkedGraphMeta.ws_ocdbt_scales`` asserts presence via + ``fork_exists`` so callers get a clear error instead of a tensorstore + internal failure. - All three kvstack layers below AND all three `*_data_prefix` options + All three kvstack layers below AND all three ``*_data_prefix`` options are load-bearing; removing any of them causes fork writes to leak into the immutable base (verified empirically). + + When ``pinned_at`` is set, the opened kvstore is read-only and returns + state as of the specified version. Accepts an integer generation + number (exact) or an ISO-8601 UTC timestamp string with ``Z`` suffix + (interpreted as ``commit_time <= T``). """ base = _base_ocdbt_path(ws_path) fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") @@ -199,19 +193,22 @@ def build_cg_ocdbt_spec(ws_path: str, graph_id: str) -> dict: "base": _ensure_trailing_slash(fork_dir + data_prefix), } - return { + spec = { "driver": "ocdbt", "base": { "driver": "kvstack", "layers": [base_layer, fork_manifest_layer, fork_data_layer], }, - "config": dict(OCDBT_CONFIG), + "config": config.ts_config(), # Steer every kind of OCDBT write under `_d/` so the # fork_data_layer catches them. "value_data_prefix": data_prefix, "btree_node_data_prefix": data_prefix, "version_tree_node_data_prefix": data_prefix, } + if pinned_at is not None: + spec["version"] = pinned_at + return spec def fork_base_manifest(ws_path: str, graph_id: str, wipe_existing: bool = False): @@ -237,19 +234,80 @@ def fork_base_manifest(ws_path: str, graph_id: str, wipe_existing: bool = False) fork_kvs.write("manifest.ocdbt", manifest).result() -def get_seg_source_and_destination_ocdbt(ws_path: str, graph_id: str) -> tuple: +def ensure_fork_synced(ws_path: str, graph_id: str) -> bool: + """Self-heal stale fork manifests by re-snapshotting from base. + + ``setup_base`` calls ``fork_base_manifest`` once at graph creation — + before populate has committed most of its writes to base. Any base + commits after that don't propagate into the fork's manifest, so + kvstack-routed reads through the fork miss every chunk written to + base after fork creation. Symptom: meshing reads return all zeros. + + This helper refreshes the fork's manifest from base whenever both + are true: + 1. Fork's manifest differs from base's manifest (stale snapshot). + 2. The fork's ``_d/`` data prefix is empty (no edits yet). + + Edit-free is the safety guard: any SV-split write puts value/btree + files under that prefix BEFORE updating the fork manifest, so a + non-empty prefix means an edit either landed or is in flight, and + we must not overwrite the fork manifest in either case. With an + empty prefix the refresh is race-free vs concurrent populate (which + writes only to base, never to the fork). + + Returns True if the fork manifest was refreshed. + """ + if not fork_exists(ws_path, graph_id): + return False + base = _base_ocdbt_path(ws_path) + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + base_kvs = ts.KvStore.open(base).result() + fork_kvs = ts.KvStore.open(fork_dir).result() + base_manifest = base_kvs.read("manifest.ocdbt").result().value + fork_manifest = fork_kvs.read("manifest.ocdbt").result().value + if base_manifest == fork_manifest: + return False + data_prefix = f"{graph_id}_d/" + edit_files = fork_kvs.list( + ts.KvStore.KeyRange(data_prefix, data_prefix[:-1] + chr(ord("/") + 1)) + ).result() + if len(edit_files) > 0: + # Fork has edits; can't safely overwrite its manifest. + logger.warning( + f"fork {fork_dir} has {len(edit_files)} edit files but its " + f"manifest is stale vs base. Auto-refresh skipped to preserve " + f"edits. Call fork_base_manifest(..., wipe_existing=True) " + f"explicitly if you want to drop edits and re-snapshot." + ) + return False + fork_kvs.write("manifest.ocdbt", base_manifest).result() + logger.note(f"refreshed fork manifest at {fork_dir} from base (no edits)") + return True + + +def get_seg_source_and_destination_ocdbt( + ws_path: str, + graph_id: str, + config: OcdbtConfig, + *, + pinned_at: "int | str | None" = None, +) -> tuple: """Open source watershed + CG's delta OCDBT destination (all scales). Always uses the fork-based kvstack spec. Requires the base to exist and the fork's manifest to be present (set up at ingest time). + When ``pinned_at`` is set, the destination OCDBT handles are opened + read-only at that version — used by the recovery path to read + pre-op seg values via ``ChunkedGraphMeta.pinned_seg_reads``. + Returns: (src_list, dst_list, resolutions): per-scale TensorStore handles and [x,y,z] resolutions. """ scales = _read_source_scales(ws_path) resolutions = [s["resolution"] for s in scales] - cg_kvstore = build_cg_ocdbt_spec(ws_path, graph_id) + cg_kvstore = build_cg_ocdbt_spec(ws_path, graph_id, config, pinned_at=pinned_at) src_list, dst_list = [], [] for i in range(len(scales)): @@ -327,6 +385,85 @@ def copy_ws_chunk_multiscale( dst[x0:x1, y0:y1, z0:z1].write(data).result() +def copy_ws_bbox_multiscale( + src_list, + dst_list, + resolutions, + bbox_lo: np.ndarray, + bbox_hi: np.ndarray, + dump_tag: str | None = None, +): + """Copy a base-resolution voxel bbox across all MIP scales under one + transaction so the whole multi-scale write lands as a single OCDBT commit. + + The transaction (not ``atomic=True``) is what's load-bearing: it batches + every per-chunk underlying-kvstore write across every scale into one + commit, so the d/ file count for one call is constant in bbox size and + grows only with scale count. ``atomic=True`` would add cross-key + isolation but is rejected by tensorstore's distributed-OCDBT path — + when the kvstore is opened with a ``coordinator``, atomic transactions + cannot span multiple keys (verified empirically). Non-atomic still + batches; the coordinator handles concurrency by serializing the commit + on the wire. + + Passing the source TensorStore directly into ``write(...)`` lets + tensorstore stream the copy without materializing an intermediate + numpy array in Python — peak RSS drops by roughly one scale's + worth versus the read-into-numpy-then-write pattern. + """ + assert len(src_list) == len(dst_list) == len(resolutions) + dump_enabled = bool(environ.get("ERROR_DUMP")) + base_res = np.array(resolutions[0]) + txn = ts.Transaction() + # per_scale rows are only populated when dump_enabled, so the failure + # path has enough context for the structured GCS dump without paying any + # bookkeeping cost on the happy path. + per_scale: list = [] + for i, (src, dst) in enumerate(zip(src_list, dst_list)): + factor = (np.array(resolutions[i]) / base_res).astype(int) + x0, y0, z0 = bbox_lo // factor + x1, y1, z1 = bbox_hi // factor + if x1 <= x0 or y1 <= y0 or z1 <= z0: + continue + if dump_enabled: + dims = (int(x1 - x0), int(y1 - y0), int(z1 - z0)) + nvox = dims[0] * dims[1] * dims[2] + bpv = int(np.dtype(dst.dtype.numpy_dtype).itemsize) + # The precomputed driver's read_chunk shape includes a channel + # axis; the spatial chunk shape is the first three dims. + chunk_shape = tuple(int(s) for s in dst.chunk_layout.read_chunk.shape[:3]) + n_keys = int( + np.prod( + [int(np.ceil(d / c)) if c else 0 for d, c in zip(dims, chunk_shape)] + ) + ) + max_raw_per_key = int(np.prod(chunk_shape)) * bpv + per_scale.append( + (i, dims, nvox, nvox * bpv, chunk_shape, n_keys, max_raw_per_key) + ) + dst.with_transaction(txn)[x0:x1, y0:y1, z0:z1].write( + src[x0:x1, y0:y1, z0:z1] + ).result() + try: + txn.commit_async().result() + except Exception as exc: + if dump_enabled: + payload = bbox_failure_payload( + exc, + dump_tag, + bbox_lo, + bbox_hi, + resolutions, + per_scale, + dst_list[0], + src_list[0], + ) + path = dump_failure_to_gcs(payload, dump_tag) + if path: + logger.note(f"OCDBT commit failure dump → {path}") + raise + + def _mode_downsample(data: np.ndarray, factors: tuple) -> np.ndarray: """Mode downsample 4D segmentation array [X,Y,Z,C] by per-axis factors. @@ -408,22 +545,32 @@ def propagate_to_coarser_scales(dst_scales, resolutions, base_slices): prev_slices = target_slices -def write_seg(meta, bbs, bbe, data): - """Write segmentation at base scale and propagate to coarser scales. +def write_seg_chunks(meta, seg_writes): + """Write a flat batch of pre-sliced L2 chunks to OCDBT in parallel. - Single entry point for all SV-split-time segmentation writes. Builds - the tensorstore slices from the bounding box and adds the channel - dimension, so callers just pass the 3D bbox + 3D data. + ``seg_writes`` is the aggregated output of ``edits_sv.split_supervoxels`` + across every rep in an operation — each pair is one L2 chunk's worth + of ``(voxel_slices, data)``. Flattening across reps matters: one + ``write_seg_chunks`` call fires every chunk write in one parallel + tensorstore batch instead of serializing rep-by-rep. + + Only chunks that actually received new SV IDs appear here; gap + chunks between cross-chunk-connected pieces and neighbor chunks the + overlap read touched are skipped by the split planner. + + Coarser MIP levels stay the downsample worker's job — it picks up + the pubsub message ``publish_edit`` sends after this returns. Args: - meta: ChunkedGraphMeta with ws_ocdbt_scales and ws_ocdbt_resolutions. - bbs: (3,) array — start of the region in base-resolution voxels. - bbe: (3,) array — end of the region in base-resolution voxels. - data: 3D numpy array of new segmentation IDs. + meta: ChunkedGraphMeta with ``ws_ocdbt`` (base-scale handle). + seg_writes: iterable of ``(voxel_slices, data)`` pairs, where + ``voxel_slices`` is a 3-tuple of ``slice`` objects covering one + L2 chunk's x/y/z extent and ``data`` is the 3D label block + (shape matches the slice extents). """ - slices = tuple(slice(int(s), int(e)) for s, e in zip(bbs, bbe)) - meta.ws_ocdbt[slices + (slice(None),)] = data[..., np.newaxis] - if len(meta.ws_ocdbt_scales) > 1: - propagate_to_coarser_scales( - meta.ws_ocdbt_scales, meta.ws_ocdbt_resolutions, slices - ) + futures = [ + meta.ws_ocdbt[voxel_slices + (slice(None),)].write(data[..., np.newaxis]) + for voxel_slices, data in seg_writes + ] + for f in futures: + f.result() diff --git a/pychunkedgraph/graph/ocdbt/meta.py b/pychunkedgraph/graph/ocdbt/meta.py new file mode 100644 index 000000000..df18b4dca --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/meta.py @@ -0,0 +1,78 @@ +"""OcdbtConfig dataclass — single source of truth for per-CG OCDBT settings.""" + +from dataclasses import asdict, dataclass, field +from typing import Dict, Optional + + +@dataclass +class OcdbtConfig: + """Per-CG OCDBT settings, persisted in ``ChunkedGraphMeta.custom_data["ocdbt_config"]``. + + Carries both ingest-time choices (populate base? at which layer?) and + tensorstore kvstore options (compression, inline byte cap) that must + stay consistent for the lifetime of the OCDBT base. Built once from + the dataset yaml's ``ocdbt_config:`` section and stored alongside the + CG so every code path that opens an OCDBT store reads back the same + values. + """ + + enabled: bool = False + populate_base: bool = False + populate_layer: int = 3 + sv_split_threshold: int = 10 + compression: Dict = field(default_factory=lambda: {"id": "zstd", "level": 12}) + # Inline-vs-out-of-line threshold. Values ≤ this size live in the btree + # leaf bytes; larger values get written to a d/ file and the mutation + # carries only an IndirectDataReference. This directly determines + # cooperator-forwarded RPC size in distributed mode: inline values are + # carried inside the gRPC WriteRequest's `mutations` field, so a leaf's + # batch can blow past tensorstore's hardcoded 4 MiB gRPC max-receive + # whenever multiple inline values pile up on the same node. Verified + # by reading btree_writer.cc StagePending in v0.1.81. + # + # 4 KiB keeps small metadata (info JSON ~1.5 KB, populate-marker files) + # inline while forcing every segmentation chunk value out-of-line — + # chunks compress to 100s of KB even for the smallest scales. With + # chunk bytes out-of-line the WriteRequest stays tiny regardless of + # how many keys a worker commits at once. Tradeoff vs the previous + # 1 MiB cap: each chunk now has its own zstd-framed d/ blob instead of + # sharing a leaf's compression context, which can cost a few percent + # of compression ratio (much less than the originally-feared "7× + # bloat", which only applied at the 100-byte default). + max_inline_value_bytes: int = 4096 + + @classmethod + def from_dict(cls, d: Optional[Dict]) -> "OcdbtConfig": + """Build from a dict. Unknown keys are ignored so older configs don't + break newer code, and newer fields default in when older configs are + loaded. + """ + if not d: + return cls() + known = {f.name for f in cls.__dataclass_fields__.values()} + return cls(**{k: v for k, v in d.items() if k in known}) + + @classmethod + def resolve(cls, *dicts: Optional[Dict]) -> "OcdbtConfig": + """Layered merge: later dicts override earlier ones, all over defaults. + + Use to express precedence — e.g. ``resolve(yaml_dict, info_file_dict)`` + gives info-file values priority over yaml-supplied ones, with + dataclass defaults filling anything neither side specifies. + ``None`` and empty dicts are no-ops. + """ + merged: Dict = {} + for d in dicts: + if d: + merged.update(d) + return cls.from_dict(merged) + + def to_dict(self) -> Dict: + return asdict(self) + + def ts_config(self) -> Dict: + """The subset that belongs inside a tensorstore OCDBT kvstore ``config``.""" + return { + "compression": dict(self.compression), + "max_inline_value_bytes": self.max_inline_value_bytes, + } diff --git a/pychunkedgraph/graph/ocdbt/utils.py b/pychunkedgraph/graph/ocdbt/utils.py new file mode 100644 index 000000000..4c04e5799 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/utils.py @@ -0,0 +1,156 @@ +"""Internal helpers for the OCDBT package. + +Path builders, schema extraction, populate-marker IO, layer-bbox math. +Not part of the public API except for the marker IO and ``_layer_bbox`` +which the ingest worker uses across the package boundary. +""" + +import json +from typing import Optional + +import numpy as np +import tensorstore as ts +from tenacity import ( + retry, + retry_if_exception_message, + stop_after_attempt, + wait_exponential, +) + +# tensorstore raises ValueError with an absl/grpc status-code prefix. Retry +# only the transient classes — DNS hiccups, deadline-blown reads, server +# 5xx — so a single flaky GCS call doesn't kill the populate task. Persistent +# errors (NOT_FOUND, INVALID_ARGUMENT, RESOURCE_EXHAUSTED, …) propagate. +_transient = retry( + retry=retry_if_exception_message( + match=r"^(UNAVAILABLE|DEADLINE_EXCEEDED|ABORTED|INTERNAL):" + ), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=0.5, min=0.5, max=8), + reraise=True, +) + + +def _ensure_trailing_slash(path: str) -> str: + """Ensure kvstore paths end with / so they're treated as directories.""" + return path if path.endswith("/") else path + "/" + + +def _base_ocdbt_path(ws_path: str) -> str: + return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/base") + + +def _populate_markers_path(ws_path: str) -> str: + return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/.populated") + + +def _marker_key(layer: int, coords) -> str: + return f"l{int(layer)}_{int(coords[0])}_{int(coords[1])}_{int(coords[2])}" + + +def _read_source_scales(ws_path: str): + """Read the source precomputed ``info`` JSON to get scale count and resolutions. + + The leading '/' in '/info' is required for GCS — without it the read + returns empty. + """ + kvs = ts.KvStore.open(ws_path).result() + info = json.loads(kvs.read("/info").result().value) + return info["scales"] + + +def _open_precomputed_scale( + kvstore, scale_index: int, create: bool = False, **schema_kw +): + """Open one neuroglancer_precomputed scale on top of a kvstore spec.""" + spec = { + "driver": "neuroglancer_precomputed", + "kvstore": kvstore, + "scale_index": scale_index, + } + return ts.open(spec, create=create, **schema_kw).result() + + +def _schema_from_src(src_handle) -> dict: + """Extract the schema kwargs needed to open a matching destination. + + ``domain`` already carries both extent and origin (voxel_offset). Passing + ``shape`` alongside conflicts with non-zero-origin sources because shape + implies origin=0 — tensorstore refuses to merge ``[0, N)`` with + ``[offset, offset+N)``. + """ + s = src_handle.schema + return dict( + rank=s.rank, + dtype=s.dtype, + codec=s.codec, + domain=s.domain, + chunk_layout=s.chunk_layout, + dimension_units=s.dimension_units, + ) + + +@_transient +def is_chunk_populated(ws_path: str, layer: int, coords) -> bool: + """Check whether this chunk's precomputed→OCDBT copy has already completed. + + Markers live outside the OCDBT keyspace at + ``/ocdbt/.populated/l___`` so retried ingest tasks + don't re-copy chunks and bloat the database with redundant versioned + writes. + """ + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + result = kvs.read(_marker_key(layer, coords)).result() + return result.value is not None and len(result.value) > 0 + + +@_transient +def mark_chunk_populated(ws_path: str, layer: int, coords) -> None: + """Record that this chunk's precomputed→OCDBT copy completed.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + kvs.write(_marker_key(layer, coords), b"1").result() + + +@_transient +def read_populate_meta(ws_path: str) -> Optional[dict]: + """Return the per-base populate config dict, or None if not yet written.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + r = kvs.read("meta.json").result() + if r.value is None or len(r.value) == 0: + return None + return json.loads(r.value) + + +@_transient +def write_populate_meta(ws_path: str, meta: dict) -> None: + """Persist the per-base populate config (layer, etc.) alongside markers.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + kvs.write("meta.json", json.dumps(meta).encode()).result() + + +def base_exists(ws_path: str) -> bool: + """Check if the base OCDBT has already been created for this watershed.""" + base = _base_ocdbt_path(ws_path) + kvs = ts.KvStore.open(base).result() + result = kvs.read("manifest.ocdbt").result() + return result.value is not None and len(result.value) > 0 + + +def fork_exists(ws_path: str, graph_id: str) -> bool: + """Check if this ChunkedGraph's fork has been initialized.""" + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + kvs = ts.KvStore.open(fork_dir).result() + result = kvs.read("manifest.ocdbt").result() + return result.value is not None and len(result.value) > 0 + + +def _layer_bbox(meta, layer: int, coords) -> tuple: + """Base-resolution voxel bbox of a chunk at this layer.""" + chunk_size = np.array(meta.graph_config.CHUNK_SIZE, dtype=int) + layer_chunk_size = chunk_size * (1 << (layer - 2)) + coords = np.array(coords, dtype=int) + vol_start = meta.voxel_bounds[:, 0] + vol_end = meta.voxel_bounds[:, 1] + lo = coords * layer_chunk_size + vol_start + hi = np.minimum(lo + layer_chunk_size, vol_end) + return lo, hi diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 73ad898d8..60a2305c5 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -21,15 +21,17 @@ from . import locks from . import edits +from . import edits_sv from . import types +from .ocdbt import write_seg_chunks from pychunkedgraph.graph import attributes from .edges import Edges from .edges.utils import get_edges_status from pychunkedgraph.graph import basetypes from pychunkedgraph.graph import serializers from .cache import CacheService -from .cutting import run_multicut -from .exceptions import PreconditionError, SupervoxelSplitRequiredError +from .cutting import Cut, SvSplitRequired, run_multicut +from .exceptions import PreconditionError from .exceptions import PostconditionError from .utils.generic import get_bounding_box as get_bbox from pychunkedgraph.graph import get_valid_timestamp @@ -50,7 +52,9 @@ class GraphEditOperation(ABC): "do_sanity_check", ] Result = namedtuple( - "Result", ["operation_id", "new_root_ids", "new_lvl2_ids", "old_root_ids"] + "Result", + ["operation_id", "new_root_ids", "new_lvl2_ids", "old_root_ids", "seg_bbox"], + defaults=(None,), ) def __init__( @@ -460,11 +464,6 @@ def execute( new_lvl2_ids=new_lvl2_ids, old_root_ids=root_ids, ) - except SupervoxelSplitRequiredError as err: - # no need for self.cg.cache = None, the cache must be retained after sv split - raise SupervoxelSplitRequiredError( - str(err), err.sv_remapping, operation_id=lock.operation_id - ) from err except PreconditionError as err: self.cg.cache = None raise PreconditionError(err) from err @@ -550,6 +549,10 @@ def _write( new_root_ids=new_root_ids, new_lvl2_ids=new_lvl2_ids, old_root_ids=old_root_ids, + # Only set when the operation actually ran SV splits (MulticutOperation + # populates this; other operations leave the attr absent and it defaults + # to None via the Result namedtuple's default). + seg_bbox=getattr(self, "seg_bboxes", None) or None, ) @@ -630,13 +633,14 @@ def _update_root_ids(self) -> np.ndarray: def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - root_ids = set( - self.cg.get_roots( - self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts - ) - ) + sv_ids = self.added_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + root_ids = set(roots) if len(root_ids) < 2 and not self.allow_same_segment_merge: - raise PreconditionError("Supervoxels must belong to different objects.") + raise PreconditionError( + f"Supervoxels must belong to different objects. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" + ) atomic_edges = self.added_edges fake_edge_rows = [] @@ -761,33 +765,26 @@ def __init__( assert np.sum(layers) == layers.size, "IDs must be supervoxels." def _update_root_ids(self) -> np.ndarray: - root_ids = np.unique( - self.cg.get_roots( - self.removed_edges.ravel(), - assert_roots=True, - time_stamp=self.parent_ts, - ) - ) + sv_ids = self.removed_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + root_ids = np.unique(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" + ) return root_ids def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - if ( - len( - set( - self.cg.get_roots( - self.removed_edges.ravel(), - assert_roots=True, - time_stamp=self.parent_ts, - ) - ) + sv_ids = self.removed_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + if len(set(roots)) > 1: + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" ) - > 1 - ): - raise PreconditionError("Supervoxels must belong to the same object.") with TimeIt("remove_edges", self.cg.graph_id, operation_id): return edits.remove_edges( @@ -866,6 +863,11 @@ class MulticutOperation(GraphEditOperation): "path_augment", "disallow_isolating_cut", "do_sanity_check", + # Base-resolution bboxes of SV splits done as part of this op, one + # per rep. Populated only when the multicut hit SvSplitRequired and + # split_supervoxels actually ran. Surfaced on the Result so the + # downsample worker knows which regions to re-mip. + "seg_bboxes", ] def __init__( @@ -893,6 +895,7 @@ def __init__( self.path_augment = path_augment self.disallow_isolating_cut = disallow_isolating_cut self.do_sanity_check = do_sanity_check + self.seg_bboxes = [] ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID) layers = self.cg.get_chunk_layers(ids) @@ -902,30 +905,113 @@ def _update_root_ids(self) -> np.ndarray: sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)).astype( basetypes.NODE_ID ) - root_ids = np.unique( - self.cg.get_roots( - sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts - ) + roots = self.cg.get_roots( + sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts ) + root_ids = np.unique(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same segment.") + raise PreconditionError( + f"Supervoxels must belong to the same segment. " + f"sources={self.source_ids.tolist()} sinks={self.sink_ids.tolist()} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) return root_ids def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - # Verify that sink and source are from the same root object - root_ids = set( - self.cg.get_roots( - np.concatenate([self.source_ids, self.sink_ids]).astype( - basetypes.NODE_ID - ), - assert_roots=True, - time_stamp=self.parent_ts, + result = self._run_multicut(operation_id) + if isinstance(result, SvSplitRequired): + # Running under GraphEditOperation.execute's RootLock — no same-root + # edit can interleave between the SV split and the retry multicut. + # `plan_sv_splits` returns the chunk scope for both locks below, + # `split_supervoxels` is a pure planner that computes the full + # payload. Writes happen here inside nested L2 chunk locks: + # - `L2ChunkLock` (temporal) spans the seg reads (inside + # `split_supervoxels`) and the writes, so no concurrent + # op can mutate our chunks mid-compute. + # - `IndefiniteL2ChunkLock` is scoped tightly to the writes + # only. A worker death inside it leaves the indefinite + # cell set on every chunk row in scope, blocking future + # ops until operator replay clears them. + tasks, chunk_ids = edits_sv.plan_sv_splits( + self.cg, + sv_remapping=result.sv_remapping, + source_ids=self.source_ids, + sink_ids=self.sink_ids, + source_coords=self.source_coords, + sink_coords=self.sink_coords, + ) + with locks.L2ChunkLock( + self.cg, + chunk_ids, + operation_id, + privileged_mode=self.privileged_mode, + ): + sv_result = edits_sv.split_supervoxels( + self.cg, + tasks=tasks, + sv_remapping=result.sv_remapping, + source_ids=self.source_ids, + sink_ids=self.sink_ids, + operation_id=operation_id, + timestamp=timestamp, + ) + with locks.IndefiniteL2ChunkLock( + self.cg, + chunk_ids, + operation_id, + privileged_mode=self.privileged_mode, + ): + write_seg_chunks(self.cg.meta, sv_result.seg_writes) + self.cg.client.write(sv_result.bigtable_rows) + self.seg_bboxes = sv_result.seg_bboxes + self.source_ids = sv_result.source_ids_fresh + self.sink_ids = sv_result.sink_ids_fresh + result = self._run_multicut(operation_id) + if isinstance(result, SvSplitRequired): + raise PreconditionError( + "Supervoxel split succeeded but source and sink remain " + "connected; place source and sink farther apart." + ) + + assert isinstance(result, Cut), f"unexpected multicut result: {result!r}" + self.removed_edges = result.atomic_edges + if not self.removed_edges.size: + raise PostconditionError("Mincut could not find any edges to remove.") + + with TimeIt("remove_edges", self.cg.graph_id, operation_id): + return edits.remove_edges( + self.cg, + operation_id=operation_id, + atomic_edges=self.removed_edges, + time_stamp=timestamp, + parent_ts=self.parent_ts, + do_sanity_check=self.do_sanity_check, ) + + def _run_multicut(self, operation_id): + """Build the local subgraph and run multicut; returns the tagged result. + + Factored so `_apply` can call it twice — once for initial detection + and again after an SV split to get fresh atomic_edges against the + post-split graph topology. + """ + sink_and_source_ids = np.concatenate([self.source_ids, self.sink_ids]).astype( + basetypes.NODE_ID + ) + roots = self.cg.get_roots( + sink_and_source_ids, + assert_roots=True, + time_stamp=self.parent_ts, ) + root_ids = set(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sources={self.source_ids.tolist()} sinks={self.sink_ids.tolist()} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) bbox = get_bbox( self.source_coords, @@ -936,7 +1022,6 @@ def _apply( l2id_agglomeration_d, edges_tuple = self.cg.get_subgraph( root_ids.pop(), bbox=bbox, bbox_is_coordinate=True ) - edges = reduce(lambda x, y: x + y, edges_tuple, Edges([], [])) supervoxels = np.concatenate( [agg.supervoxels for agg in l2id_agglomeration_d.values()] @@ -948,7 +1033,7 @@ def _apply( raise PreconditionError("No local edges found.") with TimeIt("multicut", self.cg.graph_id, operation_id): - self.removed_edges = run_multicut( + return run_multicut( edges, self.source_ids, self.sink_ids, @@ -956,18 +1041,6 @@ def _apply( disallow_isolating_cut=self.disallow_isolating_cut, sv_split_supported=self.cg.meta.ocdbt_seg, ) - if not self.removed_edges.size: - raise PostconditionError("Mincut could not find any edges to remove.") - - with TimeIt("remove_edges", self.cg.graph_id, operation_id): - return edits.remove_edges( - self.cg, - operation_id=operation_id, - atomic_edges=self.removed_edges, - time_stamp=timestamp, - parent_ts=self.parent_ts, - do_sanity_check=self.do_sanity_check, - ) def _create_log_record( self, diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index 4ebe0533e..64d37a5bc 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -173,8 +173,9 @@ def get_local_segmentation(meta, bbox_start, bbox_end, mip: int = 0) -> np.ndarr def lookup_svs_from_seg(meta, coordinates): """Read SV IDs directly from OCDBT segmentation at given coordinates.""" - bbox_start = np.min(coordinates, axis=0) - bbox_end = np.max(coordinates, axis=0) + 1 + coordinates = np.asarray(coordinates, dtype=int) + bbox_start = coordinates.min(axis=0) + bbox_end = coordinates.max(axis=0) + 1 seg = get_local_segmentation(meta, bbox_start, bbox_end)[..., 0] - local_coords = coordinates - bbox_start - return np.array([seg[tuple(c)] for c in local_coords], dtype=np.uint64) + local = coordinates - bbox_start + return seg[local[:, 0], local[:, 1], local[:, 2]].astype(np.uint64) diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index 7f7d8f927..60b1d0799 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -128,10 +128,14 @@ def get_atomic_ids_from_coords( """ import fastremap - if parent_id_layer == 1 and meta.ocdbt_seg: + if meta.ocdbt_seg: + # Unified path: any OCDBT lookup reads the current seg at the + # coords, ignoring the user-supplied parent (the parent may be + # stale after an SV split, or — for 3D mesh clicks — not an L1 + # SV at all). See handle_supervoxel_id_lookup for the rationale. return lookup_svs_from_seg(meta, coordinates) - if parent_id_layer == 1 and not meta.ocdbt_seg: + if parent_id_layer == 1: return np.array([parent_id] * len(coordinates), dtype=np.uint64) coordinates_nm = coordinates * np.array(meta.resolution) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index bd2f07626..811f19891 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -5,6 +5,8 @@ """ import os +from functools import partial +from time import sleep from pychunkedgraph import configure_logging, DEBUG @@ -14,23 +16,20 @@ from .cluster import create_atomic_chunk, create_parent_chunk, enqueue_l2_tasks from .manager import IngestionManager +from .ocdbt import coordinator, setup_base from .utils import ( bootstrap, - chunk_id_str, + job_type_guard, print_completion_rate, print_status, + purge_layer_state, queue_layer_helper, - job_type_guard, + requeue_chunk, ) from .simple_tests import run_all from .create.parent_layer import add_parent_chunk from ..graph.chunkedgraph import ChunkedGraph -from ..graph.ocdbt import ( - base_exists, - create_base_ocdbt, - fork_base_manifest, - wipe_base_ocdbt, -) +from ..graph.ocdbt import OcdbtConfig from ..utils.redis import get_redis_connection, keys as r_keys group_name = "ingest" @@ -52,75 +51,76 @@ def flush_redis(): @ingest_cli.command("graph") @click.argument("graph_id", type=str) -@click.argument("dataset", type=click.Path(exists=True)) -@click.option("--ocdbt", is_flag=True, help="Precomputed supervoxel seg into ocdbt.") +@click.argument("dataset", type=click.Path(exists=True), required=False) +@click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") +@click.option( + "--retry", + "-r", + is_flag=True, + help="Re-run setup against the existing table (no cg.create()).", +) @click.option( - "--sv-split-threshold", - type=int, - default=10, - help="Distance threshold for SV split edge matching.", + "--skip-queue", + "-s", + is_flag=True, + help="Set up everything but don't enqueue L2 tasks.", ) -@click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") -@click.option("--retry", is_flag=True, help="Rerun without creating a new table.") @click.option( - "--reset-ocdbt", + "--test", + "-t", is_flag=True, - help="Wipe base AND this CG's delta OCDBT, then recreate from scratch.", + help="Test 8 chunks at the center of dataset.", ) -@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @job_type_guard(group_name) def ingest_graph( graph_id: str, dataset: click.Path, - ocdbt: bool, - sv_split_threshold: int, raw: bool, retry: bool, - reset_ocdbt: bool, + skip_queue: bool, test: bool, ): - """ - Main ingest command. - Takes ingest config from a yaml file and queues atomic tasks. + """Main ingest command. Takes config from yaml, queues atomic tasks. + + Purely about the bigtable graph: creates the table and enqueues L2 + tasks. OCDBT base + fork creation happens in ``ingest layer N`` when + N matches ``ocdbt_populate_layer``; that's the single owner of the + OCDBT lifecycle. + + ``--retry`` reuses the existing IngestionManager from redis and skips + ``cg.create()``. Pair with ``--skip-queue`` to skip L2 enqueue too. """ redis = get_redis_connection() - redis.set(r_keys.JOB_TYPE, group_name) - with open(dataset, "r") as stream: - config = yaml.safe_load(stream) - if test: configure_logging(level=DEBUG) - meta, ingest_config, client_info = bootstrap(graph_id, config, raw, test) - cg = ChunkedGraph(meta=meta, client_info=client_info) - if not retry: + if retry: + imanager_pickle = redis.get(r_keys.INGESTION_MANAGER) + if imanager_pickle is None: + raise click.ClickException( + f"--retry requires an existing `{group_name}` job in redis. " + f"Run without --retry to start a new job." + ) + imanager = IngestionManager.from_pickle(imanager_pickle) + else: + if dataset is None: + raise click.ClickException("dataset is required unless --retry is passed.") + redis.set(r_keys.JOB_TYPE, group_name) + with open(dataset, "r") as stream: + config = yaml.safe_load(stream) + meta, ingest_config, client_info, ocdbt_config_dict = bootstrap( + graph_id, config, raw, test + ) + cg = ChunkedGraph(meta=meta, client_info=client_info) cg.create() - - needs_base = False - if ocdbt: - ws = cg.meta.data_source.WATERSHED - cg.meta.custom_data["seg"] = { - "ocdbt": True, - "sv_split_threshold": sv_split_threshold, - } - cg.update_meta(cg.meta, overwrite=True) - - if reset_ocdbt: - wipe_base_ocdbt(ws) - - needs_base = not base_exists(ws) - if needs_base: - create_base_ocdbt(ws) - - fork_base_manifest(ws, graph_id, wipe_existing=retry or reset_ocdbt) - - imanager = IngestionManager( - ingest_config, - meta, - ocdbt_seg=ocdbt, - ocdbt_populate_base=needs_base, - ) - enqueue_l2_tasks(imanager, create_atomic_chunk) + imanager = IngestionManager( + ingest_config, + meta, + ocdbt_config=ocdbt_config_dict, + ) + + if not skip_queue: + enqueue_l2_tasks(imanager, create_atomic_chunk) os._exit(0) @@ -140,33 +140,104 @@ def pickle_imanager(graph_id: str, dataset: click.Path, raw: bool): except yaml.YAMLError as exc: print(exc) - meta, ingest_config, _ = bootstrap(graph_id, config=config, raw=raw) - imanager = IngestionManager(ingest_config, meta) + meta, ingest_config, _, ocdbt_config_dict = bootstrap( + graph_id, config=config, raw=raw + ) + imanager = IngestionManager(ingest_config, meta, ocdbt_config=ocdbt_config_dict) imanager.redis.set(r_keys.JOB_TYPE, group_name) @ingest_cli.command("layer") @click.argument("parent_layer", type=int) +@click.option( + "--queue-only", + "-q", + is_flag=True, + help="Only enqueue tasks; do not start the OCDBT coordinator. " + "Use when a coordinator is already running in another process.", +) +@click.option( + "--ocdbt-only", + "-o", + is_flag=True, + help="Workers run only OCDBT populate (skip add_parent_chunk). " + "Requires the OCDBT populate layer.", +) +@click.option( + "--ingest-only", + "-i", + is_flag=True, + help="Workers run only add_parent_chunk (skip OCDBT populate). " + "Use when the OCDBT base is already populated for this layer.", +) @job_type_guard(group_name) -def queue_layer(parent_layer): +def queue_layer(parent_layer, queue_only, ocdbt_only, ingest_only): """ Queue all chunk tasks at a given layer. Must be used when all the chunks at `parent_layer - 1` have completed. + + When this layer is the OCDBT populate layer, this command also owns the + OCDBT lifecycle: idempotently creates the base + fork via ``setup_base`` + and starts a ``DistributedCoordinatorServer`` so every worker's commit + routes through one process (eliminates manifest-CAS races and orphan + ``d/`` files). Stays in the foreground until killed. + + Flags: + ``--queue-only`` skips the coordinator (one is assumed running elsewhere). + ``--ocdbt-only`` task body = OCDBT populate only. + ``--ingest-only`` task body = add_parent_chunk only. """ assert parent_layer > 2, "This command is for layers 3 and above." + if ocdbt_only and ingest_only: + raise click.ClickException( + "--ocdbt-only and --ingest-only are mutually exclusive." + ) redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - queue_layer_helper(parent_layer, imanager, create_parent_chunk) + + is_populate_layer = imanager.is_ocdbt_populate_layer(parent_layer) + if ocdbt_only and not is_populate_layer: + raise click.ClickException( + "--ocdbt-only requires running at the OCDBT populate layer." + ) + + if is_populate_layer: + # Single owner of the OCDBT lifecycle: create base + fork if + # missing, reconcile config with on-disk meta, then re-pickle + # imanager so queued workers read the resolved config. + resolved = setup_base(imanager.cg, OcdbtConfig.from_dict(imanager.ocdbt_config)) + imanager.ocdbt_config = resolved.to_dict() + imanager.redis.set(r_keys.INGESTION_MANAGER, imanager.serialized(pickled=True)) + + mode = "ocdbt" if ocdbt_only else ("ingest" if ingest_only else "full") + task_fn = ( + partial(create_parent_chunk, mode=mode) + if mode != "full" + else create_parent_chunk + ) + + # Coordinator only matters when OCDBT populate will actually run. + needs_coordinator = ( + is_populate_layer and mode in ("full", "ocdbt") and not queue_only + ) + if needs_coordinator: + with coordinator(imanager.redis): + queue_layer_helper(parent_layer, imanager, task_fn) + while True: + sleep(60) + else: + queue_layer_helper(parent_layer, imanager, task_fn) @ingest_cli.command("status") +@click.option("--refresh", type=int, default=5, help="Seconds between redis polls.") @job_type_guard(group_name) -def ingest_status(): +def ingest_status(refresh: int): """Print ingest status to console by layer.""" redis = get_redis_connection() try: imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - print_status(imanager, redis) + print_status(imanager, redis, refresh_seconds=refresh) except TypeError as err: print(f"\nNo current `{group_name}` job found in redis: {err}") @@ -177,23 +248,7 @@ def ingest_status(): @job_type_guard(group_name) def ingest_chunk(queue: str, chunk_info): """Manually queue chunk when a job is stuck for whatever reason.""" - redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layer, coords = chunk_info[0], chunk_info[1:] - - func = create_parent_chunk - args = (layer, coords) - if layer == 2: - func = create_atomic_chunk - args = (coords,) - queue = imanager.get_task_queue(queue) - queue.enqueue( - func, - job_id=chunk_id_str(layer, coords), - job_timeout=f"{int(layer * layer)}m", - result_ttl=0, - args=args, - ) + requeue_chunk(queue, chunk_info, create_atomic_chunk, create_parent_chunk) @ingest_cli.command("chunk_local") @@ -228,3 +283,14 @@ def rate(layer: int, span: int): @job_type_guard(group_name) def run_tests(graph_id): run_all(ChunkedGraph(graph_id=graph_id)) + + +@ingest_cli.command("purge_layer") +@click.argument("layer", type=int) +@click.confirmation_option(prompt="Purge ALL redis state for this layer?") +@job_type_guard(group_name) +def purge_layer(layer: int): + """Drop the per-layer RQ queue + registries + completion set so the + layer can be re-run from a previous layer's backup.""" + purge_layer_state(get_redis_connection(), layer) + click.echo(f"purged redis state for layer {layer}") diff --git a/pychunkedgraph/ingest/cli_upgrade.py b/pychunkedgraph/ingest/cli_upgrade.py index 3a3ccb2e1..83e9e53c8 100644 --- a/pychunkedgraph/ingest/cli_upgrade.py +++ b/pychunkedgraph/ingest/cli_upgrade.py @@ -4,44 +4,30 @@ cli for running upgrade """ -from time import sleep - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) - import click -import tensorstore as ts from flask.cli import AppGroup -from pychunkedgraph import __version__ + +from pychunkedgraph import __version__, get_logger from pychunkedgraph.graph.meta import GraphConfig from . import IngestConfig -from .cluster import ( - convert_edges_to_ocdbt, - enqueue_l2_tasks, - upgrade_atomic_chunk, - upgrade_parent_chunk, -) +from .cluster import enqueue_l2_tasks, upgrade_atomic_chunk, upgrade_parent_chunk from .manager import IngestionManager +from .ocdbt import setup_base from .utils import ( - chunk_id_str, + job_type_guard, print_completion_rate, print_status, queue_layer_helper, - start_ocdbt_server, - job_type_guard, + requeue_chunk, ) from ..graph.chunkedgraph import ChunkedGraph, ChunkedGraphMeta -from ..graph.ocdbt import ( - base_exists, - create_base_ocdbt, - fork_base_manifest, - wipe_base_ocdbt, -) +from ..graph.ocdbt import OcdbtConfig from ..utils.redis import get_redis_connection from ..utils.redis import keys as r_keys +logger = get_logger(__name__) + group_name = "upgrade" upgrade_cli = AppGroup(group_name) @@ -63,26 +49,18 @@ def flush_redis(): @click.argument("graph_id", type=str) @click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @click.option("--ocdbt", is_flag=True, help="Enable ocdbt seg (SV splitting support).") -@click.option("--ocdbt-edges", is_flag=True, help="Convert edges to ocdbt kv store.") @click.option( "--sv-split-threshold", type=int, default=10, help="Distance threshold for SV split edge matching.", ) -@click.option( - "--reset-ocdbt", - is_flag=True, - help="Wipe base AND this CG's delta OCDBT, then recreate from scratch.", -) @job_type_guard(group_name) def upgrade_graph( graph_id: str, test: bool, ocdbt: bool, - ocdbt_edges: bool, sv_split_threshold: int, - reset_ocdbt: bool, ): """ Main upgrade command. Queues atomic tasks. @@ -103,38 +81,18 @@ def upgrade_graph( cg = ChunkedGraph(graph_id=graph_id) if ocdbt: - ws = cg.meta.data_source.WATERSHED - cg.meta.custom_data["seg"] = { - "ocdbt": True, - "sv_split_threshold": sv_split_threshold, - } - cg.update_meta(cg.meta, overwrite=True) + ocdbt_cfg = OcdbtConfig.from_dict(cg.meta.custom_data.get("ocdbt_config")) + ocdbt_cfg.enabled = True + ocdbt_cfg.sv_split_threshold = sv_split_threshold + setup_base(cg, ocdbt_cfg) logger.note(f"enabled ocdbt seg with sv_split_threshold={sv_split_threshold}") - - if reset_ocdbt: - wipe_base_ocdbt(ws) - - if not base_exists(ws): - create_base_ocdbt(ws) - - fork_base_manifest(ws, graph_id, wipe_existing=reset_ocdbt) try: cg.client.create_column_family("4") except Exception: ... imanager = IngestionManager(ingest_config, cg.meta) - if ocdbt_edges: - server = ts.ocdbt.DistributedCoordinatorServer() - start_ocdbt_server(imanager, server) - - fn = convert_edges_to_ocdbt if ocdbt_edges else upgrade_atomic_chunk - enqueue_l2_tasks(imanager, fn) - - if ocdbt_edges: - logger.note("All tasks queued. Keep this alive for ocdbt coordinator server.") - while True: - sleep(60) + enqueue_l2_tasks(imanager, upgrade_atomic_chunk) @upgrade_cli.command("layer") @@ -153,13 +111,14 @@ def queue_layer(parent_layer: int, splits: int = 0): @upgrade_cli.command("status") +@click.option("--refresh", type=int, default=5, help="Seconds between redis polls.") @job_type_guard(group_name) -def upgrade_status(): +def upgrade_status(refresh: int): """Print upgrade status to console.""" redis = get_redis_connection() try: imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - print_status(imanager, redis, upgrade=True) + print_status(imanager, redis, upgrade=True, refresh_seconds=refresh) except TypeError as err: print(f"\nNo current `{group_name}` job found in redis: {err}") @@ -170,23 +129,7 @@ def upgrade_status(): @job_type_guard(group_name) def upgrade_chunk(queue: str, chunk_info): """Manually queue chunk when a job is stuck for whatever reason.""" - redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layer, coords = chunk_info[0], chunk_info[1:] - - func = upgrade_parent_chunk - args = (layer, coords) - if layer == 2: - func = upgrade_atomic_chunk - args = (coords,) - queue = imanager.get_task_queue(queue) - queue.enqueue( - func, - job_id=chunk_id_str(layer, coords), - job_timeout=f"{int(layer * layer)}m", - result_ttl=0, - args=args, - ) + requeue_chunk(queue, chunk_info, upgrade_atomic_chunk, upgrade_parent_chunk) @upgrade_cli.command("rate") diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 4e8149ead..36d111f1a 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -5,19 +5,20 @@ """ from os import environ - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) from time import sleep from typing import Callable, Dict, Iterable, Tuple, Sequence import numpy as np from rq import Queue as RQueue, Retry +from pychunkedgraph import get_logger + +logger = get_logger(__name__) + from .utils import chunk_id_str, get_chunks_not_done, randomize_grid_points from .manager import IngestionManager +from .ocdbt import get_coordinator_address, populate_chunk from .ran_agglomeration import ( get_active_edges, read_raw_edge_data, @@ -27,11 +28,10 @@ from .create.parent_layer import add_parent_chunk from .upgrade.atomic_layer import update_chunk as update_atomic_chunk from .upgrade.parent_layer import update_chunk as update_parent_chunk -from ..graph.edges import EDGE_TYPES, Edges, put_edges +from ..graph.edges import EDGE_TYPES from ..graph import ChunkedGraph, ChunkedGraphMeta -from ..graph.ocdbt import copy_ws_chunk_multiscale, open_base_ocdbt +from ..graph.ocdbt import is_chunk_populated from ..graph.chunks.hierarchy import get_children_chunk_coords -from ..graph.basetypes import NODE_ID from ..io.edges import get_chunk_edges from ..io.components import get_chunk_components from ..utils.redis import keys as r_keys, get_redis_connection @@ -64,18 +64,47 @@ def _post_task_completion( def create_parent_chunk( parent_layer: int, parent_coords: Sequence[int], + mode: str = "full", ) -> None: + """One parent-chunk task. ``mode`` (bound at queue time via partial) + selects which halves run: + ``full`` : OCDBT populate (if eligible) + add_parent_chunk + ``ocdbt`` : only OCDBT populate (skip add_parent_chunk) + ``ingest`` : only add_parent_chunk (skip OCDBT populate) + + ``_post_task_completion`` always runs so the layer's progress tracking + in redis stays consistent. + + OCDBT populate runs FIRST so any failure aborts the task BEFORE graph + mutation; otherwise a half-built graph would force corrupt-state retries. + """ imanager = _get_imanager() - add_parent_chunk( - imanager.cg, - parent_layer, - parent_coords, - get_children_chunk_coords( - imanager.cg_meta, + + do_ocdbt = mode in ("full", "ocdbt") and imanager.is_ocdbt_populate_layer( + parent_layer + ) + do_ingest = mode in ("full", "ingest") + + if do_ocdbt: + ws = imanager.cg.meta.data_source.WATERSHED + if not is_chunk_populated(ws, parent_layer, parent_coords): + address = get_coordinator_address(imanager.redis) + populate_chunk( + imanager, ws, parent_layer, parent_coords, coordinator_address=address + ) + + if do_ingest: + add_parent_chunk( + imanager.cg, parent_layer, parent_coords, - ), - ) + get_children_chunk_coords( + imanager.cg_meta, + parent_layer, + parent_coords, + ), + ) + _post_task_completion(imanager, parent_layer, parent_coords) @@ -146,21 +175,6 @@ def create_atomic_chunk(coords: Sequence[int]): for k, v in chunk_edges_active.items(): logger.debug(f"active_{k}: {len(v)}") - if imanager.ocdbt_seg and imanager.ocdbt_populate_base: - # Populate the shared base OCDBT with precomputed chunks (one-time - # per watershed). Uses the raw base handles, NOT the per-CG fork - # spec — the fork only stores SV-split deltas. - src_list, dst_list, resolutions = open_base_ocdbt( - imanager.cg.meta.data_source.WATERSHED - ) - copy_ws_chunk_multiscale( - src_list, - dst_list, - resolutions, - imanager.cg.meta.graph_config.CHUNK_SIZE, - coords, - imanager.cg.meta.voxel_bounds, - ) _post_task_completion(imanager, 2, coords) @@ -172,47 +186,6 @@ def upgrade_atomic_chunk(coords: Sequence[int]): _post_task_completion(imanager, 2, coords) -def convert_edges_to_ocdbt(coords: Sequence[int]): - """ - Convert edges stored per chunk to ajacency list in the tensorstore ocdbt kv store. - """ - imanager = _get_imanager() - coords = np.array(list(coords), dtype=int) - chunk_edges_all, mapping = _get_atomic_chunk_data(imanager, coords) - - node_ids1 = [] - node_ids2 = [] - affinities = [] - areas = [] - for edges in chunk_edges_all.values(): - node_ids1.extend(edges.node_ids1) - node_ids2.extend(edges.node_ids2) - affinities.extend(edges.affinities) - areas.extend(edges.areas) - - edges = Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) - nodes = np.concatenate( - [edges.node_ids1, edges.node_ids2, np.fromiter(mapping.keys(), dtype=NODE_ID)] - ) - nodes = np.unique(nodes) - - chunk_id = imanager.cg.get_chunk_id(layer=1, x=coords[0], y=coords[1], z=coords[2]) - chunk_ids = imanager.cg.get_chunk_ids_from_node_ids(nodes) - - host = imanager.redis.get("OCDBT_COORDINATOR_HOST").decode() - port = imanager.redis.get("OCDBT_COORDINATOR_PORT").decode() - environ["OCDBT_COORDINATOR_HOST"] = host - environ["OCDBT_COORDINATOR_PORT"] = port - logger.note(f"OCDBT Coordinator address {host}:{port}") - - put_edges( - f"{imanager.cg.meta.data_source.EDGES}/ocdbt", - nodes[chunk_ids == chunk_id], - edges, - ) - _post_task_completion(imanager, 2, coords) - - def _get_test_chunks(meta: ChunkedGraphMeta): """Chunks at the center most likely not to be empty""" parent_coords = np.array(meta.layer_chunk_bounds[3]) // 2 diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py index 30043710d..69ea8a709 100644 --- a/pychunkedgraph/ingest/create/atomic_layer.py +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -11,6 +11,8 @@ import numpy as np +from pychunkedgraph import get_logger + from ...graph import attributes, basetypes, serializers, get_valid_timestamp from ...graph.chunkedgraph import ChunkedGraph from ...graph.edges import Edges @@ -19,6 +21,8 @@ from ...graph.utils.flatgraph import build_gt_graph from ...graph.utils.flatgraph import connected_components +logger = get_logger(__name__) + def add_atomic_chunk( cg: ChunkedGraph, @@ -28,6 +32,10 @@ def add_atomic_chunk( time_stamp: Optional[datetime.datetime] = None, ): chunk_node_ids, chunk_edge_ids = _get_chunk_nodes_and_edges(chunk_edges_d, isolated) + logger.note( + f"L2 chunk {tuple(coords)}: nodes={len(chunk_node_ids):,} " + f"edges={len(chunk_edge_ids):,}" + ) if not chunk_node_ids.size: return diff --git a/pychunkedgraph/ingest/create/parent_layer.py b/pychunkedgraph/ingest/create/parent_layer.py index a12d2b858..1d36f9edd 100644 --- a/pychunkedgraph/ingest/create/parent_layer.py +++ b/pychunkedgraph/ingest/create/parent_layer.py @@ -12,6 +12,9 @@ import fastremap import numpy as np + +from pychunkedgraph import get_logger + from ...graph import types, attributes, basetypes, serializers, get_valid_timestamp from ...utils.general import chunked from ...graph.utils import flatgraph @@ -22,6 +25,8 @@ from .cross_edges import get_children_chunk_cross_edges from .cross_edges import get_chunk_nodes_cross_edge_layer +logger = get_logger(__name__) + def add_parent_chunk( cg: ChunkedGraph, @@ -50,6 +55,11 @@ def add_parent_chunk( raw_ccs = flatgraph.connected_components(graph) # connected components with indices connected_components = [graph_ids[cc] for cc in raw_ccs] + logger.note( + f"L{layer_id} chunk {tuple(coords)}: nodes={len(connected_components):,} " + f"cx_edges={len(cx_edges):,}" + ) + _write_connected_components( cg, layer_id, diff --git a/pychunkedgraph/ingest/manager.py b/pychunkedgraph/ingest/manager.py index 915538320..566558e05 100644 --- a/pychunkedgraph/ingest/manager.py +++ b/pychunkedgraph/ingest/manager.py @@ -15,8 +15,7 @@ def __init__( self, config: IngestConfig, chunkedgraph_meta: ChunkedGraphMeta, - ocdbt_seg: bool = False, - ocdbt_populate_base: bool = False, + ocdbt_config: dict = None, _from_pickle: bool = False, ): self._config = config @@ -25,8 +24,7 @@ def __init__( self._redis = None self._task_queues = {} self._from_pickle = _from_pickle - self.ocdbt_seg = ocdbt_seg - self.ocdbt_populate_base = ocdbt_populate_base + self.ocdbt_config = ocdbt_config or {} if not _from_pickle: # initiate redis and store serialized state @@ -55,12 +53,34 @@ def redis(self): self._redis.set(r_keys.INGESTION_MANAGER, self.serialized(pickled=True)) return self._redis + @property + def ocdbt_seg(self) -> bool: + return bool(self.ocdbt_config.get("enabled")) + + @property + def ocdbt_populate_base(self) -> bool: + return bool(self.ocdbt_config.get("populate_base")) + + @property + def ocdbt_populate_layer(self) -> int: + return int(self.ocdbt_config.get("populate_layer", 3)) + + def is_ocdbt_populate_layer(self, layer: int) -> bool: + """True iff OCDBT is enabled, base-populate is on, AND the given + layer matches the configured populate layer. Single guard for any + code that branches on 'should this layer touch OCDBT?'. + """ + return ( + self.ocdbt_seg + and self.ocdbt_populate_base + and layer == self.ocdbt_populate_layer + ) + def serialized(self, pickled=False): params = { "config": self._config, "chunkedgraph_meta": self._chunkedgraph_meta, - "ocdbt_seg": self.ocdbt_seg, - "ocdbt_populate_base": self.ocdbt_populate_base, + "ocdbt_config": self.ocdbt_config, } if pickled: return pickle.dumps(params) diff --git a/pychunkedgraph/ingest/ocdbt.py b/pychunkedgraph/ingest/ocdbt.py new file mode 100644 index 000000000..a8b10666f --- /dev/null +++ b/pychunkedgraph/ingest/ocdbt.py @@ -0,0 +1,128 @@ +"""OCDBT-specific ingest helpers. + +Single home for everything OCDBT-related at the ingest layer: + * coordinator-server lifecycle (`coordinator`) + * per-chunk populate task (`populate_chunk`), used from `create_parent_chunk` + * shared base setup (`setup_base`), used by both ingest and upgrade CLIs +""" + +from contextlib import contextmanager +from os import environ + +import tensorstore as ts + +from pychunkedgraph import get_logger + +from ..graph.ocdbt import ( + OcdbtConfig, + _layer_bbox, + base_exists, + copy_ws_bbox_multiscale, + create_base_ocdbt, + fork_base_manifest, + mark_chunk_populated, + open_base_ocdbt, + read_populate_meta, + write_populate_meta, +) + +logger = get_logger(__name__) + +_COORD_HOST_KEY = "OCDBT_COORDINATOR_HOST" +_COORD_PORT_KEY = "OCDBT_COORDINATOR_PORT" + + +@contextmanager +def coordinator(redis): + """Start a ``DistributedCoordinatorServer`` and advertise its address in + Redis so parallel populate workers route every OCDBT commit through this + one server — no manifest-CAS races, no orphan ``d/`` files. + + The server lives as long as the ``with`` block does; on exit the Redis + advertisement is cleared so a stale address can't outlive the server. + Caller blocks inside the ``with`` body (e.g. ``while True: sleep(60)``) + to keep the server reference alive across the populate phase. + """ + server = ts.ocdbt.DistributedCoordinatorServer() + host = environ.get("MY_POD_IP", "localhost") + redis.set(_COORD_HOST_KEY, host) + redis.set(_COORD_PORT_KEY, str(server.port)) + logger.note(f"OCDBT Coordinator listening at {host}:{server.port}") + try: + yield server + finally: + redis.delete(_COORD_HOST_KEY, _COORD_PORT_KEY) + logger.note("OCDBT Coordinator advertisement cleared.") + + +def get_coordinator_address(redis) -> str: + """Return the advertised ``"host:port"`` for the OCDBT coordinator. + + The address goes into the OCDBT kvstore spec's ``coordinator`` field — + the only routing knob tensorstore actually honors (verified against + the tensorstore binary; ``OCDBT_COORDINATOR_HOST/PORT`` env vars are + not consulted). + + Distributed callers MUST go through this getter so the populate fails + loudly when the coordinator isn't advertised — uncoordinated parallel + commits race the shared manifest and leak orphan ``d/`` files, the + exact bug this code exists to prevent. + """ + host = redis.get(_COORD_HOST_KEY) + port = redis.get(_COORD_PORT_KEY) + if not host or not port: + raise RuntimeError( + "OCDBT coordinator address not advertised in Redis " + f"({_COORD_HOST_KEY}/{_COORD_PORT_KEY} unset). " + "Run `flask ingest layer N` (with N == ocdbt_populate_layer) to " + "start the coordinator before queuing populate workers." + ) + return f"{host.decode()}:{port.decode()}" + + +def populate_chunk( + imanager, ws: str, layer: int, coords, coordinator_address: str | None = None +) -> None: + """One LN parent-layer task's OCDBT populate. + + When ``coordinator_address`` is set, every commit routes through that + server (mandatory for distributed workers — see ``get_coordinator_address``). + Single-process callers (notebooks, local one-off runs) can omit it and + write directly; safe as long as no other writer is committing concurrently. + + Copies the base-resolution bbox at every scale under one atomic + transaction and records the per-chunk completion marker. + """ + cfg = OcdbtConfig.from_dict(imanager.ocdbt_config) + src_list, dst_list, resolutions = open_base_ocdbt( + ws, cfg, coordinator_address=coordinator_address + ) + lo, hi = _layer_bbox(imanager.cg.meta, layer, coords) + coord_str = "_".join(str(int(c)) for c in coords) + dump_tag = f"{imanager.cg.meta.graph_id}/L{layer}/{coord_str}" + logger.note(f"L{layer} OCDBT populate {tuple(int(c) for c in coords)}") + copy_ws_bbox_multiscale(src_list, dst_list, resolutions, lo, hi, dump_tag=dump_tag) + mark_chunk_populated(ws, layer, coords) + + +def setup_base(cg, ocdbt_cfg: OcdbtConfig) -> OcdbtConfig: + """Idempotent OCDBT base + fork setup, shared by ingest and upgrade. + + Creates the base if missing; reconciles the yaml/CLI-supplied config + with the on-disk populate_meta (info-file wins per + ``OcdbtConfig.resolve``); persists the resolved config to + ``cg.meta.custom_data["ocdbt_config"]``; forks the manifest for this + CG. Returns the resolved OcdbtConfig. To wipe and start over, use + ``gcloud storage rm -r gs:///ocdbt/`` before invoking. + """ + ws = cg.meta.data_source.WATERSHED + if not base_exists(ws): + create_base_ocdbt(ws, ocdbt_cfg) + info = read_populate_meta(ws) + resolved = OcdbtConfig.resolve(ocdbt_cfg.to_dict(), info) + if resolved.populate_base: + write_populate_meta(ws, resolved.to_dict()) + cg.meta.custom_data["ocdbt_config"] = resolved.to_dict() + cg.update_meta(cg.meta, overwrite=True) + fork_base_manifest(ws, cg.meta.graph_id) + return resolved diff --git a/pychunkedgraph/ingest/utils.py b/pychunkedgraph/ingest/utils.py index d69756104..fbd55dedc 100644 --- a/pychunkedgraph/ingest/utils.py +++ b/pychunkedgraph/ingest/utils.py @@ -1,29 +1,45 @@ # pylint: disable=invalid-name, missing-docstring import functools - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) -import math, random, sys +import math +import sys from os import environ from time import sleep -from typing import Any, Generator, Tuple +from typing import Dict, Generator, Tuple import numpy as np -import tensorstore as ts -from rq import Queue, Retry, Worker -from rq.worker import WorkerStatus +from kvdbclient import BigTableConfig, HBaseConfig +from rich import box +from rich.console import Group +from rich.live import Live +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich.text import Text +from rq import Queue, Retry +from rq.registry import ( + CanceledJobRegistry, + DeferredJobRegistry, + FailedJobRegistry, + FinishedJobRegistry, + ScheduledJobRegistry, + StartedJobRegistry, +) +from rq.worker_registration import WORKERS_BY_QUEUE_KEY + +from pychunkedgraph import get_logger from . import IngestConfig from .manager import IngestionManager -from ..graph.meta import ChunkedGraphMeta, DataSource, GraphConfig from ..graph import BackendClientInfo -from kvdbclient import BigTableConfig, HBaseConfig +from ..graph.meta import ChunkedGraphMeta, DataSource, GraphConfig +from ..graph.ocdbt import OcdbtConfig from ..utils.general import chunked from ..utils.redis import get_redis_connection from ..utils.redis import keys as r_keys +logger = get_logger(__name__) + chunk_id_str = lambda layer, coords: f"{layer}_{'_'.join(map(str, coords))}" @@ -32,8 +48,13 @@ def bootstrap( config: dict, raw: bool = False, test_run: bool = False, -) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo]: - """Parse config loaded from a yaml file.""" +) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo, Dict]: + """Parse config loaded from a yaml file. + + Returns ``(meta, ingest_config, client_info, ocdbt_config_dict)`` where the + ocdbt config dict is sanitized through ``OcdbtConfig.from_dict(...).to_dict()`` + so unknown yaml keys are dropped and missing fields take dataclass defaults. + """ ingest_config = IngestConfig( **config.get("ingest_config", {}), USE_RAW_EDGES=raw, @@ -55,7 +76,8 @@ def bootstrap( data_source = DataSource(**config["data_source"]) meta = ChunkedGraphMeta(graph_config, data_source) - return (meta, ingest_config, client_info) + ocdbt_config_dict = OcdbtConfig.from_dict(config.get("ocdbt_config")).to_dict() + return (meta, ingest_config, client_info, ocdbt_config_dict) def move_up(lines: int = 1): @@ -95,16 +117,6 @@ def postprocess_edge_data(im, edge_dict): raise ValueError(f"Unknown data_version: {data_version}") -def start_ocdbt_server(imanager: IngestionManager, server: Any): - spec = {"driver": "ocdbt", "base": f"{imanager.cg.meta.data_source.EDGES}/ocdbt"} - spec["coordinator"] = {"address": f"localhost:{server.port}"} - ts.KvStore.open(spec).result() - imanager.redis.set("OCDBT_COORDINATOR_PORT", str(server.port)) - ocdbt_host = environ.get("MY_POD_IP", "localhost") - imanager.redis.set("OCDBT_COORDINATOR_HOST", ocdbt_host) - logger.note(f"OCDBT Coordinator address {ocdbt_host}:{server.port}") - - def randomize_grid_points(X: int, Y: int, Z: int) -> Generator[int, int, int]: indices = np.arange(X * Y * Z) np.random.shuffle(indices) @@ -148,64 +160,237 @@ def print_completion_rate(imanager: IngestionManager, layer: int, span: int = 30 move_up() -def print_status(imanager: IngestionManager, redis, upgrade: bool = False): +def _workers_busy_per_queue(redis, worker_keys_per_layer): + """For each layer's set of worker keys, return parallel (workers, busy) + string lists — "-" / "-" when no workers are registered for that layer. + + Two-round-trip approach: caller already fetched the SMEMBERS sets; this + function pipelines HGET state for every worker key and counts busy. + """ + state_pipe = redis.pipeline() + for keys in worker_keys_per_layer: + for wk in keys: + state_pipe.hget(wk, "state") + states = state_pipe.execute() if any(worker_keys_per_layer) else [] + + workers, busy = [], [] + idx = 0 + for keys in worker_keys_per_layer: + total = len(keys) + b = 0 + for _ in keys: + if states[idx] == b"busy": + b += 1 + idx += 1 + workers.append(f"{total}" if total else "-") + busy.append(f"{b}" if total else "-") + return workers, busy + + +def _layer_keys(layers) -> list: + """Stable per-layer redis keys (completed-set, queue list, failed zset, workers set). + + Returned once before the refresh loop so each refresh skips Queue / + FailedJobRegistry construction and the lazy rq.registry import. """ - Helper to print status to console. + return [ + ( + f"{layer}c", + f"rq:queue:l{layer}", + f"rq:failed:l{layer}", + WORKERS_BY_QUEUE_KEY % f"l{layer}", + ) + for layer in layers + ] + + +def _layer_status(redis, layer_keys): + """Pipelined fetch of job_type + per-layer counts + busy-worker ratios.""" + pipeline = redis.pipeline() + pipeline.get(r_keys.JOB_TYPE) + for completed_key, queue_key, failed_key, workers_key in layer_keys: + pipeline.scard(completed_key) + pipeline.llen(queue_key) + pipeline.zcard(failed_key) + pipeline.smembers(workers_key) + results = pipeline.execute() + + job_type = results[0].decode() if results[0] else "not_available" + completed, queued, failed, worker_keys_per_layer = [], [], [], [] + for i in range(1, len(results), 4): + completed.append(results[i]) + queued.append(results[i + 1]) + failed.append(results[i + 2]) + worker_keys_per_layer.append(results[i + 3]) + + workers, busy = _workers_busy_per_queue(redis, worker_keys_per_layer) + return job_type, completed, queued, failed, workers, busy + + +def _sized_table(columns: list, rows: list, **table_kwargs) -> Table: + """Build a Rich Table whose column widths are sized to the actual data. + + `columns` is a list of (name, justify) tuples. + `rows` is a list of tuples of cell strings (one per column). + Each column gets width = max(len(name), max(len(cell)) over rows) so Rich + never wraps or crops because no column is implicitly squeezed. + """ + table = Table( + box=None, + pad_edge=False, + padding=(0, 2), + show_header=True, + header_style="bold", + **table_kwargs, + ) + for col_idx, (name, justify) in enumerate(columns): + width = max(len(name), max((len(row[col_idx]) for row in rows), default=0)) + # Header wrapped in Text so any brackets in `name` render literally + # rather than being parsed as Rich markup tags. + table.add_column( + Text(name, style="bold"), justify=justify, width=width, no_wrap=True + ) + for row in rows: + table.add_row(*row) + return table + + +def _aligned_kv_table(pairs: list, widths: list) -> Table: + """One-data-row mini-table with externally-provided per-column widths.""" + table = Table( + box=None, pad_edge=False, padding=(0, 1), show_header=True, header_style="bold" + ) + for (name, _), w in zip(pairs, widths): + table.add_column(name, justify="left", width=w, no_wrap=True) + table.add_row(*(v for _, v in pairs)) + return table + + +def _header_renderables(imanager: IngestionManager) -> list: + """Graph and ocdbt rows as mini-tables sharing column widths so columns line up.""" + graph_pairs = [ + ("version", str(imanager.cg.version)), + ("graph_id", imanager.cg.graph_id), + ("chunk_size", str(imanager.cg.meta.graph_config.CHUNK_SIZE)), + ] + ocdbt_pairs = [] + if imanager.ocdbt_seg: + ocdbt_pairs = [ + ("ocdbt", str(imanager.ocdbt_seg)), + ("populate_base", str(imanager.ocdbt_populate_base)), + ("populate_layer", str(imanager.ocdbt_populate_layer)), + ] + + # Per-column width = max length seen in EITHER row's header or value at that index. + n = max(len(graph_pairs), len(ocdbt_pairs)) + widths = [] + for i in range(n): + sizes = [] + if i < len(graph_pairs): + sizes.append(len(graph_pairs[i][0])) + sizes.append(len(graph_pairs[i][1])) + if i < len(ocdbt_pairs): + sizes.append(len(ocdbt_pairs[i][0])) + sizes.append(len(ocdbt_pairs[i][1])) + widths.append(max(sizes)) + + out = [_aligned_kv_table(graph_pairs, widths)] + if ocdbt_pairs: + out.append(Rule(style="dim")) + out.append(_aligned_kv_table(ocdbt_pairs, widths)) + return out + + +def _status_table( + layers, layer_counts, completed, queued, failed, workers, busy +) -> Table: + """One row per layer with progress, queue, and worker stats.""" + columns = [ + ("layer", "center"), + ("queued", "right"), + ("completed", "right"), + ("total", "right"), + ("progress", "right"), + ("failed", "right"), + ("workers", "right"), + ("busy", "right"), + ] + rows = [] + for layer, done, count, q, f, w, b in zip( + layers, completed, layer_counts, queued, failed, workers, busy + ): + pct = math.floor((done / count) * 100) if count else 0 + rows.append( + ( + str(layer), + f"{q:,}", + f"{done:,}", + f"{count:,}", + f"{pct}%", + f"{f:,}", + str(w), + str(b), + ) + ) + return _sized_table(columns, rows) + + +def _status_renderable( + imanager, + layers, + layer_counts, + job_type, + completed, + queued, + failed, + workers, + busy, +): + """Combine header rows + per-layer table inside one Panel; job_type goes in the title.""" + body = Group( + *_header_renderables(imanager), + Rule(style="dim"), + _status_table(layers, layer_counts, completed, queued, failed, workers, busy), + ) + return Panel( + body, + title=job_type, + title_align="left", + box=box.ROUNDED, + padding=(0, 1), + expand=False, + ) + + +def print_status( + imanager: IngestionManager, + redis, + upgrade: bool = False, + refresh_seconds: int = 5, +): + """ + Print status to console. If `upgrade=True`, status does not include the root layer, since there is no need to update cross edges for root ids. + `refresh_seconds` is how often redis is re-polled between redraws. """ layers = range(2, imanager.cg_meta.layer_count + 1) if upgrade: layers = range(2, imanager.cg_meta.layer_count) - - def _refresh_status(): - pipeline = redis.pipeline() - pipeline.get(r_keys.JOB_TYPE) - worker_busy = ["-"] * len(layers) - for layer in layers: - pipeline.scard(f"{layer}c") - queue = Queue(f"l{layer}", connection=redis) - pipeline.llen(queue.key) - pipeline.zcard(queue.failed_job_registry.key) - - results = pipeline.execute() - job_type = "not_available" - if results[0] is not None: - job_type = results[0].decode() - completed = [] - queued = [] - failed = [] - for i in range(1, len(results), 3): - result = results[i : i + 3] - completed.append(result[0]) - queued.append(result[1]) - failed.append(result[2]) - return job_type, completed, queued, failed, worker_busy - - job_type, completed, queued, failed, worker_busy = _refresh_status() - layer_counts = imanager.cg_meta.layer_chunk_counts - header = ( - f"\njob_type: \t{job_type}" - f"\nversion: \t{imanager.cg.version}" - f"\ngraph_id: \t{imanager.cg.graph_id}" - f"\nchunk_size: \t{imanager.cg.meta.graph_config.CHUNK_SIZE}" - "\n\nlayer status:" - ) - print(header) - while True: - for layer, done, count in zip(layers, completed, layer_counts): - print( - f"{layer}\t| {done:9} / {count} \t| {math.floor((done/count)*100):6}%" - ) + layer_keys = _layer_keys(layers) - print("\n\nqueue status:") - for layer, q, f, wb in zip(layers, queued, failed, worker_busy): - print(f"l{layer}\t| queued: {q:<10} failed: {f:<10} busy: {wb}") + def render(): + return _status_renderable( + imanager, layers, layer_counts, *_layer_status(redis, layer_keys) + ) - sleep(1) - _, completed, queued, failed, worker_busy = _refresh_status() - move_up(lines=2 * len(layers) + 3) + # Start Live with a placeholder so the panel paints instantly; the first + # real fetch (which includes redis connection setup) replaces it. + with Live(Text("loading…"), screen=False) as live: + while True: + live.update(render()) + sleep(refresh_seconds) def queue_layer_helper( @@ -267,6 +452,54 @@ def queue_layer_helper( logger.note(f"Queued {len(job_datas)} chunks.") +_RQ_REGISTRY_CLASSES = ( + FailedJobRegistry, + StartedJobRegistry, + DeferredJobRegistry, + ScheduledJobRegistry, + FinishedJobRegistry, + CanceledJobRegistry, +) + + +def purge_layer_state(redis, layer: int) -> None: + """Reset per-layer state so a layer can be re-run from a previous + layer's backup: drop the RQ queue (deletes jobs too), wipe each RQ + registry by its own ``.key`` attribute (so we don't hardcode RQ's + internal key naming), and clear the pychunkedgraph completion set + ``f"{layer}c"``. + """ + name = f"l{layer}" + Queue(name=name, connection=redis).delete(delete_jobs=True) + for cls in _RQ_REGISTRY_CLASSES: + redis.delete(cls(name=name, connection=redis).key) + redis.delete(f"{layer}c") + + +def requeue_chunk(queue_name: str, chunk_info, atomic_fn, parent_fn): + """Body of the ``chunk`` CLI command (shared by ingest and upgrade). + + Loads the manager from Redis, dispatches ``atomic_fn`` for L2 or + ``parent_fn`` for L3+, and enqueues a single task with the standard + job_id / timeout convention. + """ + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + layer, coords = chunk_info[0], chunk_info[1:] + if layer == 2: + fn, args = atomic_fn, (coords,) + else: + fn, args = parent_fn, (layer, coords) + queue = imanager.get_task_queue(queue_name) + queue.enqueue( + fn, + job_id=chunk_id_str(layer, coords), + job_timeout=f"{int(layer * layer)}m", + result_ttl=0, + args=args, + ) + + def job_type_guard(job_type: str): def decorator_job_type_guard(func): @functools.wraps(func) diff --git a/pychunkedgraph/repair/stuck_ops.py b/pychunkedgraph/repair/stuck_ops.py new file mode 100644 index 000000000..4f22b508f --- /dev/null +++ b/pychunkedgraph/repair/stuck_ops.py @@ -0,0 +1,290 @@ +"""Operator recovery for SV-split ops that crashed mid-write. + +A crash inside `IndefiniteL2ChunkLock`'s critical section leaves the +per-chunk `Concurrency.IndefiniteLock` cells set *and* records the +chunk scope on the op-log row's `OperationLogs.L2ChunkLockScope` field. +Ops on other (non-overlapping) chunks continue to succeed and advance +the OCDBT manifest while the stuck op sits there blocking its own +chunks. + +Recovery = cleanup + replay. The cleanup step reverts the stuck op's +partial OCDBT writes by copying pre-op voxel values (read from a +version-pinned OCDBT handle at the op's `OperationTimeStamp`) back to +the latest manifest. The replay then runs the op normally via the +existing `repair.edits.repair_operation` path — reads latest (clean on +the stuck op's chunks, current on everyone else's), writes fresh SV +IDs, and `IndefiniteL2ChunkLock`'s privileged-mode exit deletes the +crashed op's pre-existing cells. + +See `docs/sv_splitting_recovery.md` for the full architecture and +correctness argument. +""" + +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timedelta, timezone + +import numpy as np + +from pychunkedgraph import get_logger +from pychunkedgraph.graph import ChunkedGraph, attributes +from pychunkedgraph.graph.chunks.utils import get_chunk_coordinates +from pychunkedgraph.graph.locks import _l2_chunk_lock_row_key +from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt +from pychunkedgraph.repair.edits import repair_operation + +logger = get_logger(__name__) + + +def _operation_ts_to_pin(operation_ts: datetime) -> str: + """Convert an op-log `OperationTimeStamp` to the OCDBT `version` + string format — ISO-8601 UTC with `Z` suffix, microsecond + precision. OCDBT's binder rejects `+00:00`. + """ + if operation_ts.tzinfo is None: + operation_ts = operation_ts.replace(tzinfo=timezone.utc) + else: + operation_ts = operation_ts.astimezone(timezone.utc) + return operation_ts.isoformat().replace("+00:00", "Z") + + +def _chunk_voxel_slices(cg: ChunkedGraph, chunk_id: int) -> tuple: + """Voxel-space slice tuple covering one L2 chunk, clipped to volume bounds.""" + coords = get_chunk_coordinates(cg.meta, np.uint64(chunk_id)) + chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) + voxel_bounds = cg.meta.voxel_bounds + lo = coords * chunk_size + voxel_bounds[:, 0] + hi = np.minimum(lo + chunk_size, voxel_bounds[:, 1]) + return tuple(slice(int(s), int(e)) for s, e in zip(lo, hi)) + + +def list_stuck(cg: ChunkedGraph, min_age: timedelta = timedelta(minutes=10)) -> list: + """Return op-log entries whose `L2ChunkLockScope` is set past `min_age`, + excluding successfully-completed ops. + + The authoritative signal for a stuck op is "scope recorded" — + `IndefiniteL2ChunkLock.__enter__` writes it before any seg/bigtable + write and its clean `__exit__` clears it. An op whose scope is + still populated is either a worker crash (Status=CREATED, Fix 1's + `__exit__` short-circuit never ran) or an exception during the + persist block (Status=EXCEPTION, Fix 1 held the cells on the way + out). Either way it's still holding `Concurrency.IndefiniteLock` + cells on the listed chunks and blocking any new op that overlaps. + + Ops that reach `SUCCESS` normally have scope cleared — we defensively + filter them out in case `_clear_scope_on_op_log`'s best-effort write + failed and logged. Failed ops that never touched the persist block + (e.g. a PreconditionError from multicut) have no scope and don't + show up here; they're not blocking anything. + """ + now = datetime.now(timezone.utc) + cutoff = now - min_age + entries = cg.client.read_log_entries() + stuck = [] + success_code = attributes.OperationLogs.StatusCodes.SUCCESS.value + for op_id, entry in entries.items(): + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + continue + if entry.get(attributes.OperationLogs.Status) == success_code: + continue + op_ts = entry.get(attributes.OperationLogs.OperationTimeStamp) + if op_ts is None: + continue + if op_ts.tzinfo is None: + op_ts = op_ts.replace(tzinfo=timezone.utc) + if op_ts > cutoff: + continue + stuck.append( + { + "op_id": int(op_id), + "operation_ts": op_ts, + "age": now - op_ts, + "user_id": entry.get(attributes.OperationLogs.UserID), + "l2_chunk_scope": scope, + "status": entry.get(attributes.OperationLogs.Status), + } + ) + stuck.sort(key=lambda r: r["op_id"]) + return stuck + + +def cleanup_partial_writes(cg: ChunkedGraph, op_id: int) -> int: + """Revert a stuck op's partial OCDBT writes to pre-op voxel values. + + Reads each chunk in the op's `L2ChunkLockScope` through an OCDBT + handle pinned at the op's `OperationTimeStamp` (which pre-dates any + of its commits), then writes those pre-op values back to the latest + manifest. Overwrites the crashed op's partial seg writes at the + same chunk keys; neighbor chunks are untouched, preserving any + concurrent ops' updates. + + Returns the number of chunks rewritten. + """ + log_entries = cg.client.read_log_entries(operation_ids=[np.uint64(op_id)]) + if not log_entries: + raise ValueError(f"No op-log row for op_id={op_id}") + entry = log_entries[np.uint64(op_id)] + + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + logger.info(f"op {op_id} has no L2ChunkLockScope — nothing to clean up") + return 0 + + operation_ts = entry.get(attributes.OperationLogs.OperationTimeStamp) + if operation_ts is None: + raise ValueError(f"op {op_id} has no OperationTimeStamp") + pin_str = _operation_ts_to_pin(operation_ts) + + # Pinned read handle (read-only at pre-op version) vs. unpinned + # write handle (latest). Tensorstore refuses writes on version-pinned + # kvstores, so the two paths use separate handles. + _, pinned_scales, _ = get_seg_source_and_destination_ocdbt( + cg.meta.data_source.WATERSHED, + cg.meta.graph_id, + cg.meta.ocdbt_config, + pinned_at=pin_str, + ) + pinned_ws = pinned_scales[0] + latest_ws = cg.meta.ws_ocdbt + + def _revert_chunk(chunk_id: int) -> None: + voxel_slices = _chunk_voxel_slices(cg, int(chunk_id)) + pre_op = pinned_ws[voxel_slices + (slice(None),)].read().result() + latest_ws[voxel_slices + (slice(None),)].write(pre_op).result() + + # Parallel read-then-write per chunk. Bounded pool so large scopes + # don't saturate tensorstore's internal concurrency. + max_workers = min(16, max(1, len(scope))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_revert_chunk, int(c)) for c in scope] + for future in as_completed(futures): + future.result() + + logger.info(f"op {op_id}: reverted {len(scope)} partial chunk writes") + return len(scope) + + +def _verify_indefinite_cells(cg: ChunkedGraph, op_id: int, scope) -> list: + """Check each chunk in `scope` actually has `Concurrency.IndefiniteLock` + held by `op_id`. Returns the list of chunk IDs whose cell is missing + or held by a different op_id — an empty list means everything is + consistent. + + Guards `replay` against acting on a stale scope: if cells aren't + actually held (operator already ran replay, manual intervention, + any bug that released cells without clearing scope), `cleanup_ + partial_writes` would revert chunks that another op may have + legitimately written to in the meantime. Refusing loudly is safer + than assuming. + """ + lock_column = attributes.Concurrency.IndefiniteLock + expected = np.uint64(op_id) + discrepancies = [] + for chunk_id in scope: + row_key = _l2_chunk_lock_row_key(int(chunk_id)) + cells = cg.client._read_byte_row(row_key, columns=lock_column) + if not cells: + discrepancies.append(int(chunk_id)) + continue + held_by = cells[0].value if hasattr(cells[0], "value") else None + if held_by != expected: + discrepancies.append(int(chunk_id)) + return discrepancies + + +def replay(cg: ChunkedGraph, op_id: int): + """Recovery: verify locks, clean up partial OCDBT writes, then run + the op normally. + + Before any destructive step, read back the per-chunk + `Concurrency.IndefiniteLock` cells listed in the op's + `L2ChunkLockScope` and confirm they're still held by `op_id`. If + any are missing or held by another op, raise and do nothing — + proceeding would have `cleanup_partial_writes` revert chunks we + don't actually own. + + On clean verification, `cleanup_partial_writes` reverts the op's + partial OCDBT writes, then `repair.edits.repair_operation` reruns + `operation.execute(..., privileged_mode=True, parent_ts=)`. `IndefiniteL2ChunkLock.__enter__` in privileged mode + populates `acquired_keys` from the scope so `__exit__` releases the + crashed op's pre-existing indefinite cells after the replay writes + land. + """ + log_entries = cg.client.read_log_entries(operation_ids=[np.uint64(op_id)]) + if not log_entries: + raise ValueError(f"No op-log row for op_id={op_id}") + entry = log_entries[np.uint64(op_id)] + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + raise RuntimeError( + f"op {op_id} has no L2ChunkLockScope — not a stuck SV-split op. " + "If the op failed cleanly, the client should re-submit under a " + "fresh op_id rather than replay." + ) + + mismatched = _verify_indefinite_cells(cg, op_id, scope) + if mismatched: + raise RuntimeError( + f"op {op_id}: L2ChunkLockScope lists chunks {[int(c) for c in scope]}, " + f"but the following chunks do not have Concurrency.IndefiniteLock " + f"held by op_id={op_id}: {mismatched}. Refusing to replay — the " + "recorded scope disagrees with live lock state. Possible causes: " + "replay already ran, cells were manually cleared, or a different " + "op acquired these chunks. Investigate before retrying." + ) + + cleanup_partial_writes(cg, op_id) + return repair_operation(cg, op_id, unlock=True) + + +def _main(): + parser = argparse.ArgumentParser( + description="Recover stuck SV-split operations via cleanup + replay." + ) + sub = parser.add_subparsers(dest="cmd", required=True) + + p_list = sub.add_parser( + "list", + help="List stuck ops (L2ChunkLockScope still populated past min-age).", + ) + p_list.add_argument("--graph", required=True, help="Graph ID.") + p_list.add_argument( + "--min-age", + type=int, + default=10, + help="Minimum age in minutes before an op is considered stuck (default: 10).", + ) + + p_replay = sub.add_parser( + "replay", help="Clean up partial writes and replay a stuck op." + ) + p_replay.add_argument("--graph", required=True, help="Graph ID.") + p_replay.add_argument("--op-id", type=int, required=True, help="Op ID to replay.") + + args = parser.parse_args() + cg = ChunkedGraph(graph_id=args.graph) + + if args.cmd == "list": + stuck = list_stuck(cg, min_age=timedelta(minutes=args.min_age)) + if not stuck: + print("No stuck ops.") + return + for row in stuck: + scope_size = ( + len(row["l2_chunk_scope"]) if row["l2_chunk_scope"] is not None else 0 + ) + print( + f"op {row['op_id']}: user={row['user_id']} " + f"ts={row['operation_ts'].isoformat()} " + f"age={row['age']} " + f"l2_chunks={scope_size}" + ) + elif args.cmd == "replay": + result = replay(cg, args.op_id) + print(f"replay complete: {result}") + + +if __name__ == "__main__": + _main() diff --git a/pychunkedgraph/tests/graph/test_cutting.py b/pychunkedgraph/tests/graph/test_cutting.py index 89cf4969d..4411d876c 100644 --- a/pychunkedgraph/tests/graph/test_cutting.py +++ b/pychunkedgraph/tests/graph/test_cutting.py @@ -4,8 +4,10 @@ import pytest from pychunkedgraph.graph.cutting import ( + Cut, IsolatingCutException, LocalMincutGraph, + PreviewCut, merge_cross_chunk_edges_graph_tool, run_multicut, ) @@ -336,8 +338,9 @@ def test_basic_split(self): path_augment=True, disallow_isolating_cut=False, ) - assert len(result) > 0 - result_set = set(map(tuple, result)) + assert isinstance(result, Cut) + assert len(result.atomic_edges) > 0 + result_set = set(map(tuple, result.atomic_edges)) assert (2, 3) in result_set or (3, 2) in result_set def test_basic_split_direct(self): @@ -354,8 +357,9 @@ def test_basic_split_direct(self): path_augment=False, disallow_isolating_cut=False, ) - assert len(result) > 0 - result_set = set(map(tuple, result)) + assert isinstance(result, Cut) + assert len(result.atomic_edges) > 0 + result_set = set(map(tuple, result.atomic_edges)) assert (2, 3) in result_set or (3, 2) in result_set def test_no_edges_raises(self): @@ -377,7 +381,7 @@ def test_no_edges_raises(self): ) def test_split_preview_mode(self): - """run_multicut with split_preview=True returns (ccs, illegal_split).""" + """run_multicut with split_preview=True returns a PreviewCut.""" node_ids1 = np.array([1, 2, 3], dtype=np.uint64) node_ids2 = np.array([2, 3, 4], dtype=np.uint64) affinities = np.array([0.9, 0.05, 0.9], dtype=np.float32) @@ -391,10 +395,10 @@ def test_split_preview_mode(self): path_augment=False, disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert isinstance(supervoxel_ccs, list) - assert len(supervoxel_ccs) >= 2 - assert isinstance(illegal_split, bool) + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert len(result.supervoxel_ccs) >= 2 + assert isinstance(result.illegal_split, bool) class TestMergeCrossChunkEdgesOverlap: @@ -641,7 +645,7 @@ class TestRunMulticutSplitPreview: """Test run_multicut in split_preview mode returns correct structure.""" def test_split_preview_returns_ccs_and_flag(self): - """run_multicut with split_preview=True should return (ccs, illegal_split).""" + """run_multicut with split_preview=True should return a PreviewCut.""" node_ids1 = np.array([1, 2, 3], dtype=np.uint64) node_ids2 = np.array([2, 3, 4], dtype=np.uint64) affinities = np.array([0.9, 0.01, 0.9], dtype=np.float32) @@ -656,15 +660,15 @@ def test_split_preview_returns_ccs_and_flag(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert isinstance(supervoxel_ccs, list) - assert len(supervoxel_ccs) >= 2 - assert isinstance(illegal_split, bool) + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert len(result.supervoxel_ccs) >= 2 + assert isinstance(result.illegal_split, bool) # Source side CC - assert 1 in supervoxel_ccs[0] + assert 1 in result.supervoxel_ccs[0] # Sink side CC - assert 4 in supervoxel_ccs[1] + assert 4 in result.supervoxel_ccs[1] def test_split_preview_with_path_augment(self): """run_multicut with split_preview=True and path_augment=True.""" @@ -682,12 +686,12 @@ def test_split_preview_with_path_augment(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert len(supervoxel_ccs) >= 2 + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 # Source side - assert 1 in supervoxel_ccs[0] + assert 1 in result.supervoxel_ccs[0] # Sink side - assert 5 in supervoxel_ccs[1] + assert 5 in result.supervoxel_ccs[1] def test_split_preview_larger_graph(self): """split_preview on a larger graph with a clear cut point.""" @@ -709,14 +713,14 @@ def test_split_preview_larger_graph(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - source_cc = set(supervoxel_ccs[0]) - sink_cc = set(supervoxel_ccs[1]) + assert isinstance(result, PreviewCut) + source_cc = set(result.supervoxel_ccs[0]) + sink_cc = set(result.supervoxel_ccs[1]) # Source cluster assert {1, 2, 3}.issubset(source_cc) # Sink cluster assert {4, 5, 6}.issubset(sink_cc) - assert not illegal_split + assert not result.illegal_split class TestLocalMincutGraphWithLogger: @@ -1040,7 +1044,7 @@ class TestRunSplitPreview: """ def test_basic_split_preview(self): - """run_multicut with split_preview should return CCs and a flag.""" + """run_multicut with split_preview should return a PreviewCut.""" edges_sv = Edges( np.array([1, 2, 3, 4], dtype=np.uint64), np.array([2, 3, 4, 5], dtype=np.uint64), @@ -1049,16 +1053,17 @@ def test_basic_split_preview(self): ) sources = np.array([1], dtype=np.uint64) sinks = np.array([5], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, split_preview=True, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert isinstance(illegal_split, bool) - assert len(ccs) >= 2 + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert isinstance(result.illegal_split, bool) + assert len(result.supervoxel_ccs) >= 2 def test_split_preview_with_areas(self): """Split preview with areas provided.""" @@ -1070,7 +1075,7 @@ def test_split_preview_with_areas(self): ) sources = np.array([10], dtype=np.uint64) sinks = np.array([40], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, @@ -1078,12 +1083,10 @@ def test_split_preview_with_areas(self): path_augment=False, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert len(ccs) >= 2 - # Source side should contain 10 - assert 10 in ccs[0] - # Sink side should contain 40 - assert 40 in ccs[1] + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 + assert 10 in result.supervoxel_ccs[0] + assert 40 in result.supervoxel_ccs[1] def test_split_preview_path_augment(self): """Split preview with path_augment=True.""" @@ -1094,7 +1097,7 @@ def test_split_preview_path_augment(self): ) sources = np.array([1], dtype=np.uint64) sinks = np.array([6], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, @@ -1102,11 +1105,11 @@ def test_split_preview_path_augment(self): path_augment=True, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert len(ccs) >= 2 - assert 1 in ccs[0] - assert 6 in ccs[1] - assert not illegal_split + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 + assert 1 in result.supervoxel_ccs[0] + assert 6 in result.supervoxel_ccs[1] + assert not result.illegal_split class TestFilterGraphCCsWithLogger: diff --git a/pychunkedgraph/tests/graph/test_downsample.py b/pychunkedgraph/tests/graph/test_downsample.py new file mode 100644 index 000000000..2eb799334 --- /dev/null +++ b/pychunkedgraph/tests/graph/test_downsample.py @@ -0,0 +1,309 @@ +"""Tests for pychunkedgraph.graph.downsample.""" + +import shutil +import tempfile +import threading +import time +from types import SimpleNamespace + +import numpy as np +import pytest +import tensorstore as ts + +from pychunkedgraph.graph import downsample as ds +from pychunkedgraph.graph.locks import ( + DownsampleBlockLock, + _downsample_block_lock_row_key, +) +from pychunkedgraph.graph import exceptions +from pychunkedgraph.tests.helpers import ( + RowKeyLockRegistry, + make_cg_with_row_key_lock_registry, +) + + +@pytest.fixture +def local_ocdbt(): + """3-scale file-backed OCDBT store with factor (2,2,1) between scales. + + Matches the fixture in test_ocdbt.py so downsample behaviour can be + exercised end-to-end against real tensorstore handles. + """ + tmpdir = tempfile.mkdtemp() + base = f"file://{tmpdir}/ocdbt/base" + mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + + def mk(size, resolution, extra_mm=None): + spec = { + "driver": "neuroglancer_precomputed", + "kvstore": {"driver": "ocdbt", "base": base}, + "scale_metadata": { + "size": size, + "resolution": resolution, + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], + }, + } + if extra_mm: + spec["multiscale_metadata"] = extra_mm + return ts.open(spec, create=True).result() + + scales = [ + mk([64, 64, 32], [4, 4, 40], extra_mm=mm), + mk([32, 32, 32], [8, 8, 40]), + mk([16, 16, 32], [16, 16, 40]), + ] + resolutions = [[4, 4, 40], [8, 8, 40], [16, 16, 40]] + + yield {"scales": scales, "resolutions": resolutions} + shutil.rmtree(tmpdir) + + +def _make_meta(local_ocdbt_, voxel_bounds=None): + """Minimal ChunkedGraphMeta stand-in with only the attributes downsample reads.""" + scales = local_ocdbt_["scales"] + if voxel_bounds is None: + # Full volume from scale 0. + dom = scales[0].domain + voxel_bounds = np.array( + [ + [dom[0].inclusive_min, dom[0].exclusive_max], + [dom[1].inclusive_min, dom[1].exclusive_max], + [dom[2].inclusive_min, dom[2].exclusive_max], + ], + dtype=int, + ) + return SimpleNamespace( + ws_ocdbt_scales=scales, + ws_ocdbt_resolutions=local_ocdbt_["resolutions"], + voxel_bounds=voxel_bounds, + ) + + +class TestBlockGeometry: + def test_num_output_mips(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + assert ds.num_output_mips(meta) == 2 + + def test_uniform_factor(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + assert ds.uniform_factor(meta) == (2, 2, 1) + + def test_non_uniform_factor_asserts(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + meta.ws_ocdbt_resolutions = [[4, 4, 40], [8, 8, 40], [8, 16, 40]] + with pytest.raises(AssertionError): + ds.uniform_factor(meta) + + def test_block_shape_covers_one_coarsest_chunk(self, local_ocdbt): + # coarsest chunk = 32 mip-2 voxels per axis; factor^2 = (4,4,1). + # Block = 32 * (4,4,1) = (128, 128, 32) base voxels. + meta = _make_meta(local_ocdbt) + assert tuple(ds.block_shape(meta).tolist()) == (128, 128, 32) + + def test_blocks_for_bbox_single(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + # Tiny bbox entirely inside block (0,0,0). + blocks = ds.blocks_for_bbox(meta, [10, 10, 5], [20, 20, 10]) + assert blocks == [(0, 0, 0)] + + def test_blocks_for_bbox_spans_block_boundary(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + # Block shape = (128,128,32). Bbox from (120,0,0) to (200,50,10) + # crosses the x-axis boundary at 128. + blocks = ds.blocks_for_bbox(meta, [120, 0, 0], [200, 50, 10]) + assert blocks == sorted([(0, 0, 0), (1, 0, 0)]) + + def test_block_base_bbox_roundtrip(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + lo, hi = ds.block_base_bbox(meta, (0, 0, 0)) + assert tuple(lo.tolist()) == (0, 0, 0) + assert tuple(hi.tolist()) == (128, 128, 32) + + lo, hi = ds.block_base_bbox(meta, (2, 1, 0)) + assert tuple(lo.tolist()) == (256, 128, 0) + assert tuple(hi.tolist()) == (384, 256, 32) + + +class TestProcessBlockInMemory: + def test_writes_to_every_non_base_scale(self, local_ocdbt): + """Base region intersected by bbox propagates to mip 1 and mip 2.""" + scales = local_ocdbt["scales"] + # Seed base with a constant label. + data = np.full((32, 32, 32), 7, dtype=np.uint64) + scales[0][0:32, 0:32, 0:32, :].write(data[..., np.newaxis]).result() + + meta = _make_meta(local_ocdbt) + # Block (0,0,0) has shape (128,128,32); only its (0..32, 0..32, 0..32) + # subregion has real data — the rest is zeros. + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([32, 32, 32]))] + ) + + mip1 = scales[1][0:16, 0:16, 0:32, :].read().result() + mip2 = scales[2][0:8, 0:8, 0:32, :].read().result() + assert (mip1 == 7).all() + assert (mip2 == 7).all() + + def test_region_outside_bbox_stays_zero(self, local_ocdbt): + """Mip tiles whose base footprint misses the bbox are not written.""" + scales = local_ocdbt["scales"] + # Seed base with 3 inside the edit bbox only. + edit_data = np.full((16, 16, 16), 3, dtype=np.uint64) + scales[0][0:16, 0:16, 0:16, :].write(edit_data[..., np.newaxis]).result() + + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([16, 16, 16]))] + ) + + # Tile inside edit: written with label 3. + mip1_inside = scales[1][0:8, 0:8, 0:16, :].read().result() + assert (mip1_inside == 3).all() + # Tile outside edit (far corner of block): still zero. + mip1_outside = scales[1][12:16, 12:16, 16:32, :].read().result() + assert (mip1_outside == 0).all() + + +class TestProcessBlockDispatcher: + def test_selects_in_memory_when_under_budget(self, local_ocdbt, monkeypatch): + """Typical small affected region → in-memory path.""" + calls = {"in_memory": 0, "per_mip": 0} + monkeypatch.setattr( + ds, + "_process_block_in_memory", + lambda *a, **kw: calls.__setitem__("in_memory", calls["in_memory"] + 1), + ) + monkeypatch.setattr( + ds, + "_process_block_per_mip", + lambda *a, **kw: calls.__setitem__("per_mip", calls["per_mip"] + 1), + ) + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([16, 16, 16]))] + ) + assert calls == {"in_memory": 1, "per_mip": 0} + + def test_selects_per_mip_when_over_budget(self, local_ocdbt, monkeypatch): + """When the base read would exceed budget, the per-mip path runs.""" + calls = {"in_memory": 0, "per_mip": 0} + monkeypatch.setattr( + ds, + "_process_block_in_memory", + lambda *a, **kw: calls.__setitem__("in_memory", calls["in_memory"] + 1), + ) + monkeypatch.setattr( + ds, + "_process_block_per_mip", + lambda *a, **kw: calls.__setitem__("per_mip", calls["per_mip"] + 1), + ) + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, + (0, 0, 0), + [(np.array([0, 0, 0]), np.array([128, 128, 32]))], + memory_budget_bytes=1, # force the fallback + ) + assert calls == {"in_memory": 0, "per_mip": 1} + + +class TestDownsampleBlockRowKey: + def test_length(self): + assert len(_downsample_block_lock_row_key((0, 0, 0))) == 26 + + def test_deterministic(self): + assert _downsample_block_lock_row_key( + (7, 8, 9) + ) == _downsample_block_lock_row_key((7, 8, 9)) + + def test_distinct_coords_distinct_keys(self): + a = _downsample_block_lock_row_key((1, 0, 0)) + b = _downsample_block_lock_row_key((0, 1, 0)) + assert a != b + + def test_hash_prefix_scatters(self): + """Adjacent block coords should not produce adjacent row keys (the whole + point of the hash prefix).""" + # Gather hash prefixes for a line of adjacent coords; they should span + # many distinct first-bytes, not cluster in one byte. + prefixes = {_downsample_block_lock_row_key((i, 0, 0))[0] for i in range(128)} + assert len(prefixes) > 32 + + +class TestDownsampleBlockLock: + def test_acquire_and_release(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + with DownsampleBlockLock(cg, [(0, 0, 0), (1, 0, 0)], np.uint64(42)): + assert len(registry._held) == 2 + assert registry._held == {} + + def test_non_overlapping_concurrent(self): + """Two locks on disjoint block sets can coexist.""" + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + l1 = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(1)) + l2 = DownsampleBlockLock(cg, [(5, 5, 5)], np.uint64(2)) + l1.__enter__() + l2.__enter__() + assert len(registry._held) == 2 + l1.__exit__(None, None, None) + l2.__exit__(None, None, None) + assert registry._held == {} + + def test_overlapping_contends(self, monkeypatch): + """Two overlapping acquisitions serialize: second blocks until first releases.""" + # Short backoff so the waiting thread retries quickly after release. + monkeypatch.setattr(DownsampleBlockLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.05) + + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + + l1 = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(1)) + l1.__enter__() + + second_entered = threading.Event() + second_failed = threading.Event() + + def second(): + lock = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(2)) + try: + lock.__enter__() + second_entered.set() + lock.__exit__(None, None, None) + except exceptions.LockingError: + second_failed.set() + + t = threading.Thread(target=second) + t.start() + time.sleep(0.2) + # l1 is still holding; second should not have entered. + assert not second_entered.is_set() + # Now release; second should succeed on its next retry. + l1.__exit__(None, None, None) + t.join(timeout=2.0) + assert second_entered.is_set() + assert not second_failed.is_set() + assert registry._held == {} + + def test_partial_acquire_released_on_failure(self, monkeypatch): + """If any coord in the set fails to lock, prior ones are released.""" + monkeypatch.setattr(DownsampleBlockLock, "_MAX_ACQUIRE_ATTEMPTS", 2) + monkeypatch.setattr(DownsampleBlockLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.01) + + registry = RowKeyLockRegistry() + # Pre-hold (1,0,0) so the second coord always fails. + registry.lock_by_row_key( + _downsample_block_lock_row_key((1, 0, 0)), np.uint64(99) + ) + + cg = make_cg_with_row_key_lock_registry(registry) + lock = DownsampleBlockLock(cg, [(0, 0, 0), (1, 0, 0)], np.uint64(1)) + with pytest.raises(exceptions.LockingError): + lock.__enter__() + # Only (1,0,0) should remain held, by the pre-existing holder. + assert len(registry._held) == 1 + only_key = next(iter(registry._held)) + assert only_key == _downsample_block_lock_row_key((1, 0, 0)) diff --git a/pychunkedgraph/tests/graph/test_edits_sv.py b/pychunkedgraph/tests/graph/test_edits_sv.py index bced0a070..34a4de109 100644 --- a/pychunkedgraph/tests/graph/test_edits_sv.py +++ b/pychunkedgraph/tests/graph/test_edits_sv.py @@ -6,9 +6,11 @@ from unittest.mock import MagicMock, patch from pychunkedgraph.graph.edits_sv import ( + _coords_bbox, _voxel_crop, _parse_results, copy_parents_and_add_lineage, + plan_sv_splits, ) from pychunkedgraph.graph import attributes, basetypes @@ -237,3 +239,130 @@ def test_operation_id_stored(self): assert val_dict[attributes.OperationLogs.OperationID] == 99 op_id_found = True assert op_id_found + + def test_time_stamp_threaded_to_new_sv_writes(self): + """New-SV writes (FormerIdentity/OperationID on new, NewIdentity + on old) land at `time_stamp`. Parent-copy and Child-list writes + preserve the old cell's timestamp so pre-op readers still see + the old hierarchy. + """ + from datetime import datetime, timezone + + old = np.uint64(10) + new1 = np.uint64(101) + parent = np.uint64(1000) + + old_cell_ts = 42 # old cell's timestamp, preserved on Parent/Child copies + op_ts = datetime(2026, 4, 23, tzinfo=timezone.utc) # op's logical write time + + parent_cells_map = {old: [_FakeCell(parent, timestamp=old_cell_ts)]} + children_cells_map = { + parent: [ + _FakeCell( + np.array([old], dtype=basetypes.NODE_ID), timestamp=old_cell_ts + ) + ] + } + cg = self._make_cg(parent_cells_map, children_cells_map) + + copy_parents_and_add_lineage( + cg, operation_id=7, old_new_map={old: {new1}}, time_stamp=op_ts + ) + + # Classify each mutate_row call by which column it writes. + for call in cg.client.mutate_row.call_args_list: + val_dict = call[0][1] + kw = call[1] + ts = kw.get("time_stamp") + cols = set(val_dict.keys()) + + if attributes.Hierarchy.FormerIdentity in cols: + # New-SV lineage write — should use op's time_stamp. + assert ts == op_ts, f"FormerIdentity write ts={ts}, expected {op_ts}" + elif attributes.Hierarchy.NewIdentity in cols: + # Old-SV NewIdentity write — should use op's time_stamp. + assert ts == op_ts, f"NewIdentity write ts={ts}, expected {op_ts}" + elif attributes.Hierarchy.Parent in cols: + # Copied-parent write — preserves old cell's timestamp. + assert ( + ts == old_cell_ts + ), f"Parent-copy write ts={ts}, expected {old_cell_ts}" + elif attributes.Hierarchy.Child in cols: + # Updated-children write on L2 parent — preserves old timestamp. + assert ( + ts == old_cell_ts + ), f"Child-list write ts={ts}, expected {old_cell_ts}" + + +# ============================================================ +# Tests: _coords_bbox / plan_sv_splits bbox is seed-driven, not rep-driven +# ============================================================ +class TestCoordsBbox: + def _make_cg(self, chunk_size=(64, 64, 64), volume=(1024, 1024, 1024)): + cg = MagicMock() + cg.meta.graph_config.CHUNK_SIZE = list(chunk_size) + cg.meta.voxel_bounds = np.array( + [[0, volume[0]], [0, volume[1]], [0, volume[2]]] + ) + cg.get_chunk_id.side_effect = lambda layer, x, y, z: ( + (layer << 60) | (x << 40) | (y << 20) | z + ) + return cg + + def test_envelope_around_seeds_with_one_chunk_margin(self): + cg = self._make_cg(chunk_size=(64, 64, 64)) + src = np.array([[100, 200, 300]]) + sink = np.array([[150, 250, 350]]) + bbs, bbe = _coords_bbox(cg, src, sink) + # min - chunk_size, max + chunk_size, clipped to volume bounds. + np.testing.assert_array_equal(bbs, np.array([100 - 64, 200 - 64, 300 - 64])) + np.testing.assert_array_equal(bbe, np.array([150 + 64, 250 + 64, 350 + 64])) + + def test_clipped_to_volume_bounds(self): + cg = self._make_cg(chunk_size=(64, 64, 64), volume=(256, 256, 256)) + src = np.array([[10, 10, 10]]) + sink = np.array([[250, 250, 250]]) + bbs, bbe = _coords_bbox(cg, src, sink) + # Lower seed - 64 = -54 → clipped to 0; upper seed + 64 = 314 → clipped to 256. + np.testing.assert_array_equal(bbs, np.array([0, 0, 0])) + np.testing.assert_array_equal(bbe, np.array([256, 256, 256])) + + def test_plan_sv_splits_bbox_independent_of_rep_extent(self): + """The returned per-task bbox follows the seeds, not the rep's + cross-chunk pieces. A rep whose pieces span the whole volume + produces the same tight bbox as a rep with one piece, given the + same src/sink coords. + """ + cg = self._make_cg(chunk_size=(64, 64, 64), volume=(1024, 1024, 1024)) + + # Two source/sink IDs that map to the same rep — the SV-split + # trigger condition. The rep's other pieces (b..z) sit far from + # the seeds. They would have ballooned the old `_rep_bbox`; the + # new `_coords_bbox` ignores them. + rep = np.uint64(1) + sv_remapping = { + np.uint64(10): rep, # src + np.uint64(20): rep, # sink + **{np.uint64(100 + i): rep for i in range(28)}, # 28 distant pieces + } + + source_ids = np.array([10], dtype=basetypes.NODE_ID) + sink_ids = np.array([20], dtype=basetypes.NODE_ID) + source_coords = np.array([[100, 200, 300]]) + sink_coords = np.array([[150, 250, 350]]) + + tasks, _ = plan_sv_splits( + cg, + sv_remapping=sv_remapping, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + ) + assert len(tasks) == 1 + np.testing.assert_array_equal( + tasks[0].bbs, np.array([100 - 64, 200 - 64, 300 - 64]) + ) + np.testing.assert_array_equal( + tasks[0].bbe, np.array([150 + 64, 250 + 64, 350 + 64]) + ) diff --git a/pychunkedgraph/tests/graph/test_locks.py b/pychunkedgraph/tests/graph/test_locks.py index 97da9334c..0d82262a9 100644 --- a/pychunkedgraph/tests/graph/test_locks.py +++ b/pychunkedgraph/tests/graph/test_locks.py @@ -1,10 +1,23 @@ +import threading +import time from time import sleep from datetime import datetime, timedelta, UTC import numpy as np import pytest -from ..helpers import create_chunk, to_label +from ..helpers import ( + RowKeyLockRegistry, + create_chunk, + make_cg_with_row_key_lock_registry, + to_label, +) +from ...graph import attributes, exceptions +from ...graph.locks import ( + IndefiniteL2ChunkLock, + L2ChunkLock, + _l2_chunk_lock_row_key, +) from ...graph.lineage import get_future_root_ids from ...ingest.create.parent_layer import add_parent_chunk @@ -702,6 +715,30 @@ def test_indefiniterootlock_exit_handles_exception(self): # Should not raise lock.__exit__(None, None, None) + def test_indefiniterootlock_exit_holds_on_exception_path(self): + """When `__exit__` is called with a propagating exception, cells + stay held — partial bigtable hierarchy writes may have landed + and further ops must refuse until operator recovery runs. + """ + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100), np.uint64(101)]) + cg.client.lock_roots_indefinitely.return_value = ( + True, + list(root_ids), + [], + ) + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + lock.__enter__() + lock.__exit__(ValueError, ValueError("boom"), None) + + cg.client.unlock_indefinitely_locked_root.assert_not_called() + class TestIndefiniteRootLockComputesFutureRootIds: def test_indefiniterootlock_computes_future_root_ids(self): @@ -750,3 +787,266 @@ def test_rootlock_as_context_manager(self): assert lock.lock_acquired is True cg.client.unlock_root.assert_called_once() + + +class TestL2ChunkLockRowKey: + def test_length(self): + assert len(_l2_chunk_lock_row_key(0)) == 10 + + def test_deterministic(self): + assert _l2_chunk_lock_row_key(0xDEADBEEF) == _l2_chunk_lock_row_key(0xDEADBEEF) + + def test_distinct_chunks_distinct_keys(self): + assert _l2_chunk_lock_row_key(42) != _l2_chunk_lock_row_key(43) + + def test_hash_prefix_scatters(self): + """Adjacent chunk IDs should not cluster in one first-byte prefix — + that's the whole point of the hash prefix.""" + prefixes = {_l2_chunk_lock_row_key(i)[0] for i in range(256)} + # blake2b over 8 bytes of changing input distributes uniformly. + assert len(prefixes) > 128 + + +class TestL2ChunkLock: + def test_acquire_and_release(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + with L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(42)): + assert len(registry._held) == 2 + assert registry._held == {} + + def test_non_overlapping_concurrent(self): + """Disjoint chunk sets can coexist — no shared row keys.""" + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + l1 = L2ChunkLock(cg, [np.uint64(1)], np.uint64(1)) + l2 = L2ChunkLock(cg, [np.uint64(5)], np.uint64(2)) + l1.__enter__() + l2.__enter__() + assert len(registry._held) == 2 + l1.__exit__(None, None, None) + l2.__exit__(None, None, None) + assert registry._held == {} + + def test_overlapping_contends(self, monkeypatch): + """Two overlapping acquisitions serialize: second blocks until first releases.""" + monkeypatch.setattr(L2ChunkLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.05) + + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + + l1 = L2ChunkLock(cg, [np.uint64(7)], np.uint64(1)) + l1.__enter__() + + second_entered = threading.Event() + second_failed = threading.Event() + + def second(): + lock = L2ChunkLock(cg, [np.uint64(7)], np.uint64(2)) + try: + lock.__enter__() + second_entered.set() + lock.__exit__(None, None, None) + except exceptions.LockingError: + second_failed.set() + + t = threading.Thread(target=second) + t.start() + time.sleep(0.2) + assert not second_entered.is_set() + l1.__exit__(None, None, None) + t.join(timeout=2.0) + assert second_entered.is_set() + assert not second_failed.is_set() + assert registry._held == {} + + def test_partial_acquire_released_on_failure(self, monkeypatch): + """If any chunk in the set fails to lock, prior ones are released.""" + monkeypatch.setattr(L2ChunkLock, "_MAX_ACQUIRE_ATTEMPTS", 2) + monkeypatch.setattr(L2ChunkLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.01) + + registry = RowKeyLockRegistry() + registry.lock_by_row_key(_l2_chunk_lock_row_key(np.uint64(2)), np.uint64(99)) + + cg = make_cg_with_row_key_lock_registry(registry) + lock = L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(1)) + with pytest.raises(exceptions.LockingError): + lock.__enter__() + # Only chunk 2 remains held, by the pre-existing holder. + assert len(registry._held) == 1 + assert next(iter(registry._held)) == _l2_chunk_lock_row_key(np.uint64(2)) + + def test_privileged_mode_skips_acquire(self): + """Replay path: indefinite cells from the crashed op are still + set, so a normal temporal acquire would refuse. Privileged mode + bypasses the acquire entirely — the indefinite cells are the + de-facto lock and the inner `IndefiniteL2ChunkLock(privileged=True)` + releases them on exit. + """ + registry = RowKeyLockRegistry() + # Crashed op's indefinite cells block a normal temporal acquire. + crashed_op = np.uint64(42) + for c in (np.uint64(1), np.uint64(2)): + registry.lock_by_row_key_indefinitely(_l2_chunk_lock_row_key(c), crashed_op) + + cg = make_cg_with_row_key_lock_registry(registry) + + # Normal acquire refuses because indefinite is held. + normal = L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(99)) + with pytest.raises(exceptions.LockingError): + normal.__enter__() + + # Privileged acquire — called from replay with the same op_id as + # the crashed op — skips the acquire and returns cleanly. + priv = L2ChunkLock( + cg, [np.uint64(1), np.uint64(2)], crashed_op, privileged_mode=True + ) + priv.__enter__() + priv.__exit__(None, None, None) + # Indefinite cells still held (privileged-L2ChunkLock doesn't + # touch them — that's IndefiniteL2ChunkLock(privileged=True)'s job). + assert len(registry._held_indefinite) == 2 + + +class TestIndefiniteL2ChunkLock: + """`IndefiniteL2ChunkLock` lifecycle: acquire + scope write on enter, + release + scope clear on exit; privileged mode releases pre-existing + cells left by a crashed op. + """ + + def _scope_mutate_calls(self, cg): + """Extract (row_key, scope_value) from cg.client.mutate_row calls + that set `L2ChunkLockScope`. Lets tests assert on what was written.""" + calls = [] + for call in cg.client.mutate_row.call_args_list: + row_key, val_dict = call[0][:2] + if attributes.OperationLogs.L2ChunkLockScope in val_dict: + calls.append( + (row_key, val_dict[attributes.OperationLogs.L2ChunkLockScope]) + ) + return calls + + def test_enter_writes_scope_and_acquires_cells(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(3), np.uint64(1), np.uint64(2)] + op_id = np.uint64(42) + + lock = IndefiniteL2ChunkLock(cg, chunks, op_id) + lock.__enter__() + try: + # Every chunk now has an indefinite cell. + assert len(registry._held_indefinite) == 3 + # Scope written to op-log row; value is the sorted chunk list. + scope_calls = self._scope_mutate_calls(cg) + non_empty = [c for c in scope_calls if len(c[1]) > 0] + assert len(non_empty) == 1 + assert list(non_empty[0][1]) == [1, 2, 3] + finally: + lock.__exit__(None, None, None) + + def test_exit_releases_cells_and_clears_scope(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(1), np.uint64(2)] + with IndefiniteL2ChunkLock(cg, chunks, np.uint64(42)): + pass + # Cells released. + assert registry._held_indefinite == {} + # Scope cleared: one write of an empty array to L2ChunkLockScope. + empty_calls = [c for c in self._scope_mutate_calls(cg) if len(c[1]) == 0] + assert len(empty_calls) == 1 + + def test_privileged_mode_releases_preexisting(self): + """Crashed op left indefinite cells under its op_id; the replay + re-enters with privileged_mode=True and the `__exit__` is expected + to delete those pre-existing cells (value-matched by op_id). + """ + registry = RowKeyLockRegistry() + op_id = np.uint64(42) + chunks = [np.uint64(10), np.uint64(20)] + for c in chunks: + assert registry.lock_by_row_key_indefinitely( + _l2_chunk_lock_row_key(c), op_id + ) + assert len(registry._held_indefinite) == 2 + + cg = make_cg_with_row_key_lock_registry(registry) + with IndefiniteL2ChunkLock(cg, chunks, op_id, privileged_mode=True): + # Privileged enter skips acquire, so pre-existing cells persist. + assert len(registry._held_indefinite) == 2 + # Privileged mode does not re-write the scope either; only the + # clear-on-exit writes `L2ChunkLockScope`. + assert self._scope_mutate_calls(cg) == [] + # Exit released the pre-existing cells. + assert registry._held_indefinite == {} + + def test_double_acquire_fails(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + op_a = np.uint64(1) + op_b = np.uint64(2) + with IndefiniteL2ChunkLock(cg, [np.uint64(5)], op_a): + lock_b = IndefiniteL2ChunkLock(cg, [np.uint64(5)], op_b) + with pytest.raises(exceptions.LockingError): + lock_b.__enter__() + # Op A's cell still held. + assert len(registry._held_indefinite) == 1 + + def test_replay_nested_privileged_clears_crashed_cells(self): + """Replay lock-dance against a crashed op's pre-existing cells. + + Simulates what `MulticutOperation._apply` does during replay: + `with L2ChunkLock(privileged=True): with IndefiniteL2ChunkLock( + privileged=True): ...`. Both locks must succeed despite indefinite + cells being pre-held, and the inner `__exit__` must release them. + + This regresses the bug where `L2ChunkLock` lacked a privileged + escape hatch — the temporal acquire would refuse because + `lock_by_row_key_with_indefinite` sees the crashed op's + indefinite cell. + """ + registry = RowKeyLockRegistry() + crashed_op = np.uint64(42) + chunks = [np.uint64(1), np.uint64(2), np.uint64(3)] + # Seed crashed op's indefinite cells. + for c in chunks: + registry.lock_by_row_key_indefinitely(_l2_chunk_lock_row_key(c), crashed_op) + assert len(registry._held_indefinite) == 3 + + cg = make_cg_with_row_key_lock_registry(registry) + + # Replay's exact lock-dance from operation.py _apply. + with L2ChunkLock(cg, chunks, crashed_op, privileged_mode=True): + with IndefiniteL2ChunkLock(cg, chunks, crashed_op, privileged_mode=True): + # Simulated replay writes would happen here; we just + # assert the locks entered without raising. + pass + # Crashed op's cells released. + assert registry._held_indefinite == {} + + def test_exit_holds_on_exception_path(self): + """When `__exit__` is called with a propagating exception, cells + stay held and the op-log scope is NOT cleared — partial OCDBT / + bigtable writes may exist and subsequent ops must refuse until + operator recovery runs. + """ + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(1), np.uint64(2)] + op_id = np.uint64(42) + + lock = IndefiniteL2ChunkLock(cg, chunks, op_id) + lock.__enter__() + # Enter wrote scope + held cells. + assert len(registry._held_indefinite) == 2 + scope_writes = self._scope_mutate_calls(cg) + assert any(len(v) > 0 for _, v in scope_writes) + + # Simulate an exception propagating through the `with` block. + lock.__exit__(ValueError, ValueError("boom"), None) + + # Cells still held, scope not cleared (no empty-array mutate). + assert len(registry._held_indefinite) == 2 + empty_writes = [(k, v) for k, v in self._scope_mutate_calls(cg) if len(v) == 0] + assert empty_writes == [] diff --git a/pychunkedgraph/tests/graph/test_meta.py b/pychunkedgraph/tests/graph/test_meta.py index fc24f9917..d8f8134d9 100644 --- a/pychunkedgraph/tests/graph/test_meta.py +++ b/pychunkedgraph/tests/graph/test_meta.py @@ -582,8 +582,9 @@ def test_ws_ocdbt_asserts_when_not_ocdbt(self): with pytest.raises(AssertionError, match="ocdbt"): _ = meta.ws_ocdbt + @patch("pychunkedgraph.graph.meta.fork_exists", return_value=True) @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") - def test_ws_ocdbt_returns_base_scale(self, mock_get_ocdbt): + def test_ws_ocdbt_returns_base_scale(self, mock_get_ocdbt, _mock_fork_exists): gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) @@ -602,8 +603,9 @@ def test_ws_ocdbt_returns_base_scale(self, mock_get_ocdbt): assert meta.ws_ocdbt_resolutions == [[4, 4, 40], [8, 8, 40]] mock_get_ocdbt.assert_called_once_with("gs://bucket/ws", "test_graph") + @patch("pychunkedgraph.graph.meta.fork_exists", return_value=True) @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") - def test_ws_ocdbt_cached(self, mock_get_ocdbt): + def test_ws_ocdbt_cached(self, mock_get_ocdbt, _mock_fork_exists): gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) diff --git a/pychunkedgraph/tests/graph/test_multicut.py b/pychunkedgraph/tests/graph/test_multicut.py index 590476ffd..4edd5962f 100644 --- a/pychunkedgraph/tests/graph/test_multicut.py +++ b/pychunkedgraph/tests/graph/test_multicut.py @@ -2,8 +2,7 @@ import pytest from ...graph.edges import Edges -from ...graph import exceptions -from ...graph.cutting import run_multicut +from ...graph.cutting import Cut, SvSplitRequired, run_multicut class TestGraphMultiCut: @@ -25,13 +24,15 @@ def test_cut_multi_tree(self, gen_graph): source_ids = np.array([1, 2], dtype=np.uint64) sink_ids = np.array([5, 6], dtype=np.uint64) - cut_edges = run_multicut( + result = run_multicut( edges, source_ids, sink_ids, path_augment=False, disallow_isolating_cut=False, ) + assert isinstance(result, Cut) + cut_edges = result.atomic_edges assert cut_edges.shape[0] > 0 # Verify the cut actually separates sources from sinks @@ -64,14 +65,19 @@ def test_path_augmented_multicut(self, sv_data): edges = Edges( sv_edges[:, 0], sv_edges[:, 1], affinities=sv_affinity, areas=sv_area ) - cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) - assert cut_edges_aug.shape[0] == 350 + result = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) + assert isinstance(result, Cut) + assert result.atomic_edges.shape[0] == 350 - with pytest.raises(exceptions.SupervoxelSplitRequiredError): - run_multicut( - edges, - sv_sources, - sv_sinks, - path_augment=False, - sv_split_supported=True, - ) + # Without path augmentation on this fixture, source/sink share a + # cross-chunk representative — returned as SvSplitRequired when + # sv_split_supported=True (no exception escapes run_multicut). + sv_result = run_multicut( + edges, + sv_sources, + sv_sinks, + path_augment=False, + sv_split_supported=True, + ) + assert isinstance(sv_result, SvSplitRequired) + assert sv_result.sv_remapping # non-empty mapping diff --git a/pychunkedgraph/tests/graph/test_ocdbt.py b/pychunkedgraph/tests/graph/test_ocdbt.py index efa6ed048..501229ada 100644 --- a/pychunkedgraph/tests/graph/test_ocdbt.py +++ b/pychunkedgraph/tests/graph/test_ocdbt.py @@ -4,6 +4,8 @@ import os import shutil import tempfile +import time +from datetime import datetime, timezone import numpy as np import pytest @@ -13,6 +15,13 @@ from pychunkedgraph.graph import ocdbt as ocdbt_mod from pychunkedgraph.graph.meta import ChunkedGraphMeta, GraphConfig, DataSource +SCALE_META_BASE = { + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], +} +MULTISCALE_META = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + def _make_mock_src(num_scales=2): """Build a mock TensorStore source handle with a copyable schema.""" @@ -62,7 +71,9 @@ def _setup_ts_mock(mock_ts, num_scales=2): class TestBuildCgOcdbtSpec: def test_spec_structure(self): """build_cg_ocdbt_spec returns the expected kvstack-layered spec.""" - spec = ocdbt_mod.build_cg_ocdbt_spec("gs://bucket/ws", "my_graph") + spec = ocdbt_mod.build_cg_ocdbt_spec( + "gs://bucket/ws", "my_graph", ocdbt_mod.OcdbtConfig() + ) assert spec["driver"] == "ocdbt" layers = spec["base"]["layers"] assert len(layers) == 3 @@ -80,41 +91,33 @@ def test_spec_structure(self): class TestForkBaseManifest: - def test_copies_manifest(self): + """Byte-level behavior of `fork_base_manifest` — manifest copy + wipe.""" + + def test_copies_manifest(self, local_ocdbt): """fork_base_manifest copies the base manifest via tensorstore kvstore.""" - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - # Create a real base OCDBT with a manifest. - base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() - base_kvs.write("manifest.ocdbt", b"fake_manifest_bytes").result() + ws = local_ocdbt["ws"] + base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() + base_kvs.write("manifest.ocdbt", b"fake_manifest_bytes").result() - ocdbt_mod.fork_base_manifest(ws, "my_graph") + ocdbt_mod.fork_base_manifest(ws, "my_graph") - fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - result = fork_kvs.read("manifest.ocdbt").result() - assert result.value == b"fake_manifest_bytes" - finally: - shutil.rmtree(tmpdir) + fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + assert fork_kvs.read("manifest.ocdbt").result().value == b"fake_manifest_bytes" - def test_wipe_existing_cleans_fork_dir(self): + def test_wipe_existing_cleans_fork_dir(self, local_ocdbt): """wipe_existing=True removes the fork directory before copying.""" - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() - base_kvs.write("manifest.ocdbt", b"manifest_v1").result() + ws = local_ocdbt["ws"] + base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() + base_kvs.write("manifest.ocdbt", b"manifest_v1").result() - fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - fork_kvs.write("stale_file", b"stale").result() + fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + fork_kvs.write("stale_file", b"stale").result() - ocdbt_mod.fork_base_manifest(ws, "my_graph", wipe_existing=True) + ocdbt_mod.fork_base_manifest(ws, "my_graph", wipe_existing=True) - fork_kvs2 = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - assert fork_kvs2.read("manifest.ocdbt").result().value == b"manifest_v1" - assert len(fork_kvs2.read("stale_file").result().value) == 0 - finally: - shutil.rmtree(tmpdir) + fork_kvs2 = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + assert fork_kvs2.read("manifest.ocdbt").result().value == b"manifest_v1" + assert len(fork_kvs2.read("stale_file").result().value) == 0 class TestModeDownsample: @@ -216,45 +219,80 @@ def test_boundary_clipping(self): @pytest.fixture def local_ocdbt(): - """Create a local precomputed multi-scale OCDBT store. - - Builds 3 scales (factors 2,2,1 between each) with known segmentation IDs - so downsampling behaviour and propagation can be asserted against exact - values. Returns paths + handles for tests to work against directly. + """Shared OCDBT test environment. + + Creates a local 3-scale precomputed base OCDBT (factors 2,2,1 per + level) and exposes helpers for fork-based tests. Every OCDBT test + that needs real storage uses this fixture — no duplicated tmpdir + scaffolding. + + Yields: + tmpdir: on-disk workspace (cleaned up on teardown). + ws: `file://{tmpdir}` URL — what `build_cg_ocdbt_spec` expects. + base: base OCDBT kvstore URL. + scales: 3 precomputed handles on the base (multi-scale tests). + resolutions: per-scale [x,y,z] resolution arrays. + make_fork(graph_id, *, scale_index=0, pinned_at=None): opens a + precomputed handle through a fork of the base. Creates the + fork on first call per `graph_id` and reuses it thereafter; + repeated calls with the same id never re-copy the manifest + (which would clobber fork writes). """ tmpdir = tempfile.mkdtemp() - base = f"file://{tmpdir}/ocdbt/base" - - mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + ws = f"file://{tmpdir}" + base = f"{ws}/ocdbt/base" - def mk(scale_idx, size, resolution, extra_mm=None): + def _mk_scale(size, resolution, *, include_mm): + # Match OcdbtConfig defaults so forks (which always use them) don't + # trip the "Configuration mismatch on max_inline_value_bytes" check. spec = { "driver": "neuroglancer_precomputed", - "kvstore": {"driver": "ocdbt", "base": base}, + "kvstore": { + "driver": "ocdbt", + "base": base, + "config": ocdbt_mod.OcdbtConfig().ts_config(), + }, "scale_metadata": { "size": size, "resolution": resolution, - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], + **SCALE_META_BASE, }, } - if extra_mm: - spec["multiscale_metadata"] = extra_mm + if include_mm: + spec["multiscale_metadata"] = MULTISCALE_META return ts.open(spec, create=True).result() scales = [ - mk(0, [64, 64, 32], [4, 4, 40], extra_mm=mm), - mk(1, [32, 32, 32], [8, 8, 40]), - mk(2, [16, 16, 32], [16, 16, 40]), + _mk_scale([64, 64, 32], [4, 4, 40], include_mm=True), + _mk_scale([32, 32, 32], [8, 8, 40], include_mm=False), + _mk_scale([16, 16, 32], [16, 16, 40], include_mm=False), ] resolutions = [[4, 4, 40], [8, 8, 40], [16, 16, 40]] + _created_forks = set() + + def make_fork(graph_id, *, scale_index=0, pinned_at=None): + if graph_id not in _created_forks: + ocdbt_mod.fork_base_manifest(ws, graph_id) + _created_forks.add(graph_id) + spec = ocdbt_mod.build_cg_ocdbt_spec( + ws, graph_id, ocdbt_mod.OcdbtConfig(), pinned_at=pinned_at + ) + return ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": spec, + "scale_index": scale_index, + } + ).result() + yield { "tmpdir": tmpdir, + "ws": ws, "base": base, "scales": scales, "resolutions": resolutions, + "make_fork": make_fork, } shutil.rmtree(tmpdir) @@ -398,220 +436,235 @@ def test_repeated_update_reflects_latest_base(self, local_ocdbt): assert (scales[2][0:4, 0:4, 0:16, :].read().result() == 2).all() -class TestWriteSeg: - def test_writes_base_and_propagates(self, local_ocdbt): - """`write_seg` writes to base scale AND propagates to all coarser scales.""" +class TestWriteSegChunks: + """`write_seg_chunks` now takes a flat list of (slices, data) pairs. + + `edits_sv.split_supervoxels` is responsible for producing this list + across all reps so the outer rep loop is a pure data gather — + tensorstore writes fire in one parallel batch. + """ + + def test_writes_only_supplied_chunks(self, local_ocdbt): + """Chunks absent from `seg_writes` stay untouched (OCDBT delta + stays proportional to the actual SV change).""" scales = local_ocdbt["scales"] - res = local_ocdbt["resolutions"] meta = MagicMock() meta.ws_ocdbt = scales[0] - meta.ws_ocdbt_scales = scales - meta.ws_ocdbt_resolutions = res - data = np.full((16, 16, 16), 55, dtype=np.uint64) - ocdbt_mod.write_seg(meta, [0, 0, 0], [16, 16, 16], data) + # One chunk at [0..32] with label 55. The adjacent chunk at + # [32..64] is NOT in the write list, so it should stay zero. + chunk_data = np.full((32, 32, 32), 55, dtype=np.uint64) + seg_writes = [ + ( + (slice(0, 32), slice(0, 32), slice(0, 32)), + chunk_data, + ) + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) - # Base scale: written region has label 55. - assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 55).all() - # Coarser scales: propagated. - assert (scales[1][0:8, 0:8, 0:16, :].read().result() == 55).all() - assert (scales[2][0:4, 0:4, 0:16, :].read().result() == 55).all() + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 55).all() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 0).all() + # Coarser scales untouched — downsample worker's job. + assert (scales[1][0:16, 0:16, 0:32, :].read().result() == 0).all() + assert (scales[2][0:8, 0:8, 0:32, :].read().result() == 0).all() - def test_single_scale_skips_propagation(self, local_ocdbt): - """With only one scale in the list, propagation is a no-op (no IndexError).""" + def test_multiple_chunks_in_one_batch(self, local_ocdbt): + """Multiple chunks (e.g. from different reps) fire in one call.""" + scales = local_ocdbt["scales"] meta = MagicMock() - meta.ws_ocdbt = local_ocdbt["scales"][0] - meta.ws_ocdbt_scales = [local_ocdbt["scales"][0]] - meta.ws_ocdbt_resolutions = [local_ocdbt["resolutions"][0]] + meta.ws_ocdbt = scales[0] - data = np.full((8, 8, 8), 99, dtype=np.uint64) - ocdbt_mod.write_seg(meta, [0, 0, 0], [8, 8, 8], data) - assert (meta.ws_ocdbt[0:8, 0:8, 0:8, :].read().result() == 99).all() + seg_writes = [ + ( + (slice(0, 32), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 11, dtype=np.uint64), + ), + ( + (slice(32, 64), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 22, dtype=np.uint64), + ), + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 11).all() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 22).all() -class TestMetaToForkEndToEnd: - """Full path: ChunkedGraphMeta.ws_ocdbt_scales → real kvstack fork → read/write.""" + def test_offset_region(self, local_ocdbt): + """Writes at a non-origin offset land in the right chunk.""" + scales = local_ocdbt["scales"] + meta = MagicMock() + meta.ws_ocdbt = scales[0] - def test_meta_opens_fork_and_merges_base(self): - """meta.ws_ocdbt_scales opens a real kvstack-backed OCDBT and reads - merge base + fork correctly. + seg_writes = [ + ( + (slice(32, 64), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 99, dtype=np.uint64), + ) + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) - Only `_read_source_scales` is mocked (it reads `/info` which is a - GCS-only key). The full meta → build_cg_ocdbt_spec → kvstack → - OCDBT → read/write path is exercised for real. - """ - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - MM = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} - SCALE = { - "size": [64, 64, 32], - "resolution": [4, 4, 40], - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], - } - FAKE_SCALES = [ - { - "resolution": [4, 4, 40], - "size": [64, 64, 32], - "chunk_sizes": [[32, 32, 32]], - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - } - ] - - # Source precomputed — needed by get_seg_source_and_destination_ocdbt - # to open the source handle and copy its schema. - ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": f"{ws}/", - "multiscale_metadata": MM, - "scale_metadata": SCALE, - }, - create=True, - ).result() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 99).all() + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 0).all() - # Create base OCDBT with known data. - base_kvstore = { - "driver": "ocdbt", - "base": f"{ws}/ocdbt/base/", - "config": dict(ocdbt_mod.OCDBT_CONFIG), - } - base_store = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - "multiscale_metadata": MM, - "scale_metadata": SCALE, + +class TestWsOcdbtScalesProperty: + """`ChunkedGraphMeta.ws_ocdbt_scales` opens a fork over the shared base. + + Full path exercised: property → build_cg_ocdbt_spec → kvstack → OCDBT + read/write. Only `_read_source_scales` is mocked (it reads `/info` + which lives on the source watershed, not the OCDBT fork). + """ + + def test_opens_fork_and_merges_base(self, local_ocdbt): + ws = local_ocdbt["ws"] + + # Source precomputed at ws root — needed by + # get_seg_source_and_destination_ocdbt to copy the schema. + ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": f"{ws}/", + "multiscale_metadata": MULTISCALE_META, + "scale_metadata": { + "size": [64, 64, 32], + "resolution": [4, 4, 40], + **SCALE_META_BASE, }, - create=True, - ).result() - base_store[...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) - - # Fork for graph "test_cg". - ocdbt_mod.fork_base_manifest(f"{ws}/", "test_cg") - - gc = GraphConfig(ID="test_cg", CHUNK_SIZE=[32, 32, 32]) - ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) - meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) - - # Mock only _read_source_scales ('/info' is GCS-only). - with patch.object( - ocdbt_mod, "_read_source_scales", return_value=FAKE_SCALES - ): - scales = meta.ws_ocdbt_scales - assert len(scales) == 1 - - # Read: should see base data. - r = scales[0][0:16, 0:16, 0:16, :].read().result() - assert (r == 50).all(), f"fork should see base, got {np.unique(r)}" - - # Write via the fork handle. - scales[0][0:16, 0:16, 0:16, :] = np.full( - (16, 16, 16, 1), 7, dtype=np.uint64 - ) + }, + create=True, + ).result() - # Read back: edited = 7, untouched = 50. - assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 7).all() - assert (scales[0][32:48, 0:16, 0:16, :].read().result() == 50).all() - - # Base unchanged. - base_ro = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - } - ).result() - assert (base_ro[0:16, 0:16, 0:16, :].read().result() == 50).all() - finally: - shutil.rmtree(tmpdir) + # Seed base scale 0 with a known value via the fixture's handle. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + gc = GraphConfig(ID="ws_scales_cg", CHUNK_SIZE=[32, 32, 32]) + ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) -class TestForkIsolation: - """End-to-end: two forks on the same base, writes isolated, base immutable.""" + # Trigger fork creation through the same helper the property will use. + local_ocdbt["make_fork"]("ws_scales_cg") - def test_two_forks_isolated(self): - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - # Build a base OCDBT with known data. - MM = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} - SCALE = { - "size": [64, 64, 32], + fake_scales = [ + { "resolution": [4, 4, 40], + "size": [64, 64, 32], + "chunk_sizes": [[32, 32, 32]], "encoding": "compressed_segmentation", "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], } - base_kvstore = { - "driver": "ocdbt", - "base": f"{ws}/ocdbt/base/", - "config": dict(ocdbt_mod.OCDBT_CONFIG), - } - base_store = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - "multiscale_metadata": MM, - "scale_metadata": SCALE, - }, - create=True, - ).result() - base_store[...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) - - base_path = f"{tmpdir}/ocdbt/base" - base_files_before = set( - os.path.relpath(os.path.join(r, f), base_path) - for r, _, fs in os.walk(base_path) - for f in fs + ] + with patch.object( + ocdbt_mod.main, "_read_source_scales", return_value=fake_scales + ): + scales = meta.ws_ocdbt_scales + assert len(scales) == 1 + + # Fork sees base data. + assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Write to the fork and confirm isolation. + scales[0][0:16, 0:16, 0:16, :] = np.full( + (16, 16, 16, 1), 7, dtype=np.uint64 ) + assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 7).all() + assert (scales[0][32:48, 0:16, 0:16, :].read().result() == 50).all() - # Fork A and B via fork_base_manifest. - ocdbt_mod.fork_base_manifest(ws, "fork_a") - ocdbt_mod.fork_base_manifest(ws, "fork_b") - - def open_fork(gid): - spec = ocdbt_mod.build_cg_ocdbt_spec(ws, gid) - return ts.open( - {"driver": "neuroglancer_precomputed", "kvstore": spec}, - ).result() - - fork_a = open_fork("fork_a") - fork_b = open_fork("fork_b") - - # Both see base data. - assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 50).all() - assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() - - # Write different values to each fork. - fork_a[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) - fork_b[32:48, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) - - # Each fork sees ONLY its own edit + base for the rest. - assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 1).all() - assert (fork_a[32:48, 0:16, 0:16, :].read().result() == 50).all() - assert (fork_b[32:48, 0:16, 0:16, :].read().result() == 2).all() - assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() - - # Base is unchanged. - base_files_after = set( - os.path.relpath(os.path.join(r, f), base_path) - for r, _, fs in os.walk(base_path) - for f in fs - ) - assert ( - base_files_before == base_files_after - ), f"base was mutated: new={base_files_after - base_files_before}" - - # Fork writes went to their own directories. - fork_a_files = os.listdir(f"{tmpdir}/ocdbt/fork_a") - fork_b_files = os.listdir(f"{tmpdir}/ocdbt/fork_b") - assert any("fork_a_d" in f for f in fork_a_files) - assert any("fork_b_d" in f for f in fork_b_files) - finally: - shutil.rmtree(tmpdir) + # Base still reports the original value (fork write didn't leak). + assert ( + local_ocdbt["scales"][0][0:16, 0:16, 0:16, :].read().result() == 50 + ).all() + + +class TestForkIsolation: + """Two forks on the same base: writes isolated, base immutable.""" + + def test_two_forks_isolated(self, local_ocdbt): + tmpdir = local_ocdbt["tmpdir"] + # Seed base scale 0 with a known value. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + + base_path = f"{tmpdir}/ocdbt/base" + base_files_before = { + os.path.relpath(os.path.join(r, f), base_path) + for r, _, fs in os.walk(base_path) + for f in fs + } + + fork_a = local_ocdbt["make_fork"]("fork_a") + fork_b = local_ocdbt["make_fork"]("fork_b") + + # Both see base data. + assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 50).all() + assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Write different values to each fork. + fork_a[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) + fork_b[32:48, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) + + # Each fork sees ONLY its own edit + base for the rest. + assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 1).all() + assert (fork_a[32:48, 0:16, 0:16, :].read().result() == 50).all() + assert (fork_b[32:48, 0:16, 0:16, :].read().result() == 2).all() + assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Base files unchanged (no new bytes written under ocdbt/base/). + base_files_after = { + os.path.relpath(os.path.join(r, f), base_path) + for r, _, fs in os.walk(base_path) + for f in fs + } + assert ( + base_files_before == base_files_after + ), f"base was mutated: new={base_files_after - base_files_before}" + + # Fork writes went to their own directories. + assert any("fork_a_d" in f for f in os.listdir(f"{tmpdir}/ocdbt/fork_a")) + assert any("fork_b_d" in f for f in os.listdir(f"{tmpdir}/ocdbt/fork_b")) + + +class TestPinnedAt: + """Versioned reads: pinning a fork to a prior generation/timestamp + returns pre-write state; default (unpinned) returns latest. + + Documents both pin forms OCDBT accepts — integer generation (exact) + and ISO-8601 UTC timestamp with `Z` suffix (commit_time upper bound). + """ + + def test_pin_by_generation_and_by_timestamp(self, local_ocdbt): + # Seed base so fork reads see data even before the first fork write. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + + fork = local_ocdbt["make_fork"]("pin_cg") + + # Write v1 then v2 at the same voxels. Capture pin markers between + # the two writes so pre-v2 state is what each pin should return. + fork[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) + + fork_manifest_kvs = ts.KvStore.open( + f"{local_ocdbt['ws']}/ocdbt/pin_cg/" + ).result() + pin_gen = ts.ocdbt.dump(fork_manifest_kvs).result()["versions"][-1][ + "generation_number" + ] + + time.sleep(0.01) + pin_ts = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") + time.sleep(0.01) + + fork[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) + + fork_latest = local_ocdbt["make_fork"]("pin_cg") + assert (fork_latest[0:16, 0:16, 0:16, :].read().result() == 2).all() + + fork_gen = local_ocdbt["make_fork"]("pin_cg", pinned_at=pin_gen) + assert (fork_gen[0:16, 0:16, 0:16, :].read().result() == 1).all() + + fork_ts = local_ocdbt["make_fork"]("pin_cg", pinned_at=pin_ts) + assert (fork_ts[0:16, 0:16, 0:16, :].read().result() == 1).all() + + # Untouched region still shows base data under every pin. + for handle in (fork_latest, fork_gen, fork_ts): + assert (handle[32:48, 0:16, 0:16, :].read().result() == 50).all() class TestCopyWsChunkMultiscale: diff --git a/pychunkedgraph/tests/graph/test_stuck_ops.py b/pychunkedgraph/tests/graph/test_stuck_ops.py new file mode 100644 index 000000000..823c7595e --- /dev/null +++ b/pychunkedgraph/tests/graph/test_stuck_ops.py @@ -0,0 +1,380 @@ +"""Tests for pychunkedgraph.repair.stuck_ops — cleanup + replay path +for SV-split ops that crashed mid-write. + +The heavy test (`test_cleanup_reverts_partial_writes_to_pre_op`) +exercises the full cleanup flow against a real local OCDBT store — it +writes a known pre-op state, snapshots the manifest, writes simulated +"partial" data, constructs an op-log row with `L2ChunkLockScope` and +`OperationTimeStamp`, and asserts that cleanup reverts the scoped +chunks to pre-op values while leaving neighbor chunks alone. +""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import tensorstore as ts + +from pychunkedgraph.graph import attributes, ocdbt as ocdbt_mod +from pychunkedgraph.graph.chunks.utils import get_chunk_coordinates +from pychunkedgraph.graph.locks import _l2_chunk_lock_row_key +from pychunkedgraph.graph.meta import ChunkedGraphMeta, DataSource, GraphConfig +from pychunkedgraph.repair import stuck_ops + +# Pick up the shared `local_ocdbt` fixture from test_ocdbt. +from .test_ocdbt import local_ocdbt # noqa: F401 + + +class TestListStuck: + """`list_stuck` surfaces ops with non-empty `L2ChunkLockScope` past + `min_age` whose Status isn't SUCCESS — i.e. still holding + `Concurrency.IndefiniteLock` cells somewhere.""" + + def _entry(self, status, age_seconds, user="u", scope=None): + now = datetime.now(timezone.utc) + entry = { + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationTimeStamp: now + - timedelta(seconds=age_seconds), + attributes.OperationLogs.UserID: user, + } + if scope is not None: + entry[attributes.OperationLogs.L2ChunkLockScope] = np.asarray( + scope, dtype=np.uint64 + ) + return entry + + def _cg(self, entries): + cg = MagicMock() + cg.client.read_log_entries.return_value = entries + return cg + + def test_filters_out_success_with_scope(self): + """Defensive: a SUCCESS op with stale scope (if + `_clear_scope_on_op_log` ever failed silently) must not be + listed as stuck.""" + success = attributes.OperationLogs.StatusCodes.SUCCESS.value + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(1): self._entry(success, 900, scope=[10, 20]), + np.uint64(2): self._entry(created, 900, scope=[10, 20]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert [r["op_id"] for r in stuck] == [2] + + def test_filters_out_empty_scope(self): + """Ops that never touched the persist block (no scope) are not + stuck via L2 locks — they're outside `stuck_ops`' concern.""" + created = attributes.OperationLogs.StatusCodes.CREATED.value + exception = attributes.OperationLogs.StatusCodes.EXCEPTION.value + cg = self._cg( + { + np.uint64(1): self._entry(created, 900), # no scope + np.uint64(2): self._entry(exception, 900), # no scope + np.uint64(3): self._entry(created, 900, scope=[42]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert [r["op_id"] for r in stuck] == [3] + + def test_surfaces_exception_path_with_scope(self): + """After Fix 1, a Python exception during the persist block + leaves cells held + scope set but Status=EXCEPTION. The op must + be listed so the operator can recover it.""" + exception = attributes.OperationLogs.StatusCodes.EXCEPTION.value + cg = self._cg( + { + np.uint64(42): self._entry( + exception, 900, user="alice", scope=[100, 200] + ), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert len(stuck) == 1 + row = stuck[0] + assert row["op_id"] == 42 + assert row["status"] == exception + assert list(row["l2_chunk_scope"]) == [100, 200] + + def test_filters_out_young_ops(self): + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(1): self._entry(created, 10, scope=[1]), # too young + np.uint64(2): self._entry(created, 3600, scope=[2]), # an hour old + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=10)) + assert [r["op_id"] for r in stuck] == [2] + + def test_returns_scope_and_user(self): + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(7): self._entry(created, 1800, user="op", scope=[100, 200]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=10)) + assert len(stuck) == 1 + row = stuck[0] + assert row["op_id"] == 7 + assert row["user_id"] == "op" + assert list(row["l2_chunk_scope"]) == [100, 200] + assert row["age"] > timedelta(minutes=10) + + +class TestVerifyIndefiniteCells: + """`_verify_indefinite_cells` reads each chunk's indefinite-lock cell + and reports any that don't match the expected op_id.""" + + class _Cell: + def __init__(self, value): + self.value = value + + def _cg(self, cells_by_row_key): + cg = MagicMock() + + def read(row_key, columns=None): + return cells_by_row_key.get(row_key, []) + + cg.client._read_byte_row.side_effect = read + return cg + + def test_all_held_by_same_op(self): + op_id = 42 + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(np.uint64(op_id))], + stuck_ops._l2_chunk_lock_row_key(2): [self._Cell(np.uint64(op_id))], + } + cg = self._cg(cells) + assert stuck_ops._verify_indefinite_cells(cg, op_id, scope) == [] + + def test_cell_missing_flagged(self): + op_id = 42 + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(np.uint64(op_id))], + # chunk 2 has no cell + } + cg = self._cg(cells) + discrepancies = stuck_ops._verify_indefinite_cells(cg, op_id, scope) + assert discrepancies == [2] + + def test_cell_held_by_different_op_flagged(self): + op_id = 42 + other_op = np.uint64(99) + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(other_op)], + stuck_ops._l2_chunk_lock_row_key(2): [self._Cell(np.uint64(op_id))], + } + cg = self._cg(cells) + discrepancies = stuck_ops._verify_indefinite_cells(cg, op_id, scope) + assert discrepancies == [1] + + +class TestReplayVerifies: + """`replay` refuses to call cleanup_partial_writes or repair_operation + when the recorded scope disagrees with live indefinite-lock state.""" + + def test_replay_refuses_when_cells_missing(self, monkeypatch): + op_id = 77 + scope = np.asarray([1, 2], dtype=np.uint64) + + cg = MagicMock() + cg.client.read_log_entries.return_value = { + np.uint64(op_id): { + attributes.OperationLogs.L2ChunkLockScope: scope, + attributes.OperationLogs.OperationTimeStamp: datetime.now(timezone.utc), + } + } + # No cells held on either chunk. + cg.client._read_byte_row.return_value = [] + + # Spy on the destructive steps — neither should be called. + cleanup_called = {"v": False} + repair_called = {"v": False} + monkeypatch.setattr( + stuck_ops, + "cleanup_partial_writes", + lambda *a, **k: cleanup_called.__setitem__("v", True), + ) + monkeypatch.setattr( + stuck_ops, + "repair_operation", + lambda *a, **k: repair_called.__setitem__("v", True), + ) + + with pytest.raises(RuntimeError, match="Refusing to replay"): + stuck_ops.replay(cg, op_id) + assert not cleanup_called["v"] + assert not repair_called["v"] + + def test_replay_refuses_when_empty_scope(self, monkeypatch): + op_id = 77 + cg = MagicMock() + cg.client.read_log_entries.return_value = { + np.uint64(op_id): { + attributes.OperationLogs.OperationTimeStamp: datetime.now(timezone.utc), + } + } + cleanup_called = {"v": False} + monkeypatch.setattr( + stuck_ops, + "cleanup_partial_writes", + lambda *a, **k: cleanup_called.__setitem__("v", True), + ) + + with pytest.raises(RuntimeError, match="not a stuck SV-split op"): + stuck_ops.replay(cg, op_id) + assert not cleanup_called["v"] + + +class TestCleanupPartialWrites: + """Cleanup reverts partial OCDBT writes using pinned reads of pre-op state.""" + + def _meta_with_fork(self, local_ocdbt_fixture, graph_id): + """Build a real ChunkedGraphMeta pointing at the fixture's fork so + `ws_ocdbt` reads/writes go through the same kvstack as production. + + Creates a matching source precomputed at the watershed root so + `get_seg_source_and_destination_ocdbt` and `ws_cv` both work. + Sets `layer_count` explicitly to bypass `ws_cv.bounds` inference. + """ + ws = local_ocdbt_fixture["ws"] + mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + scale_metadata = { + "size": [64, 64, 32], + "resolution": [4, 4, 40], + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], + } + ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": f"{ws}/", + "multiscale_metadata": mm, + "scale_metadata": scale_metadata, + }, + create=True, + ).result() + + local_ocdbt_fixture["make_fork"](graph_id) + + gc = GraphConfig( + ID=graph_id, + CHUNK_SIZE=np.array([32, 32, 32], dtype=int), + ) + ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + meta.layer_count = 3 # avoids lazy cloudvolume layer inference + return meta + + def _capture_fork_pin(self, local_ocdbt_fixture, graph_id): + """Return an ISO-8601 `Z`-suffix pin string for the fork's current + manifest commit — the pre-op timestamp for cleanup to pin on. + """ + fork_manifest_kvs = ts.KvStore.open( + f"{local_ocdbt_fixture['ws']}/ocdbt/{graph_id}/" + ).result() + manifest = ts.ocdbt.dump(fork_manifest_kvs).result() + # commit_time is recorded as int ns since epoch; use a timestamp + # just past the last commit as the pin so the upper-bound filter + # picks up everything written so far. + last_ns = manifest["versions"][-1]["commit_time"] + return datetime.fromtimestamp(last_ns / 1e9 + 0.001, tz=timezone.utc) + + def test_cleanup_reverts_partial_writes_to_pre_op(self, local_ocdbt): + """Write known pre-op state, snapshot time, write partial state to + one chunk, simulate a stuck op with that chunk in scope, and + confirm cleanup reverts the chunk while leaving a non-scoped + neighbor chunk untouched. + """ + fixture = local_ocdbt + + meta = self._meta_with_fork(fixture, "stuck_cg") + fork_scale0 = fixture["make_fork"]("stuck_cg") + + # Pre-op state: chunk 0 region filled with 111, chunk 1 with 222. + # Chunk grid is at base resolution with 32^3 voxels per chunk. + fork_scale0[0:32, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 111, dtype=np.uint64 + ) + fork_scale0[32:64, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 222, dtype=np.uint64 + ) + + # Snapshot pin timestamp just after the pre-op writes. + pre_op_pin_dt = self._capture_fork_pin(fixture, "stuck_cg") + + # Partial "crash" writes: overwrite chunk 0 with garbage, touch + # chunk 1 too to prove scope-boundedness (scope will only list + # chunk 0, so chunk 1's garbage must persist after cleanup). + fork_scale0[0:32, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 999, dtype=np.uint64 + ) + fork_scale0[32:64, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 888, dtype=np.uint64 + ) + + # Chunk IDs for chunk-coord (0,0,0) and (1,0,0) at layer 2. + chunk_id_0 = _chunk_id_from_coord(meta, layer=2, coord=(0, 0, 0)) + chunk_id_1 = _chunk_id_from_coord(meta, layer=2, coord=(1, 0, 0)) + + # Sanity: scope chunk decodes back to the right coord. + assert tuple(get_chunk_coordinates(meta, chunk_id_0)) == (0, 0, 0) + + # Synthetic op-log row with scope=[chunk_id_0] and OperationTimeStamp=pre_op_pin. + op_id = 777 + op_log_row = { + attributes.OperationLogs.L2ChunkLockScope: np.asarray( + [chunk_id_0], dtype=np.uint64 + ), + attributes.OperationLogs.OperationTimeStamp: pre_op_pin_dt, + } + + cg = MagicMock() + cg.meta = meta + cg.client.read_log_entries.return_value = {np.uint64(op_id): op_log_row} + + # `_read_source_scales` reads `/info` from the watershed via + # tensorstore's kvstore interface — fine on GCS, not on file://. + # Bypass with a fake scale list matching the test's scale 0. + fake_scales = [ + { + "resolution": [4, 4, 40], + "size": [64, 64, 32], + "chunk_sizes": [[32, 32, 32]], + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + } + ] + with patch.object(ocdbt_mod, "_read_source_scales", return_value=fake_scales): + reverted = stuck_ops.cleanup_partial_writes(cg, op_id) + assert reverted == 1 + + # Scoped chunk reverted to pre-op. + scoped = fork_scale0[0:32, 0:32, 0:32, :].read().result() + assert ( + scoped == 111 + ).all(), f"scoped chunk not reverted: unique={np.unique(scoped)}" + # Non-scoped neighbor still has its post-crash "garbage" (888) — + # cleanup does not touch it. + neighbor = fork_scale0[32:64, 0:32, 0:32, :].read().result() + assert ( + neighbor == 888 + ).all(), f"neighbor chunk erroneously reverted: unique={np.unique(neighbor)}" + + +def _chunk_id_from_coord(meta, layer, coord): + """Encode (layer, x, y, z) into a chunk ID using the graph's bitmasks.""" + from pychunkedgraph.graph.chunks.utils import get_chunk_id + + return get_chunk_id( + meta, layer=layer, x=int(coord[0]), y=int(coord[1]), z=int(coord[2]) + ) diff --git a/pychunkedgraph/tests/helpers.py b/pychunkedgraph/tests/helpers.py index c41d629f6..009fec730 100644 --- a/pychunkedgraph/tests/helpers.py +++ b/pychunkedgraph/tests/helpers.py @@ -1,4 +1,6 @@ +import threading from functools import reduce +from unittest.mock import MagicMock import numpy as np @@ -109,3 +111,76 @@ def get_layer_chunk_bounds( layer_bounds = atomic_chunk_bounds / (2 ** (layer - 2)) layer_bounds_d[layer] = np.ceil(layer_bounds).astype(int) return layer_bounds_d + + +class RowKeyLockRegistry: + """Thread-safe in-memory stand-in for kvdbclient's row-key lock API. + + Matches the full `cg.client.lock_by_row_key*` / `unlock_by_row_key*` + / `renew_lock_by_row_key` surface — including the indefinite-column + variants — so row-key-based lock primitives (DownsampleBlockLock, + L2ChunkLock, IndefiniteL2ChunkLock, …) can be exercised without a + bigtable emulator. + + Two separate maps, one per column. The "with_indefinite" temporal + acquire refuses if either map holds the row, mirroring the filter + union that `lock_by_row_key_with_indefinite` uses on bigtable. + """ + + def __init__(self): + self._lock = threading.Lock() + self._held = {} + self._held_indefinite = {} + + def lock_by_row_key(self, row_key, operation_id): + with self._lock: + if row_key in self._held: + return False + self._held[row_key] = operation_id + return True + + def lock_by_row_key_with_indefinite(self, row_key, operation_id): + with self._lock: + if row_key in self._held or row_key in self._held_indefinite: + return False + self._held[row_key] = operation_id + return True + + def lock_by_row_key_indefinitely(self, row_key, operation_id): + with self._lock: + if row_key in self._held_indefinite: + return False + self._held_indefinite[row_key] = operation_id + return True + + def unlock_by_row_key(self, row_key, operation_id): + with self._lock: + if self._held.get(row_key) == operation_id: + del self._held[row_key] + return True + return False + + def unlock_indefinitely_locked_by_row_key(self, row_key, operation_id): + with self._lock: + if self._held_indefinite.get(row_key) == operation_id: + del self._held_indefinite[row_key] + return True + return False + + def renew_lock_by_row_key(self, row_key, operation_id): + with self._lock: + return self._held.get(row_key) == operation_id + + +def make_cg_with_row_key_lock_registry(registry: RowKeyLockRegistry): + """Attach a `RowKeyLockRegistry` to a `MagicMock` cg.client.""" + cg = MagicMock() + cg.client.lock_by_row_key = registry.lock_by_row_key + cg.client.lock_by_row_key_with_indefinite = registry.lock_by_row_key_with_indefinite + cg.client.lock_by_row_key_indefinitely = registry.lock_by_row_key_indefinitely + cg.client.unlock_by_row_key = registry.unlock_by_row_key + cg.client.unlock_indefinitely_locked_by_row_key = ( + registry.unlock_indefinitely_locked_by_row_key + ) + cg.client.renew_lock_by_row_key = registry.renew_lock_by_row_key + return cg diff --git a/pychunkedgraph/tests/ingest/test_ingest_utils.py b/pychunkedgraph/tests/ingest/test_ingest_utils.py index 4c5bdf0af..400ce3a0c 100644 --- a/pychunkedgraph/tests/ingest/test_ingest_utils.py +++ b/pychunkedgraph/tests/ingest/test_ingest_utils.py @@ -44,10 +44,13 @@ def test_from_config(self): }, "ingest_config": {}, } - meta, ingest_config, client_info = bootstrap("test_graph", config=config) + meta, ingest_config, client_info, ocdbt_config_dict = bootstrap( + "test_graph", config=config + ) assert meta.graph_config.ID == "test_graph" assert meta.graph_config.FANOUT == 2 assert ingest_config.USE_RAW_EDGES is False + assert isinstance(ocdbt_config_dict, dict) class TestPostprocessEdgeData: @@ -329,7 +332,6 @@ def my_func(): # ===================================================================== # Additional pure unit tests # ===================================================================== -from pychunkedgraph.ingest.utils import start_ocdbt_server class TestGetChunksNotDoneWithSplits: @@ -390,57 +392,6 @@ def test_get_chunks_not_done_splits_coord_str_format(self): assert call_args[0][1] == ["2_3_4_0"] -class TestStartOcdbtServer: - """Test start_ocdbt_server function.""" - - @patch("pychunkedgraph.ingest.utils.ts") - @patch.dict("os.environ", {"MY_POD_IP": "10.0.0.1"}) - def test_start_ocdbt_server(self, mock_ts): - """start_ocdbt_server should open a KvStore and set redis keys.""" - imanager = MagicMock() - imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" - mock_redis = MagicMock() - imanager.redis = mock_redis - - server = MagicMock() - server.port = 12345 - - mock_kv_future = MagicMock() - mock_ts.KvStore.open.return_value = mock_kv_future - - start_ocdbt_server(imanager, server) - - # Verify tensorstore was called with the right spec - call_args = mock_ts.KvStore.open.call_args[0][0] - assert call_args["driver"] == "ocdbt" - assert "gs://bucket/edges/ocdbt" in call_args["base"] - assert call_args["coordinator"]["address"] == "localhost:12345" - mock_kv_future.result.assert_called_once() - - # Verify redis keys were set - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_PORT", "12345") - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "10.0.0.1") - - @patch("pychunkedgraph.ingest.utils.ts") - @patch.dict("os.environ", {}, clear=True) - def test_start_ocdbt_server_default_host(self, mock_ts): - """When MY_POD_IP is not set, should default to localhost.""" - imanager = MagicMock() - imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" - mock_redis = MagicMock() - imanager.redis = mock_redis - - server = MagicMock() - server.port = 9999 - - mock_kv_future = MagicMock() - mock_ts.KvStore.open.return_value = mock_kv_future - - start_ocdbt_server(imanager, server) - - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "localhost") - - class TestPostprocessEdgeDataNoneValues: """Test postprocess_edge_data when edge_dict values are None.""" diff --git a/requirements.in b/requirements.in index 143d90399..42f6cd592 100644 --- a/requirements.in +++ b/requirements.in @@ -14,6 +14,7 @@ pyyaml cachetools werkzeug tensorstore +rich edt connected-components-3d scikit-image @@ -28,8 +29,9 @@ task-queue>=2.14.0 messagingclient>0.3.0 dracopy>=1.5.0 datastoreflex>=0.5.0 -kvdbclient>0.5.0 +kvdbclient>=0.7.0 zstandard>=0.23.0 +tinybrain>=1.7.0 # Conda only - use requirements.yml (or install manually): # graph-tool \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index af29a75bd..c0d5fe147 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile --output-file=requirements.txt requirements.in +# pip-compile requirements.in # attrs==25.4.0 # via @@ -195,17 +195,21 @@ jsonschema==4.26.0 # python-jsonschema-objects jsonschema-specifications==2025.9.1 # via jsonschema -kvdbclient==0.6.0 +kvdbclient==0.7.0 # via -r requirements.in lazy-loader==0.4 # via scikit-image markdown==3.10.2 # via python-jsonschema-objects +markdown-it-py==4.2.0 + # via rich markupsafe==3.0.3 # via # flask # jinja2 # werkzeug +mdurl==0.1.2 + # via markdown-it-py messagingclient==0.4.0 # via -r requirements.in microviewer==1.20.0 @@ -244,6 +248,7 @@ numpy==2.4.2 # task-queue # tensorstore # tifffile + # tinybrain # zmesh opentelemetry-api==1.39.1 # via @@ -320,7 +325,9 @@ pybind11==3.0.2 pycparser==3.0 # via cffi pygments==2.19.2 - # via pytest + # via + # pytest + # rich pysimdjson==7.0.2 # via cloud-volume pytest==9.0.2 @@ -359,6 +366,8 @@ requests==2.32.5 # kvdbclient # middle-auth-client # task-queue +rich==15.0.0 + # via -r requirements.in rpds-py==0.30.0 # via # jsonschema @@ -397,6 +406,8 @@ tensorstore==0.1.81 # via -r requirements.in tifffile==2026.3.3 # via scikit-image +tinybrain==1.7.0 + # via -r requirements.in tqdm==4.67.3 # via # cloud-files diff --git a/workers/downsample_worker.py b/workers/downsample_worker.py new file mode 100644 index 000000000..ae436a9aa --- /dev/null +++ b/workers/downsample_worker.py @@ -0,0 +1,86 @@ +# pylint: disable=invalid-name, missing-docstring, logging-fstring-interpolation + +"""Pubsub worker that updates coarser segmentation mips after an SV split. + +Consumes the same edits exchange the mesh worker uses, but binds its own +queue and filters on the `downsample="true"` attribute set by +`publish_edit` when `result.seg_bbox` is populated. For each block the +SV-split touched, acquires the block's lock, runs the in-memory / +per-mip pyramid writer, releases. +""" + +import gc +import logging +import pickle +from os import getenv + +from messagingclient import MessagingClient + +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.downsample import blocks_for_bbox, process_block +from pychunkedgraph.graph.locks import DownsampleBlockLock + +PCG_CACHE = {} + +INFO_HIGH = 25 +logging.basicConfig( + level=INFO_HIGH, + format="%(asctime)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) + + +def callback(payload): + # Filter by attribute rather than queue binding so all edit-triggered + # workers can share the same exchange. Split edits set + # `downsample=true`; merges/undos/redos/rollbacks don't. + if payload.attributes.get("downsample") != "true": + return + + data = pickle.loads(payload.data) + op_id = int(data["operation_id"]) + table_id = payload.attributes["table_id"] + seg_bboxes = data.get("seg_bboxes") + if not seg_bboxes: + return + + try: + cg = PCG_CACHE[table_id] + except KeyError: + cg = ChunkedGraph(graph_id=table_id) + PCG_CACHE[table_id] = cg + + # Defensive: non-OCDBT graphs have no coarser scales to write to. + seg_cfg = cg.meta.custom_data.get("seg", {}) + if not seg_cfg.get("ocdbt"): + logging.log( + INFO_HIGH, + f"graph {table_id} not OCDBT-backed; skipping downsample op {op_id}", + ) + return + + # Each published bbox is one SV split's write region. Collapse the + # list into the union of blocks touched so we lock/process each + # block exactly once even if two bboxes share blocks. + unique_blocks = set() + for bbs, bbe in seg_bboxes: + unique_blocks.update(blocks_for_bbox(cg.meta, bbs, bbe)) + block_list = sorted(unique_blocks) + + logging.log( + INFO_HIGH, + f"downsampling {len(block_list)} block(s) for op {op_id} graph {table_id}", + ) + with DownsampleBlockLock(cg, block_list, op_id): + for block in block_list: + process_block(cg.meta, block, seg_bboxes) + logging.log(INFO_HIGH, f"downsample complete op {op_id} graph {table_id}") + gc.collect() + + +c = MessagingClient() +downsample_queue = getenv("PYCHUNKEDGRAPH_DOWNSAMPLE_QUEUE") +assert ( + downsample_queue is not None +), "env PYCHUNKEDGRAPH_DOWNSAMPLE_QUEUE not specified." +c.consume(downsample_queue, callback)