diff --git a/AGENTS.md b/AGENTS.md index 745ce10..9e76207 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,7 +124,7 @@ Version bumping: `bumpver update --patch`, `--minor`, or `--major`. - Scripts that instantiate `SegmentChestTotalSegmentator` must guard the top-level invocation with `if __name__ == "__main__":` on Windows (`torch.multiprocessing` requires it). -- Single quotes for strings; double quotes for docstrings. Keep lines at or +- Double quotes for strings and docstrings. Keep lines at or below 88 characters. - Full type hints are required under strict mypy. Use `Optional[X]`, not `X | None`. diff --git a/CLAUDE.md b/CLAUDE.md index bf268f0..af531fa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -148,7 +148,7 @@ Document via docstrings and inline comments. ## Code Style -- Single quotes for strings; double quotes for docstrings +- Double quotes for strings and docstrings - Full type hints (`mypy` strict; `disallow_untyped_defs = true`) - `Optional[X]` not `X | None` (ruff `UP007` suppressed) - Breaking changes are acceptable — backward compatibility is not a priority diff --git a/docs/api/utilities/index.rst b/docs/api/utilities/index.rst index 0e4d2c0..62277bb 100644 --- a/docs/api/utilities/index.rst +++ b/docs/api/utilities/index.rst @@ -21,6 +21,7 @@ Quick Links **Utility Modules**: * :doc:`image_tools` - Image processing utilities + * :doc:`labelmap_tools` - Labelmap to registration-mask conversion * :doc:`transform_tools` - Transform operations * :doc:`contour_tools` - Contour processing * :doc:`image_conversion` - 4D image to 3D time-series utilities @@ -34,6 +35,7 @@ Module Documentation :maxdepth: 2 image_tools + labelmap_tools transform_tools contour_tools image_conversion diff --git a/docs/api/utilities/labelmap_tools.rst b/docs/api/utilities/labelmap_tools.rst new file mode 100644 index 0000000..7b453b5 --- /dev/null +++ b/docs/api/utilities/labelmap_tools.rst @@ -0,0 +1,19 @@ +==================================== +Labelmap Tools +==================================== + +.. currentmodule:: physiomotion4d + +Convert segmentation labelmaps into binary registration masks, with optional +label exclusion and physically isotropic dilation. + +Module Reference +================ + +.. automodule:: physiomotion4d.labelmap_tools + :members: + :undoc-members: + +.. rubric:: Navigation + +:doc:`index` | :doc:`image_tools` | :doc:`transform_tools` diff --git a/docs/contributing.rst b/docs/contributing.rst index 4c62e97..d6b5d96 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -141,7 +141,7 @@ PhysioMotion4D follows strict code quality standards using modern, fast tooling. Formatting and Linting with Ruff --------------------------------- -We use **Ruff** for all formatting and linting (line length: 88, single quotes): +We use **Ruff** for all formatting and linting (line length: 88, double quotes): .. code-block:: bash diff --git a/docs/developer/registration_images.rst b/docs/developer/registration_images.rst index dab4078..1c0d023 100644 --- a/docs/developer/registration_images.rst +++ b/docs/developer/registration_images.rst @@ -25,7 +25,11 @@ Basic Pattern registered = registrar.get_registered_image() The result dictionary contains ``forward_transform``, ``inverse_transform``, -and ``loss``. +and ``loss``. Applying the right one is critical and direction-dependent: +``forward_transform`` warps the moving image onto the fixed grid, while +``inverse_transform`` warps moving points/landmarks into fixed space (image and +point warps use opposite transforms). See +:doc:`transform_conventions` for the full rules. Time Series =========== @@ -57,5 +61,6 @@ Development Notes See Also ======== +* :doc:`transform_conventions` * :doc:`../api/registration/index` * :doc:`workflows` diff --git a/docs/developer/registration_models.rst b/docs/developer/registration_models.rst index 3102f1a..466b904 100644 --- a/docs/developer/registration_models.rst +++ b/docs/developer/registration_models.rst @@ -45,9 +45,14 @@ Development Notes * Convert volumetric meshes to surfaces before surface registration when needed. * Treat ITK/PyVista coordinate transforms as high-risk and add focused tests. * Keep synthetic test meshes small and deterministic. +* ``RegisterModelsPCA`` returns ``forward_point_transform`` / + ``inverse_point_transform``. These are **point** transforms whose orientation + is opposite to the image-registration transforms; see + :doc:`transform_conventions` before applying them to images or meshes. See Also ======== +* :doc:`transform_conventions` * :doc:`../api/model_registration/index` * :doc:`workflows` diff --git a/docs/developer/transform_conventions.rst b/docs/developer/transform_conventions.rst new file mode 100644 index 0000000..ee80cf0 --- /dev/null +++ b/docs/developer/transform_conventions.rst @@ -0,0 +1,137 @@ +=============================== +Transform Direction Conventions +=============================== + +Registration in PhysioMotion4D produces a pair of transforms, and choosing the +wrong one of the pair is the single most common registration mistake. The rules +are simple but easy to get backwards, because **warping an image and warping a +point require opposite transforms**, and because **model (PCA) registration +returns its transforms in the opposite orientation from image registration**. + +Read this page before applying any transform to an image, mask, contour, or +landmark. + +The two transform families +=========================== + +Image registration + :class:`physiomotion4d.RegisterImagesANTS`, + :class:`physiomotion4d.RegisterImagesICON`, and + :class:`physiomotion4d.RegisterImagesGreedy` register a *moving* image to a + *fixed* image and return a dict with ``forward_transform`` and + ``inverse_transform``. :class:`physiomotion4d.RegisterTimeSeriesImages` + returns the list-valued ``forward_transforms`` / ``inverse_transforms``. + +Model (PCA) registration + :class:`physiomotion4d.RegisterModelsPCA` deforms a *template* model toward + a *target* (patient) and, via ``compute_pca_transforms()``, returns + ``forward_point_transform`` and ``inverse_point_transform``. These are + **point transforms**, oriented opposite to the image-registration transforms + (see `PCA point transforms`_ below). + +Image warping vs. point warping use opposite transforms +======================================================== + +ITK resampling is a *pull-back* operation. To build the warped image on the +fixed grid, :func:`TransformTools.transform_image` (an ``itk.ResampleImageFilter``) +visits every fixed-grid sample ``q`` and looks up the moving image at +``transform.TransformPoint(q)``. The transform it needs therefore maps +**fixed-space coordinates to moving-space coordinates**. + +Warping a *point* (landmark, contour vertex, mesh node) is a *push-forward* +operation: :func:`TransformTools.transform_pvcontour` / +:func:`TransformTools.transform_dataset` apply ``transform.TransformPoint(p)`` +directly to each input point. To move a moving-space landmark to its location in +the fixed image, the transform must map **moving-space coordinates to +fixed-space coordinates** -- the inverse of the image-warp transform. + +So for the **same** moving-to-fixed registration result: + +.. list-table:: Image registration: which transform to apply + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Transform + - Helper + * - Warp the **moving image** into fixed space (onto the fixed grid) + - ``forward_transform`` + - :func:`TransformTools.transform_image` + * - Warp **moving points / contours / landmarks** into fixed space + - ``inverse_transform`` + - :func:`TransformTools.transform_pvcontour` + * - Warp the **fixed image** into moving space (e.g. time-series reconstruction) + - ``inverse_transform`` + - :func:`TransformTools.transform_image` + * - Warp **fixed points / contours / landmarks** into moving space + - ``forward_transform`` + - :func:`TransformTools.transform_pvcontour` + +The first two rows are the everyday case (warping the registered moving data +into the fixed/reference frame): the **image uses** ``forward_transform``, the +**points use** ``inverse_transform``. The last two rows are the mirror image; +:meth:`physiomotion4d.RegisterTimeSeriesImages.reconstruct_time_series` is the +canonical consumer of ``inverse_transform`` for image warping (it resamples the +fixed image back onto each moving frame's grid). + +.. note:: + + All three image-registration backends (ANTS, ICON, Greedy) follow this same + convention. ``transform_image(moving, forward_transform, fixed)`` is the + correct call to warp the moving image onto the fixed grid for every backend. + +PCA point transforms +==================== + +:class:`physiomotion4d.RegisterModelsPCA` builds ``forward_point_transform`` +directly from the template-to-target point displacement, so +``forward_point_transform.TransformPoint(template_point)`` returns the +corresponding *target* point. As a **point** map it goes template (moving) to +target (fixed) -- which is the same orientation as image registration's +``inverse_transform``, and therefore the **opposite** orientation of image +registration's ``forward_transform``. + +Concretely, treating the template as the moving object and the patient/target as +the fixed object: + +.. list-table:: Same goal, opposite transform names across the two families + :header-rows: 1 + :widths: 50 25 25 + + * - Goal + - Image registration + - PCA model registration + * - Warp the **image** (moving/template space -> fixed/target grid) + - ``forward_transform`` + - ``inverse_point_transform`` + * - Warp **points / meshes** (moving/template -> fixed/target) + - ``inverse_transform`` + - ``forward_point_transform`` + +In other words, ``forward_point_transform`` plays the role that +``inverse_transform`` plays for image registration, and +``inverse_point_transform`` plays the role of ``forward_transform``. Deforming +the template mesh onto the patient (the usual PCA use, performed internally by +``transform_template_model()`` and ``transform_point()``) uses +``forward_point_transform``; resampling an image with the PCA result uses +``inverse_point_transform``. + +Rule of thumb +============= + +* **Images pull back; points push forward.** For one registration result, the + image and the points always use the two *different* members of the transform + pair. +* **Image into the reference frame** -> ``forward_transform`` (image + registration) / ``inverse_point_transform`` (PCA). +* **Points into the reference frame** -> ``inverse_transform`` (image + registration) / ``forward_point_transform`` (PCA). +* When in doubt, warp a known landmark and a small image patch and confirm they + land in the same place before trusting a pipeline. + +See Also +======== + +* :doc:`registration_images` +* :doc:`registration_models` +* :doc:`utilities` diff --git a/docs/index.rst b/docs/index.rst index 458de99..320a9fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -162,6 +162,7 @@ per-tutorial implementation details. developer/segmentation developer/registration_images developer/registration_models + developer/transform_conventions developer/usd_generation developer/utilities diff --git a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py b/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py deleted file mode 100644 index addadea..0000000 --- a/experiments/Heart-GatedCT_To_USD/test_compare_registration_speed.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python -# %% [markdown] -# # Compare registration speed: Greedy vs ANTs vs ICON -# -# This notebook times **Greedy**, **ANTs**, and **ICON** when registering two time points of CT from the Slicer-Heart-CT data (TruncalValve 4D CT). -# -# **Prerequisites:** Run `0-download_and_convert_4d_to_3d.py` first so that `data/Slicer-Heart-CT/` contains the 4D NRRD and the 3D slice series (`slice_000.mha`, `slice_001.mha`, ...), and `results/slice_fixed.mha` exists. - -# %% -import os -import time - -import itk -import matplotlib.pyplot as plt -import pandas as pd -from itk import TubeTK as ttk - -from physiomotion4d.test_tools import TestTools -from physiomotion4d.register_images_ants import RegisterImagesANTS -from physiomotion4d.register_images_greedy import RegisterImagesGreedy -from physiomotion4d.register_images_icon import RegisterImagesICON - -_HERE = os.path.dirname(os.path.abspath(__file__)) - -# %% -data_dir = os.path.join(_HERE, "..", "..", "data", "Slicer-Heart-CT") -output_dir = os.path.join(_HERE, "results") -os.makedirs(output_dir, exist_ok=True) - -# Fixed = reference time point; moving = time point to align to fixed -fixed_image_path = os.path.join(output_dir, "slice_fixed.mha") -moving_image_path = os.path.join(data_dir, "slice_000.mha") - -if not os.path.exists(fixed_image_path): - raise FileNotFoundError( - f"Fixed image not found: {fixed_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) -if not os.path.exists(moving_image_path): - raise FileNotFoundError( - f"Moving image not found: {moving_image_path}. " - "Run 0-download_and_convert_4d_to_3d.py first." - ) - -fixed_image = itk.imread(fixed_image_path) -moving_image = itk.imread(moving_image_path) -print(f"Fixed image: {itk.size(fixed_image)}, spacing {itk.spacing(fixed_image)}") -print(f"Moving image: {itk.size(moving_image)}, spacing {itk.spacing(moving_image)}") - -# %% [markdown] -# ## Optional: downsample for faster comparison -# -# Set `downsample_factor = 1.0` to use full resolution (slower). Use e.g. `0.5` to halve each dimension for a quicker run. - -# %% -downsample_factor = 0.5 # 1.0 = full resolution - -if downsample_factor != 1.0: - resampler_f = ttk.ResampleImage.New(Input=fixed_image) - resampler_f.SetResampleFactor([downsample_factor] * 3) - resampler_f.Update() - fixed_image = resampler_f.GetOutput() - - resampler_m = ttk.ResampleImage.New(Input=moving_image) - resampler_m.SetResampleFactor([downsample_factor] * 3) - resampler_m.Update() - moving_image = resampler_m.GetOutput() - print(f"Downsampled to factor {downsample_factor}") - print(f" Fixed: {itk.size(fixed_image)}") - print(f" Moving: {itk.size(moving_image)}") -else: - print("Using full resolution.") - -# %% [markdown] -# ## Run each method and record time -# -# All three use **deformable** registration (Greedy: affine + deformable; ANTs: SyN; ICON: deep learning). Settings are chosen for a fair comparison with reduced iterations so the notebook runs in a few minutes. - -# %% -results_list = [] - -# --- Greedy (deformable) --- -try: - reg_g = RegisterImagesGreedy() - reg_g.set_modality("ct") - reg_g.set_transform_type("Deformable") - reg_g.set_number_of_iterations([10, 5, 2]) - reg_g.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_g = reg_g.register(moving_image) - elapsed_g = time.perf_counter() - t0 - - loss_g = out_g.get("loss") - results_list.append( - { - "method": "Greedy", - "time_sec": round(elapsed_g, 2), - "loss": float(loss_g) if loss_g is not None else None, - } - ) - print(f"Greedy: {elapsed_g:.2f} s") -except Exception as e: - results_list.append({"method": "Greedy", "time_sec": None, "loss": None}) - print(f"Greedy: failed - {e}") - -# --- ANTs (deformable SyN) --- -try: - reg_a = RegisterImagesANTS() - reg_a.set_modality("ct") - reg_a.set_transform_type("Deformable") - reg_a.set_number_of_iterations([10, 5, 2]) # reduced for speed - reg_a.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_a = reg_a.register(moving_image) - elapsed_a = time.perf_counter() - t0 - - loss_a = out_a.get("loss") - results_list.append( - { - "method": "ANTs", - "time_sec": round(elapsed_a, 2), - "loss": float(loss_a) if loss_a is not None else None, - } - ) - print(f"ANTs: {elapsed_a:.2f} s") -except Exception as e: - results_list.append({"method": "ANTs", "time_sec": None, "loss": None}) - print(f"ANTs: failed - {e}") - -# --- ICON (deformable, GPU) --- -try: - reg_i = RegisterImagesICON() - reg_i.set_modality("ct") - reg_i.set_number_of_iterations(50) - reg_i.set_fixed_image(fixed_image) - - t0 = time.perf_counter() - out_i = reg_i.register(moving_image) - elapsed_i = time.perf_counter() - t0 - - loss_i = out_i.get("loss") - results_list.append( - { - "method": "ICON", - "time_sec": round(elapsed_i, 2), - "loss": float(loss_i) if loss_i is not None else None, - } - ) - print(f"ICON: {elapsed_i:.2f} s") -except Exception as e: - results_list.append({"method": "ICON", "time_sec": None, "loss": None}) - print(f"ICON: failed - {e}") - -df = pd.DataFrame(results_list) - -# %% -print(df) - -# %% -fig, ax = plt.subplots(figsize=(6, 4)) -valid = df["time_sec"].notna() -if valid.any(): - methods = df.loc[valid, "method"] - times = df.loc[valid, "time_sec"] - ax.bar(methods, times, color=["#2ecc71", "#3498db", "#9b59b6"]) - ax.set_ylabel("Time (seconds)") - ax.set_title("Registration time: two time points (Slicer-Heart-CT)") - plt.tight_layout() - if not TestTools.running_as_test(): - plt.show() -else: - print("No successful runs to plot.") - -# %% [markdown] -# ## Notes -# -# - **Greedy**: CPU-based, often faster than ANTs for comparable quality; see [Greedy](https://greedy.readthedocs.io/) and [picsl-greedy](https://pypi.org/project/picsl-greedy/). -# - **ANTs**: CPU-based, very widely used; typically slower than Greedy for similar settings. -# - **ICON**: GPU-based (UniGradIcon); speed depends on GPU. Loss values are not directly comparable across methods. -# - For a quicker comparison, use `downsample_factor = 0.5` or reduce `number_of_iterations` further. diff --git a/experiments/LongitudinalRegistration/1-finetune_icon.py b/experiments/LongitudinalRegistration/1-finetune_icon.py deleted file mode 100644 index 968a078..0000000 --- a/experiments/LongitudinalRegistration/1-finetune_icon.py +++ /dev/null @@ -1,186 +0,0 @@ -# %% [markdown] -# # Fine-tune uniGradICON on Duke 4D Gated CT Data -# -# Discovers per-patient gated CT images and their precomputed -# SegmentHeartSimpleware labelmaps and applies the project-wide fixed 80/20 -# train/test split (sort patients in ``ref_data_dir`` by filename; the first -# 80% are train, the last 20% are test). The train cohort is handed to -# :class:`WorkflowFineTuneICONRegistration`, which builds the paired dataset -# JSON, YAML config, and derived loss-function masks, then launches -# ``unigradicon.finetuning.finetune`` as a subprocess. -# -# ``2-recon_4d_icon_eval.py`` re-derives the same split from the same sorted -# patient list — no cached split file is needed. -# -# Each patient directory under ``src_data_dir_base`` is one ``subject_id``; -# all of that patient's gated time-point frames form a paired training group. -# Frames whose labelmap is missing on disk are dropped from the dataset. - -# %% -import os -from pathlib import Path - -import itk - -from physiomotion4d import WorkflowFineTuneICONRegistration -from physiomotion4d.register_images_icon import RegisterImagesICON - -# %% [markdown] -# ## 1. Configure data, output locations, and the train/test split - -# %% -ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") -src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") -segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") - -# Where the workflow writes the dataset JSON, YAML config, derived masks, and -# the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to -# ``output_dir / fine_tune_name``. -output_dir = Path("./results") -fine_tune_name = "icon_finetuned" - -# Fixed train/test split: sort patients in ``ref_data_dir`` by filename; -# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies -# the same rule so the two scripts agree without a cached split record. -train_fraction = 0.8 - -# Local clone of uniGradICON (feat-add-finetuning branch) — prepended to -# PYTHONPATH so the subprocess picks up the local source instead of the -# installed package. Set to ``None`` to use the pip-installed unigradicon. -unigradicon_src_path: Path | None = Path(__file__).parent / "uniGradICON" / "src" - -# %% [markdown] -# ## 2. Enumerate patients and apply the fixed 80/20 split -# -# Sort ``ref_data_dir`` by filename to produce the canonical patient order. -# The first 80% become the train cohort; the last 20% are the held-out test -# cohort that ``2-recon_4d_icon_eval.py`` will evaluate. - -# %% -ref_files = sorted( - p - for p in ref_data_dir.iterdir() - if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] -) -all_patient_ids = [p.name[:6] for p in ref_files] -print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") - -if len(all_patient_ids) < 2: - raise FileNotFoundError( - f"Need at least 2 patients to form a train/test split; " - f"discovered {len(all_patient_ids)} under {ref_data_dir}" - ) - -n_train = max( - 1, - min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))), -) -train_subjects = all_patient_ids[:n_train] -test_subjects = all_patient_ids[n_train:] -print(f" Train (first {n_train}): {train_subjects}") -print(f" Test (last {len(test_subjects)}): {test_subjects}") - -# %% [markdown] -# ## 3. Gather the train cohort's gated frames and labelmaps -# -# For each train-cohort patient, list gated frames in -# ``src_data_dir_base / `` (excluding ``"nop"`` non-gated -# references) and pair each frame with its -# ``_labelmap.nii.gz`` under ``segmentation_dir_base / ``. -# Patients with no source directory or no valid frames are skipped here only -# — they remain part of the canonical train list above, but contribute no -# training data. Missing labelmaps are recorded as ``None`` so the workflow -# skips just that frame. - -# %% -train_image_files: list[list[str]] = [] -train_segmentation_files: list[list[str | None]] = [] -valid_train_subjects: list[str] = [] - -for patient_id in train_subjects: - src_dir = src_data_dir_base / patient_id - seg_dir = segmentation_dir_base / patient_id - - if not src_dir.is_dir(): - print(f" Skipping {patient_id}: source dir {src_dir} not found") - continue - - frame_names = sorted( - f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") - ) - if not frame_names: - print(f" Skipping {patient_id}: no valid frames in {src_dir}") - continue - - image_paths = [str(src_dir / f) for f in frame_names] - seg_paths: list[str | None] = [] - for f in frame_names: - labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") - seg_paths.append(str(labelmap) if labelmap.exists() else None) - - train_image_files.append(image_paths) - train_segmentation_files.append(seg_paths) - valid_train_subjects.append(patient_id) - - n_seg = sum(1 for s in seg_paths if s is not None) - print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") - -# %% [markdown] -# ## 4. Pre-compute loss-function masks next to each labelmap -# -# Use :meth:`RegisterImagesICON.create_mask` (``>0`` threshold + 5 mm -# physical-radius dilation) to derive each frame's binary heart-ROI mask and -# write it as ``_mask.nii.gz`` in the labelmap's own directory. -# Pre-computing here means the workflow does not have to re-derive masks -# during ``run_fine_tuning`` and the same masks are reused by downstream -# evaluation scripts. - -# %% -mask_dilation_mm = 5.0 -train_mask_files: list[list[str | None]] = [] -for image_paths, seg_paths in zip( - train_image_files, train_segmentation_files, strict=True -): - mask_paths: list[str | None] = [] - for seg_path in seg_paths: - if seg_path is None: - mask_paths.append(None) - continue - seg_p = Path(seg_path) - stem = seg_p.name - stem = stem[:-7] if stem.endswith(".nii.gz") else seg_p.stem - mask_p = seg_p.parent / f"{stem}_mask.nii.gz" - if not mask_p.exists(): - mask = RegisterImagesICON.create_mask( - itk.imread(str(seg_p)), dilation_mm=mask_dilation_mm - ) - itk.imwrite(mask, str(mask_p), compression=True) - mask_paths.append(str(mask_p)) - train_mask_files.append(mask_paths) - -# %% [markdown] -# ## 5. Fine-tune uniGradICON on the train cohort -# -# The workflow consumes both the labelmaps (for paired-with-seg training and -# ``use_label``) and the pre-computed masks (for ``loss_function_masking``) -# and launches ``unigradicon.finetuning.finetune`` as a subprocess. The -# final checkpoint lands at -# :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is -# the default ``--finetuned-weights-path`` read by ``2-recon_4d_icon_eval.py``. - -# %% -workflow = WorkflowFineTuneICONRegistration( - subject_image_files=train_image_files, - output_dir=output_dir, - fine_tune_name=fine_tune_name, - subject_ids=valid_train_subjects, - subject_segmentation_files=train_segmentation_files, - subject_mask_files=train_mask_files, - mask_dilation_mm=mask_dilation_mm, - unigradicon_src_path=unigradicon_src_path, - epochs=100, -) - -weights_path = workflow.run_fine_tuning() -print(f"\nFine-tuning complete. Expected weights at: {weights_path}") -print(f"Held-out test cohort (for 2-recon_4d_icon_eval.py): {test_subjects}") diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py new file mode 100644 index 0000000..e9074b7 --- /dev/null +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -0,0 +1,806 @@ +# %% [markdown] +# # Pre-registration: compare ANTS vs Greedy vs ICON on the Duke gated CT cohort +# +# Registers every gated CT time-point of every Duke patient under +# ``ref_data_dir`` (100% of the cohort -- no train/test split) to that +# patient's reference image, using three backends in turn: +# +# * :class:`RegisterImagesANTS` (CPU, SyN deformable) +# * :class:`RegisterImagesGreedy` (CPU, deformable) +# * :class:`RegisterImagesICON` (GPU, uniGradICON deformable) +# +# For each frame the script records wall-clock registration time, writes +# the warped/resampled moving image to disk, warps the moving labelmap +# into reference space to compute per-label Dice, and warps the moving +# landmarks into reference space to compute squared-error landmark +# statistics (mm^2) against the reference landmarks. +# +# Inputs (same data as ``1-finetune_icon.py``): +# * ``ref_data_dir / pm*_ref.nii.gz`` -- per-patient reference CT +# * ``src_data_dir_base / / *.nii.gz`` -- gated CT frames +# * ``segmentation_dir_base / / _labelmap.nii.gz`` +# -- per-frame multi-label segmentations +# * ``segmentation_dir_base / / _labelmap_mask.nii.gz`` +# -- pre-computed loss-function masks (re-derived on the fly if absent, +# matching the 3 mm dilation used by ``1-finetune_icon.py``) +# * ``segmentation_dir_base / / _landmark.mrk.json`` +# -- per-frame 3D Slicer Markups landmarks in LPS +# +# Outputs under ``results/``: +# * ``ants///.mha``, +# ``greedy///.mha`` and +# ``icon///.mha`` -- warped moving image +# per time point, alongside the forward/inverse transforms (``.hdf``), +# the warped ``_labelmap.mha`` and its warped +# loss-function mask ``_labelmap_mask.mha``, +# and the warped ``_landmark.mrk.json`` +# * ``registration_landmarks_.csv`` -- per-landmark squared errors +# * ``registration_dice_.csv`` -- per-label Dice +# * ``registration_summary_.csv`` -- per-(subject, method, timepoint) +# registration time, per-frame total time, mean Dice, MSE, RMSE +# * ``registration_timing_.csv`` -- per-step wall-clock seconds, +# appended live as each frame's steps complete (register, write_transforms, +# warp_image, warp_labelmap, warp_mask, dice, landmarks, frame_total) +# * ``registration_timing_summary_.csv`` -- per-(method, step) count, +# mean, and total seconds, written once at the end of the run +# +# Run interactively cell-by-cell; all paths are hard-coded. + +# %% +import csv +import re +import time +from pathlib import Path +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.labelmap_tools import LabelmapTools +from physiomotion4d.landmark_tools import LandmarkTools +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% [markdown] +# ## 1. Hard-coded paths and configuration + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +output_dir.mkdir(parents=True, exist_ok=True) + +# Reference frames in gated_nii are named ``_ref.nii.gz``; every +# other ``.nii.gz`` (excluding ``nop`` non-gated references) is a gated +# time point. Timepoint tag ``g###`` is extracted from each filename. +exclude_tokens = ["nop"] +ref_suffix = "_ref" +timepoint_re = re.compile(r"_g(?P[0-9]{3})") + +# Mask dilation matches 1-finetune_icon.py so any masks we have to +# derive here are identical to the ones written by the fine-tune script. +mask_dilation_mm = 3.0 +labelmap_tools = LabelmapTools() + +# Iteration schedules. Kept modest for a cohort-wide comparison; raise +# either list for higher accuracy at the cost of runtime. ANTS and Greedy +# take a multi-resolution list; ICON takes a single per-pair iterative +# optimization step count (0 disables it, using the pretrained forward pass +# alone). +number_of_iterations_ANTS = [40, 20, 10] +number_of_iterations_greedy = [40, 20, 10] +number_of_iterations_ICON = 50 + +# Optional uniGradICON checkpoint (".trch") to load instead of the default +# pretrained weights under ``network_weights/unigradicon1.0/``. When None, +# the default pretrained weights are used. +icon_weights_path: Optional[Path] = None + +methods: list[str] = ["ANTS", "Greedy", "ICON"] + +# Debug knob: when non-empty, only these patient IDs are processed. +# Set to ``[]`` (or ``None``) to run the full cohort. +debug_subjects: list[str] = [] # ["pm0002"] + +run_stamp = time.time() +detail_landmarks_file = output_dir / f"registration_landmarks_{run_stamp}.csv" +detail_dice_file = output_dir / f"registration_dice_{run_stamp}.csv" +summary_file = output_dir / f"registration_summary_{run_stamp}.csv" +# Per-step wall-clock times, appended live as each frame's steps complete. +timing_detail_file = output_dir / f"registration_timing_{run_stamp}.csv" +# Per-(method, step) timing aggregates, written once at the end of the run. +timing_summary_file = output_dir / f"registration_timing_summary_{run_stamp}.csv" +for previous in ( + detail_landmarks_file, + detail_dice_file, + summary_file, + timing_detail_file, + timing_summary_file, +): + if previous.exists(): + previous.unlink() + +# %% [markdown] +# ## 2. Enumerate the full patient cohort +# +# Sort ``ref_data_dir`` by filename so the patient order is stable. +# Every patient is processed -- no train/test split. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") +if debug_subjects: + cohort = [pid for pid in all_patient_ids if pid in debug_subjects] + print( + f"DEBUG: restricting cohort to {debug_subjects} -> " + f"{len(cohort)} matching patients" + ) +else: + cohort = all_patient_ids +print(f"Patient cohort: {cohort}") + +# %% [markdown] +# ## 3. Helpers: labelmap warping, per-label Dice, landmark squared error + +# %% +landmark_tools = LandmarkTools() +transform_tools = TransformTools() + +# Per-step timing records (subject, method, timepoint, step, seconds), +# accumulated in memory for the end-of-run timing summary and mirrored live +# into timing_detail_file as each step finishes. +timing_rows: list[dict[str, object]] = [] + + +def record_step_time( + subject_id: str, + method_name: str, + timepoint: str, + step: str, + seconds: float, +) -> None: + """Report a single processing step's wall-clock time. + + Prints the time immediately, appends a row to ``timing_detail_file`` so + progress is visible while the run is still going, and stores the same + row in ``timing_rows`` for the end-of-run timing summary. + """ + print(f" [time] {step:<18}{seconds:8.2f} s", flush=True) + timing_rows.append( + { + "subject_id": subject_id, + "method": method_name, + "timepoint": timepoint, + "step": step, + "seconds": float(seconds), + } + ) + with timing_detail_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow(["subject_id", "method", "timepoint", "step", "seconds"]) + writer.writerow([subject_id, method_name, timepoint, step, f"{seconds:.6f}"]) + + +def per_label_dice( + fixed_labelmap: itk.Image, warped_labelmap: itk.Image +) -> dict[int, float]: + """Return ``{label_id: Dice}`` for every positive label present in + either the fixed or the warped labelmap. + + Arrays come back from :func:`itk.array_from_image` in shape + ``(Z, Y, X)`` (numpy reverses ITK's index order); we compare element-wise + so the axis convention does not matter as long as both labelmaps live + on the same reference grid (guaranteed because ``warped_labelmap`` was + resampled with ``fixed_labelmap`` as the reference image). + """ + fixed_array = itk.array_from_image(fixed_labelmap) + warped_array = itk.array_from_image(warped_labelmap) + labels = sorted( + {int(v) for v in np.unique(fixed_array)} + | {int(v) for v in np.unique(warped_array)} + ) + labels = [label for label in labels if label > 0] + + dice_by_label: dict[int, float] = {} + for label in labels: + a = fixed_array == label + b = warped_array == label + denom = int(a.sum()) + int(b.sum()) + if denom == 0: + continue + intersection = int(np.logical_and(a, b).sum()) + dice_by_label[label] = 2.0 * intersection / denom + return dice_by_label + + +def warp_landmarks( + inverse_transform: itk.Transform, + moving_landmarks: dict[str, tuple[float, float, float]], +) -> dict[str, tuple[float, float, float]]: + """Warp every moving landmark into reference space. + + Point/landmark warping uses ``inverse_transform`` -- the moving-space -> + fixed-space point map -- which is the opposite of the transform used to + warp the moving image onto the fixed grid (images pull back; points push + forward). Returns a ``{label: (x, y, z)}`` dict in LPS. See + docs/developer/transform_conventions. + """ + return { + name: tuple(float(c) for c in inverse_transform.TransformPoint(point)) + for name, point in moving_landmarks.items() + } + + +def landmark_squared_errors( + warped_landmarks: dict[str, tuple[float, float, float]], + reference_landmarks: dict[str, tuple[float, float, float]], +) -> list[tuple[str, float]]: + """Return per-landmark squared Euclidean error in mm^2 between the + reference-space ``warped_landmarks`` and the matching reference + landmarks, in sorted-name order. + """ + shared = sorted(warped_landmarks.keys() & reference_landmarks.keys()) + errors: list[tuple[str, float]] = [] + for name in shared: + diff = np.asarray(warped_landmarks[name], dtype=np.float64) - np.asarray( + reference_landmarks[name], dtype=np.float64 + ) + errors.append((name, float(np.dot(diff, diff)))) + return errors + + +def load_or_derive_mask(labelmap: itk.Image, mask_path: Path) -> itk.Image: + """Return the cached ``_labelmap_mask.nii.gz`` next to the + labelmap, or derive it via + :meth:`LabelmapTools.convert_labelmap_to_mask` (threshold ``>0`` plus + 3 mm physical-radius dilation) and write it out so subsequent runs and + the ICON eval reuse the same mask. + """ + # Force mask update + # if mask_path.exists(): + # return itk.imread(str(mask_path)) + mask = labelmap_tools.convert_labelmap_to_mask( + labelmap, + dilation_in_mm=mask_dilation_mm, + exclude_labels=[1, 2, 3, 4], + # These are labels for the interior of the heart chambers (the LV, RV, LA, RA) + ) + itk.imwrite(mask, str(mask_path), compression=True) + return mask + + +# %% [markdown] +# ## 4. Drive the comparison: every patient x every method +# +# For each patient: load the reference image, labelmap, mask, and +# landmarks; load every gated frame (excluding ``nop`` and ``_ref``) with +# its labelmap, mask, and landmarks; then register each frame to the +# reference under both backends. Each frame starts from identity so the +# ANTS-vs-Greedy comparison is independent across frames. + +# %% +summary_rows: list[dict[str, object]] = [] + +# (subject_id, method, timepoint) for frames that produced no usable +# landmark errors -- either no landmark file or no labels shared with the +# reference. Echoed in a highlighted block at the end of the run. +frames_missing_landmarks: list[tuple[str, str, str]] = [] + +for subject_index, subject_id in enumerate(cohort): + print(f"\n=== Subject {subject_index + 1}/{len(cohort)}: {subject_id} ===") + src_dir = src_data_dir_base / subject_id + seg_dir = segmentation_dir_base / subject_id + + if not src_dir.is_dir(): + print(f" Skipping {subject_id}: source dir {src_dir} not found") + continue + if not seg_dir.is_dir(): + print(f" Skipping {subject_id}: segmentation dir {seg_dir} not found") + continue + + # Locate this patient's reference frame in gated_nii (matches the + # `_ref.nii.gz` filename under ref_data_dir). + ref_file = next((p for p in ref_files if p.name.startswith(subject_id)), None) + if ref_file is None: + print(f" Skipping {subject_id}: no reference image found") + continue + ref_stem = ref_file.name[:-7] + ref_labelmap_path = seg_dir / f"{ref_stem}_labelmap.nii.gz" + ref_mask_path = seg_dir / f"{ref_stem}_labelmap_mask.nii.gz" + ref_landmark_path = seg_dir / f"{ref_stem}_landmark.mrk.json" + if not ref_labelmap_path.exists() or not ref_landmark_path.exists(): + print( + f" Skipping {subject_id}: missing reference labelmap or " + f"landmarks under {seg_dir}" + ) + continue + + fixed_image = itk.imread(str(ref_file), pixel_type=itk.F) + fixed_labelmap = itk.imread(str(ref_labelmap_path)) + fixed_mask = load_or_derive_mask(fixed_labelmap, ref_mask_path) + reference_landmarks = landmark_tools.read_landmarks_3dslicer(ref_landmark_path) + + # Gated moving frames (exclude `nop` and the `_ref` frame itself). + gated_files = sorted( + p + for p in src_dir.glob("*.nii.gz") + if not any(token in p.name for token in exclude_tokens) + and not p.name.endswith(f"{ref_suffix}.nii.gz") + ) + moving_records: list[dict[str, object]] = [] + for image_path in gated_files: + stem = image_path.name[:-7] + labelmap_path = seg_dir / f"{stem}_labelmap.nii.gz" + mask_path = seg_dir / f"{stem}_labelmap_mask.nii.gz" + landmark_path = seg_dir / f"{stem}_landmark.mrk.json" + if not labelmap_path.exists(): + print(f" Dropping {stem}: no labelmap at {labelmap_path}") + continue + match = timepoint_re.search(image_path.name) + if match is None: + print(f" Dropping {stem}: no g### timepoint tag in name") + continue + moving_records.append( + { + "stem": stem, + "timepoint": match.group("timepoint"), + "image_path": image_path, + "labelmap_path": labelmap_path, + "mask_path": mask_path, + "landmark_path": landmark_path if landmark_path.exists() else None, + } + ) + if not moving_records: + print(f" Skipping {subject_id}: no usable gated frames") + continue + + print(f" {len(moving_records)} moving frames; reference {ref_file.name}") + + print(f" Loading {len(moving_records)} moving images / labelmaps / masks ...") + moving_images = [] + moving_labelmaps = [] + moving_masks = [] + moving_landmarks_list: list[Optional[dict[str, tuple[float, float, float]]]] = [] + for r_index, r in enumerate(moving_records): + print( + f" [{r_index + 1}/{len(moving_records)}] g{r['timepoint']} {r['stem']}" + ) + moving_image = itk.imread(str(r["image_path"]), pixel_type=itk.F) + labelmap = itk.imread(str(r["labelmap_path"])) + moving_images.append(moving_image) + moving_labelmaps.append(labelmap) + moving_masks.append(load_or_derive_mask(labelmap, r["mask_path"])) + landmark_path = r["landmark_path"] + if landmark_path is None: + moving_landmarks_list.append(None) + else: + moving_landmarks_list.append( + landmark_tools.read_landmarks_3dslicer(landmark_path) + ) + + for method_name in methods: + print(f"\n --- Method: {method_name} ---") + if method_name == "ANTS": + reg = RegisterImagesANTS() + reg.set_number_of_iterations(number_of_iterations_ANTS) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + reg.set_metric("CC") + elif method_name == "Greedy": + reg = RegisterImagesGreedy() + reg.set_number_of_iterations(number_of_iterations_greedy) + reg.set_transform_type("Deformable") + # NCC ("CC") beats MeanSquares for same-modality CT registration. + reg.set_metric("CC") + else: # ICON: GPU deep-learning deformable registration. + reg = RegisterImagesICON() + reg.set_number_of_iterations(number_of_iterations_ICON) + if icon_weights_path is not None: + reg.set_weights_path(str(icon_weights_path)) + reg.set_modality("ct") + reg.set_mask_dilation(mask_dilation_mm) + reg.set_fixed_image(fixed_image) + reg.set_fixed_mask(fixed_mask) + + method_dir = output_dir / method_name.lower() / subject_id + method_dir.mkdir(parents=True, exist_ok=True) + + method_t_start = time.perf_counter() + for index, record in enumerate(moving_records): + timepoint = record["timepoint"] + stem = record["stem"] + print( + f" [{method_name} {index + 1}/{len(moving_records)}] " + f"g{timepoint} registering ...", + flush=True, + ) + + frame_total_start = time.perf_counter() + frame_t_start = frame_total_start + reg_result = reg.register( + moving_image=moving_images[index], + moving_mask=moving_masks[index], + ) + frame_elapsed = time.perf_counter() - frame_t_start + + forward_transform = reg_result["forward_transform"] + inverse_transform = reg_result["inverse_transform"] + frame_loss = float(reg_result["loss"]) + print(f" done in {frame_elapsed:.1f} s, loss={frame_loss:.4f}") + record_step_time( + subject_id, method_name, timepoint, "register", frame_elapsed + ) + + step_t_start = time.perf_counter() + itk.transformwrite( + forward_transform, + str(method_dir / f"{stem}_fwd.hdf"), + compression=True, + ) + itk.transformwrite( + inverse_transform, + str(method_dir / f"{stem}_inv.hdf"), + compression=True, + ) + record_step_time( + subject_id, + method_name, + timepoint, + "write_transforms", + time.perf_counter() - step_t_start, + ) + + # Warp the moving image into reference space and save it + # (forward_transform resamples the moving image onto the fixed grid). + step_t_start = time.perf_counter() + warped_image = transform_tools.transform_image( + moving_images[index], + forward_transform, + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_image, + str(method_dir / f"{stem}.mha"), + compression=True, + ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_image", + time.perf_counter() - step_t_start, + ) + + # Warp the moving labelmap onto the fixed grid (forward_transform; + # nearest neighbour preserves label IDs) for per-label Dice. + step_t_start = time.perf_counter() + warped_labelmap = transform_tools.transform_image( + moving_labelmaps[index], + forward_transform, + fixed_labelmap, + interpolation_method="nearest", + ) + itk.imwrite( + warped_labelmap, + str(method_dir / f"{stem}_labelmap.mha"), + compression=True, + ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_labelmap", + time.perf_counter() - step_t_start, + ) + + # Warp the moving loss-function mask onto the fixed grid + # (forward_transform; nearest neighbour preserves the binary ROI) + # so downstream fine-tuning reuses it instead of re-deriving a + # mask from the warped labelmap. + step_t_start = time.perf_counter() + warped_mask = transform_tools.transform_image( + moving_masks[index], + forward_transform, + fixed_mask, + interpolation_method="nearest", + ) + itk.imwrite( + warped_mask, + str(method_dir / f"{stem}_labelmap_mask.mha"), + compression=True, + ) + record_step_time( + subject_id, + method_name, + timepoint, + "warp_mask", + time.perf_counter() - step_t_start, + ) + + step_t_start = time.perf_counter() + dice_by_label = per_label_dice(fixed_labelmap, warped_labelmap) + with detail_dice_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + ["subject_id", "method", "timepoint", "label", "dice"] + ) + for label, dice in dice_by_label.items(): + writer.writerow([subject_id, method_name, timepoint, label, dice]) + mean_dice = ( + float(np.mean(list(dice_by_label.values()))) + if dice_by_label + else float("nan") + ) + record_step_time( + subject_id, + method_name, + timepoint, + "dice", + time.perf_counter() - step_t_start, + ) + + # Warp the moving landmarks into reference space, save them next + # to the transforms, then score squared error vs the reference. + step_t_start = time.perf_counter() + moving_landmarks = moving_landmarks_list[index] + if moving_landmarks is None: + sq_errors: list[tuple[str, float]] = [] + else: + warped_landmarks = warp_landmarks(inverse_transform, moving_landmarks) + landmark_tools.write_landmarks_3dslicer( + warped_landmarks, + str(method_dir / f"{stem}_landmark.mrk.json"), + ) + sq_errors = landmark_squared_errors( + warped_landmarks, reference_landmarks + ) + with detail_landmarks_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + [ + "subject_id", + "method", + "timepoint", + "name", + "sq_err_mm2", + ] + ) + for name, sq_err in sq_errors: + writer.writerow([subject_id, method_name, timepoint, name, sq_err]) + record_step_time( + subject_id, + method_name, + timepoint, + "landmarks", + time.perf_counter() - step_t_start, + ) + + sq_values = np.asarray([e for _, e in sq_errors], dtype=np.float64) + if sq_values.size: + mse_mm2 = float(np.mean(sq_values)) + rmse_mm = float(np.sqrt(mse_mm2)) + else: + mse_mm2 = float("nan") + rmse_mm = float("nan") + # Highlight frames with no usable landmarks so they are not + # silently scored as NaN in the CSV / summary table. + reason = ( + "no landmark file" + if moving_landmarks is None + else "no landmarks shared with reference" + ) + frames_missing_landmarks.append((subject_id, method_name, timepoint)) + print( + f" >>> WARNING: {subject_id} {method_name} " + f"g{timepoint} has NO landmarks ({reason})", + flush=True, + ) + + frame_total = time.perf_counter() - frame_total_start + record_step_time( + subject_id, method_name, timepoint, "frame_total", frame_total + ) + + summary_rows.append( + { + "subject_id": subject_id, + "method": method_name, + "timepoint": timepoint, + "time_sec": float(frame_elapsed), + "frame_total_sec": float(frame_total), + "loss": frame_loss, + "n_labels": int(len(dice_by_label)), + "mean_dice": mean_dice, + "n_landmarks": int(sq_values.size), + "mse_mm2": mse_mm2, + "rmse_mm": rmse_mm, + } + ) + + method_elapsed = time.perf_counter() - method_t_start + print( + f" [{method_name}] subject {subject_id} total " + f"{method_elapsed:.1f} s " + f"({method_elapsed / len(moving_records):.1f} s/frame)" + ) + +# %% [markdown] +# ## 5. Write the per-(subject, method, timepoint) summary CSV + +# %% +if summary_rows: + with summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(summary_rows[0].keys())) + writer.writeheader() + writer.writerows(summary_rows) + print(f"\nWrote summary: {summary_file}") + print(f"Wrote landmarks: {detail_landmarks_file}") + print(f"Wrote dice: {detail_dice_file}") + print(f"Wrote timing: {timing_detail_file}") +else: + print("\nNo frames processed; nothing to summarize.") + +# %% [markdown] +# ## 5b. Highlight frames that produced no landmark errors + +# %% +if frames_missing_landmarks: + banner = "!" * 70 + print(f"\n{banner}") + print( + f"WARNING: {len(frames_missing_landmarks)} frame(s) missing ALL " + f"landmarks (scored as NaN):" + ) + for subject_id, method_name, timepoint in frames_missing_landmarks: + print(f" - {subject_id} {method_name} g{timepoint}") + print(banner) +else: + print("\nAll processed frames had at least one scored landmark.") + +# %% [markdown] +# ## 6. Per-method aggregate table across the whole cohort +# +# Reports mean per-frame registration time, mean / median / p95 of the +# squared landmark errors (mm^2), the matching RMSE in mm, and the mean +# per-label Dice averaged across (subject, timepoint, label) entries. + +# %% +if summary_rows: + sq_by_method: dict[str, list[float]] = {} + with detail_landmarks_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + sq_by_method.setdefault(row["method"], []).append(float(row["sq_err_mm2"])) + + dice_by_method: dict[str, list[float]] = {} + with detail_dice_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + dice_by_method.setdefault(row["method"], []).append(float(row["dice"])) + + time_by_method: dict[str, list[float]] = {} + for row in summary_rows: + method_name = str(row["method"]) + time_by_method.setdefault(method_name, []).append(float(row["time_sec"])) + + header = ( + f"{'Method':<10}{'N_lm':>8}{'MSE(mm2)':>12}{'RMSE(mm)':>12}" + f"{'p95(mm2)':>12}{'meanDice':>12}{'sec/frame':>12}" + ) + print() + print("=" * len(header)) + print(f"Pre-registration comparison ({len(all_patient_ids)} patients)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for method_name in methods: + sq_arr = np.asarray(sq_by_method.get(method_name, []), dtype=np.float64) + dice_arr = np.asarray(dice_by_method.get(method_name, []), dtype=np.float64) + time_arr = np.asarray(time_by_method.get(method_name, []), dtype=np.float64) + if sq_arr.size == 0: + print(f"{method_name:<10}{0:>8}{'':>12}{'':>12}{'':>12}{'':>12}{'':>12}") + continue + mse = float(np.mean(sq_arr)) + rmse = float(np.sqrt(mse)) + p95 = float(np.percentile(sq_arr, 95)) + mean_dice_val = float(np.mean(dice_arr)) if dice_arr.size else float("nan") + mean_time = float(np.mean(time_arr)) if time_arr.size else float("nan") + print( + f"{method_name:<10}" + f"{sq_arr.size:>8}" + f"{mse:>12.3f}" + f"{rmse:>12.3f}" + f"{p95:>12.3f}" + f"{mean_dice_val:>12.3f}" + f"{mean_time:>12.2f}" + ) + print("=" * len(header)) + +# %% [markdown] +# ## 7. Per-(method, step) timing summary +# +# Aggregates the live per-step timings into mean and total wall-clock +# seconds per (method, step), printed as a table and written to +# ``timing_summary_file``. ``frame_total`` is the end-to-end per-frame +# time (register + all warps/writes + scoring); the other rows are its +# components. + +# %% +if timing_rows: + # Preserve the pipeline order in which steps are timed; any unexpected + # step name is appended in first-seen order so nothing is dropped. + step_order = [ + "register", + "write_transforms", + "warp_image", + "warp_labelmap", + "warp_mask", + "dice", + "landmarks", + "frame_total", + ] + seconds_by_method_step: dict[str, dict[str, list[float]]] = {} + for row in timing_rows: + method_name = str(row["method"]) + step = str(row["step"]) + seconds = float(row["seconds"]) + seconds_by_method_step.setdefault(method_name, {}).setdefault(step, []).append( + seconds + ) + if step not in step_order: + step_order.append(step) + + timing_summary_rows: list[dict[str, object]] = [] + timing_header = ( + f"{'Method':<10}{'Step':<18}{'N':>6}{'mean_sec':>12}{'total_sec':>12}" + ) + print() + print("=" * len(timing_header)) + print("Timing summary (wall-clock seconds)") + print("=" * len(timing_header)) + print(timing_header) + print("-" * len(timing_header)) + for method_name in methods: + step_times = seconds_by_method_step.get(method_name, {}) + if not step_times: + continue + for step in step_order: + values = step_times.get(step) + if not values: + continue + arr = np.asarray(values, dtype=np.float64) + mean_sec = float(np.mean(arr)) + total_sec = float(np.sum(arr)) + timing_summary_rows.append( + { + "method": method_name, + "step": step, + "n": int(arr.size), + "mean_sec": mean_sec, + "total_sec": total_sec, + } + ) + print( + f"{method_name:<10}{step:<18}{arr.size:>6}" + f"{mean_sec:>12.2f}{total_sec:>12.2f}" + ) + print("-" * len(timing_header)) + print("=" * len(timing_header)) + + with timing_summary_file.open("w", newline="", encoding="utf-8") as fh: + writer = csv.DictWriter(fh, fieldnames=list(timing_summary_rows[0].keys())) + writer.writeheader() + writer.writerows(timing_summary_rows) + print(f"Wrote timing summary: {timing_summary_file}") diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py new file mode 100644 index 0000000..8ed8f6d --- /dev/null +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -0,0 +1,292 @@ +# %% [markdown] +# # Fine-tune uniGradICON on Duke 4D Gated CT Data +# +# Discovers per-patient gated CT images and their precomputed +# SegmentHeartSimpleware labelmaps and applies the project-wide fixed 80/20 +# train/test split (sort patients in ``ref_data_dir`` by filename; the first +# 80% are train, the last 20% are test). The train cohort is handed to +# :class:`WorkflowFineTuneICONRegistration`, which builds the paired dataset +# JSON, YAML config, and derived loss-function masks, then launches +# ``unigradicon.finetuning.finetune`` as a subprocess. +# +# ``2-recon_4d_icon_eval.py`` re-derives the same split from the same sorted +# patient list — no cached split file is needed. +# +# Each patient directory under ``src_data_dir_base`` is one ``subject_id``; +# all of that patient's gated time-point frames form a paired training group. +# Frames whose labelmap is missing on disk are dropped from the dataset. +# +# In addition to the original ``gated_nii`` frames, each patient's training +# group is augmented with that patient's ANTS- and Greedy-warped frames +# written by ``1-preregistration.py`` (warped image + labelmap per gated +# frame, under ``output_dir / / ``). Because the warped +# frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the +# original gated frames and both backends' pre-registered frames together. + +# %% +import os +from pathlib import Path +from typing import Optional + +import itk + +from physiomotion4d import WorkflowFineTuneICONRegistration +from physiomotion4d.labelmap_tools import LabelmapTools + +# %% [markdown] +# ## 1. Configure data, output locations, and the train/test split + +# %% +ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") +src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") +segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") + +# Where the workflow writes the dataset JSON, YAML config, derived masks, and +# the uniGradICON ``checkpoints/`` tree. experiment_dir resolves to +# ``output_dir / fine_tune_name``. +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +fine_tune_name = "icon_finetuned" + +# Pre-registration augmentation: ``1-preregistration.py`` warps every gated +# moving frame into reference space with these backends and writes the warped +# image + labelmap under ``preregistration_dir / .lower() / +# ``. Those warped frames are merged into each patient's training +# group below (section 4b). +preregistration_dir = output_dir +preregistration_methods = ["ANTS", "greedy"] + +# Fixed train/test split: sort patients in ``ref_data_dir`` by filename; +# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies +# the same rule so the two scripts agree without a cached split record. +train_fraction = 0.8 + +# Local clone of uniGradICON (feat-add-finetuning branch) — prepended to +# PYTHONPATH so the subprocess picks up the local source instead of the +# installed package. Set to ``None`` to use the pip-installed unigradicon. +unigradicon_src_path: Optional[Path] = Path(__file__).parent / "uniGradICON" / "src" + +# %% [markdown] +# ## 2. Enumerate patients and apply the fixed 80/20 split +# +# Sort ``ref_data_dir`` by filename to produce the canonical patient order. +# The first 80% become the train cohort; the last 20% are the held-out test +# cohort that ``2-recon_4d_icon_eval.py`` will evaluate. + +# %% +ref_files = sorted( + p + for p in ref_data_dir.iterdir() + if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +) +all_patient_ids = [p.name[:6] for p in ref_files] +print(f"Found {len(all_patient_ids)} patients under {ref_data_dir}") + +if len(all_patient_ids) < 2: + raise FileNotFoundError( + f"Need at least 2 patients to form a train/test split; " + f"discovered {len(all_patient_ids)} under {ref_data_dir}" + ) + +n_train = max( + 1, + min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))), +) +train_subjects = all_patient_ids[:n_train] +test_subjects = all_patient_ids[n_train:] +print(f" Train (first {n_train}): {train_subjects}") +print(f" Test (last {len(test_subjects)}): {test_subjects}") + +# %% [markdown] +# ## 3. Gather the train cohort's gated frames and labelmaps +# +# For each train-cohort patient, list gated frames in +# ``src_data_dir_base / `` (excluding ``"nop"`` non-gated +# references) and pair each frame with its +# ``_labelmap.nii.gz`` under ``segmentation_dir_base / ``. +# Patients with no source directory or no valid frames are skipped here only +# — they remain part of the canonical train list above, but contribute no +# training data. Missing labelmaps are recorded as ``None`` so the workflow +# skips just that frame. + +# %% +train_image_files: list[list[str]] = [] +train_segmentation_files: list[list[Optional[str]]] = [] +valid_train_subjects: list[str] = [] + +for patient_id in train_subjects: + src_dir = src_data_dir_base / patient_id + seg_dir = segmentation_dir_base / patient_id + + if not src_dir.is_dir(): + print(f" Skipping {patient_id}: source dir {src_dir} not found") + continue + + frame_names = sorted( + f for f in os.listdir(src_dir) if "nop" not in f and f.endswith(".nii.gz") + ) + if not frame_names: + print(f" Skipping {patient_id}: no valid frames in {src_dir}") + continue + + image_paths = [str(src_dir / f) for f in frame_names] + seg_paths: list[Optional[str]] = [] + for f in frame_names: + labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") + seg_paths.append(str(labelmap) if labelmap.exists() else None) + + train_image_files.append(image_paths) + train_segmentation_files.append(seg_paths) + valid_train_subjects.append(patient_id) + + n_seg = sum(1 for s in seg_paths if s is not None) + print(f" {patient_id}: {len(image_paths)} frames, {n_seg} with labelmap") + +# %% [markdown] +# ## 4. Pre-compute loss-function masks next to each labelmap +# +# Use :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` threshold + 5 mm +# physical-radius dilation) to derive each frame's binary heart-ROI mask and +# write it as ``_mask.nii.gz`` in the labelmap's own directory. +# Pre-computing here means the workflow does not have to re-derive masks +# during ``run_fine_tuning`` and the same masks are reused by downstream +# evaluation scripts. + +# %% +mask_dilation_mm = 5.0 +labelmap_tools = LabelmapTools() + + +def derive_mask_for(labelmap_path: Path) -> str: + """Create (or reuse) a loss-function mask next to ``labelmap_path``. + + Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm + via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as + ``_mask.nii.gz`` in the labelmap's own directory. Handles + both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` + (pre-registration warped labelmaps). Returns the mask path as a string; + existing masks on disk are reused unmodified. + """ + name = labelmap_path.name + if name.endswith(".nii.gz"): + stem = name[:-7] + elif name.endswith(".mha"): + stem = name[:-4] + else: + stem = labelmap_path.stem + mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz" + if not mask_p.exists(): + mask = labelmap_tools.convert_labelmap_to_mask( + itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm + ) + itk.imwrite(mask, str(mask_p), compression=True) + return str(mask_p) + + +train_mask_files: list[list[Optional[str]]] = [] +for seg_paths in train_segmentation_files: + train_mask_files.append( + [derive_mask_for(Path(s)) if s is not None else None for s in seg_paths] + ) + +# %% [markdown] +# ## 4b. Merge ANTS / Greedy pre-registered frames into each training group +# +# ``1-preregistration.py`` warps every gated moving frame into reference space +# with the ANTS and Greedy backends, writing ``.mha`` (warped image), +# ``_labelmap.mha`` (warped labelmap), and ``_deformation_grid.mha`` +# under ``preregistration_dir / / ``. Here those warped +# frames + labelmaps (with derived loss masks) are appended to the *same* +# patient's training group, so uniGradICON pairs the original gated frames and +# both backends' pre-registered frames together (they share a ``subject_id``). +# Patients/methods with no pre-registration output on disk are skipped. + + +# %% +def gather_warped_frames(method_dir: Path) -> tuple[list[str], list[Optional[str]]]: + """Return ``(warped_image_paths, warped_labelmap_paths)`` for one + ``preregistration_dir / / `` directory. + + Enumerates the warped moving images (``.mha``), excluding the + ``_labelmap.mha``, ``_labelmap_mask.mha``, and ``_deformation_grid.mha`` + companions, and pairs each with its ``_labelmap.mha`` (``None`` when + that labelmap is absent). Returns empty lists when ``method_dir`` does + not exist. + """ + if not method_dir.is_dir(): + return [], [] + companion_suffixes = ( + "_labelmap.mha", + "_labelmap_mask.mha", + "_deformation_grid.mha", + ) + image_paths: list[str] = [] + labelmap_paths: list[Optional[str]] = [] + for mha in sorted(method_dir.glob("*.mha")): + if mha.name.endswith(companion_suffixes): + continue + stem = mha.name[:-4] + labelmap = method_dir / f"{stem}_labelmap.mha" + image_paths.append(str(mha)) + labelmap_paths.append(str(labelmap) if labelmap.exists() else None) + return image_paths, labelmap_paths + + +for subject_index, patient_id in enumerate(valid_train_subjects): + for method_name in preregistration_methods: + method_dir = preregistration_dir / method_name.lower() / patient_id + warped_images, warped_labelmaps = gather_warped_frames(method_dir) + if not warped_images: + print( + f" {patient_id}/{method_name}: no pre-registered frames " + f"in {method_dir}" + ) + continue + warped_masks: list[Optional[str]] = [] + for lm in warped_labelmaps: + if lm is None: + warped_masks.append(None) + continue + # 1-preregistration.py writes the warped loss mask next to the + # warped labelmap; prefer it, deriving one only if it is absent. + warped_mask = Path(f"{lm[:-4]}_mask.mha") + warped_masks.append( + str(warped_mask) if warped_mask.exists() else derive_mask_for(Path(lm)) + ) + train_image_files[subject_index].extend(warped_images) + train_segmentation_files[subject_index].extend(warped_labelmaps) + train_mask_files[subject_index].extend(warped_masks) + n_seg = sum(1 for lm in warped_labelmaps if lm is not None) + print( + f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, " + f"{n_seg} with labelmap" + ) + +# %% [markdown] +# ## 5. Fine-tune uniGradICON on the train cohort +# +# Each train group now holds the original gated frames plus the merged ANTS +# and Greedy pre-registered frames (section 4b). The workflow consumes both +# the labelmaps (for paired-with-seg training) and the pre-computed masks (for +# ``loss_function_masking``) +# and launches ``unigradicon.finetuning.finetune`` as a subprocess. The +# final checkpoint lands at +# :meth:`WorkflowFineTuneICONRegistration.expected_weights_path`, which is +# the default ``--finetuned-weights-path`` read by ``2-recon_4d_icon_eval.py``. + +# %% +workflow = WorkflowFineTuneICONRegistration( + subject_image_files=train_image_files, + output_dir=output_dir, + fine_tune_name=fine_tune_name, + subject_ids=valid_train_subjects, + subject_segmentation_files=train_segmentation_files, + subject_mask_files=train_mask_files, + mask_dilation_mm=mask_dilation_mm, + unigradicon_src_path=unigradicon_src_path, + epochs=500, +) + +weights_path = workflow.run_fine_tuning() +print(f"\nFine-tuning complete. Expected weights at: {weights_path}") +print(f"Held-out test cohort (for 2-recon_4d_icon_eval.py): {test_subjects}") diff --git a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py similarity index 83% rename from experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py rename to experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py index 83ff66c..bd462b4 100644 --- a/experiments/LongitudinalRegistration/2-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py @@ -25,8 +25,8 @@ import numpy as np from physiomotion4d import RegisterTimeSeriesImages +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.landmark_tools import LandmarkTools -from physiomotion4d.register_images_icon import RegisterImagesICON # %% [markdown] # ## 1. Hard-coded paths and configuration @@ -35,15 +35,21 @@ ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") timepoint_base_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -output_dir = Path("./results") -finetuned_weights_path = Path( - "./results/icon_finetuned/checkpoints/Finetune_multi_final.trch" + +_HERE = Path(__file__).parent +output_dir = _HERE / "results" +finetuned_weights_path = ( + output_dir + / "icon_finetuned" + / "icon_finetuned_model" + / "checkpoints" + / "network_weights_final.trch" ) train_fraction = 0.8 -icon_iterations = 20 +icon_iterations = None reference_percentile = 0.70 -exclude_tokens = ("nop", "dia", "sys", "_ref") +exclude_tokens = ["nop"] timepoint_re = re.compile(r"_g(?P[0-9]{3})") methods: list[tuple[str, Optional[Path]]] = [ @@ -87,12 +93,13 @@ # Landmarks are read with :meth:`LandmarkTools.read_landmarks_3dslicer` — # they were written as ``_landmark.mrk.json`` (3D Slicer Markups JSON, # LPS) by ``0-cardiacGatedCT_segment_and_landmark.py``. Binary registration -# masks come from :meth:`RegisterImagesICON.create_mask` (``>0`` threshold -# plus 5 mm dilation by default), matching the loss-function masks used +# masks come from :meth:`LabelmapTools.convert_labelmap_to_mask` (``>0`` +# threshold plus 5 mm dilation), matching the loss-function masks used # during fine-tuning in ``1-finetune_icon.py``. # %% landmark_tools = LandmarkTools() +labelmap_tools = LabelmapTools() # %% [markdown] @@ -103,15 +110,20 @@ for subject_id in test_subjects: source_dir = timepoint_base_dir / subject_id + print(f"Source directory: {source_dir}") + seg_dir = segmentation_base_dir / subject_id + print(f"Segmentation directory: {seg_dir}") image_files = [ p for p in sorted(source_dir.glob("*.nii.gz")) if not any(t in p.name for t in exclude_tokens) ] + print(f"Found {len(image_files)} image files") stems = [p.name[:-7] for p in image_files] labelmap_files = [seg_dir / f"{s}_labelmap.nii.gz" for s in stems] + mask_files = [seg_dir / f"{s}_labelmap_mask.nii.gz" for s in stems] landmark_files = [seg_dir / f"{s}_landmark.mrk.json" for s in stems] timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] @@ -122,17 +134,34 @@ ) fixed_image = itk.imread(str(image_files[reference_index]), pixel_type=itk.F) - fixed_mask = RegisterImagesICON.create_mask( - itk.imread(str(labelmap_files[reference_index])) - ) + fixed_labelmap = itk.imread(str(labelmap_files[reference_index])) + if mask_files[reference_index].exists(): + fixed_mask = itk.imread(str(mask_files[reference_index])) + else: + fixed_mask = labelmap_tools.convert_labelmap_to_mask( + fixed_labelmap, dilation_in_mm=5.0 + ) + itk.imwrite(fixed_mask, str(mask_files[reference_index]), compression=True) reference_landmarks = landmark_tools.read_landmarks_3dslicer( landmark_files[reference_index] ) moving_images = [itk.imread(str(p), pixel_type=itk.F) for p in image_files] - moving_masks = [ - RegisterImagesICON.create_mask(itk.imread(str(p))) for p in labelmap_files + moving_labelmaps = [itk.imread(str(p)) for p in labelmap_files] + moving_landmarks = [ + landmark_tools.read_landmarks_3dslicer(str(p)) for p in landmark_files ] + moving_masks = [] + for index, p in enumerate(mask_files): + if not p.exists(): + mask = labelmap_tools.convert_labelmap_to_mask( + moving_labelmaps[index], dilation_in_mm=5.0 + ) + itk.imwrite(mask, str(p), compression=True) + moving_masks.append(mask) + else: + mask = itk.imread(str(p)) + moving_masks.append(mask) for method_name, weights_path in methods: print(f" Method: {method_name}") @@ -147,7 +176,7 @@ result = registrar.register_time_series( moving_images=moving_images, moving_masks=moving_masks, - moving_labelmaps=None, + moving_labelmaps=moving_labelmaps, reference_frame=reference_index, register_reference=False, prior_weight=0.0, @@ -178,9 +207,7 @@ # inverse_transform follows the ITK resampler convention — it maps # moving-grid points back to reference-grid points, which is what # we need to warp time-point landmarks into reference space. - timepoint_landmarks = landmark_tools.read_landmarks_3dslicer( - landmark_files[index] - ) + timepoint_landmarks = moving_landmarks[index] shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) errors: list[tuple[str, float]] = [] for name in shared: diff --git a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py b/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py similarity index 98% rename from experiments/LongitudinalRegistration/3-run_registration_method_comparison.py rename to experiments/LongitudinalRegistration/4-recon_4d_all_eval.py index 5467cad..3a81f10 100644 --- a/experiments/LongitudinalRegistration/3-run_registration_method_comparison.py +++ b/experiments/LongitudinalRegistration/4-recon_4d_all_eval.py @@ -494,9 +494,15 @@ def run_method_for_subject( if artifacts.landmark_file is not None: timepoint_landmarks = read_landmarks(artifacts.landmark_file) + # Warp the reference landmarks into the timepoint (moving) space to + # compare against this timepoint's landmarks. Warping reference -> + # time POINTS uses forward_transform (the fixed -> moving point map), + # which is the opposite of the reference_to_time IMAGE above (images + # pull back, points push forward). See + # docs/developer/transform_conventions. direct_landmarks = transform_landmarks( reference_landmarks, - inverse_transform, + forward_transform, ) direct_errors = landmark_errors(direct_landmarks, timepoint_landmarks) write_error_details( diff --git a/experiments/LongitudinalRegistration/registration_test.py b/experiments/LongitudinalRegistration/registration_test.py new file mode 100644 index 0000000..0c66ef3 --- /dev/null +++ b/experiments/LongitudinalRegistration/registration_test.py @@ -0,0 +1,134 @@ +# %% [markdown] +# # Registration test: pm0003 time point 20 -> time point 60 +# +# Registers pm0003 gated CT time point 20 (moving) to time point 60 +# (fixed) with deformable registration, then warps time point 20 into +# time point 60's space and writes it to disk. +# +# Switch backends by editing the single ``method`` variable below +# ("ANTS", "ICON", or "Greedy"). All paths are hard-coded; run the +# cells top to bottom. + +# %% +import time +from pathlib import Path + +import itk + +from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy +from physiomotion4d.register_images_icon import RegisterImagesICON +from physiomotion4d.transform_tools import TransformTools + +# %% [markdown] +# ## 1. Configuration and hard-coded paths +# +# Change ``method`` to switch backends. Time point 20 is the moving +# image; time point 60 is the fixed image. + +# %% +method = "Greedy" # one of: "ANTS", "ICON", "Greedy" + +data_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii/pm0003") +moving_path = data_dir / "pm0003_dupr_135-0094_135_4700_g020_s2.000_n0058_11.nii.gz" +fixed_path = data_dir / "pm0003_dupr_135-0094_135_4700_g060_s2.000_n0058_15.nii.gz" + +output_dir = Path(__file__).parent / "results" / "registration_test" +output_dir.mkdir(parents=True, exist_ok=True) +output_path = output_dir / f"pm0003_g020_to_g060_{method.lower()}.mha" + +# %% [markdown] +# ## 2. Load the fixed (time point 60) and moving (time point 20) images + +# %% +fixed_image = itk.imread(str(fixed_path), pixel_type=itk.F) +moving_image = itk.imread(str(moving_path), pixel_type=itk.F) +print(f"Fixed (g060): {fixed_path.name}") +print(f"Moving (g020): {moving_path.name}") + +# %% [markdown] +# ## 3. Build and configure the registration backend +# +# ANTS and Greedy share ``set_transform_type``/``set_metric`` and take a +# per-level iteration list; ICON takes a single iteration count and has +# no transform-type/metric setters. + +# %% +if method == "ANTS": + reg = RegisterImagesANTS() + reg.set_transform_type("Deformable") + reg.set_metric("MeanSquares") + reg.set_number_of_iterations([40, 20, 10]) +elif method == "Greedy": + reg = RegisterImagesGreedy() + reg.set_transform_type("Deformable") + # NCC (CC) beats SSD for same-modality CT; tighter update-field smoothing + # (first sigma) captures more cardiac motion while staying diffeomorphic. + reg.set_metric("CC") + reg.set_number_of_iterations([40, 20, 10]) + reg.deformable_smoothing = "1.0vox 0.5vox" +elif method == "ICON": + reg = RegisterImagesICON() + reg.set_number_of_iterations(50) +else: + raise ValueError(f"Unknown method: {method}") + +reg.set_modality("ct") +reg.set_fixed_image(fixed_image) + +# %% [markdown] +# ## 4. Register time point 20 to time point 60 + +# %% +t_start = time.perf_counter() +reg_result = reg.register(moving_image=moving_image) +elapsed = time.perf_counter() - t_start + +forward_transform = reg_result["forward_transform"] +loss = float(reg_result["loss"]) +print(f"{method} registration done in {elapsed:.1f} s, loss={loss:.4f}") + +# %% [markdown] +# ## 5. Warp time point 20 into time point 60's space and save +# +# ``forward_transform`` is the transform consumed by ``transform_image`` to +# resample the moving image onto the fixed grid (it supplies the fixed->moving +# sampling map the ITK resampler needs). ``inverse_transform`` is the opposite +# direction, used to warp the fixed image onto the moving grid (e.g. in +# ``RegisterTimeSeriesImages.reconstruct_time_series``). This holds for all +# three backends (ANTS, ICON, Greedy). + +# %% +transform_tools = TransformTools() +warp_t_start = time.perf_counter() +warped_image = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", +) +itk.imwrite(warped_image, str(output_path), compression=True) +warp_elapsed = time.perf_counter() - warp_t_start +print(f"Wrote warped time point 20 -> 60: {output_path}") + +# %% [markdown] +# ## 6. Timing report +# +# Wall-clock seconds for the registration and the warp/write step. The +# registration time dominates and is the figure to compare across backends; +# for ICON it includes the one-time network load on this first (and only) +# pair. + +# %% +total_elapsed = elapsed + warp_elapsed +print() +print("=" * 44) +print(f"Timing report ({method})") +print("=" * 44) +print(f"{'Step':<22}{'seconds':>12}") +print("-" * 44) +print(f"{'register':<22}{elapsed:>12.2f}") +print(f"{'warp + write':<22}{warp_elapsed:>12.2f}") +print("-" * 44) +print(f"{'total':<22}{total_elapsed:>12.2f}") +print("=" * 44) diff --git a/pyproject.toml b/pyproject.toml index cefcec4..70d0aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -280,7 +280,6 @@ minversion = "7.0" addopts = [ "--strict-markers", "--strict-config", - "-W", "always", "--cov=physiomotion4d", "--cov-report=term-missing", "--cov-report=html", @@ -288,6 +287,17 @@ addopts = [ ] testpaths = ["tests"] pythonpath = ["."] +# Surface every warning by default ("always"), then silence the third-party +# ITK/SWIG binding noise emitted while wrapped C++ types are defined at import +# on CPython >=3.12 ("builtin type Swig... has no __module__ attribute"). +# Those come from the bindings, not our code, and would otherwise drown the +# warnings summary. Order matters: the trailing "ignore" entries are applied +# last and so take precedence over "always" for these specific messages. +filterwarnings = [ + "always", + "ignore:builtin type Swig", + "ignore:builtin type swigvarlink", +] markers = [ "unit: marks tests as unit tests (fast, isolated)", "integration: marks tests as integration tests (slower, multiple components)", @@ -398,12 +408,12 @@ lines-after-imports = 2 max-complexity = 10 [tool.ruff.lint.flake8-quotes] -inline-quotes = "single" +inline-quotes = "double" multiline-quotes = "double" docstring-quotes = "double" [tool.ruff.format] -quote-style = "single" +quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" diff --git a/src/physiomotion4d/__init__.py b/src/physiomotion4d/__init__.py index 5d00e32..63fc3a7 100644 --- a/src/physiomotion4d/__init__.py +++ b/src/physiomotion4d/__init__.py @@ -43,6 +43,7 @@ # Utility classes from .image_tools import ImageTools +from .labelmap_tools import LabelmapTools from .landmark_tools import LandmarkTools # Base classes @@ -106,6 +107,7 @@ "PhysioMotion4DBase", # Utility classes "ImageTools", + "LabelmapTools", "LandmarkTools", "TestTools", "TransformTools", diff --git a/src/physiomotion4d/contour_tools.py b/src/physiomotion4d/contour_tools.py index 9f3028e..6e72efe 100644 --- a/src/physiomotion4d/contour_tools.py +++ b/src/physiomotion4d/contour_tools.py @@ -230,7 +230,7 @@ def create_mask_from_mesh( # Direction: use identity for now (axis-aligned), will be handled by resampling # Flip Z axis to match ITK convention - ref_dir = np.array(binary_image.GetDirection()) + ref_dir = itk.array_from_matrix(binary_image.GetDirection()) ref_dir[2, 2] = -ref_dir[2, 2] binary_image.SetDirection(ref_dir) diff --git a/src/physiomotion4d/image_tools.py b/src/physiomotion4d/image_tools.py index c6e4023..c031603 100644 --- a/src/physiomotion4d/image_tools.py +++ b/src/physiomotion4d/image_tools.py @@ -279,9 +279,12 @@ def flip_image( flip1 = False flip2 = False if flip_and_make_identity: - flip0 = np.array(in_image.GetDirection())[0, 0] < 0 - flip1 = np.array(in_image.GetDirection())[1, 1] < 0 - flip2 = np.array(in_image.GetDirection())[2, 2] < 0 + # itk.array_from_matrix avoids itk.Matrix.__array__, whose missing + # copy keyword triggers a numpy>=2.0 DeprecationWarning. + direction = itk.array_from_matrix(in_image.GetDirection()) + flip0 = direction[0, 0] < 0 + flip1 = direction[1, 1] < 0 + flip2 = direction[2, 2] < 0 if flip_x: flip0 = True if flip_y: diff --git a/src/physiomotion4d/labelmap_tools.py b/src/physiomotion4d/labelmap_tools.py new file mode 100644 index 0000000..ff5ad78 --- /dev/null +++ b/src/physiomotion4d/labelmap_tools.py @@ -0,0 +1,100 @@ +""" +Labelmap Tools for PhysioMotion4D + +This module provides the :class:`LabelmapTools` class with the definitive +utility for turning a multi-label (or binary) segmentation labelmap into a +binary registration mask, optionally excluding specific labels and dilating +the result by a physical radius in millimeters. +""" + +import logging +from typing import Optional + +import itk +import numpy as np + +from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase + + +class LabelmapTools(PhysioMotion4DBase): + """ + Utilities for converting segmentation labelmaps into registration masks. + + A labelmap is an ``itk.Image`` of integer labels where ``0`` is background + and each positive value identifies an anatomical structure. A registration + mask is a binary ``itk.Image`` where every foreground voxel is ``1``. This + class centralizes the labelmap-to-mask conversion so that thresholding, + label exclusion, and physically isotropic dilation are performed + identically everywhere in the platform. + + Example: + >>> tools = LabelmapTools() + >>> # Binary mask of every labeled voxel, dilated 5 mm + >>> mask = tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + >>> # Exclude the table/background labels 8 and 9 before masking + >>> mask = tools.convert_labelmap_to_mask( + ... labelmap, dilation_in_mm=5.0, exclude_labels=[8, 9] + ... ) + """ + + def __init__(self, log_level: int | str = logging.INFO) -> None: + """Initialize LabelmapTools. + + Args: + log_level: Logging level (default: logging.INFO) + """ + super().__init__(class_name=self.__class__.__name__, log_level=log_level) + + def convert_labelmap_to_mask( + self, + labelmap: itk.Image, + dilation_in_mm: float = 0.0, + exclude_labels: Optional[list[int]] = None, + ) -> itk.Image: + """Convert a labelmap into a binary registration mask. + + Any voxel whose label is in ``exclude_labels`` is set to background + first; every remaining non-zero voxel becomes foreground (``1``). The + binary mask is then dilated by ``dilation_in_mm`` millimeters of + physical radius. The radius is converted into per-axis voxel counts + from the labelmap's spacing so the dilation is physically isotropic + even on anisotropic grids; each per-axis count is clamped to at least + 1 voxel when ``dilation_in_mm > 0``. + + Axis ordering: the labelmap is a scalar 3D ``itk.Image`` in ITK + world-axis order (X, Y, Z). All thresholding is performed on the numpy + view (Z, Y, X) and written back through ``CopyInformation``, so origin, + spacing, and direction are preserved. + + Args: + labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel + that is not excluded is treated as foreground. + dilation_in_mm: Physical radius of the binary dilation in + millimeters. Pass ``0`` (or negative) to skip dilation and + return the raw thresholded mask. Default 0.0. + exclude_labels: Optional list of integer label values to force + to background before thresholding. When ``None`` (the default) + no labels are excluded. + + Returns: + ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as + ``labelmap`` (origin, spacing, direction copied from the input). + """ + arr = itk.array_from_image(labelmap) + if exclude_labels: + arr = np.where(np.isin(arr, exclude_labels), 0, arr) + mask_arr = (arr > 0).astype(np.uint8) + mask = itk.image_from_array(mask_arr) + mask.CopyInformation(labelmap) + + if dilation_in_mm <= 0: + return mask + + spacing = labelmap.GetSpacing() + radius = itk.Size[3]() + for i in range(3): + radius[i] = max(1, int(round(dilation_in_mm / float(spacing[i])))) + structuring_element = itk.FlatStructuringElement[3].Ball(radius) + return itk.binary_dilate_image_filter( + mask, kernel=structuring_element, foreground_value=1 + ) diff --git a/src/physiomotion4d/register_images_ants.py b/src/physiomotion4d/register_images_ants.py index 182afaf..7412763 100644 --- a/src/physiomotion4d/register_images_ants.py +++ b/src/physiomotion4d/register_images_ants.py @@ -17,6 +17,7 @@ import itk import numpy as np from numpy.typing import NDArray + from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.transform_tools import TransformTools @@ -199,9 +200,7 @@ def _itk_to_ants_image( image_dimension = len(spatial_shape) - direction = np.asarray(itk_image.GetDirection()).reshape( - (image_dimension, image_dimension) - ) + direction = itk.array_from_matrix(itk_image.GetDirection()) spacing = list(itk_image.GetSpacing()) origin = list(itk_image.GetOrigin()) @@ -526,16 +525,23 @@ def registration_method( region of interest in the moving image moving_image_pre (ants.core.ANTsImage, optional): Pre-processed moving image in ANTs format. If None, preprocessing is performed automatically - initial_forward_transform (itk.Transform, optional): Initial transform from moving - to fixed space. Can be any ITK transform type (Affine, Rigid, - DisplacementField, Composite, etc.). Will be converted to ANTs - format automatically. The returned transforms will include this - initial transform composed with the registration result. + initial_forward_transform (itk.Transform, optional): Initial + forward transform (same convention as the returned + forward_transform: used to warp the moving image onto the fixed + grid). Can be any ITK transform type (Affine, Rigid, + DisplacementField, Composite, etc.). It is applied by pre-warping + the moving image onto the fixed grid before registration; the + returned transforms compose this initial alignment with the + registration refinement. Returns: dict: Dictionary containing: - - "forward_transform": Transformation from moving to fixed - - "inverse_transform": Transformation from fixed to moving + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: @@ -543,11 +549,13 @@ def registration_method( consistent. The forward and inverse transforms are stored separately by ANTs. - IMPORTANT: ANTs registration does NOT include the initial_transform - in its output fwdtransforms/invtransforms. This method automatically - composes the initial transform with the registration result, so the - returned transforms include both the initial alignment and - the registration refinement. + IMPORTANT: the initial transform is applied by pre-warping the + moving image onto the fixed grid (the same approach as + RegisterImagesICON) rather than via ants.registration's + initial_transform argument, which mishandles matrix (affine/ + translation) initials. This method composes the initial transform + with the registration result, so the returned transforms include + both the initial alignment and the registration refinement. Implementation details: - Uses ANTs registration with configurable transform types @@ -584,24 +592,29 @@ def registration_method( if self.fixed_image_pre is None: self.fixed_image_pre = self.preprocess(self.fixed_image, self.modality) - # Convert initial ITK transform to ANTs format if provided - initial_transform: str | list[str] = "identity" + # Apply any initial transform by pre-warping the moving image onto the + # fixed grid (the same approach RegisterImagesICON uses), instead of + # passing it to ants.registration as an initial_transform. ANTS + # mishandles matrix (affine/translation) initial transforms, badly + # corrupting the result; pre-warping keeps the composition below + # self-consistent for any initial transform type. The registration then + # solves the residual and the composition recovers the full transform. if initial_forward_transform is not None: - self.log_info("Converting initial ITK transform to ANTs format...") - initial_transform = self.itk_transform_to_ANTSfile( - itk_tfm=initial_forward_transform, - reference_image=self.fixed_image, - output_filename="initial_transform_temp.mat", + self.log_info("Pre-warping moving image with initial transform...") + transform_tools = TransformTools() + self.moving_image_pre = transform_tools.transform_image( + self.moving_image_pre, + initial_forward_transform, + self.fixed_image, ) - self.log_info("Initial transform converted successfully") transform_type = None if self.transform_type == "Deformable": transform_type = "antsRegistrationSyNQuick[so]" elif self.transform_type == "Affine": - transform_type = "antsRegistrationAffineQuick[so]" + transform_type = "Affine" elif self.transform_type == "Rigid": - transform_type = "antsRegistrationRigidQuick[so]" + transform_type = "Rigid" else: self.log_error("Invalid transform type: %s", self.transform_type) raise ValueError(f"Invalid transform type: {self.transform_type}") @@ -627,13 +640,36 @@ def registration_method( elif self.metric == "MeanSquares": syn_metric = "meansquares" + # antsRegistration --dimensionality 3 --float 0 \ + # --output [$thisfolder/pennTemplate_to_${sub}_,$thisfolder/pennTemplate_to_${sub}_Warped.nii.gz] \ + # --interpolation Linear \ + # --winsorize-image-intensities [0.005,0.995] \ + # --use-histogram-matching 0 \ + # --initial-moving-transform [$t1brain,$template,1] \ + # --transform Rigid[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform Affine[0.1] \ + # --metric MI[$t1brain,$template,1,32,Regular,0.25] \ + # --convergence [1000x500x250x100,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # --transform SyN[0.1,3,0] \ + # --metric CC[$t1brain,$template,1,4] \ + # --convergence [100x70x50x20,1e-6,10] \ + # --shrink-factors 8x4x2x1 \ + # --smoothing-sigmas 3x2x1x0vox \ + # -x $brainlesionmask + if self.fixed_mask is not None and self.moving_mask is not None: registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), mask=self._itk_to_ants_image(self.fixed_mask), moving=self._itk_to_ants_image(self.moving_image_pre), moving_mask=self._itk_to_ants_image(self.moving_mask), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, @@ -646,7 +682,7 @@ def registration_method( registration_result = ants.registration( fixed=self._itk_to_ants_image(self.fixed_image_pre), moving=self._itk_to_ants_image(self.moving_image_pre), - initial_transform=[initial_transform], + initial_transform=["identity"], type_of_transform=transform_type, aff_metric=aff_metric, syn_metric=syn_metric, diff --git a/src/physiomotion4d/register_images_base.py b/src/physiomotion4d/register_images_base.py index 463e26f..1c397fb 100644 --- a/src/physiomotion4d/register_images_base.py +++ b/src/physiomotion4d/register_images_base.py @@ -19,9 +19,8 @@ from typing import Any, Optional, Union import itk -import numpy as np -from itk import TubeTK as ttk +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.transform_tools import TransformTools @@ -59,8 +58,8 @@ class and implement the register() method. ... def registration_method(self, moving_image, **kwargs): ... # Implement specific registration algorithm ... return { - ... 'forward_transform': tfm_forward, # Moving → Fixed - ... 'inverse_transform': tfm_inverse, # Fixed → Moving + ... 'forward_transform': tfm_forward, # warps moving image -> fixed grid + ... 'inverse_transform': tfm_inverse, # warps fixed image -> moving grid ... 'loss': 0.0, ... } >>> @@ -68,8 +67,8 @@ class and implement the register() method. >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(reference_image) >>> result = registrar.register(moving_image) - >>> forward_tfm = result['forward_transform'] # Moving → Fixed - >>> inverse_tfm = result['inverse_transform'] # Fixed → Moving + >>> forward_tfm = result['forward_transform'] # warps moving image -> fixed grid + >>> inverse_tfm = result['inverse_transform'] # warps fixed image -> moving grid """ def __init__(self, log_level: int | str = logging.INFO) -> None: @@ -84,6 +83,8 @@ def __init__(self, log_level: int | str = logging.INFO) -> None: """ super().__init__(class_name=self.__class__.__name__, log_level=log_level) + self.labelmap_tools = LabelmapTools(log_level=log_level) + self.net: Any = None self.modality: str = "ct" @@ -180,16 +181,10 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: if self.fixed_image is None: raise ValueError("Fixed image must be set before setting a fixed mask.") - mask_arr = itk.GetArrayFromImage(fixed_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - self.fixed_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + self.fixed_mask = self.labelmap_tools.convert_labelmap_to_mask( + fixed_mask, dilation_in_mm=self.mask_dilation_mm + ) self.fixed_mask.CopyInformation(self.fixed_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(self.fixed_mask) - imMath.Dilate( - int(self.mask_dilation_mm / self.fixed_image.GetSpacing()[0]), 1, 0 - ) - self.fixed_mask = imMath.GetOutputUChar() def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: """Set the fixed image labelmap (multi-label segmentation). @@ -251,8 +246,11 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": Transform that warps moving image into fixed space - - "inverse_transform": Transform that warps fixed image into moving space + - "forward_transform": Warps the moving image onto the fixed + grid. Warping moving points/landmarks into fixed space uses + "inverse_transform" instead (see register() and + docs/developer/transform_conventions). + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Registration loss/metric value Raises: @@ -283,13 +281,24 @@ def register( Returns: dict: Dictionary containing transformation results: - - "forward_transform": Transforms moving image to fixed space (warps moving → fixed) - - "inverse_transform": Transforms fixed image to moving space (warps fixed → moving) + - "forward_transform": Warps the moving IMAGE onto the fixed + grid, i.e. transform_image(moving, forward_transform, fixed). + - "inverse_transform": Warps the fixed IMAGE onto the moving + grid, i.e. transform_image(fixed, inverse_transform, moving). - "loss": Registration loss/metric value Note: - - forward_transform: Use this to warp the moving image to match the fixed image - - inverse_transform: Use this to warp the fixed image to match the moving image + Image warps and point/landmark warps use OPPOSITE members of the + transform pair, because ITK image resampling pulls back (it maps a + fixed-grid sample to the moving image) while point transforms push + forward (they map a point to its corresponding location): + + - Warp the moving image into fixed space -> forward_transform + - Warp moving points/landmarks into fixed -> inverse_transform + - Warp the fixed image into moving space -> inverse_transform + - Warp fixed points/landmarks into moving -> forward_transform + + See docs/developer/transform_conventions for the full discussion. Raises: NotImplementedError: This method must be implemented by subclasses @@ -313,16 +322,10 @@ def register( new_moving_mask = moving_mask if moving_mask is not None: - mask_arr = itk.GetArrayFromImage(moving_mask) - mask_arr = np.where(mask_arr > 0, 1, 0) - new_moving_mask = itk.GetImageFromArray(mask_arr.astype(np.uint8)) + new_moving_mask = self.labelmap_tools.convert_labelmap_to_mask( + moving_mask, dilation_in_mm=self.mask_dilation_mm + ) new_moving_mask.CopyInformation(moving_image) - if self.mask_dilation_mm > 0: - imMath = ttk.ImageMath.New(new_moving_mask) - imMath.Dilate( - int(self.mask_dilation_mm / moving_image.GetSpacing()[0]), 1, 0 - ) - new_moving_mask = imMath.GetOutputUChar() self.moving_image = moving_image self.moving_image_pre = moving_image_pre diff --git a/src/physiomotion4d/register_images_greedy.py b/src/physiomotion4d/register_images_greedy.py index d3caed8..8a7a6cb 100644 --- a/src/physiomotion4d/register_images_greedy.py +++ b/src/physiomotion4d/register_images_greedy.py @@ -12,6 +12,8 @@ from __future__ import annotations import logging +import os +import tempfile from typing import Any, Optional, Union import itk @@ -135,6 +137,30 @@ def _greedy_iterations_str(self) -> str: """Format iterations as Greedy -n string (e.g. 40x20x10).""" return "x".join(str(i) for i in self.number_of_iterations) + def _write_affine_matrix_file(self, mat_4x4: NDArray[np.float64]) -> str: + """Write a 4x4 RAS affine matrix to a temporary Greedy ``.mat`` file. + + Greedy's in-memory interface corrupts the heap when a numpy affine + matrix is supplied as an initial transform (``-ia``/``-it``); passing a + file path instead avoids that native crash. Greedy reads a plain-text + 4x4 RAS matrix, which is what ``numpy.savetxt`` writes here. + + Args: + mat_4x4: 4x4 affine matrix in RAS (Greedy) convention. + + Returns: + Path to the written temporary ``.mat`` file. The caller is + responsible for deleting it. + """ + mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) + if mat_4x4.shape != (4, 4): + raise ValueError(f"Expected 4x4 matrix, got shape {mat_4x4.shape}") + fd, path = tempfile.mkstemp(suffix=".mat", prefix="greedy_aff_") + os.close(fd) + np.savetxt(path, mat_4x4, fmt="%.10f") + self.log_debug("Wrote Greedy affine init matrix to %s", path) + return path + def _matrix_to_itk_affine(self, mat_4x4: NDArray[np.float64]) -> itk.Transform: """Convert 4x4 affine matrix to ITK AffineTransform.""" mat_4x4 = np.asarray(mat_4x4, dtype=np.float64) @@ -195,17 +221,26 @@ def _registration_method_affine_or_rigid( cmd += " -gm fixed_mask -mm moving_mask" kwargs["fixed_mask"] = fixed_mask_sitk kwargs["moving_mask"] = moving_mask_sitk + # Greedy crashes (heap corruption) when an initial affine is passed as an + # in-memory matrix; write it to a temp file and pass the path instead. + initial_affine_file: Optional[str] = None if initial_affine is not None: - cmd += " -ia aff_initial" - kwargs["aff_initial"] = initial_affine + initial_affine_file = self._write_affine_matrix_file(initial_affine) + cmd += f" -ia {initial_affine_file}" - g.execute(cmd, **kwargs) + self.log_debug("Greedy affine/rigid command: %s", cmd) + try: + g.execute(cmd, **kwargs) + finally: + if initial_affine_file is not None: + os.remove(initial_affine_file) mat = np.array(g["aff_out"], dtype=np.float64) try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy affine/rigid registration loss: %s", loss) return mat, loss def _registration_method_deformable( @@ -230,17 +265,21 @@ def _registration_method_deformable( cmd_aff += " -gm fixed_mask -mm moving_mask" kwargs_aff["fixed_mask"] = fixed_mask_sitk kwargs_aff["moving_mask"] = moving_mask_sitk + self.log_debug("Greedy deformable affine-init command: %s", cmd_aff) g.execute(cmd_aff, **kwargs_aff) initial_affine = np.array(g["aff_init"], dtype=np.float64) + self.log_info("Greedy deformable affine init complete") + # Greedy crashes (heap corruption) when the affine init is passed as an + # in-memory matrix via -it; write it to a temp file and pass the path. + initial_affine_file = self._write_affine_matrix_file(initial_affine) cmd_def = ( - f"-i fixed moving -it aff_init -n {iterations_str} " + f"-i fixed moving -it {initial_affine_file} -n {iterations_str} " f"-m {metric_str} -s {self.deformable_smoothing} -o warp_out" ) kwargs_def = { "fixed": fixed_sitk, "moving": moving_sitk, - "aff_init": initial_affine, "warp_out": None, } if fixed_mask_sitk is not None and moving_mask_sitk is not None: @@ -248,13 +287,18 @@ def _registration_method_deformable( kwargs_def["fixed_mask"] = fixed_mask_sitk kwargs_def["moving_mask"] = moving_mask_sitk - g.execute(cmd_def, **kwargs_def) + self.log_debug("Greedy deformable command: %s", cmd_def) + try: + g.execute(cmd_def, **kwargs_def) + finally: + os.remove(initial_affine_file) warp_out = g["warp_out"] try: ml = g.metric_log() loss = float(ml[-1]["TotalPerPixelMetric"][-1]) if ml else 0.0 except Exception: loss = 0.0 + self.log_info("Greedy deformable registration loss: %s", loss) return initial_affine, warp_out, loss def registration_method( @@ -270,6 +314,13 @@ def registration_method( Converts ITK images to SimpleITK, runs Greedy (affine and/or deformable), then converts outputs back to ITK transforms. Composes with initial_forward_transform when provided. + + Returns a dict with "forward_transform", "inverse_transform", and + "loss". As with the other image-registration backends, + forward_transform warps the moving image onto the fixed grid and + inverse_transform warps the fixed image onto the moving grid; point and + landmark warps use the opposite transform from image warps (see + docs/developer/transform_conventions). """ if self.fixed_image is None or self.fixed_image_pre is None: raise ValueError("Fixed image must be set before registration.") @@ -371,13 +422,17 @@ def registration_method( ) disp_tfm = itk.DisplacementFieldTransform[itk.D, 3].New() disp_tfm.SetDisplacementField(disp_itk) - # Forward = moving -> fixed: first affine then deformable in Greedy + # forward_transform is consumed by transform_image(moving, ..., + # fixed) to warp the moving image onto the fixed grid, so it holds + # Greedy's raw affine+warp (Greedy applies the affine first, then + # the warp). inverse_transform is the numerically inverted field, + # used to warp the fixed image onto the moving grid. This matches + # RegisterImagesANTS/ICON and RegisterTimeSeriesImages. forward_composite = itk.CompositeTransform[itk.D, 3].New() if aff_tfm is not None: forward_composite.AddTransform(aff_tfm) forward_composite.AddTransform(disp_tfm) forward_transform = forward_composite - # Inverse: inverse warp then inverse affine inv_disp = TransformTools().invert_displacement_field_transform(disp_tfm) inv_aff = itk.AffineTransform[itk.D, 3].New() if aff_tfm is not None: diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 28bd03d..3d28dad 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -10,6 +10,8 @@ """ import logging +import pathlib +import sys from typing import Optional, Union import icon_registration as icon @@ -88,6 +90,10 @@ def set_weights_path(self, weights_path: str) -> None: pretrained weights. Clears any previously loaded network so the new weights are applied on the next call to register(). + Also, use this to specify the path to store the downloaded weights. The + file must not exist for the weights to be downloaded correctly. Typical + suffix is ".trch". + Args: weights_path: Path to a uniGradICON checkpoint, e.g. "results/duke_4d_finetune/checkpoints/network_weights_100" @@ -185,16 +191,21 @@ def registration_method( Returns: dict: Dictionary containing: - - "forward_transform": transform moving image into fixed space - - "inverse_transform": transform fixed image to moving space + - "forward_transform": Warps the moving image onto the fixed + grid (warping moving points/landmarks into fixed space uses + "inverse_transform" instead -- image and point warps use + opposite transforms; see + docs/developer/transform_conventions) + - "inverse_transform": Warps the fixed image onto the moving grid - "loss": Loss value from the registration Note: The transformations are inverse consistent, meaning - forward_transform ≈ inverse(inverse_transform). - The inverse_transform is used to warp the fixed image - to the moving image space. The forward_transform is used - to warp the moving image to the fixed image space. + forward_transform is approximately inverse(inverse_transform). + Use forward_transform to warp the moving image onto the fixed grid, + and inverse_transform to warp the fixed image onto the moving grid. + Point/landmark warps use the opposite transform from image warps + (see docs/developer/transform_conventions). Implementation details: - Uses UniGradIcon with LNCC loss function @@ -292,13 +303,29 @@ def _ensure_net(self) -> None: """ if self.net is not None: return + main_module = sys.modules.get("__main__") + main_file = getattr(main_module, "__file__", None) + top_dir = pathlib.Path.cwd() + if main_file is not None: + top_dir = pathlib.Path(main_file).resolve().parent if self.use_multi_modality: + if self.weights_path is None: + self.weights_path = str( + top_dir + / "network_weights" + / "multigradicon1.0" + / "Step_2_final.trch" + ) self.net = get_multigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, weights_location=self.weights_path, ) else: + if self.weights_path is None: + self.weights_path = str( + top_dir / "network_weights" / "unigradicon1.0" / "Step_2_final.trch" + ) self.net = get_unigradicon( loss_fn=icon.LNCC(sigma=5), apply_intensity_conservation_loss=self.use_mass_preservation, @@ -346,42 +373,6 @@ def _image_to_resized_tensor( tensor, size=shape[2:], mode="trilinear", align_corners=False ) - @staticmethod - def create_mask(labelmap: itk.Image, dilation_mm: float = 5.0) -> itk.Image: - """Create a binary registration mask from a labelmap. - - Thresholds the labelmap at ``>0`` (so every non-zero label becomes - foreground) and dilates the result by ``dilation_mm`` millimeters of - physical radius. The radius is converted into per-axis voxel counts - from the labelmap's spacing so the dilation is physically isotropic - even on anisotropic grids; each per-axis count is clamped to at least - 1 voxel when ``dilation_mm > 0``. - - Args: - labelmap: Multi-label or binary ``itk.Image``. Any non-zero voxel - is treated as foreground. - dilation_mm: Physical radius of the binary dilation in - millimeters. Pass ``0`` (or negative) to skip dilation and - return the raw ``>0`` mask. Default 5.0 mm. - - Returns: - ``itk.Image[itk.UC, 3]`` binary mask in the same physical space as - ``labelmap`` (origin, spacing, direction copied from the input). - """ - arr = (itk.array_from_image(labelmap) > 0).astype(np.uint8) - mask = itk.image_from_array(arr) - mask.CopyInformation(labelmap) - if dilation_mm <= 0: - return mask - spacing = labelmap.GetSpacing() - radius = itk.Size[3]() - for i in range(3): - radius[i] = max(1, int(round(dilation_mm / float(spacing[i])))) - structuring_element = itk.FlatStructuringElement[3].Ball(radius) - return itk.binary_dilate_image_filter( - mask, kernel=structuring_element, foreground_value=1 - ) - def _mask_to_resized_tensor( self, mask: itk.Image, shape: torch.Size ) -> torch.Tensor: diff --git a/src/physiomotion4d/register_models_distance_maps.py b/src/physiomotion4d/register_models_distance_maps.py index 4b8090e..c03751f 100644 --- a/src/physiomotion4d/register_models_distance_maps.py +++ b/src/physiomotion4d/register_models_distance_maps.py @@ -41,7 +41,7 @@ >>> >>> # Access results >>> aligned_model = result['registered_model'] - >>> forward_transform = result['forward_transform'] # Moving to fixed transform + >>> forward_transform = result['forward_transform'] # warps moving image -> fixed grid """ import logging @@ -49,9 +49,9 @@ import itk import pyvista as pv -from itk import TubeTK as ttk from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -74,8 +74,15 @@ class RegisterModelsDistanceMaps(PhysioMotion4DBase): - **Optional**: ICON deep learning refinement after any mode **Transform Convention:** - - forward_transform: Moving → fixed space transformation - - inverse_transform: Fixed → moving space transformation + These are the underlying image-registration (ANTs/ICON) transforms, so + they follow the image convention (see + docs/developer/transform_conventions): + + - forward_transform: warps the moving image/mask onto the fixed grid. + Warping the moving MODEL points/landmarks onto the fixed model uses + inverse_transform instead (image and point warps use opposite + transforms). + - inverse_transform: warps the fixed image/mask onto the moving grid. Attributes: moving_model (pv.PolyData): Surface model to be aligned @@ -148,6 +155,7 @@ def __init__( # Utilities self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) # Registration instances self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) @@ -194,12 +202,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.fixed_model, self.reference_image ) - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int( - self.roi_dilation_mm / self.reference_image.GetSpacing()[0] + self.fixed_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm ) - imMath.Dilate(dilation_voxels, 1, 0) - self.fixed_mask_roi_image = imMath.GetOutput() # Create moving mask self.moving_mask_image = self.contour_tools.create_distance_map( @@ -216,9 +221,9 @@ def _create_masks_from_models(self) -> None: mask = self.contour_tools.create_mask_from_mesh( self.moving_model, self.reference_image ) - imMath = ttk.ImageMath.New(self.moving_mask_image) - imMath.Dilate(dilation_voxels, 1, 0) - self.moving_mask_roi_image = imMath.GetOutputUChar() + self.moving_mask_roi_image = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=self.roi_dilation_mm + ) self.log_info("Mask generation complete") diff --git a/src/physiomotion4d/register_models_icp.py b/src/physiomotion4d/register_models_icp.py index b3b2b61..3236f02 100644 --- a/src/physiomotion4d/register_models_icp.py +++ b/src/physiomotion4d/register_models_icp.py @@ -61,10 +61,16 @@ class RegisterModelsICP(PhysioMotion4DBase): - **Affine transform type**: Centroid alignment → Rigid ICP → Affine ICP **Transform Convention:** - - forward_point_transform: moving → fixed space transformation - (This is the inverse of the transform used to wrap the moving image to the - fixed image) - - inverse_point_transform: moving → fixed space transformation + These are POINT transforms (applied with TransformPoint, e.g. via + TransformTools.transform_pvcontour), so their orientation is opposite to + the image-registration transforms (see + docs/developer/transform_conventions): + + - forward_point_transform: maps moving points -> fixed points; use it to + warp the moving model/landmarks onto the fixed model. This is the + inverse of the transform that would warp the moving IMAGE onto the + fixed grid. + - inverse_point_transform: maps fixed points -> moving points. Attributes: moving_model (pv.PolyData): Surface model to be aligned diff --git a/src/physiomotion4d/register_models_pca.py b/src/physiomotion4d/register_models_pca.py index d230e0a..339b34a 100644 --- a/src/physiomotion4d/register_models_pca.py +++ b/src/physiomotion4d/register_models_pca.py @@ -42,10 +42,15 @@ class RegisterModelsPCA(PhysioMotion4DBase): pca_coefficients (np.ndarray): Optimized PCA coefficients registered_model (pv.DataSet): Final registered and deformed model post_pca_transform (itk.Transform): Transform to apply after PCA registration - forward_point_transform (itk.DisplacementFieldTransform): Forward displacement field transform - (Does not include the post-PCA transform) - inverse_point_transform (itk.DisplacementFieldTransform): Inverse displacement field transform - (Does not include the post-PCA transform) + forward_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping template points -> registered/target points; use it to warp + the template model/landmarks onto the target. Its orientation is + opposite to an image-registration forward_transform (see + docs/developer/transform_conventions). Does not include the post-PCA + transform. + inverse_point_transform (itk.DisplacementFieldTransform): POINT transform + mapping target points -> template points. Does not include the + post-PCA transform. Example: >>> # Load PCA model data @@ -741,8 +746,14 @@ def compute_pca_transforms(self, reference_image: itk.Image) -> dict: Returns: Dictionary containing: - - 'forward_point_transform': Forward displacement field transform - - 'inverse_point_transform': Inverse displacement field transform + - 'forward_point_transform': POINT transform mapping template + points -> target points (warps the template onto the target) + - 'inverse_point_transform': POINT transform mapping target + points -> template points + + Note: + These are point transforms, oriented opposite to image-registration + transforms; see docs/developer/transform_conventions. """ assert self.registered_model_pca_deformation is not None, ( "PCA deformation must be computed" diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 5e14f5a..89dd02d 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -75,8 +75,8 @@ class RegisterTimeSeriesImages(RegisterImagesBase): ... prior_weight=0.5, ... ) >>> - >>> forward_tfms = result['forward_transforms'] # Moving → Fixed - >>> inverse_tfms = result['inverse_transforms'] # Fixed → Moving + >>> forward_tfms = result['forward_transforms'] # warp moving images -> fixed grid + >>> inverse_tfms = result['inverse_transforms'] # warp fixed image -> moving grids >>> losses = result['losses'] >>> >>> # Reconstruct time series with optional upsampling @@ -205,6 +205,16 @@ def set_fixed_mask(self, fixed_mask: Optional[itk.Image]) -> None: """ self.fixed_mask = fixed_mask + def set_fixed_labelmap(self, fixed_labelmap: Optional[itk.Image]) -> None: + """Set a labelmap for the fixed image region of interest. + + This passes through to the underlying registration method. + + Args: + fixed_labelmap (itk.Image): Labelmap defining ROI + """ + self.fixed_labelmap = fixed_labelmap + def register_time_series( self, moving_images: list[itk.Image], @@ -247,10 +257,14 @@ def register_time_series( Returns: dict: Dictionary containing results: - - "forward_transforms" (list[itk.Transform]): Transforms from moving to fixed - space for each image (warps moving → fixed) - - "inverse_transforms" (list[itk.Transform]): Transforms from fixed to moving - space for each image (warps fixed → moving) + - "forward_transforms" (list[itk.Transform]): one per image; + each warps its moving image onto the fixed grid (warping + moving points/landmarks into fixed space uses the matching + inverse transform instead -- see + docs/developer/transform_conventions) + - "inverse_transforms" (list[itk.Transform]): one per image; + each warps the fixed image onto that moving image's grid + (used by reconstruct_time_series) - "losses" (list[float]): Registration loss value for each image Raises: @@ -277,6 +291,7 @@ def register_time_series( >>> result = registrar.register_time_series( ... moving_images=image_list, ... moving_masks=mask_list, # Optional + ... moving_labelmaps=labelmap_list, # Optional ... reference_frame=5, ... register_reference=True, ... prior_weight=0.5, @@ -630,7 +645,8 @@ def reconstruct_time_series( Args: moving_images (list[itk.Image]): List of moving images to reconstruct inverse_transforms (list[itk.Transform]): List of inverse transforms - (one per moving image) from fixed space to moving space + (one per moving image), each used to warp the fixed image onto + that moving image's grid upsample_to_fixed_resolution (bool, optional): If True, reconstructed images will be upsampled to isotropic resolution (mean of fixed image's X and Y spacing) while maintaining their original origin and direction. diff --git a/src/physiomotion4d/segment_anatomy_base.py b/src/physiomotion4d/segment_anatomy_base.py index ca2a788..8368d6b 100644 --- a/src/physiomotion4d/segment_anatomy_base.py +++ b/src/physiomotion4d/segment_anatomy_base.py @@ -557,29 +557,6 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: """ raise NotImplementedError("This method should be implemented by the subclass.") - def dilate_mask(self, mask: itk.image, dilation: int) -> itk.image: - """ - Dilate a binary mask using morphological operations. - - Expands the mask regions by the specified number of pixels to create - larger regions of interest. Useful for creating candidate regions or - ensuring complete coverage of anatomical structures. - - Args: - mask (itk.image): The binary mask to dilate - dilation (int): Number of pixels to dilate in each direction - - Returns: - itk.image: The dilated binary mask - - Example: - >>> dilated_heart = segmenter.dilate_mask(heart_mask, 5) - """ - imMath = tube.ImageMath.New(mask) - imMath.Dilate(dilation, 1, 0) - dilated_mask = imMath.GetOutputUChar() - return dilated_mask - def segment( self, input_image: itk.image, diff --git a/src/physiomotion4d/segment_heart_simpleware.py b/src/physiomotion4d/segment_heart_simpleware.py index 15a1031..1d59e4a 100644 --- a/src/physiomotion4d/segment_heart_simpleware.py +++ b/src/physiomotion4d/segment_heart_simpleware.py @@ -338,8 +338,8 @@ def segmentation_method(self, preprocessed_image: itk.image) -> itk.image: ) if mask_image is not None: - in_direction = np.array(preprocessed_image.GetDirection()) - out_direction = np.array(mask_image.GetDirection()) + in_direction = itk.array_from_matrix(preprocessed_image.GetDirection()) + out_direction = itk.array_from_matrix(mask_image.GetDirection()) flip = [False, False, False] for i in range(3): if np.sign(out_direction[i, i]) != np.sign(in_direction[i, i]): diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index 223a5c0..3b91e0a 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -42,8 +42,8 @@ import numpy as np import yaml +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.register_time_series_images import RegisterTimeSeriesImages from physiomotion4d.transform_tools import TransformTools @@ -95,7 +95,9 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): traceability; not consumed by uniGradICON fine-tuning itself. mask_dilation_mm (float): Millimeters of physical-radius binary dilation applied to the >0 labelmap when deriving the loss-masking - binary mask via :meth:`RegisterImagesICON.create_mask`. + binary mask via :meth:`LabelmapTools.convert_labelmap_to_mask`. + mask_exclude_labels (Optional[list[int]]): Labels to exclude from the mask. + Default is None. mask_dir (Optional[Path]): Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk. @@ -148,13 +150,14 @@ def __init__( similarity: str = "lncc", lambda_value: float = 1.5, dice_loss_weight: float = 0.5, - lncc_sigma: int = 5, + lncc_sigma: int = 1, ct_window: tuple[float, float] = (-1000.0, 1000.0), is_ct: bool = True, gpus: Optional[list[int]] = None, eval_period: int = 10, save_period: int = 50, mask_dilation_mm: float = 5.0, + mask_exclude_labels: Optional[list[int]] = None, mask_dir: Optional[Path] = None, unigradicon_src_path: Optional[Path] = None, log_level: Union[int, str] = logging.INFO, @@ -176,7 +179,7 @@ def __init__( form ``subject_0000``, ``subject_0001``, ... Must be unique. subject_segmentation_files: Per-subject multi-label segmentation (labelmap) paths matching ``subject_image_files``. ``None`` - disables paired-with-seg training (no ``use_label``). + disables paired-with-seg training. Individual ``None`` entries inside the inner lists skip just those frames when paired-with-seg training is enabled. subject_mask_files: Per-subject binary mask paths matching @@ -205,8 +208,8 @@ def __init__( mask_dilation_mm: Physical radius (millimeters) of binary dilation applied to the >0 labelmap when deriving the loss-masking binary mask via - :meth:`RegisterImagesICON.create_mask`. Ignored when no - segmentations are supplied. Default 5.0 mm. + :meth:`LabelmapTools.convert_labelmap_to_mask`. Ignored when + no segmentations are supplied. Default 5.0 mm. mask_dir: Directory where derived binary masks are written and looked up. ``None`` (default) writes each derived mask next to its source labelmap on disk @@ -277,14 +280,19 @@ def __init__( self.gpus = list(gpus) if gpus is not None else [0] self.eval_period = eval_period self.save_period = save_period + self.mask_exclude_labels = mask_exclude_labels self.mask_dilation_mm = float(mask_dilation_mm) self.unigradicon_src_path = ( Path(unigradicon_src_path) if unigradicon_src_path is not None else None ) self.transform_tools = TransformTools() + self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None + self._use_segmentations: Optional[bool] = None + self._use_masks: Optional[bool] = None + self._dataset_json_path: Optional[Path] = None self._config_yaml_path: Optional[Path] = None @@ -309,48 +317,21 @@ def _validate_companion_shape( f"subject_image_files[{i}] length ({len(images)})" ) - @property - def uses_segmentations(self) -> bool: - """Whether at least one segmentation file is supplied for training. - - Drives the uniGradICON ``training.use_label`` flag. - """ - return self._any_non_none(self.subject_segmentation_files) - - @property - def uses_masks(self) -> bool: - """Whether the dataset will have a ``mask`` field on every kept entry. - - True when explicit masks are supplied OR when segmentations are supplied - (since masks are then derived). Drives the uniGradICON - ``training.loss_function_masking`` flag. - """ - return self._any_non_none(self.subject_mask_files) or self.uses_segmentations - - @staticmethod - def _any_non_none( - companion: Optional[list[list[Optional[str]]]], - ) -> bool: - """Return True when ``companion`` contains at least one non-``None`` entry.""" - if companion is None: - return False - for inner in companion: - for item in inner: - if item is not None: - return True - return False - @staticmethod def _posix(path: Union[str, Path]) -> str: """Return a forward-slashed string path (uniGradICON expects POSIX paths).""" return str(path).replace("\\", "/") - def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: + def _derive_mask( + self, + labelmap_path: Union[str, Path], + ) -> Path: """Create (or reuse) a dilated binary mask from a multi-label labelmap. Threshold the labelmap at ``>0`` and dilate by ``mask_dilation_mm`` mm - of physical radius via :meth:`RegisterImagesICON.create_mask` to widen - the ROI for loss-function masking. + of physical radius via + :meth:`LabelmapTools.convert_labelmap_to_mask` to widen the ROI for + loss-function masking. When :attr:`mask_dir` is ``None`` (the default) the mask is written next to the source labelmap as @@ -378,13 +359,17 @@ def _derive_mask(self, labelmap_path: Union[str, Path]) -> Path: return mask_path labelmap = itk.imread(str(labelmap_path)) - mask = RegisterImagesICON.create_mask( - labelmap, dilation_mm=self.mask_dilation_mm + mask = self.labelmap_tools.convert_labelmap_to_mask( + labelmap, + dilation_in_mm=self.mask_dilation_mm, + exclude_labels=self.mask_exclude_labels, ) itk.imwrite(mask, str(mask_path), compression=True) return mask_path - def prepare_dataset(self) -> Path: + def prepare_dataset( + self, use_segmentations: bool = True, use_masks: bool = True + ) -> Path: """Write the uniGradICON dataset JSON from the configured file lists. Builds one entry per image with ``image``, optional ``segmentation``, @@ -406,8 +391,9 @@ def prepare_dataset(self) -> Path: does not exist on disk. """ self.experiment_dir.mkdir(parents=True, exist_ok=True) - use_seg = self.uses_segmentations - use_mask = self.uses_masks + + self._use_segmentations = use_segmentations + self._use_masks = use_masks dataset_entries: list[dict[str, str]] = [] for subject_index, image_files in enumerate(self.subject_image_files): @@ -416,16 +402,24 @@ def prepare_dataset(self) -> Path: if self.subject_ids is not None else f"subject_{subject_index:04d}" ) - seg_list = ( - self.subject_segmentation_files[subject_index] - if self.subject_segmentation_files is not None - else [None] * len(image_files) - ) - mask_list = ( - self.subject_mask_files[subject_index] - if self.subject_mask_files is not None - else [None] * len(image_files) - ) + seg_list: list[Optional[str]] + if not use_segmentations: + seg_list = [None] * len(image_files) + else: + seg_list = ( + self.subject_segmentation_files[subject_index] + if self.subject_segmentation_files is not None + else [None] * len(image_files) + ) + mask_list: list[Optional[str]] + if not use_masks: + mask_list = [None] * len(image_files) + else: + mask_list = ( + self.subject_mask_files[subject_index] + if self.subject_mask_files is not None + else [None] * len(image_files) + ) landmark_list = ( self.subject_landmark_files[subject_index] if self.subject_landmark_files is not None @@ -444,7 +438,7 @@ def prepare_dataset(self) -> Path: "subject_id": subject_id, } - if use_seg: + if use_segmentations: if seg_file is None or not Path(seg_file).exists(): self.log_warning( "Skipping %s: segmentation missing for paired-with-seg " @@ -455,7 +449,7 @@ def prepare_dataset(self) -> Path: continue entry["segmentation"] = self._posix(seg_file) - if use_mask: + if use_masks: if mask_file is not None and Path(mask_file).exists(): resolved_mask: Path = Path(mask_file) elif seg_file is not None and Path(seg_file).exists(): @@ -529,8 +523,8 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "lambda": self.lambda_value, "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, - "loss_function_masking": self.uses_masks, - "use_label": self.uses_segmentations, + "loss_function_masking": self._use_masks, + "use_label": False, "roi_masking": False, }, "datasets": [ @@ -735,8 +729,8 @@ def apply_registration( self.log_info("ICON weights: %s", weights_path) fixed_mask = ( - RegisterImagesICON.create_mask( - reference_segmentation, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + reference_segmentation, dilation_in_mm=self.mask_dilation_mm ) if reference_segmentation is not None else None @@ -745,8 +739,8 @@ def apply_registration( if moving_segmentations is not None: moving_masks = [ ( - RegisterImagesICON.create_mask( - seg, dilation_mm=self.mask_dilation_mm + self.labelmap_tools.convert_labelmap_to_mask( + seg, dilation_in_mm=self.mask_dilation_mm ) if seg is not None else None diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index a533f4f..0e2c384 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -30,6 +30,7 @@ import pyvista as pv from physiomotion4d.contour_tools import ContourTools +from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_icon import RegisterImagesICON @@ -199,6 +200,7 @@ def __init__( # Utilities (needed for create_reference_image when patient_image is None) self.transform_tools = TransformTools() self.contour_tools = ContourTools() + self.labelmap_tools = LabelmapTools() if patient_image is not None: self.patient_image = patient_image @@ -319,11 +321,9 @@ def _auto_generate_mask( if dilate_mm is None: dilate_mm = self.mask_dilation_mm if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / self.patient_image.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - mask = imMath.GetOutputUChar() + mask = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) self.log_info("Masks auto-generated successfully.") @@ -349,11 +349,9 @@ def _auto_generate_roi_mask( # Generate model ROI mask roi = None if dilate_mm > 0: - ttk = _load_tubetk() - imMath = ttk.ImageMath.New(mask) - dilation_voxels = int(dilate_mm / mask.GetSpacing()[0]) - imMath.Dilate(dilation_voxels, 1, 0) - roi = imMath.GetOutputUChar() + roi = self.labelmap_tools.convert_labelmap_to_mask( + mask, dilation_in_mm=dilate_mm + ) else: roi = mask diff --git a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py index ebfb732..b3105c5 100644 --- a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py @@ -61,8 +61,10 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): registration_method (str): Registration method ('ANTS', 'ICON', or 'ANTS_ICON') number_of_iterations: Iterations for registration registrar (RegisterTimeSeriesImages): Internal registration object - forward_transforms (list[itk.Transform]): Forward transforms (moving → fixed) - inverse_transforms (list[itk.Transform]): Inverse transforms (fixed → moving) + forward_transforms (list[itk.Transform]): one per frame; each warps its + moving image onto the fixed grid + inverse_transforms (list[itk.Transform]): one per frame; each warps the + fixed image onto that frame's moving grid (used for reconstruction) losses (list[float]): Registration loss values reconstructed_images (list[itk.Image]): Reconstructed high-resolution images @@ -260,10 +262,11 @@ def register_time_series(self) -> dict: Returns: dict: Dictionary containing: - - 'forward_transforms' (list[itk.Transform]): Transforms from moving - to fixed space (warps moving → fixed) - - 'inverse_transforms' (list[itk.Transform]): Transforms from fixed - to moving space (warps fixed → moving) + - 'forward_transforms' (list[itk.Transform]): one per frame; + each warps its moving image onto the fixed grid + - 'inverse_transforms' (list[itk.Transform]): one per frame; + each warps the fixed image onto that frame's moving grid + (see docs/developer/transform_conventions) - 'losses' (list[float]): Registration loss value for each image Raises: diff --git a/tests/test_image_tools.py b/tests/test_image_tools.py index 7e48b7c..100a35c 100644 --- a/tests/test_image_tools.py +++ b/tests/test_image_tools.py @@ -428,7 +428,7 @@ def test_flip_and_make_identity_sets_direction_to_identity( direction = np.diag([-1.0, 1.0, 1.0]) itk_image = _make_synthetic_itk_image(shape_xyz, arr=arr, direction=direction) out = image_tools.flip_image(itk_image, flip_and_make_identity=True) - out_direction = np.array(out.GetDirection()) + out_direction = itk.array_from_matrix(out.GetDirection()) identity = np.eye(3) assert np.allclose(out_direction, identity), ( "flip_and_make_identity should set direction to identity" @@ -450,7 +450,7 @@ def test_flip_and_make_identity_with_mask_sets_both_directions_to_identity( itk_image, in_mask=itk_mask, flip_and_make_identity=True ) for im, name in [(out_image, "image"), (out_mask, "mask")]: - dir_mat = np.array(im.GetDirection()) + dir_mat = itk.array_from_matrix(im.GetDirection()) assert np.allclose(dir_mat, np.eye(3)), ( f"flip_and_make_identity should set {name} direction to identity" ) diff --git a/tests/test_labelmap_tools.py b/tests/test_labelmap_tools.py new file mode 100644 index 0000000..e6f15cf --- /dev/null +++ b/tests/test_labelmap_tools.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +""" +Tests for LabelmapTools functionality. + +Covers thresholding a multi-label labelmap into a binary registration mask, +physically isotropic dilation that respects per-axis spacing, and forcing +selected labels to background via ``exclude_labels``. +""" + +from __future__ import annotations + +import itk +import numpy as np +import pytest + +from physiomotion4d.labelmap_tools import LabelmapTools + + +class TestLabelmapTools: + """Test suite for LabelmapTools.convert_labelmap_to_mask.""" + + @pytest.fixture + def labelmap_tools(self) -> LabelmapTools: + """Create LabelmapTools instance.""" + return LabelmapTools() + + def test_threshold_without_dilation(self, labelmap_tools: LabelmapTools) -> None: + """Every non-zero label becomes foreground; no dilation grows it.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 # non-zero label id + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + mask_arr = itk.array_from_image(mask) + + assert set(np.unique(mask_arr).tolist()) == {0, 1} + assert int(mask_arr.sum()) == 1 + assert mask_arr[2, 2, 2] == 1 + + def test_dilation_grows_mask(self, labelmap_tools: LabelmapTools) -> None: + """Positive dilation_in_mm grows the mask but keeps the seed voxel.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[2, 2, 2] = 3 + labelmap = itk.image_from_array(arr) + # Unit isotropic spacing so dilation_in_mm == voxel radius. + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + dilated = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=1.0) + dilated_arr = itk.array_from_image(dilated) + + assert int(dilated_arr.sum()) > 1 + assert dilated_arr[2, 2, 2] == 1 + + def test_dilation_respects_anisotropic_spacing( + self, labelmap_tools: LabelmapTools + ) -> None: + """A 5 mm radius covers more voxels along the finely spaced axis.""" + arr = np.zeros((11, 11, 11), dtype=np.uint8) + arr[5, 5, 5] = 1 + labelmap = itk.image_from_array(arr) + # numpy axes are (Z, Y, X); ITK spacing is (X, Y, Z). Make X coarse + # (5 mm/voxel -> 1-voxel radius) and Z fine (1 mm/voxel -> 5-voxel + # radius) so the per-axis radius differs. + labelmap.SetSpacing([5.0, 1.0, 1.0]) + + dilated = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=5.0) + ) + + # Z axis (numpy axis 0) reaches 5 voxels out; X axis (numpy axis 2) + # only 1 voxel out. + assert dilated[0, 5, 5] == 1 + assert dilated[10, 5, 5] == 1 + assert dilated[5, 5, 0] == 0 + assert dilated[5, 5, 10] == 0 + + def test_exclude_labels_removes_voxels(self, labelmap_tools: LabelmapTools) -> None: + """Excluded labels become background before thresholding.""" + arr = np.zeros((5, 5, 5), dtype=np.uint8) + arr[1, 1, 1] = 2 # kept + arr[3, 3, 3] = 7 # excluded + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([1.0, 1.0, 1.0]) + + mask_arr = itk.array_from_image( + labelmap_tools.convert_labelmap_to_mask( + labelmap, dilation_in_mm=0.0, exclude_labels=[7] + ) + ) + + assert mask_arr[1, 1, 1] == 1 + assert mask_arr[3, 3, 3] == 0 + assert int(mask_arr.sum()) == 1 + + def test_preserves_image_information(self, labelmap_tools: LabelmapTools) -> None: + """Origin, spacing, and direction are copied from the labelmap.""" + arr = np.zeros((4, 4, 4), dtype=np.uint8) + arr[2, 2, 2] = 1 + labelmap = itk.image_from_array(arr) + labelmap.SetSpacing([0.5, 1.0, 2.0]) + labelmap.SetOrigin([10.0, -5.0, 3.0]) + + mask = labelmap_tools.convert_labelmap_to_mask(labelmap, dilation_in_mm=0.0) + + assert list(mask.GetSpacing()) == [0.5, 1.0, 2.0] + assert list(mask.GetOrigin()) == [10.0, -5.0, 3.0] diff --git a/tests/test_register_images_ants.py b/tests/test_register_images_ants.py index 8102c58..61c778d 100644 --- a/tests/test_register_images_ants.py +++ b/tests/test_register_images_ants.py @@ -20,6 +20,27 @@ from physiomotion4d.transform_tools import TransformTools +def _foreground_ncc( + reference_arr: np.ndarray, warped_arr: np.ndarray, foreground: np.ndarray +) -> float: + """Normalized cross-correlation over a foreground mask (higher = better). + + Args: + reference_arr: Reference image array (e.g. the fixed image), axes (Z, Y, X). + warped_arr: Warped image array on the same grid/axes as ``reference_arr``. + foreground: Boolean mask (same shape) selecting the voxels to score. + + Returns: + NCC in [-1, 1] over the foreground voxels (nan if degenerate). + """ + a = reference_arr[foreground].astype(np.float64) + b = warped_arr[foreground].astype(np.float64) + a0 = a - a.mean() + b0 = b - b.mean() + denom = float(np.sqrt((a0 * a0).sum() * (b0 * b0).sum())) + return float((a0 * b0).sum() / denom) if denom > 0 else float("nan") + + @pytest.mark.slow class TestRegisterImagesANTS: """Test suite for ANTs-based image registration.""" @@ -309,6 +330,218 @@ def test_registration_with_initial_transform( print("Registration with initial transform complete") + def test_initial_transform_composition_metrics( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + test_directories: dict[str, Path], + ) -> None: + """Verify the initial_forward_transform composition path with metrics. + + Exercises the two initial-transform inputs the platform actually uses + (identity and a prior deformable forward_transform, as in prior-based + time-series registration) and confirms the composed forward_transform + warps the moving image onto the fixed grid. Scored with foreground NCC + over the brightest 30% of the fixed image (tissue/blood pool). See + docs/developer/transform_conventions. + + Asserted facts: + * a plain registration improves on the unregistered pair, + * an identity initial reproduces the baseline exactly (the + composition machinery is a structurally correct no-op; note an + identity AffineTransform is itself a matrix initial), + * a prior-deformable initial reaches the no-initial baseline quality + (the composition recovers the full transform). + + The initial transform is applied by pre-warping the moving image (as in + RegisterImagesICON), which keeps the composition self-consistent for any + initial transform type. + """ + output_dir = test_directories["output"] + reg_output_dir = output_dir / "registration_ANTS" + reg_output_dir.mkdir(exist_ok=True) + + # Pick two phases that are far apart in the cycle so there is real motion. + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + # Moving and fixed share the acquisition grid (split from one 4D image), + # so the moving array is directly comparable for the unregistered score. + moving_arr = itk.array_from_image(moving_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + + transform_tools = TransformTools() + + def warp_score(forward_transform: Any) -> float: + warped = transform_tools.transform_image( + moving_image, + forward_transform, + fixed_image, + interpolation_method="linear", + ) + return _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + + ncc_unregistered = _foreground_ncc(fixed_arr, moving_arr, foreground) + + # Baseline: no initial transform. + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + baseline = registrar_ANTS.register(moving_image=moving_image) + ncc_baseline = warp_score(baseline["forward_transform"]) + + # Identity initial: the composition machinery must be a no-op. + identity = itk.AffineTransform[itk.D, 3].New() + identity.SetIdentity() + registrar_identity = RegisterImagesANTS() + registrar_identity.set_modality("ct") + registrar_identity.set_fixed_image(fixed_image) + identity_result = registrar_identity.register( + moving_image=moving_image, initial_forward_transform=identity + ) + ncc_identity = warp_score(identity_result["forward_transform"]) + + # Prior deformable initial: the realistic time-series prior use case. + registrar_prior = RegisterImagesANTS() + registrar_prior.set_modality("ct") + registrar_prior.set_fixed_image(fixed_image) + prior_result = registrar_prior.register( + moving_image=moving_image, + initial_forward_transform=baseline["forward_transform"], + ) + ncc_prior = warp_score(prior_result["forward_transform"]) + + print("\nANTS initial-transform composition metrics (foreground NCC):") + print(f" unregistered: {ncc_unregistered:.4f}") + print(f" baseline (no initial): {ncc_baseline:.4f}") + print(f" identity initial: {ncc_identity:.4f}") + print(f" prior-deformable init: {ncc_prior:.4f}") + + warped_prior = transform_tools.transform_image( + moving_image, + prior_result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + itk.imwrite( + warped_prior, + str(reg_output_dir / "ants_warped_prior_initial.mha"), + compression=True, + ) + + # Registration must improve alignment over the unregistered pair. + assert ncc_baseline > ncc_unregistered, ( + f"Baseline registration did not improve alignment: " + f"{ncc_baseline:.4f} <= {ncc_unregistered:.4f}" + ) + # Identity initial must reproduce the baseline (composition is a no-op). + assert abs(ncc_identity - ncc_baseline) < 0.03, ( + f"Identity initial transform changed the result: " + f"identity={ncc_identity:.4f} vs baseline={ncc_baseline:.4f}" + ) + # A prior-deformable initial must reach the no-initial baseline quality + # (the composition recovers the full transform). + assert ncc_prior >= ncc_baseline - 0.03, ( + f"Prior-initial composition did not reach baseline quality: " + f"{ncc_prior:.4f} < {ncc_baseline:.4f} - 0.03" + ) + + def test_initial_transform_matrix_composition( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """A matrix (translation/affine) initial composes without corruption. + + Regression guard for the previously-broken matrix initial_transform + path: feeding a translation initial used to corrupt the composition + (foreground NCC far below the unregistered pair). With the moving image + pre-warped by the initial, the composed forward_transform must align the + moving image onto the fixed grid at least as well as the unregistered + pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + translation = itk.TranslationTransform[itk.D, 3].New() + translation.SetOffset([-5.0, -5.0, -5.0]) + + registrar_ANTS.set_modality("ct") + registrar_ANTS.set_fixed_image(fixed_image) + result = registrar_ANTS.register( + moving_image=moving_image, initial_forward_transform=translation + ) + + transform_tools = TransformTools() + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\nMatrix-initial composed NCC={ncc:.4f} (unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"Matrix-initial composition worse than unregistered: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + + def test_affine_and_rigid_transform_types( + self, + registrar_ANTS: RegisterImagesANTS, + test_images: list[Any], + ) -> None: + """Affine and Rigid transform types run and improve alignment. + + Regression guard for the ANTS preset names: ``set_transform_type`` + previously mapped Affine/Rigid to ``antsRegistration{Affine,Rigid}Quick`` + preset strings that do not exist in antspy, raising ValueError. Each + type must now run and warp the moving image onto the fixed grid at least + as well as the unregistered pair. + """ + fixed_image = test_images[0] + moving_image = test_images[min(10, len(test_images) - 1)] + + fixed_arr = itk.array_from_image(fixed_image) + threshold = float(np.percentile(fixed_arr, 70.0)) + foreground = fixed_arr > threshold + ncc_unregistered = _foreground_ncc( + fixed_arr, itk.array_from_image(moving_image), foreground + ) + + transform_tools = TransformTools() + for transform_type in ("Rigid", "Affine"): + registrar = RegisterImagesANTS() + registrar.set_modality("ct") + registrar.set_transform_type(transform_type) + registrar.set_fixed_image(fixed_image) + result = registrar.register(moving_image=moving_image) + warped = transform_tools.transform_image( + moving_image, + result["forward_transform"], + fixed_image, + interpolation_method="linear", + ) + ncc = _foreground_ncc(fixed_arr, itk.array_from_image(warped), foreground) + print( + f"\n{transform_type} transform NCC={ncc:.4f} " + f"(unregistered={ncc_unregistered:.4f})" + ) + assert ncc > ncc_unregistered, ( + f"{transform_type} registration did not improve alignment: " + f"{ncc:.4f} <= {ncc_unregistered:.4f}" + ) + def test_multiple_registrations( self, registrar_ANTS: RegisterImagesANTS, test_images: list[Any] ) -> None: diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py index 11cde75..a90d9ea 100644 --- a/tests/test_workflow_fine_tune_icon_registration.py +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -20,7 +20,6 @@ import pytest import yaml -from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.workflow_fine_tune_icon_registration import ( WorkflowFineTuneICONRegistration, ) @@ -128,7 +127,7 @@ def test_init_rejects_mismatched_subject_ids_length(tmp_path: Path) -> None: ) -def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: +def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: """The two helper flags reflect supplied companions independently.""" base: dict[str, Any] = { "subject_image_files": [["a"]], @@ -136,45 +135,20 @@ def test_uses_segmentations_and_uses_masks_flags(tmp_path: Path) -> None: "fine_tune_name": "x", } none_wf = WorkflowFineTuneICONRegistration(**base) - assert not none_wf.uses_segmentations - assert not none_wf.uses_masks + assert not none_wf.use_segmentations + assert not none_wf.use_masks seg_only = WorkflowFineTuneICONRegistration( **base, subject_segmentation_files=[["seg.nii.gz"]] ) - assert seg_only.uses_segmentations - assert seg_only.uses_masks # derived from segs + assert seg_only.use_segmentations + assert seg_only.use_masks # derived from segs mask_only = WorkflowFineTuneICONRegistration( **base, subject_mask_files=[["mask.nii.gz"]] ) - assert not mask_only.uses_segmentations - assert mask_only.uses_masks - - -# --------------------------------------------------------------------------- -# RegisterImagesICON.create_mask (in-memory dilation, used by the workflow) -# --------------------------------------------------------------------------- - - -def test_create_mask_thresholds_and_dilates() -> None: - """Single-voxel labelmap becomes a binary mask whose dilation grows it.""" - arr = np.zeros((5, 5, 5), dtype=np.uint8) - arr[2, 2, 2] = 3 # non-zero label id - labelmap = itk.image_from_array(arr) - # Unit isotropic spacing so dilation_mm == voxel radius. - labelmap.SetSpacing([1.0, 1.0, 1.0]) - - no_dilate = RegisterImagesICON.create_mask(labelmap, dilation_mm=0.0) - no_dilate_arr = itk.array_from_image(no_dilate) - assert set(np.unique(no_dilate_arr).tolist()) == {0, 1} - assert int(no_dilate_arr.sum()) == 1 - - dilated = RegisterImagesICON.create_mask(labelmap, dilation_mm=1.0) - dilated_arr = itk.array_from_image(dilated) - assert int(dilated_arr.sum()) > 1 - # Original foreground voxel stays foreground. - assert dilated_arr[2, 2, 2] == 1 + assert not mask_only.use_segmentations + assert mask_only.use_masks # --------------------------------------------------------------------------- diff --git a/tutorials/tutorial_08_dirlab_pca_time_series.py b/tutorials/tutorial_08_dirlab_pca_time_series.py index 232e8ed..75cd063 100644 --- a/tutorials/tutorial_08_dirlab_pca_time_series.py +++ b/tutorials/tutorial_08_dirlab_pca_time_series.py @@ -170,9 +170,16 @@ def run_tutorial() -> dict[str, Any]: compression=True, ) + # Warp the reference-space fitted mesh into this phase's space. + # Warping reference -> phase POINTS uses the forward transform + # (the fixed -> moving point map), which is the opposite of the + # transform used to warp an image into phase space (images pull + # back, points push forward). The forward transform is named + # "phase_to_reference" after its image-warp role. See + # docs/developer/transform_conventions. phase_mesh = transform_tools.transform_pvcontour( fitted_reference_mesh, - reference_to_phase, + phase_to_reference, with_deformation_magnitude=True, ) phase_mesh_file = meshes_dir / f"{phase_name}_pca_fit.vtp"