From 55f9045280117e2b2bf234320f63464d911304e3 Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Wed, 6 May 2026 16:48:40 +0100 Subject: [PATCH 01/17] Add OASIS U-Net inference example --- README.md | 11 + examples/OASIS_inference_plot.py | 623 +++++++++++++++++++++++++++ mri_recon/reconstruction/__init__.py | 7 +- mri_recon/reconstruction/deep.py | 103 +++++ pyproject.toml | 1 + tests/test_reconstructions.py | 32 ++ 6 files changed, 776 insertions(+), 1 deletion(-) create mode 100644 examples/OASIS_inference_plot.py diff --git a/README.md b/README.md index c1274e0..5807b76 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ MRI reconstruction playground for the MRI Metrics project. | `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. | | `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. | | `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. | +| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. | ## Implemented Distortions @@ -48,6 +49,16 @@ uv sync uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" ``` +## OASIS Inference Example + +Run the OASIS plotting example with a local OASIS root folder. By default, it uses the packaged split and checkpoint manifest under `reconstruction_only/`. + +```bash +python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 +``` + +Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. + ## Pre-commit Install the local tooling and register the git hook: diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py new file mode 100644 index 0000000..7d6643e --- /dev/null +++ b/examples/OASIS_inference_plot.py @@ -0,0 +1,623 @@ +"""Inference OASIS reconstructors for k-space distortion operators. + +Usage: + python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data +""" + +from __future__ import annotations + +import argparse +import contextlib +import json +import os +import sys +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Sequence, Union + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +import numpy as np +import torch +import deepinv as dinv +from torch.utils.data import DataLoader, Dataset + +try: + import nibabel as nib +except ImportError as exc: + raise ImportError( + "The OASIS example requires nibabel. Install the project dependencies " + "or add nibabel to your environment before running this script." + ) from exc + +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + TranslationMotionDistortion, +) +from mri_recon.reconstruction import ( + ConjugateGradientReconstructor, + DeepImagePriorReconstructor, + OASISSinglecoilUnetReconstructor, + RAMReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + ZeroFilledReconstructor, +) + +REPO_ROOT = Path(__file__).resolve().parents[1] +REPORT_DIR = Path("reports") / "oasis_inference_plot" +DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" +DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" + +REPORT_DIR.mkdir(parents=True, exist_ok=True) + +ALGORITHMS = [ + # "zero-filled", + # "conjugate-gradient", + # "ram", + # "dip", + "tv-pgd", + # "wavelet-fista", + # "tv-fista", + # "tv-pdhg", + "oasis-unet", +] +DISTORTIONS = [ + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + "Radial high-pass emphasis", + # "Gaussian noise", + # "Isotropic LP", +] +METRICS = [ + "PSNR", + "NMSE", + "SSIM", + "HaarPSI", + "SharpnessIndex", + "BlurStrength", +] + + +@contextlib.contextmanager +def temp_seed( + rng: np.random.RandomState, + seed: Optional[Union[int, tuple[int, ...]]] = None, +): + """Temporarily set a NumPy random seed.""" + + if seed is None: + yield + return + + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +class MaskFunc: + """Random Cartesian undersampling mask matching the packaged OASIS checkpoints.""" + + def __init__( + self, + center_fractions: Sequence[float], + accelerations: Sequence[int], + seed: Optional[int] = None, + ) -> None: + if len(center_fractions) != len(accelerations): + raise ValueError("center_fractions and accelerations must have the same length.") + + self.center_fractions = list(center_fractions) + self.accelerations = list(accelerations) + self.rng = np.random.RandomState(seed) + + def __call__( + self, + shape: Sequence[int], + seed: Optional[Union[int, tuple[int, ...]]] = None, + ) -> torch.Tensor: + """Create a broadcastable mask for k-space shaped ``(..., H, W)``.""" + + if len(shape) < 2: + raise ValueError("Mask shape must have at least two dimensions.") + + with temp_seed(self.rng, seed): + center_fraction = self.rng.choice(self.center_fractions) + acceleration = self.rng.choice(self.accelerations) + num_cols = shape[-1] + num_low_freqs = round(num_cols * center_fraction) + + center_mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs) // 2 + center_mask[pad : pad + num_low_freqs] = 1 + + accel_prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + accel_mask = self.rng.uniform(size=num_cols) < accel_prob + + mask = np.maximum(center_mask, accel_mask.astype(np.float32)) + mask_shape = [1 for _ in shape] + mask_shape[-1] = num_cols + return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes listed in a split CSV.""" + + def __init__( + self, + split_csv: Path, + data_path: Path, + sample_rate: float = 1.0, + cache_size: int = 2, + ) -> None: + self.split_csv = Path(split_csv) + self.data_path = Path(data_path) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.cache_size = max(0, cache_size) + self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + target_np = self._read_raw_slice(subject_id, slice_num) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples: list[tuple[str, int]] = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + total_slices = int(row[-1]) + except ValueError: + continue + + subject_id = row[0] + if self.sample_rate >= 1.0: + start = 0 + stop = total_slices + else: + mid = round(total_slices / 2) + half_span = round(total_slices * self.sample_rate / 2) + start = max(0, mid - half_span) + stop = min(total_slices, mid + half_span) + + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: + volume = self._get_volume(subject_id) + return np.ascontiguousarray(volume[slice_num], dtype=np.float32) + + def _get_volume(self, subject_id: str) -> np.ndarray: + if self.cache_size > 0 and subject_id in self._volume_cache: + self._volume_cache.move_to_end(subject_id) + return self._volume_cache[subject_id] + + image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if not matches: + raise FileNotFoundError( + f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." + ) + + image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + + if self.cache_size > 0: + self._volume_cache[subject_id] = volume + if len(self._volume_cache) > self.cache_size: + self._volume_cache.popitem(last=False) + + return volume + + +def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: + """Convert k-space tensor to log-magnitude image for visualization.""" + + if kspace.ndim == 4: + kspace = kspace[0] + if kspace.ndim != 3 or kspace.shape[0] != 2: + raise ValueError( + f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" + ) + + kspace = kspace.detach().cpu() + kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) + magnitude = torch.log1p(torch.abs(kspace_complex)) + + lower = torch.quantile(magnitude, 0.05) + upper = torch.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = magnitude.clamp(lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return torch.sqrt(magnitude) + + +def save_kspace_plot( + y_clean: torch.Tensor, + y_distorted: torch.Tensor, + save_fn: Path, + distortion_name: str, +) -> None: + """Save clean and distorted k-space magnitude plots.""" + + images = [ + ("Original k-space", _kspace_to_log_magnitude(y_clean)), + ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), + ] + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + fig.suptitle(f"Distortion: {distortion_name}") + for ax, (title, image) in zip(axes, images, strict=True): + ax.imshow(image.numpy(), cmap="magma") + ax.set_title(title) + ax.axis("off") + fig.savefig(save_fn, dpi=200, bbox_inches="tight") + plt.close(fig) + + +def resolve_oasis_checkpoint( + checkpoint: Optional[Path], + acceleration: int, + manifest_path: Path, +) -> Path: + """Resolve an explicit or packaged OASIS checkpoint path.""" + + if checkpoint is not None: + return checkpoint.expanduser().resolve() + + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + checkpoints = manifest.get("checkpoints", {}) + key = str(acceleration) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + filename = Path(checkpoints[key]["filename"]) + if filename.is_absolute(): + return filename + return (manifest_path.parent.parent / filename).resolve() + + +def choose_algorithm( + name: str, + checkpoint_file: Path, + img_size: tuple = (640, 368), + device: torch.device = "cpu", + verbose: bool = False, +) -> dinv.models.Reconstructor: + """Construct a reconstructor by selector name.""" + + match name: + case "zero-filled": + return ZeroFilledReconstructor() + case "conjugate-gradient": + return ConjugateGradientReconstructor(max_iter=20) + case "ram": + return RAMReconstructor(default_sigma=0.05, device=device) + case "dip": + return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) + case "tv-pgd": + return TVPGDReconstructor(n_iter=100, verbose=verbose) + case "tv-fista": + return TVFISTAReconstructor(n_iter=200, verbose=verbose) + case "tv-pdhg": + return TVPDHGReconstructor(n_iter=100, verbose=verbose) + case "wavelet-fista": + return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) + case "oasis-unet" | "unet": + return OASISSinglecoilUnetReconstructor( + checkpoint_file=str(checkpoint_file), + device=device, + ) + case _: + raise ValueError(f"Unknown algorithm {name!r}") + + +def choose_distortion(name: str) -> BaseDistortion: + """Construct a k-space distortion by display name.""" + + match name: + case "Phase-encode ghosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "Anisotropic LP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.25, + ) + case "Hann taper LP": + return HannTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + ) + case "Kaiser taper LP": + return KaiserTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + beta=8.6, + ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "Isotropic LP": + return IsotropicResolutionReduction(radius_fraction=0.1) + case "Off-center anisotropic Gaussian bias field": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "Translation motion": + return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) + case "Rotational motion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "Segmented translation motion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 20.0, 50.0, -50.0), + shift_y_pixels=(0.0, 10.0, -20.0, 20.0), + ) + case "Gaussian bias field": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "Gaussian noise": + return GaussianNoiseDistortion(sigma=0.00001) + case _: + raise ValueError(f"Unknown distortion {name!r}") + + +def choose_metric(name: str) -> dinv.metric.Metric: + """Construct a DeepInverse metric by selector name.""" + + match name: + case "PSNR": + return dinv.metric.PSNR(max_pixel=None, complex_abs=True) + case "NMSE": + return dinv.metric.NMSE(complex_abs=True) + case "SSIM": + return dinv.metric.SSIM(max_pixel=None, complex_abs=True) + case "HaarPSI": + return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) + case "BlurStrength": + return dinv.metric.BlurStrength(complex_abs=True) + case "SharpnessIndex": + return dinv.metric.SharpnessIndex(complex_abs=True) + case _: + raise ValueError(f"Unknown metric {name!r}") + + +def build_parser() -> argparse.ArgumentParser: + """Build the command-line parser.""" + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source", + type=Path, + required=True, + help="OASIS root directory containing subject folders.", + ) + parser.add_argument( + "--split_csv", + type=Path, + default=DEFAULT_SPLIT_CSV, + help="CSV listing OASIS subjects and slice counts.", + ) + parser.add_argument( + "--manifest", + type=Path, + default=DEFAULT_MANIFEST_PATH, + help="Checkpoint manifest JSON.", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", + ) + parser.add_argument( + "--acceleration", + type=int, + default=4, + help="Packaged OASIS checkpoint acceleration factor.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.08, + help="Center fraction used by the random Cartesian sampling mask.", + ) + parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) + parser.add_argument( + "--algorithm", + type=str, + default="", + choices=ALGORITHMS, + help="Reconstruction algorithm applied to distorted OASIS k-space.", + ) + parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") + parser.add_argument( + "--sample_rate", + type=float, + default=1.0, + help="Fraction of slices per volume to include from the split CSV.", + ) + parser.add_argument("--volume_cache_size", type=int, default=2) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose output for reconstructors that support it.", + ) + return parser + + +def main() -> None: + """Run OASIS inference plots.""" + + args = build_parser().parse_args() + args.source = args.source.expanduser().resolve() + args.split_csv = args.split_csv.expanduser().resolve() + args.manifest = args.manifest.expanduser().resolve() + checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) + + device = dinv.utils.get_device() + dataset = OasisSliceDataset( + split_csv=args.split_csv, + data_path=args.source, + sample_rate=args.sample_rate, + cache_size=args.volume_cache_size, + ) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + mask_func = MaskFunc( + center_fractions=[args.center_fraction], + accelerations=[args.acceleration], + ) + metrics = [choose_metric(m) for m in METRICS] + + for i, batch in enumerate(iter(dataloader)): + if i >= args.num_samples: + break + + x = batch["x"].to(device) + subject_id = batch["subject_id"][0] + slice_num = int(batch["slice_num"][0]) + mask = mask_func(x.shape, seed=tuple(map(ord, subject_id))).to(device) + mask_2d = mask.reshape(-1).view(1, -1).expand(x.shape[-2], x.shape[-1]) + + for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: + distortion = choose_distortion(distortion_name) + + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + mask=mask_2d, + img_size=(1, 2, *x.shape[-2:]), + coil_maps=1, + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, + mask=mask_2d, + img_size=(1, 2, *x.shape[-2:]), + coil_maps=1, + device=device, + ) + + y = physics_clean(x) + y_distorted = physics(x) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}..." + ) + + algo = choose_algorithm( + algo_name, + checkpoint_file=checkpoint_file, + img_size=x.shape[-2:], + device=device, + verbose=args.verbose, + ).to(device) + + x_uncorrected = algo(y_distorted, physics_clean) + x_corrected = algo(y_distorted, physics) + + dinv.utils.plot( + { + "Ground truth OASIS slice": x, + "Distorted ksp, CG recon": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join( + f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" + for m in metrics + ), + "\n".join( + f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" + for m in metrics + ), + ], + show=False, + close=True, + suptitle=( + f"Algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}" + ), + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + + print("done!") + + +if __name__ == "__main__": + main() diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index 07dec7d..a11a909 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -1,4 +1,9 @@ -from .deep import RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor +from .deep import ( + RAMReconstructor, + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, +) from .classic import ( ZeroFilledReconstructor, ConjugateGradientReconstructor, diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 5d507c4..ba520f9 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -163,3 +163,106 @@ def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tenso out = self.model(x_in) * std + mu # (B, 1, H, W) return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + elif isinstance(checkpoint, dict): + state_dict = checkpoint + else: + raise ValueError(f"Unsupported checkpoint format in {checkpoint_path}.") + + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + if any(key.startswith("model.unet.") for key in state_dict): + return { + key[len("model.unet.") :]: value + for key, value in state_dict.items() + if key.startswith("model.unet.") + } + + if any(key.startswith("module.") for key in state_dict): + return { + key[len("module.") :]: value + for key, value in state_dict.items() + if key.startswith("module.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """ + Wrapper for a trained OASIS single-coil U-Net reconstruction model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + loads an OASIS checkpoint supplied by the caller. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + :param str checkpoint_file: Path to the trained OASIS U-Net checkpoint. + :param torch.device device: Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__( + self, + checkpoint_file: str, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + _load_unet_checkpoint_state(checkpoint_path, device), + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space.""" + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) diff --git a/pyproject.toml b/pyproject.toml index d7874ec..ce0b587 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "fastmri>=0.3.0", "h5py>=3.16.0", "matplotlib>=3.9.0", + "nibabel>=5.3.2", "numpy>=2.4.3", "ptwt>=1.0.1", "pydicom>=3.0.1", diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index b331a6b..c2689ed 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -12,6 +12,7 @@ RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, ) from mri_recon.reconstruction.classic import ( ZeroFilledReconstructor, @@ -118,3 +119,34 @@ def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): assert x_hat.shape == x.shape assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, tmp_path): + """ + Test the OASIS UNet reconstructor with a Lightning-style checkpoint. + """ + weights_path = tmp_path / "oasis_unet_test_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + x = dinv.utils.load_example( + "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + ) + x = torch.cat([x, torch.zeros_like(x)], dim=1) + + model = OASISSinglecoilUnetReconstructor( + checkpoint_file=str(weights_path), + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + ) + + y = physics(x) + x_hat = model(y, physics) + + assert x_hat.shape == x.shape + assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) From cc62b053fcf1ece80a21de195daabc68dd54bfb1 Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Thu, 7 May 2026 19:15:25 +0100 Subject: [PATCH 02/17] Add OASIS inference example --- README.md | 161 +- examples/OASIS_inference_plot.py | 1355 +++++++++-------- mri_recon/distortions/undersampling.py | 549 +++---- mri_recon/reconstruction/__init__.py | 28 +- mri_recon/reconstruction/deep.py | 536 +++---- tests/test_distortions.py | 1854 ++++++++++++------------ 6 files changed, 2305 insertions(+), 2178 deletions(-) diff --git a/README.md b/README.md index e606f6d..52e7456 100644 --- a/README.md +++ b/README.md @@ -1,86 +1,87 @@ -# mri_recon - -[![CI](https://github.com/MatthiasLen/mri_recon/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/MatthiasLen/mri_recon/actions/workflows/ci.yml) - -MRI reconstruction playground for the MRI Metrics project. - -## Implemented Reconstruction Algorithms - -| Class | Selector name | Family | Summary | -| --- | --- | --- | --- | -| `ZeroFilledReconstructor` | `zero-filled` | Direct baseline | Returns the adjoint reconstruction, i.e. the standard zero-filled inverse FFT baseline. | -| `ConjugateGradientReconstructor` | `conjugate-gradient` | Classical iterative | Uses `physics.A_dagger(...)` to solve the inverse problem with conjugate-gradient style least-squares reconstruction. | -| `TVPGDReconstructor` | `tv-pgd` | Variational iterative | Proximal gradient descent with an L2 data term and total-variation prior. | -| `WaveletFISTAReconstructor` | `wavelet-fista` | Variational iterative | FISTA with an L1 wavelet prior for sparse regularization in a wavelet basis. | -| `TVFISTAReconstructor` | `tv-fista` | Variational iterative | FISTA with total-variation regularization. | -| `TVPDHGReconstructor` | `tv-pdhg` | Variational iterative | Primal-dual hybrid gradient / Chambolle-Pock optimization with total-variation regularization. | -| `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. | -| `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. | -| `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. | -| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. | - -## Implemented Distortions - -| Class | Selector name | Family | Summary | -| --- | --- | --- | --- | -| `BaseDistortion` | `None` | Identity | Leaves the k-space unchanged and serves as the no-distortion baseline. | -| `SelfAdjointMultiplicativeMaskDistortion` | `None` | Abstract base | Super class for self-adjoint distortions that apply a real-valued elementwise multiplicative mask; subclasses implement `_mask`. | -| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Applies a circular low-pass mask in k-space to remove high frequencies isotropically. | -| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Applies an axis-aligned rectangular low-pass mask with separate cutoffs along `kx` and `ky`. | -| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Applies a circular low-pass mask with a raised-cosine transition band to soften the cutoff. | -| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Applies a circular low-pass mask with a Kaiser transition band for adjustable cutoff smoothness. | -| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates Cartesian acquisition undersampling with optional contiguous ACS center retention plus uniform-random, variable-density-random, or equispaced peripheral sampling. | -| `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | -| `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | -| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | -| `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | -| `TranslationMotionDistortion` | `Translation motion` | Motion | Applies a rigid in-plane translation as a unit-modulus phase ramp in k-space. | -| `RotationalMotionDistortion` | `Rotational motion` | Motion | Applies a rigid in-plane rotation about the image center by rotating the image-domain object and transforming back to k-space. | -| `SegmentedTranslationMotionDistortion` | `Segmented translation motion` | Motion | Splits Cartesian k-space into acquisition segments and applies a different translation phase ramp to each segment. | -| `PhaseEncodeGhostingDistortion` | `Phase-encode ghosting` | Ghosting | Applies periodic line-wise phase and magnitude inconsistency to create phase-encode ghost replicas. | - -## uv Environment Notes - -This project uses `uv.lock` and pins PyTorch through `uv` package indexes in `pyproject.toml`. - -On Windows and Linux, `uv sync` installs the CUDA 12.8 PyTorch wheels. On macOS, it falls back to CPU wheels. - -```bash -uv sync -uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" -``` - +# mri_recon + +[![CI](https://github.com/MatthiasLen/mri_recon/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/MatthiasLen/mri_recon/actions/workflows/ci.yml) + +MRI reconstruction playground for the MRI Metrics project. + +## Implemented Reconstruction Algorithms + +| Class | Selector name | Family | Summary | +| --- | --- | --- | --- | +| `ZeroFilledReconstructor` | `zero-filled` | Direct baseline | Returns the adjoint reconstruction, i.e. the standard zero-filled inverse FFT baseline. | +| `ConjugateGradientReconstructor` | `conjugate-gradient` | Classical iterative | Uses `physics.A_dagger(...)` to solve the inverse problem with conjugate-gradient style least-squares reconstruction. | +| `TVPGDReconstructor` | `tv-pgd` | Variational iterative | Proximal gradient descent with an L2 data term and total-variation prior. | +| `WaveletFISTAReconstructor` | `wavelet-fista` | Variational iterative | FISTA with an L1 wavelet prior for sparse regularization in a wavelet basis. | +| `TVFISTAReconstructor` | `tv-fista` | Variational iterative | FISTA with total-variation regularization. | +| `TVPDHGReconstructor` | `tv-pdhg` | Variational iterative | Primal-dual hybrid gradient / Chambolle-Pock optimization with total-variation regularization. | +| `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. | +| `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. | +| `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. | +| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. | + +## Implemented Distortions + +| Class | Selector name | Family | Summary | +| --- | --- | --- | --- | +| `BaseDistortion` | `None` | Identity | Leaves the k-space unchanged and serves as the no-distortion baseline. | +| `SelfAdjointMultiplicativeMaskDistortion` | `None` | Abstract base | Super class for self-adjoint distortions that apply a real-valued elementwise multiplicative mask; subclasses implement `_mask`. | +| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Applies a circular low-pass mask in k-space to remove high frequencies isotropically. | +| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Applies an axis-aligned rectangular low-pass mask with separate cutoffs along `kx` and `ky`. | +| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Applies a circular low-pass mask with a raised-cosine transition band to soften the cutoff. | +| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Applies a circular low-pass mask with a Kaiser transition band for adjustable cutoff smoothness. | +| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates Cartesian acquisition undersampling with optional contiguous ACS center retention plus uniform-random, variable-density-random, or equispaced peripheral sampling. | +| `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | +| `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | +| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | +| `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | +| `TranslationMotionDistortion` | `Translation motion` | Motion | Applies a rigid in-plane translation as a unit-modulus phase ramp in k-space. | +| `RotationalMotionDistortion` | `Rotational motion` | Motion | Applies a rigid in-plane rotation about the image center by rotating the image-domain object and transforming back to k-space. | +| `SegmentedTranslationMotionDistortion` | `Segmented translation motion` | Motion | Splits Cartesian k-space into acquisition segments and applies a different translation phase ramp to each segment. | +| `PhaseEncodeGhostingDistortion` | `Phase-encode ghosting` | Ghosting | Applies periodic line-wise phase and magnitude inconsistency to create phase-encode ghost replicas. | + +## uv Environment Notes + +This project uses `uv.lock` and pins PyTorch through `uv` package indexes in `pyproject.toml`. + +On Windows and Linux, `uv sync` installs the CUDA 12.8 PyTorch wheels. On macOS, it falls back to CPU wheels. + +```bash +uv sync +uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" +``` + ## OASIS Inference Example Run the OASIS plotting example with a local OASIS root folder. By default, it uses the packaged split and checkpoint manifest under `reconstruction_only/`. +Download the `reconstruction_only/` folder, including data splits and checkpoints, from [Google Drive](https://drive.google.com/drive/folders/1YPmjiQxy3odiUq8gwYqwOGhoXbGRcSXp?usp=drive_link). ```bash python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 -``` - -Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. - -## Pre-commit - -Install the local tooling and register the git hook: - -```bash -uv sync -uv run pre-commit install -``` - -Run the hook suite manually across the repository: - -```bash -uv run pre-commit run --all-files -``` - -GitHub Actions runs the same `pre-commit` command in CI and also runs the test suite with `uv run pytest`. - -## Contributing - -1. **Pre-commit hooks** – install and run them before pushing (see [Pre-commit](#pre-commit) above). CI enforces the same checks. -2. **Docstrings** – add a NumPy-style docstring to every public function, method, and class. Include a one-line summary, `Parameters`, and `Returns` sections where applicable. -3. **README updates** – if you add a new reconstructor or distortion, append a row to the corresponding table. Keep descriptions concise (one sentence). -4. **Tests** – add or update tests under `tests/` for any new behaviour. Run the full suite with `uv run pytest` before opening a PR. -5. **Branching** – open a feature branch, keep commits focused, and open a pull request against `main`. +``` + +Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. + +## Pre-commit + +Install the local tooling and register the git hook: + +```bash +uv sync +uv run pre-commit install +``` + +Run the hook suite manually across the repository: + +```bash +uv run pre-commit run --all-files +``` + +GitHub Actions runs the same `pre-commit` command in CI and also runs the test suite with `uv run pytest`. + +## Contributing + +1. **Pre-commit hooks** – install and run them before pushing (see [Pre-commit](#pre-commit) above). CI enforces the same checks. +2. **Docstrings** – add a NumPy-style docstring to every public function, method, and class. Include a one-line summary, `Parameters`, and `Returns` sections where applicable. +3. **README updates** – if you add a new reconstructor or distortion, append a row to the corresponding table. Keep descriptions concise (one sentence). +4. **Tests** – add or update tests under `tests/` for any new behaviour. Run the full suite with `uv run pytest` before opening a PR. +5. **Branching** – open a feature branch, keep commits focused, and open a pull request against `main`. diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index 7d6643e..1523d10 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -1,623 +1,732 @@ -"""Inference OASIS reconstructors for k-space distortion operators. - -Usage: - python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data -""" - -from __future__ import annotations - -import argparse -import contextlib -import json -import os -import sys -from collections import OrderedDict -from pathlib import Path -from typing import Optional, Sequence, Union - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import matplotlib.pyplot as plt -import numpy as np -import torch -import deepinv as dinv -from torch.utils.data import DataLoader, Dataset - -try: - import nibabel as nib -except ImportError as exc: - raise ImportError( - "The OASIS example requires nibabel. Install the project dependencies " - "or add nibabel to your environment before running this script." - ) from exc - -from mri_recon.distortions import ( - AnisotropicResolutionReduction, - BaseDistortion, - DistortedKspaceMultiCoilMRI, - GaussianKspaceBiasField, - GaussianNoiseDistortion, - HannTaperResolutionReduction, - IsotropicResolutionReduction, - KaiserTaperResolutionReduction, - OffCenterAnisotropicGaussianKspaceBiasField, - PhaseEncodeGhostingDistortion, - RadialHighPassEmphasisDistortion, - RotationalMotionDistortion, - SegmentedTranslationMotionDistortion, - TranslationMotionDistortion, -) -from mri_recon.reconstruction import ( - ConjugateGradientReconstructor, - DeepImagePriorReconstructor, - OASISSinglecoilUnetReconstructor, - RAMReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - ZeroFilledReconstructor, -) - -REPO_ROOT = Path(__file__).resolve().parents[1] -REPORT_DIR = Path("reports") / "oasis_inference_plot" -DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" - -REPORT_DIR.mkdir(parents=True, exist_ok=True) - -ALGORITHMS = [ - # "zero-filled", - # "conjugate-gradient", - # "ram", - # "dip", - "tv-pgd", - # "wavelet-fista", - # "tv-fista", - # "tv-pdhg", - "oasis-unet", -] -DISTORTIONS = [ - # "Phase-encode ghosting", - # "Segmented translation motion", - # "Translation motion", - # "Rotational motion", - # "Off-center anisotropic Gaussian bias field", - # "Gaussian bias field", - # "Anisotropic LP", - # "Hann taper LP", - # "Kaiser taper LP", - "Radial high-pass emphasis", - # "Gaussian noise", - # "Isotropic LP", -] -METRICS = [ - "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", -] - - -@contextlib.contextmanager -def temp_seed( - rng: np.random.RandomState, - seed: Optional[Union[int, tuple[int, ...]]] = None, -): - """Temporarily set a NumPy random seed.""" - - if seed is None: - yield - return - - state = rng.get_state() - rng.seed(seed) - try: - yield - finally: - rng.set_state(state) - - -class MaskFunc: - """Random Cartesian undersampling mask matching the packaged OASIS checkpoints.""" - - def __init__( - self, - center_fractions: Sequence[float], - accelerations: Sequence[int], - seed: Optional[int] = None, - ) -> None: - if len(center_fractions) != len(accelerations): - raise ValueError("center_fractions and accelerations must have the same length.") - - self.center_fractions = list(center_fractions) - self.accelerations = list(accelerations) - self.rng = np.random.RandomState(seed) - - def __call__( - self, - shape: Sequence[int], - seed: Optional[Union[int, tuple[int, ...]]] = None, - ) -> torch.Tensor: - """Create a broadcastable mask for k-space shaped ``(..., H, W)``.""" - - if len(shape) < 2: - raise ValueError("Mask shape must have at least two dimensions.") - - with temp_seed(self.rng, seed): - center_fraction = self.rng.choice(self.center_fractions) - acceleration = self.rng.choice(self.accelerations) - num_cols = shape[-1] - num_low_freqs = round(num_cols * center_fraction) - - center_mask = np.zeros(num_cols, dtype=np.float32) - pad = (num_cols - num_low_freqs) // 2 - center_mask[pad : pad + num_low_freqs] = 1 - - accel_prob = (num_cols / acceleration - num_low_freqs) / ( - num_cols - num_low_freqs - ) - accel_mask = self.rng.uniform(size=num_cols) < accel_prob - - mask = np.maximum(center_mask, accel_mask.astype(np.float32)) - mask_shape = [1 for _ in shape] - mask_shape[-1] = num_cols - return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) - - -class OasisSliceDataset(Dataset): - """Load 2D OASIS slices from Analyze/NIfTI volumes listed in a split CSV.""" - - def __init__( - self, - split_csv: Path, - data_path: Path, - sample_rate: float = 1.0, - cache_size: int = 2, - ) -> None: - self.split_csv = Path(split_csv) - self.data_path = Path(data_path) - if not 0 < sample_rate <= 1.0: - raise ValueError("sample_rate must be in the range (0, 1].") - self.sample_rate = sample_rate - self.cache_size = max(0, cache_size) - self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() - self.raw_samples = self._create_sample_list() - - def __len__(self) -> int: - """Return the number of available slices.""" - - return len(self.raw_samples) - - def __getitem__(self, index: int) -> dict[str, object]: - """Return one complex-valued OASIS slice in repo tensor convention.""" - - subject_id, slice_num = self.raw_samples[index] - target_np = self._read_raw_slice(subject_id, slice_num) - real = torch.from_numpy(target_np) - x = torch.stack([real, torch.zeros_like(real)], dim=0) - return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} - - def _create_sample_list(self) -> list[tuple[str, int]]: - samples: list[tuple[str, int]] = [] - with self.split_csv.open("r", encoding="utf-8") as handle: - for line in handle: - row = [item.strip() for item in line.split(",")] - if not row or not row[0]: - continue - try: - total_slices = int(row[-1]) - except ValueError: - continue - - subject_id = row[0] - if self.sample_rate >= 1.0: - start = 0 - stop = total_slices - else: - mid = round(total_slices / 2) - half_span = round(total_slices * self.sample_rate / 2) - start = max(0, mid - half_span) - stop = min(total_slices, mid + half_span) - - for slice_num in range(start, stop): - samples.append((subject_id, slice_num)) - return samples - - def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: - volume = self._get_volume(subject_id) - return np.ascontiguousarray(volume[slice_num], dtype=np.float32) - - def _get_volume(self, subject_id: str) -> np.ndarray: - if self.cache_size > 0 and subject_id in self._volume_cache: - self._volume_cache.move_to_end(subject_id) - return self._volume_cache[subject_id] - - image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" - matches = sorted(image_glob.glob("*t88_gfc.img")) - if not matches: - raise FileNotFoundError( - f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." - ) - - image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) - volume = np.ascontiguousarray( - np.transpose(np.squeeze(image_data), (1, 0, 2)), - dtype=np.float32, - ) - - if self.cache_size > 0: - self._volume_cache[subject_id] = volume - if len(self._volume_cache) > self.cache_size: - self._volume_cache.popitem(last=False) - - return volume - - -def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to log-magnitude image for visualization.""" - - if kspace.ndim == 4: - kspace = kspace[0] - if kspace.ndim != 3 or kspace.shape[0] != 2: - raise ValueError( - f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" - ) - - kspace = kspace.detach().cpu() - kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) - magnitude = torch.log1p(torch.abs(kspace_complex)) - - lower = torch.quantile(magnitude, 0.05) - upper = torch.quantile(magnitude, 0.995) - if float(upper) > float(lower): - magnitude = magnitude.clamp(lower, upper) - magnitude = (magnitude - lower) / (upper - lower) - else: - mag_max = float(magnitude.max()) - if mag_max > 0.0: - magnitude = magnitude / mag_max - - return torch.sqrt(magnitude) - - -def save_kspace_plot( - y_clean: torch.Tensor, - y_distorted: torch.Tensor, - save_fn: Path, - distortion_name: str, -) -> None: - """Save clean and distorted k-space magnitude plots.""" - - images = [ - ("Original k-space", _kspace_to_log_magnitude(y_clean)), - ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), - ] - - fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) - fig.suptitle(f"Distortion: {distortion_name}") - for ax, (title, image) in zip(axes, images, strict=True): - ax.imshow(image.numpy(), cmap="magma") - ax.set_title(title) - ax.axis("off") - fig.savefig(save_fn, dpi=200, bbox_inches="tight") - plt.close(fig) - - -def resolve_oasis_checkpoint( - checkpoint: Optional[Path], - acceleration: int, - manifest_path: Path, -) -> Path: - """Resolve an explicit or packaged OASIS checkpoint path.""" - - if checkpoint is not None: - return checkpoint.expanduser().resolve() - - with manifest_path.open("r", encoding="utf-8") as handle: - manifest = json.load(handle) - - checkpoints = manifest.get("checkpoints", {}) - key = str(acceleration) - if key not in checkpoints: - available = ", ".join(sorted(checkpoints)) - raise ValueError( - f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." - ) - - filename = Path(checkpoints[key]["filename"]) - if filename.is_absolute(): - return filename - return (manifest_path.parent.parent / filename).resolve() - - -def choose_algorithm( - name: str, - checkpoint_file: Path, - img_size: tuple = (640, 368), - device: torch.device = "cpu", - verbose: bool = False, -) -> dinv.models.Reconstructor: - """Construct a reconstructor by selector name.""" - - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=verbose) - case "tv-fista": - return TVFISTAReconstructor(n_iter=200, verbose=verbose) - case "tv-pdhg": - return TVPDHGReconstructor(n_iter=100, verbose=verbose) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) - case "oasis-unet" | "unet": - return OASISSinglecoilUnetReconstructor( - checkpoint_file=str(checkpoint_file), - device=device, - ) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - -def choose_distortion(name: str) -> BaseDistortion: - """Construct a k-space distortion by display name.""" - - match name: - case "Phase-encode ghosting": - return PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - case "Anisotropic LP": - return AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.25, - ) - case "Hann taper LP": - return HannTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - ) - case "Kaiser taper LP": - return KaiserTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - beta=8.6, - ) - case "Radial high-pass emphasis": - return RadialHighPassEmphasisDistortion(alpha=0.4) - case "Isotropic LP": - return IsotropicResolutionReduction(radius_fraction=0.1) - case "Off-center anisotropic Gaussian bias field": - return OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.2, - width_y_fraction=0.35, - center_x_fraction=0.15, - center_y_fraction=-0.1, - edge_gain=0.3, - ) - case "Translation motion": - return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) - case "Rotational motion": - return RotationalMotionDistortion(angle_radians=torch.pi / 6) - case "Segmented translation motion": - return SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 20.0, 50.0, -50.0), - shift_y_pixels=(0.0, 10.0, -20.0, 20.0), - ) - case "Gaussian bias field": - return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - case "Gaussian noise": - return GaussianNoiseDistortion(sigma=0.00001) - case _: - raise ValueError(f"Unknown distortion {name!r}") - - -def choose_metric(name: str) -> dinv.metric.Metric: - """Construct a DeepInverse metric by selector name.""" - - match name: - case "PSNR": - return dinv.metric.PSNR(max_pixel=None, complex_abs=True) - case "NMSE": - return dinv.metric.NMSE(complex_abs=True) - case "SSIM": - return dinv.metric.SSIM(max_pixel=None, complex_abs=True) - case "HaarPSI": - return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) - case "BlurStrength": - return dinv.metric.BlurStrength(complex_abs=True) - case "SharpnessIndex": - return dinv.metric.SharpnessIndex(complex_abs=True) - case _: - raise ValueError(f"Unknown metric {name!r}") - - -def build_parser() -> argparse.ArgumentParser: - """Build the command-line parser.""" - - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--source", - type=Path, - required=True, - help="OASIS root directory containing subject folders.", - ) - parser.add_argument( - "--split_csv", - type=Path, - default=DEFAULT_SPLIT_CSV, - help="CSV listing OASIS subjects and slice counts.", - ) - parser.add_argument( - "--manifest", - type=Path, - default=DEFAULT_MANIFEST_PATH, - help="Checkpoint manifest JSON.", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", - ) - parser.add_argument( - "--acceleration", - type=int, - default=4, - help="Packaged OASIS checkpoint acceleration factor.", - ) - parser.add_argument( - "--center_fraction", - type=float, - default=0.08, - help="Center fraction used by the random Cartesian sampling mask.", - ) - parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) - parser.add_argument( - "--algorithm", - type=str, - default="", - choices=ALGORITHMS, - help="Reconstruction algorithm applied to distorted OASIS k-space.", - ) - parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") - parser.add_argument( - "--sample_rate", - type=float, - default=1.0, - help="Fraction of slices per volume to include from the split CSV.", - ) - parser.add_argument("--volume_cache_size", type=int, default=2) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose output for reconstructors that support it.", - ) - return parser - - -def main() -> None: - """Run OASIS inference plots.""" - - args = build_parser().parse_args() - args.source = args.source.expanduser().resolve() - args.split_csv = args.split_csv.expanduser().resolve() - args.manifest = args.manifest.expanduser().resolve() - checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) - - device = dinv.utils.get_device() - dataset = OasisSliceDataset( - split_csv=args.split_csv, - data_path=args.source, - sample_rate=args.sample_rate, - cache_size=args.volume_cache_size, - ) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False) - mask_func = MaskFunc( - center_fractions=[args.center_fraction], - accelerations=[args.acceleration], - ) - metrics = [choose_metric(m) for m in METRICS] - - for i, batch in enumerate(iter(dataloader)): - if i >= args.num_samples: - break - - x = batch["x"].to(device) - subject_id = batch["subject_id"][0] - slice_num = int(batch["slice_num"][0]) - mask = mask_func(x.shape, seed=tuple(map(ord, subject_id))).to(device) - mask_2d = mask.reshape(-1).view(1, -1).expand(x.shape[-2], x.shape[-1]) - - for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: - distortion = choose_distortion(distortion_name) - - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), - mask=mask_2d, - img_size=(1, 2, *x.shape[-2:]), - coil_maps=1, - device=device, - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, - mask=mask_2d, - img_size=(1, 2, *x.shape[-2:]), - coil_maps=1, - device=device, - ) - - y = physics_clean(x) - y_distorted = physics(x) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: - print( - f"Evaluating algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}..." - ) - - algo = choose_algorithm( - algo_name, - checkpoint_file=checkpoint_file, - img_size=x.shape[-2:], - device=device, - verbose=args.verbose, - ).to(device) - - x_uncorrected = algo(y_distorted, physics_clean) - x_corrected = algo(y_distorted, physics) - - dinv.utils.plot( - { - "Ground truth OASIS slice": x, - "Distorted ksp, CG recon": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join( - f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" - for m in metrics - ), - "\n".join( - f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" - for m in metrics - ), - ], - show=False, - close=True, - suptitle=( - f"Algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}" - ), - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, - ) - - print("done!") - - -if __name__ == "__main__": - main() +"""Inference OASIS reconstructors for k-space distortion operators. + +Usage: + python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from collections import OrderedDict +from pathlib import Path +from typing import Optional + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +import numpy as np +import torch +import deepinv as dinv +from torch.utils.data import DataLoader, Dataset + +try: + import nibabel as nib +except ImportError as exc: + raise ImportError( + "The OASIS example requires nibabel. Install the project dependencies " + "or add nibabel to your environment before running this script." + ) from exc + +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + CartesianUndersampling, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + TranslationMotionDistortion, +) +from mri_recon.reconstruction import ( + ConjugateGradientReconstructor, + DeepImagePriorReconstructor, + OASISSinglecoilUnetReconstructor, + RAMReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + ZeroFilledReconstructor, +) + +REPO_ROOT = Path(__file__).resolve().parents[1] +REPORT_DIR = Path("reports") / "oasis_inference_plot" +DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" +DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" + +REPORT_DIR.mkdir(parents=True, exist_ok=True) + +ALGORITHMS = [ + # "zero-filled", + # "conjugate-gradient", + # "ram", + # "dip", + "tv-pgd", + # "wavelet-fista", + # "tv-fista", + # "tv-pdhg", + "oasis-unet", + "unet", +] +DISTORTIONS = [ + "None", + "Cartesian undersampling (variable density)", + "Cartesian undersampling (uniform random)", + "Cartesian undersampling (uniform random, zero ACS)", + "Cartesian undersampling (equispaced)", + "Cartesian undersampling (equispaced, zero ACS)", + "Phase-encode ghosting", + "Segmented translation motion", + "Translation motion", + "Rotational motion", + "Off-center anisotropic Gaussian bias field", + "Gaussian bias field", + "Anisotropic LP", + "Hann taper LP", + "Kaiser taper LP", + "Radial high-pass emphasis", + "Gaussian noise", + "Isotropic LP", +] +METRICS = [ + "PSNR", + "NMSE", + "SSIM", + "HaarPSI", + "SharpnessIndex", + "BlurStrength", +] + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes. + + Parameters + ---------- + split_csv : Path + CSV file listing OASIS subjects and slice counts. + data_path : Path + Root directory containing OASIS subject folders. + sample_rate : float, optional + Fraction of slices to include from each volume. + cache_size : int, optional + Number of loaded volumes to keep in memory. + """ + + def __init__( + self, + split_csv: Path, + data_path: Path, + sample_rate: float = 1.0, + cache_size: int = 2, + ) -> None: + self.split_csv = Path(split_csv) + self.data_path = Path(data_path) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.cache_size = max(0, cache_size) + self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + target_np = self._read_raw_slice(subject_id, slice_num) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples: list[tuple[str, int]] = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + total_slices = int(row[-1]) + except ValueError: + continue + + subject_id = row[0] + if self.sample_rate >= 1.0: + start = 0 + stop = total_slices + else: + mid = round(total_slices / 2) + half_span = round(total_slices * self.sample_rate / 2) + start = max(0, mid - half_span) + stop = min(total_slices, mid + half_span) + + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: + volume = self._get_volume(subject_id) + return np.ascontiguousarray(volume[slice_num], dtype=np.float32) + + def _get_volume(self, subject_id: str) -> np.ndarray: + if self.cache_size > 0 and subject_id in self._volume_cache: + self._volume_cache.move_to_end(subject_id) + return self._volume_cache[subject_id] + + image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if not matches: + raise FileNotFoundError( + f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." + ) + + image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + + if self.cache_size > 0: + self._volume_cache[subject_id] = volume + if len(self._volume_cache) > self.cache_size: + self._volume_cache.popitem(last=False) + + return volume + + +def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: + """Convert k-space tensor to log-magnitude image for visualization.""" + + kspace = kspace.detach().cpu() + if kspace.ndim == 5: + kspace = kspace[0] + if kspace.ndim == 4: + if kspace.shape[0] == 1 and kspace.shape[1] == 2: + kspace = kspace[0] + elif kspace.shape[0] != 2: + raise ValueError( + "Expected k-space with shape (2, H, W), (1, 2, H, W), " + f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" + ) + if kspace.ndim != 3 and kspace.ndim != 4: + raise ValueError( + "Expected k-space with shape (2, H, W), (1, 2, H, W), " + f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" + ) + if kspace.shape[0] != 2: + raise ValueError(f"Expected real/imaginary channel first, got {tuple(kspace.shape)}") + + kspace_complex = torch.view_as_complex(torch.movedim(kspace, 0, -1).contiguous()) + magnitude = torch.abs(kspace_complex) + if magnitude.ndim == 3: + magnitude = torch.sqrt(torch.sum(magnitude.square(), dim=0)) + magnitude = torch.log1p(magnitude) + + lower = torch.quantile(magnitude, 0.05) + upper = torch.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = magnitude.clamp(lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return torch.sqrt(magnitude) + + +def save_kspace_plot( + y_clean: torch.Tensor, + y_distorted: torch.Tensor, + save_fn: Path, + distortion_name: str, +) -> None: + """Save clean and distorted k-space magnitude plots.""" + + images = [ + ("Original k-space", _kspace_to_log_magnitude(y_clean)), + ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), + ] + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + fig.suptitle(f"Distortion: {distortion_name}") + for ax, (title, image) in zip(axes, images, strict=True): + ax.imshow(image.numpy(), cmap="magma") + ax.set_title(title) + ax.axis("off") + fig.savefig(save_fn, dpi=200, bbox_inches="tight") + plt.close(fig) + + +def image_to_kspace(x: torch.Tensor) -> torch.Tensor: + """Convert channel-first complex images to centered k-space.""" + + x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) + y_complex = torch.fft.fftshift( + torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), + dim=(-2, -1), + ) + return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() + + +def kspace_to_image(y: torch.Tensor) -> torch.Tensor: + """Convert centered channel-first k-space to complex images.""" + + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + x_complex = torch.fft.ifft2( + torch.fft.ifftshift(y_complex, dim=(-2, -1)), + dim=(-2, -1), + norm="ortho", + ) + return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() + + +class OasisCenteredFFTPhysics: + """Physics adapter matching the OASIS U-Net FFT convention. + + Parameters + ---------- + distortion : BaseDistortion + K-space distortion applied after the centered FFT. + """ + + def __init__(self, distortion: BaseDistortion) -> None: + self.distortion = distortion + + def A(self, x: torch.Tensor) -> torch.Tensor: + """Apply centered FFT and k-space distortion. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Distorted centered k-space tensor. + """ + + return self.distortion.A(image_to_kspace(x)) + + def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + """Apply adjoint distortion and centered inverse FFT. + + Parameters + ---------- + y : torch.Tensor + Distorted centered k-space tensor. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + return kspace_to_image(self.distortion.A_adjoint(y)) + + +def resolve_oasis_checkpoint( + checkpoint: Optional[Path], + acceleration: int, + manifest_path: Path, +) -> Path: + """Resolve an explicit or packaged OASIS checkpoint path. + + Parameters + ---------- + checkpoint : Path or None + User-provided checkpoint path. + acceleration : int + Acceleration key used when loading from the manifest. + manifest_path : Path + JSON manifest with packaged checkpoint metadata. + + Returns + ------- + Path + Resolved checkpoint path. + """ + + if checkpoint is not None: + return checkpoint.expanduser().resolve() + + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + checkpoints = manifest.get("checkpoints", {}) + key = str(acceleration) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + filename = Path(checkpoints[key]["filename"]) + if filename.is_absolute(): + return filename + return (manifest_path.parent.parent / filename).resolve() + + +def choose_algorithm( + name: str, + checkpoint_file: Path, + img_size: tuple = (640, 368), + device: torch.device = "cpu", + verbose: bool = False, +) -> dinv.models.Reconstructor: + """Construct a reconstructor by selector name.""" + + match name: + case "zero-filled": + return ZeroFilledReconstructor() + case "conjugate-gradient": + return ConjugateGradientReconstructor(max_iter=20) + case "ram": + return RAMReconstructor(default_sigma=0.05, device=device) + case "dip": + return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) + case "tv-pgd": + return TVPGDReconstructor(n_iter=100, verbose=verbose) + case "tv-fista": + return TVFISTAReconstructor(n_iter=200, verbose=verbose) + case "tv-pdhg": + return TVPDHGReconstructor(n_iter=100, verbose=verbose) + case "wavelet-fista": + return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) + case "oasis-unet" | "unet": + return OASISSinglecoilUnetReconstructor( + checkpoint_file=str(checkpoint_file), + device=device, + ) + case _: + raise ValueError(f"Unknown algorithm {name!r}") + + +def choose_distortion( + name: str, + acceleration: int, + center_fraction: float, +) -> BaseDistortion: + """Construct a k-space distortion by display name.""" + + match name: + case "None": + return BaseDistortion() + case "Cartesian undersampling (variable density)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="variable_density_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (uniform random)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="uniform_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (uniform random, zero ACS)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=0.0, + pattern="uniform_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (equispaced)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="equispaced", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (equispaced, zero ACS)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=0.0, + pattern="equispaced", + axis=-1, + seed=42, + ) + case "Phase-encode ghosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "Anisotropic LP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.25, + ) + case "Hann taper LP": + return HannTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + ) + case "Kaiser taper LP": + return KaiserTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + beta=8.6, + ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "Isotropic LP": + return IsotropicResolutionReduction(radius_fraction=0.1) + case "Off-center anisotropic Gaussian bias field": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "Translation motion": + return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) + case "Rotational motion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "Segmented translation motion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 20.0, 50.0, -50.0), + shift_y_pixels=(0.0, 10.0, -20.0, 20.0), + ) + case "Gaussian bias field": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "Gaussian noise": + return GaussianNoiseDistortion(sigma=0.00001) + case _: + raise ValueError(f"Unknown distortion {name!r}") + + +def choose_metric(name: str) -> dinv.metric.Metric: + """Construct a DeepInverse metric by selector name.""" + + match name: + case "PSNR": + return dinv.metric.PSNR(max_pixel=None, complex_abs=True) + case "NMSE": + return dinv.metric.NMSE(complex_abs=True) + case "SSIM": + return dinv.metric.SSIM(max_pixel=None, complex_abs=True) + case "HaarPSI": + return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) + case "BlurStrength": + return dinv.metric.BlurStrength(complex_abs=True) + case "SharpnessIndex": + return dinv.metric.SharpnessIndex(complex_abs=True) + case _: + raise ValueError(f"Unknown metric {name!r}") + + +def build_parser() -> argparse.ArgumentParser: + """Build the command-line parser.""" + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source", + type=Path, + required=True, + help="OASIS root directory containing subject folders.", + ) + parser.add_argument( + "--split_csv", + type=Path, + default=DEFAULT_SPLIT_CSV, + help="CSV listing OASIS subjects and slice counts.", + ) + parser.add_argument( + "--manifest", + type=Path, + default=DEFAULT_MANIFEST_PATH, + help="Checkpoint manifest JSON.", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", + ) + parser.add_argument( + "--acceleration", + type=int, + default=4, + help="Packaged OASIS checkpoint acceleration factor.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.08, + help="Center fraction used by the Cartesian undersampling distortion.", + ) + parser.add_argument( + "--distortion", + type=str, + default="Cartesian undersampling (uniform random)", + choices=DISTORTIONS, + ) + parser.add_argument( + "--algorithm", + type=str, + default="unet", + choices=ALGORITHMS, + help="Reconstruction algorithm applied to distorted OASIS k-space.", + ) + parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") + parser.add_argument( + "--sample_rate", + type=float, + default=0.6, + help="Fraction of slices per volume to include from the split CSV.", + ) + parser.add_argument("--volume_cache_size", type=int, default=2) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose output for reconstructors that support it.", + ) + return parser + + +def main() -> None: + """Run OASIS inference plots.""" + + args = build_parser().parse_args() + args.source = args.source.expanduser().resolve() + args.split_csv = args.split_csv.expanduser().resolve() + args.manifest = args.manifest.expanduser().resolve() + checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) + + device = dinv.utils.get_device() + dataset = OasisSliceDataset( + split_csv=args.split_csv, + data_path=args.source, + sample_rate=args.sample_rate, + cache_size=args.volume_cache_size, + ) + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + metrics = [choose_metric(m) for m in METRICS] + + for i, batch in enumerate(iter(dataloader)): + if i >= args.num_samples: + break + + x = batch["x"].to(device) + subject_id = batch["subject_id"][0] + slice_num = int(batch["slice_num"][0]) + + for distortion_name in [args.distortion]: + distortion = choose_distortion( + distortion_name, + acceleration=args.acceleration, + center_fraction=args.center_fraction, + ) + + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *x.shape[-2:]), + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, + img_size=(1, 2, *x.shape[-2:]), + device=device, + ) + oasis_physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) + oasis_physics = OasisCenteredFFTPhysics(distortion) + + y = image_to_kspace(x) + y_distorted = distortion.A(y) + y_physics_distorted = physics(x) + x_distorted = kspace_to_image(y_distorted) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}..." + ) + + algo = choose_algorithm( + algo_name, + checkpoint_file=checkpoint_file, + img_size=x.shape[-2:], + device=device, + verbose=args.verbose, + ).to(device) + + if algo_name in {"oasis-unet", "unet"}: + y_eval = y_distorted + eval_physics_clean = oasis_physics_clean + eval_physics = oasis_physics + else: + y_eval = y_physics_distorted + eval_physics_clean = physics_clean + eval_physics = physics + + x_uncorrected = algo(y_eval, eval_physics_clean) + x_corrected = algo(y_eval, eval_physics) + uncorrected_scores = [ + f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" for m in metrics + ] + corrected_scores = [ + f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" for m in metrics + ] + print(f" uncorrected: {', '.join(uncorrected_scores)}") + print(f" corrected: {', '.join(corrected_scores)}") + + dinv.utils.plot( + { + "Ground truth OASIS slice": x, + "Distorted ksp, zero-filled": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join(uncorrected_scores), + "\n".join(corrected_scores), + ], + show=False, + close=True, + suptitle=( + f"Algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}" + ), + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + + print("done!") + + +if __name__ == "__main__": + main() diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index 4a41674..afa6c39 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -1,274 +1,275 @@ -"""Cartesian k-space undersampling distortions.""" - -from __future__ import annotations - -import torch - -from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion - - -PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} - -# Lorentzian offset that controls how steeply the variable-density weights -# fall off with normalized distance from k-space center. -# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio -# between the ACS boundary and the k-space edge. -_VD_WEIGHT_OFFSET: float = 0.05 - - -class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): - """Cartesian k-space undersampling along phase-encode direction. - - This distortion simulates true MRI acquisition undersampling by applying a - binary sampling mask along the phase-encode direction (default). The mask - keeps a contiguous center region (ACS - Auto-Calibration Signal) fully - sampled and randomly undersamples the periphery. - - The mask is deterministic given a shape and seed, ensuring reproducibility. - - :param float keep_fraction: Fraction of phase-encode lines to keep in - ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. - :param float center_fraction: Fraction of phase-encode lines reserved for - the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to - ``0.5 * keep_fraction``, leaving part of the acquisition budget for - randomized peripheral line sampling. Set to ``0`` to sample without a - guaranteed ACS block. - :param str pattern: Peripheral sampling pattern. Supported values are - ``"uniform_random"``, ``"variable_density_random"``, and - ``"equispaced"``. Defaults to ``"variable_density_random"``. - :param int axis: Axis along which to apply undersampling. Default is -2 - (phase-encode for 4D tensors), can also be -3 for 5D tensors. - :param int | None seed: Random seed for reproducible mask generation. - If None, uses unseeded randomness (not recommended for reproducibility). - Has no effect when ``pattern="equispaced"`` because that pattern is - fully deterministic and uses no randomness. - """ - - def __init__( - self, - keep_fraction: float = 0.25, - center_fraction: float | None = None, - pattern: str = "variable_density_random", - axis: int = -2, - seed: int | None = None, - ) -> None: - super().__init__() - - if not 0.0 < keep_fraction <= 1.0: - raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") - - if axis not in (-2, -3): - raise ValueError(f"axis must be -2 or -3, got {axis}") - - if pattern not in PATTERNS: - raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") - - if center_fraction is None: - # Reserve half of the kept lines for a contiguous ACS block and - # leave the remainder for peripheral sampling. - center_fraction = 0.5 * keep_fraction - elif not 0.0 <= center_fraction <= 1.0: - raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") - - if center_fraction > keep_fraction: - raise ValueError( - f"center_fraction ({center_fraction}) must not exceed " - f"keep_fraction ({keep_fraction})" - ) - - self.keep_fraction = keep_fraction - self.center_fraction = center_fraction - self.pattern = pattern - self.axis = axis - self.seed = seed - self._cached_mask = None - self._cached_shape = None - self._cached_device: torch.device | None = None - - def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: - """Generate a binary Cartesian undersampling mask. - - The mask is applied along the specified axis (phase-encode by default). - It keeps a contiguous center region fully sampled when requested and - randomly undersamples the periphery to achieve the desired keep_fraction. - - :param tuple[int, ...] shape: k-space tensor shape. - :param torch.device device: Device for the mask. - :returns: Binary mask broadcastable to shape. - :rtype: torch.Tensor - """ - # Cache the mask keyed on both shape and device so that repeated GPU - # forward passes do not perform a CPU→GPU copy on every call. - if ( - self._cached_mask is not None - and self._cached_shape == shape - and self._cached_device == device - ): - return self._cached_mask - - # Get the size along the undersampling axis - axis_size = shape[self.axis] - - # Set random seed for reproducibility - rng_state = None - if self.seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.seed) - - try: - # Generate the 1D mask along the specified axis - mask_1d = self._generate_1d_mask(axis_size) - - # Expand the 1D mask to match the full shape - mask = self._expand_mask_to_shape(mask_1d, shape) - finally: - # Restore random state if we set a seed - if self.seed is not None and rng_state is not None: - torch.set_rng_state(rng_state) - - # Move to the target device and cache, including the device. - mask = mask.to(device) - self._cached_mask = mask - self._cached_shape = shape - self._cached_device = device - - return mask - - def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: - """Generate a 1D binary mask along the undersampling axis. - - :param int axis_size: Size of the axis along which to undersample. - :returns: 1D binary mask of shape (axis_size,). - :rtype: torch.Tensor - """ - # Calculate number of lines to keep and center region size. - num_keep = max(1, int(round(axis_size * self.keep_fraction))) - num_center = int(round(axis_size * self.center_fraction)) - - # Ensure center is not larger than total kept lines - num_center = min(num_center, num_keep) - - # Initialize mask (all zeros) - mask = torch.zeros(axis_size, dtype=torch.float32) - - # Keep the contiguous center region (ACS) when requested. - center_start = (axis_size - num_center) // 2 - center_end = center_start + num_center - if num_center > 0: - mask[center_start:center_end] = 1.0 - - peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) - - # Sample from the peripheral region with the requested pattern. - if num_keep > num_center: - num_peripheral = num_keep - num_center - selected_indices = self._select_peripheral_indices( - peripheral_indices=peripheral_indices, - num_peripheral=num_peripheral, - axis_size=axis_size, - ) - mask[selected_indices] = 1.0 - - return mask - - def _peripheral_indices( - self, - center_start: int, - center_end: int, - axis_size: int, - ) -> torch.Tensor: - """Return line indices outside the contiguous ACS region.""" - - return torch.cat( - [ - torch.arange(0, center_start, dtype=torch.long), - torch.arange(center_end, axis_size, dtype=torch.long), - ] - ) - - def _select_peripheral_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - axis_size: int, - ) -> torch.Tensor: - """Select peripheral lines according to the configured sampling pattern.""" - - if self.pattern == "uniform_random": - permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] - return peripheral_indices[permutation] - - if self.pattern == "variable_density_random": - center = 0.5 * (axis_size - 1) - distances = torch.abs(peripheral_indices.to(torch.float32) - center) - normalized_distances = distances / max(center, 1.0) - - # Favor low-frequency lines near the ACS boundary while preserving - # non-zero probability across the periphery. - weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) - selected_positions = torch.multinomial( - weights, - num_samples=num_peripheral, - replacement=False, - ) - return peripheral_indices[selected_positions] - - if self.pattern == "equispaced": - return self._select_equispaced_indices(peripheral_indices, num_peripheral) - - raise RuntimeError(f"Unsupported pattern {self.pattern!r}") - - def _select_equispaced_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - ) -> torch.Tensor: - """Select evenly spaced peripheral indices. - - Lines are spaced uniformly within the peripheral index array. - Because the peripheral array has a gap at the ACS region, the spacing - between selected k-space lines will be approximately doubled near the - ACS boundary compared to the spacing in the outer periphery. - This pattern is fully deterministic and unaffected by ``seed``. - """ - - if num_peripheral >= len(peripheral_indices): - return peripheral_indices - - # The early-return above guarantees step > 1.0, which in turn means - # floor((idx + 0.5) * step) is strictly increasing, so no collision - # resolution is needed. - step = len(peripheral_indices) / num_peripheral - positions = ( - (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() - ) - - return peripheral_indices[positions] - - def _expand_mask_to_shape( - self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] - ) -> torch.Tensor: - """Expand a 1D mask to the full k-space shape. - - The 1D mask is applied along the specified axis and broadcast to all - other dimensions. - - :param torch.Tensor mask_1d: 1D binary mask. - :param tuple[int, ...] target_shape: Target shape for expansion. - :returns: Mask broadcastable to target_shape. - :rtype: torch.Tensor - """ - # Convert negative axis to positive - ndim = len(target_shape) - axis = self.axis if self.axis >= 0 else ndim + self.axis - - # Create the full shape with 1s for broadcast dimensions - expand_shape = list(target_shape) - for i in range(len(expand_shape)): - if i != axis: - expand_shape[i] = 1 - - # Reshape the 1D mask to the expand shape - mask = mask_1d.reshape(expand_shape) - - return mask +"""Cartesian k-space undersampling distortions.""" + +from __future__ import annotations + +import torch + +from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion + + +PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} + +# Lorentzian offset that controls how steeply the variable-density weights +# fall off with normalized distance from k-space center. +# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio +# between the ACS boundary and the k-space edge. +_VD_WEIGHT_OFFSET: float = 0.05 + + +class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): + """Cartesian k-space undersampling along phase-encode direction. + + This distortion simulates true MRI acquisition undersampling by applying a + binary sampling mask along the phase-encode direction (default). The mask + keeps a contiguous center region (ACS - Auto-Calibration Signal) fully + sampled and randomly undersamples the periphery. + + The mask is deterministic given a shape and seed, ensuring reproducibility. + + :param float keep_fraction: Fraction of phase-encode lines to keep in + ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. + :param float center_fraction: Fraction of phase-encode lines reserved for + the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to + ``0.5 * keep_fraction``, leaving part of the acquisition budget for + randomized peripheral line sampling. Set to ``0`` to sample without a + guaranteed ACS block. + :param str pattern: Peripheral sampling pattern. Supported values are + ``"uniform_random"``, ``"variable_density_random"``, and + ``"equispaced"``. Defaults to ``"variable_density_random"``. + :param int axis: Axis along which to apply undersampling. Default is -2 + (phase-encode for 4D tensors), can also be -1 for readout/column + masking or -3 for the slice/depth axis in 5D tensors. + :param int | None seed: Random seed for reproducible mask generation. + If None, uses unseeded randomness (not recommended for reproducibility). + Has no effect when ``pattern="equispaced"`` because that pattern is + fully deterministic and uses no randomness. + """ + + def __init__( + self, + keep_fraction: float = 0.25, + center_fraction: float | None = None, + pattern: str = "variable_density_random", + axis: int = -2, + seed: int | None = None, + ) -> None: + super().__init__() + + if not 0.0 < keep_fraction <= 1.0: + raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") + + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + + if pattern not in PATTERNS: + raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") + + if center_fraction is None: + # Reserve half of the kept lines for a contiguous ACS block and + # leave the remainder for peripheral sampling. + center_fraction = 0.5 * keep_fraction + elif not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + + if center_fraction > keep_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"keep_fraction ({keep_fraction})" + ) + + self.keep_fraction = keep_fraction + self.center_fraction = center_fraction + self.pattern = pattern + self.axis = axis + self.seed = seed + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a binary Cartesian undersampling mask. + + The mask is applied along the specified axis (phase-encode by default). + It keeps a contiguous center region fully sampled when requested and + randomly undersamples the periphery to achieve the desired keep_fraction. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to shape. + :rtype: torch.Tensor + """ + # Cache the mask keyed on both shape and device so that repeated GPU + # forward passes do not perform a CPU→GPU copy on every call. + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + # Get the size along the undersampling axis + axis_size = shape[self.axis] + + # Set random seed for reproducibility + rng_state = None + if self.seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.seed) + + try: + # Generate the 1D mask along the specified axis + mask_1d = self._generate_1d_mask(axis_size) + + # Expand the 1D mask to match the full shape + mask = self._expand_mask_to_shape(mask_1d, shape) + finally: + # Restore random state if we set a seed + if self.seed is not None and rng_state is not None: + torch.set_rng_state(rng_state) + + # Move to the target device and cache, including the device. + mask = mask.to(device) + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D binary mask along the undersampling axis. + + :param int axis_size: Size of the axis along which to undersample. + :returns: 1D binary mask of shape (axis_size,). + :rtype: torch.Tensor + """ + # Calculate number of lines to keep and center region size. + num_keep = max(1, int(round(axis_size * self.keep_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + + # Ensure center is not larger than total kept lines + num_center = min(num_center, num_keep) + + # Initialize mask (all zeros) + mask = torch.zeros(axis_size, dtype=torch.float32) + + # Keep the contiguous center region (ACS) when requested. + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + if num_center > 0: + mask[center_start:center_end] = 1.0 + + peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) + + # Sample from the peripheral region with the requested pattern. + if num_keep > num_center: + num_peripheral = num_keep - num_center + selected_indices = self._select_peripheral_indices( + peripheral_indices=peripheral_indices, + num_peripheral=num_peripheral, + axis_size=axis_size, + ) + mask[selected_indices] = 1.0 + + return mask + + def _peripheral_indices( + self, + center_start: int, + center_end: int, + axis_size: int, + ) -> torch.Tensor: + """Return line indices outside the contiguous ACS region.""" + + return torch.cat( + [ + torch.arange(0, center_start, dtype=torch.long), + torch.arange(center_end, axis_size, dtype=torch.long), + ] + ) + + def _select_peripheral_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + axis_size: int, + ) -> torch.Tensor: + """Select peripheral lines according to the configured sampling pattern.""" + + if self.pattern == "uniform_random": + permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] + return peripheral_indices[permutation] + + if self.pattern == "variable_density_random": + center = 0.5 * (axis_size - 1) + distances = torch.abs(peripheral_indices.to(torch.float32) - center) + normalized_distances = distances / max(center, 1.0) + + # Favor low-frequency lines near the ACS boundary while preserving + # non-zero probability across the periphery. + weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) + selected_positions = torch.multinomial( + weights, + num_samples=num_peripheral, + replacement=False, + ) + return peripheral_indices[selected_positions] + + if self.pattern == "equispaced": + return self._select_equispaced_indices(peripheral_indices, num_peripheral) + + raise RuntimeError(f"Unsupported pattern {self.pattern!r}") + + def _select_equispaced_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + ) -> torch.Tensor: + """Select evenly spaced peripheral indices. + + Lines are spaced uniformly within the peripheral index array. + Because the peripheral array has a gap at the ACS region, the spacing + between selected k-space lines will be approximately doubled near the + ACS boundary compared to the spacing in the outer periphery. + This pattern is fully deterministic and unaffected by ``seed``. + """ + + if num_peripheral >= len(peripheral_indices): + return peripheral_indices + + # The early-return above guarantees step > 1.0, which in turn means + # floor((idx + 0.5) * step) is strictly increasing, so no collision + # resolution is needed. + step = len(peripheral_indices) / num_peripheral + positions = ( + (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() + ) + + return peripheral_indices[positions] + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape. + + The 1D mask is applied along the specified axis and broadcast to all + other dimensions. + + :param torch.Tensor mask_1d: 1D binary mask. + :param tuple[int, ...] target_shape: Target shape for expansion. + :returns: Mask broadcastable to target_shape. + :rtype: torch.Tensor + """ + # Convert negative axis to positive + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + # Create the full shape with 1s for broadcast dimensions + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + # Reshape the 1D mask to the expand shape + mask = mask_1d.reshape(expand_shape) + + return mask diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index a11a909..eafddf3 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -1,14 +1,14 @@ -from .deep import ( - RAMReconstructor, - DeepImagePriorReconstructor, - FastMRISinglecoilUnetReconstructor, - OASISSinglecoilUnetReconstructor, -) -from .classic import ( - ZeroFilledReconstructor, - ConjugateGradientReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, -) +from .deep import ( + RAMReconstructor, + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, +) +from .classic import ( + ZeroFilledReconstructor, + ConjugateGradientReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, +) diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index ba520f9..6da8ac1 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,268 +1,268 @@ -from pathlib import Path - -import deepinv as dinv -import torch - -from ._fastmri_unet import Unet -from ..utils import download_file_with_sha256, matches_sha256 - - -class RAMReconstructor(dinv.models.Reconstructor): - """ - Wrapper for RAM from DeepInverse. - Normalises input by magnitude of adjoint. - - :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. - """ - - def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = dinv.models.RAM(device=device) - self.default_sigma = default_sigma - - def forward(self, y, physics): - _x_adj = physics.A_adjoint(y) - scale = torch.quantile(_x_adj, 0.99) - - physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() - physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() - - if physics_norm > 1.2 or physics_norm < 0.8: - raise ValueError( - f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" - ) - if physics_adjointness > 0.1 or physics_adjointness < -0.1: - raise ValueError( - f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" - ) - - sigma = ( - None - if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") - else self.default_sigma - ) - - with torch.no_grad(): - return self.model(y / scale, physics, sigma=sigma) * scale - - -class DeepImagePriorReconstructor(dinv.models.Reconstructor): - """ - Wrapper for Deep Image Prior from DeepInverse. - - :param tuple img_size: image size of the output. Defaults to (640, 368) - :param int n_iter: number of iterations to fit the DIP. Defaults to 100. - """ - - def __init__( - self, - img_size: tuple = (640, 368), - n_iter: int = 100, - verbose: bool = True, - ) -> None: - super().__init__() - - lr = 1e-2 # learning rate for the optimizer. - channels = 64 # number of channels per layer in the decoder. - in_size = [2, 2] # size of the input to the decoder. - - self.model = dinv.models.DeepImagePrior( - dinv.models.ConvDecoder( - img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels - ), - learning_rate=lr, - iterations=n_iter, - verbose=verbose, - input_size=[channels] + in_size, - ) - - def forward(self, y, physics): - return self.model(y, physics) - - -class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): - """ - Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. - - Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. - - Note: this model discards complex information and only returns the magnitude image. - - NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). - - The pretrained fastMRI model expects magnitude images that are normalized per slice, - so this wrapper matches that preprocessing and rescales the output back to the - original adjoint-image intensity range. - - See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. - """ - - MODEL_URL = ( - "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" - "knee_sc_leaderboard_state_dict.pt" - ) - MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" - MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 256, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - - def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = Unet(**self.UNET_KWARGS) - - state_dict_path = ( - Path(state_dict_file).expanduser() - if state_dict_file is not None - else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME - ) - - if state_dict_file is None: - if not matches_sha256(state_dict_path, self.MODEL_SHA256): - download_file_with_sha256( - self.MODEL_URL, - state_dict_path, - self.MODEL_SHA256, - label="FastMRI UNet checkpoint", - ) - elif not state_dict_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") - - self.model.load_state_dict( - torch.load(state_dict_path, map_location=device, weights_only=True) - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - x_in = physics.A_adjoint(y) - - x_in = dinv.utils.complex_abs(x_in, keepdim=True) - - # Match the fastMRI normalization used for training, then rescale the - # predicted magnitude image back to the original adjoint-image intensity range. - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu # (B, 1, H, W) - - return torch.cat([out, torch.zeros_like(out)], dim=1) - - -def _load_unet_checkpoint_state( - checkpoint_path: Path, - device: torch.device, -) -> dict[str, torch.Tensor]: - """Load U-Net weights from a plain or Lightning-style checkpoint.""" - - try: - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - except TypeError: - checkpoint = torch.load(checkpoint_path, map_location=device) - - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: - state_dict = checkpoint["model_state_dict"] - elif isinstance(checkpoint, dict): - state_dict = checkpoint - else: - raise ValueError(f"Unsupported checkpoint format in {checkpoint_path}.") - - if any(key.startswith("unet.") for key in state_dict): - return { - key[len("unet.") :]: value - for key, value in state_dict.items() - if key.startswith("unet.") - } - - if any(key.startswith("model.unet.") for key in state_dict): - return { - key[len("model.unet.") :]: value - for key, value in state_dict.items() - if key.startswith("model.unet.") - } - - if any(key.startswith("module.") for key in state_dict): - return { - key[len("module.") :]: value - for key, value in state_dict.items() - if key.startswith("module.") - } - - return state_dict - - -class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): - """ - Wrapper for a trained OASIS single-coil U-Net reconstruction model. - - The model reuses the repository's fastMRI-derived :class:`Unet` module, but - loads an OASIS checkpoint supplied by the caller. The forward pass converts - k-space to a zero-filled magnitude image, applies per-slice instance - normalization, runs the U-Net, then rescales the prediction back to the - adjoint-image intensity range. - - :param str checkpoint_file: Path to the trained OASIS U-Net checkpoint. - :param torch.device device: Device on which to run inference. - """ - - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 32, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - - def __init__( - self, - checkpoint_file: str, - device: torch.device = None, - ) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - checkpoint_path = Path(checkpoint_file).expanduser() - if not checkpoint_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - - self.model = Unet(**self.UNET_KWARGS) - self.model.load_state_dict( - _load_unet_checkpoint_state(checkpoint_path, device), - strict=True, - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - """Reconstruct a magnitude image from measured k-space.""" - - x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu - - return torch.cat([out, torch.zeros_like(out)], dim=1) +from pathlib import Path + +import deepinv as dinv +import torch + +from ._fastmri_unet import Unet +from ..utils import download_file_with_sha256, matches_sha256 + + +class RAMReconstructor(dinv.models.Reconstructor): + """ + Wrapper for RAM from DeepInverse. + Normalises input by magnitude of adjoint. + + :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. + """ + + def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = dinv.models.RAM(device=device) + self.default_sigma = default_sigma + + def forward(self, y, physics): + _x_adj = physics.A_adjoint(y) + scale = torch.quantile(_x_adj, 0.99) + + physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() + physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() + + if physics_norm > 1.2 or physics_norm < 0.8: + raise ValueError( + f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" + ) + if physics_adjointness > 0.1 or physics_adjointness < -0.1: + raise ValueError( + f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" + ) + + sigma = ( + None + if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") + else self.default_sigma + ) + + with torch.no_grad(): + return self.model(y / scale, physics, sigma=sigma) * scale + + +class DeepImagePriorReconstructor(dinv.models.Reconstructor): + """ + Wrapper for Deep Image Prior from DeepInverse. + + :param tuple img_size: image size of the output. Defaults to (640, 368) + :param int n_iter: number of iterations to fit the DIP. Defaults to 100. + """ + + def __init__( + self, + img_size: tuple = (640, 368), + n_iter: int = 100, + verbose: bool = True, + ) -> None: + super().__init__() + + lr = 1e-2 # learning rate for the optimizer. + channels = 64 # number of channels per layer in the decoder. + in_size = [2, 2] # size of the input to the decoder. + + self.model = dinv.models.DeepImagePrior( + dinv.models.ConvDecoder( + img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels + ), + learning_rate=lr, + iterations=n_iter, + verbose=verbose, + input_size=[channels] + in_size, + ) + + def forward(self, y, physics): + return self.model(y, physics) + + +class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): + """ + Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. + + Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. + + Note: this model discards complex information and only returns the magnitude image. + + NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). + + The pretrained fastMRI model expects magnitude images that are normalized per slice, + so this wrapper matches that preprocessing and rescales the output back to the + original adjoint-image intensity range. + + See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. + """ + + MODEL_URL = ( + "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" + "knee_sc_leaderboard_state_dict.pt" + ) + MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" + MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 256, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = Unet(**self.UNET_KWARGS) + + state_dict_path = ( + Path(state_dict_file).expanduser() + if state_dict_file is not None + else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME + ) + + if state_dict_file is None: + if not matches_sha256(state_dict_path, self.MODEL_SHA256): + download_file_with_sha256( + self.MODEL_URL, + state_dict_path, + self.MODEL_SHA256, + label="FastMRI UNet checkpoint", + ) + elif not state_dict_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") + + self.model.load_state_dict( + torch.load(state_dict_path, map_location=device, weights_only=True) + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + x_in = physics.A_adjoint(y) + + x_in = dinv.utils.complex_abs(x_in, keepdim=True) + + # Match the fastMRI normalization used for training, then rescale the + # predicted magnitude image back to the original adjoint-image intensity range. + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu # (B, 1, H, W) + + return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + # Accept either a checkpoint with "state_dict" or a plain state_dict + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # If Lightning-style keys like 'unet.*' exist, strip the 'unet.' prefix + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """Wrapper for a trained OASIS single-coil U-Net model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + loads an OASIS checkpoint supplied by the caller. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + Parameters + ---------- + checkpoint_file : str + Path to the trained OASIS U-Net checkpoint. + device : torch.device, optional + Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__( + self, + checkpoint_file: str, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Use the helper to obtain a normalized state_dict (handles plain or Lightning) + state_dict = _load_unet_checkpoint_state(checkpoint_path, device) + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + state_dict, + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space. + + Parameters + ---------- + y : torch.Tensor + Measured k-space tensor. + physics : dinv.physics.Physics + Physics operator used to form the adjoint input image. + + Returns + ------- + torch.Tensor + Complex-valued reconstruction with zero imaginary channel. + """ + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) diff --git a/tests/test_distortions.py b/tests/test_distortions.py index ddf348a..f7362d0 100644 --- a/tests/test_distortions.py +++ b/tests/test_distortions.py @@ -1,919 +1,935 @@ -import os -import sys - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from math import sqrt -import pytest -import torch - -from mri_recon.distortions import ( - AnisotropicResolutionReduction, - BaseDistortion, - CartesianUndersampling, - DistortedKspaceMultiCoilMRI, - GaussianKspaceBiasField, - GaussianNoiseDistortion, - HannTaperResolutionReduction, - IsotropicResolutionReduction, - KaiserTaperResolutionReduction, - OffCenterAnisotropicGaussianKspaceBiasField, - PhaseEncodeGhostingDistortion, - RadialHighPassEmphasisDistortion, - RotationalMotionDistortion, - SegmentedTranslationMotionDistortion, - SelfAdjointMultiplicativeMaskDistortion, - TranslationMotionDistortion, -) - -DISTORTIONS = [ - "None", - "Isotropic LP", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Cartesian undersampling", - "Radial high-pass emphasis", - "Gaussian bias field", - "Off-center anisotropic Gaussian bias field", - "Phase-encode ghosting", - "Translation motion", - "Rotational motion", - "Segmented translation motion", -] - -EXACT_OPERATOR_DISTORTIONS = { - "None", - "Isotropic LP", - "Anisotropic LP", - "Cartesian undersampling", - "Phase-encode ghosting", - "Translation motion", - "Segmented translation motion", - "Rotational motion", -} -NON_EXPANSIVE_DISTORTIONS = { - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", -} - - -def choose_distortion(name): - match name: - case "None": - return BaseDistortion() - case "Isotropic LP": - return IsotropicResolutionReduction(radius_fraction=0.6) - case "Anisotropic LP": - return AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.35, - ) - case "Hann taper LP": - return HannTaperResolutionReduction( - radius_fraction=0.6, - transition_fraction=0.25, - ) - case "Kaiser taper LP": - return KaiserTaperResolutionReduction( - radius_fraction=0.6, - transition_fraction=0.25, - beta=8.6, - ) - case "Cartesian undersampling": - return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.2, - seed=42, - ) - case "Radial high-pass emphasis": - return RadialHighPassEmphasisDistortion(alpha=0.4) - case "Gaussian bias field": - return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - case "Off-center anisotropic Gaussian bias field": - return OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.2, - width_y_fraction=0.35, - center_x_fraction=0.15, - center_y_fraction=-0.1, - edge_gain=0.3, - ) - case "Phase-encode ghosting": - return PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - case "Translation motion": - return TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) - case "Rotational motion": - return RotationalMotionDistortion(angle_radians=torch.pi / 6) - case "Segmented translation motion": - return SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 2.0, 5.0, 5.0), - shift_y_pixels=(0.0, 1.0, 2.0, 2.0), - ) - case _: - raise ValueError(f"Unknown distortion {name!r}") - - -@pytest.fixture -def device(): - return "cpu" - - -@pytest.mark.parametrize("name", DISTORTIONS) -@pytest.mark.parametrize( - "img_size", [(1, 2, 256, 256), (1, 2, 4, 256, 256)] -) # singlecoil and multicoil -def test_distortion_properties(name, img_size, device): - """ - Test exact operators for adjointness and norm preservation, and verify that - approximate or attenuation operators remain shape-preserving and non-expansive. - """ - distortion = choose_distortion(name) - y = torch.randn(img_size, device=device) - x_dummy = torch.randn(1, 2, *img_size[-2:], device=device) - - if name in EXACT_OPERATOR_DISTORTIONS: - assert distortion.adjointness_test(x_dummy) < 0.01 - assert abs(distortion.compute_norm(x_dummy, squared=False) - 1) < 0.01 - elif name in NON_EXPANSIVE_DISTORTIONS: - y_distorted = distortion.A(x_dummy) - assert y_distorted.shape == x_dummy.shape - assert torch.max(torch.abs(y_distorted)) <= torch.max(torch.abs(x_dummy)) + 1e-6 - - if len(img_size) == 4: # singlecoil - coil_maps = None - elif len(img_size) == 5: # multicoil - coil_maps = torch.ones(1, *img_size[-3:], device=device, dtype=torch.complex64) / sqrt( - img_size[-3] - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, img_size=(1, 2, *y.shape[-2:]), coil_maps=coil_maps, device=device - ) - - if name in EXACT_OPERATOR_DISTORTIONS: - assert physics.adjointness_test(x_dummy) < 0.01 - assert abs(physics.compute_norm(x_dummy, squared=False) - 1) < 0.01 - elif name in NON_EXPANSIVE_DISTORTIONS: - y_physics = physics.A(x_dummy) - y_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), - img_size=(1, 2, *y.shape[-2:]), - coil_maps=coil_maps, - device=device, - ).A(x_dummy) - assert torch.max(torch.abs(y_physics)) <= torch.max(torch.abs(y_clean)) + 1e-6 - - -def test_gaussian_noise_distortion_preserves_shape_and_changes_values(device): - distortion = GaussianNoiseDistortion(sigma=0.1) - y = torch.zeros((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert y_distorted.shape == y.shape - assert y_distorted.dtype == y.dtype - assert not torch.equal(y_distorted, y) - - -def test_gaussian_noise_distortion_zero_sigma_is_identity(device): - distortion = GaussianNoiseDistortion(sigma=0.0) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_anisotropic_resolution_reduction_zeroes_only_filtered_axis(device): - distortion = AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.3, - ) - y = torch.ones((1, 2, 9, 11), device=device) - - y_distorted = distortion.A(y) - mask = distortion._mask(y.shape, y.device) - - assert torch.all(y_distorted == y * mask) - assert torch.all(mask[y.shape[-2] // 2, :] == 1) - assert torch.all(mask[0, :] == 0) - - -def test_anisotropic_resolution_reduction_identity_at_full_cutoffs(device): - distortion = AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=1.0, - ) - y = torch.randn((1, 2, 32, 32), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_hann_taper_resolution_reduction_has_smooth_transition(device): - distortion = HannTaperResolutionReduction( - radius_fraction=0.8, - transition_fraction=0.5, - ) - - mask = distortion._mask((1, 2, 33, 33), torch.device(device)) - - assert mask[16, 16] == pytest.approx(1.0) - assert mask[0, 0] == pytest.approx(0.0) - assert torch.any((mask > 0.0) & (mask < 1.0)) - - -def test_hann_taper_resolution_reduction_zero_transition_matches_hard_cutoff(device): - hard = IsotropicResolutionReduction(radius_fraction=0.6) - smooth = HannTaperResolutionReduction(radius_fraction=0.6, transition_fraction=0.0) - y = torch.randn((1, 2, 64, 64), device=device) - - assert torch.equal(smooth.A(y), hard.A(y)) - - -def test_kaiser_taper_resolution_reduction_has_smooth_transition(device): - distortion = KaiserTaperResolutionReduction( - radius_fraction=0.8, - transition_fraction=0.5, - beta=8.6, - ) - - mask = distortion._mask((1, 2, 33, 33), torch.device(device)) - - assert mask[16, 16] == pytest.approx(1.0) - assert mask[0, 0] == pytest.approx(0.0) - assert torch.any((mask > 0.0) & (mask < 1.0)) - - -def test_kaiser_taper_resolution_reduction_zero_transition_matches_hard_cutoff(device): - hard = IsotropicResolutionReduction(radius_fraction=0.6) - smooth = KaiserTaperResolutionReduction( - radius_fraction=0.6, - transition_fraction=0.0, - beta=8.6, - ) - y = torch.randn((1, 2, 64, 64), device=device) - - assert torch.equal(smooth.A(y), hard.A(y)) - - -def test_radial_high_pass_emphasis_distortion_boosts_edges_more_than_center(device): - distortion = RadialHighPassEmphasisDistortion(alpha=0.4) - shape = (1, 2, 33, 33) - center_y = shape[-2] // 2 - center_x = shape[-1] // 2 - - mask = distortion._mask(shape, torch.device(device)) - - assert mask[center_y, center_x] == pytest.approx(1.0) - assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) - assert torch.all(mask >= 1.0) - assert mask[center_y, center_x + 4] == pytest.approx(1.0) - assert torch.any((mask > 1.0) & (mask < 1.0 + distortion.alpha)) - - -def test_radial_high_pass_emphasis_distortion_zero_alpha_is_identity(device): - distortion = RadialHighPassEmphasisDistortion(alpha=0.0) - y = torch.randn((1, 2, 64, 64), device=device) - - assert torch.equal(distortion.A(y), y) - - -def test_radial_high_pass_emphasis_distortion_respects_custom_band(device): - distortion = RadialHighPassEmphasisDistortion( - alpha=0.4, - boost_start_radius=0.7, - boost_end_radius=0.95, - ) - shape = (1, 2, 65, 65) - center_y = shape[-2] // 2 - center_x = shape[-1] // 2 - - mask = distortion._mask(shape, torch.device(device)) - - assert mask[center_y, center_x + 12] == pytest.approx(1.0) - assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) - - -def test_centered_isotropic_bias_matches_anisotropic_special_case(device): - centered = GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - anisotropic = OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.35, - width_y_fraction=0.35, - center_x_fraction=0.0, - center_y_fraction=0.0, - edge_gain=0.4, - ) - y = torch.randn((1, 2, 64, 64), device=device) - - y_centered = centered.A(y) - y_anisotropic = anisotropic.A(y) - - assert torch.allclose(y_centered, y_anisotropic) - - -def test_translation_motion_zero_shift_is_identity(device): - distortion = TranslationMotionDistortion(shift_x_pixels=0.0, shift_y_pixels=0.0) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_translation_motion_preserves_kspace_magnitude(device): - distortion = TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) - - assert torch.allclose(torch.abs(y_distorted_complex), torch.abs(y_complex)) - - -def test_translation_motion_produces_requested_image_shift(device): - distortion = TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) - image = torch.zeros((64, 64), dtype=torch.complex64, device=device) - image[20, 18] = 1.0 - - kspace = torch.fft.fftshift(torch.fft.fft2(image)) - y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() - - y_distorted = distortion.A(y) - kspace_distorted = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) - image_distorted = torch.fft.ifft2(torch.fft.ifftshift(kspace_distorted)) - max_position = torch.nonzero(torch.abs(image_distorted) == torch.abs(image_distorted).max())[ - 0 - ].tolist() - - assert max_position == [24, 26] - - -def test_translation_motion_rejects_invalid_kspace_shape(device): - distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) - y = torch.randn((1, 64, 64), device=device) - - with pytest.raises(ValueError, match="Expected k-space with shape"): - distortion.A(y) - - -def test_translation_motion_rejects_invalid_channel_dimension(device): - distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) - y = torch.randn((1, 3, 64, 64), device=device) - - with pytest.raises(ValueError, match="channel dimension of size 2"): - distortion.A(y) - - -def test_translation_motion_rejects_non_floating_tensor(device): - distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) - y = torch.zeros((1, 2, 64, 64), device=device, dtype=torch.int64) - - with pytest.raises(TypeError, match="floating-point"): - distortion.A(y) - - -def test_rotational_motion_zero_angle_is_identity(device): - distortion = RotationalMotionDistortion(angle_radians=0.0) - y = torch.randn((1, 2, 64, 64), device=device) - - assert torch.equal(distortion.A(y), y) - assert torch.equal(distortion.A_adjoint(y), y) - - -def test_rotational_motion_rejects_invalid_kspace_shape(device): - distortion = RotationalMotionDistortion(angle_radians=torch.pi / 8) - y = torch.randn((1, 64, 64), device=device) - - with pytest.raises(ValueError, match="Expected k-space with shape"): - distortion.A(y) - - -def test_rotational_motion_rejects_non_floating_tensor(device): - distortion = RotationalMotionDistortion(angle_radians=torch.pi / 8) - y = torch.zeros((1, 2, 64, 64), device=device, dtype=torch.int64) - - with pytest.raises(TypeError, match="floating-point"): - distortion.A(y) - - -def test_rotational_motion_rotates_image_content(device): - angle_radians = -0.5 * torch.pi - distortion = RotationalMotionDistortion(angle_radians=angle_radians) - image = torch.zeros((63, 63), dtype=torch.complex64, device=device) - image[31, 40] = 1.0 - - kspace = torch.fft.fftshift(torch.fft.fft2(image)) - y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() - - y_distorted = distortion.A(y) - kspace_distorted = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) - image_distorted = torch.fft.ifft2(torch.fft.ifftshift(kspace_distorted)) - magnitude = torch.abs(image_distorted) - max_index = magnitude.reshape(-1).argmax() - max_position = torch.tensor(torch.unravel_index(max_index, magnitude.shape), device=device) - - assert torch.equal(max_position, torch.tensor([40, 32], device=device)) - - -def test_rotational_motion_uses_matched_adjoint(device): - distortion = RotationalMotionDistortion(angle_radians=torch.pi / 6) - x = torch.randn((1, 2, 64, 64), device=device) - y = torch.randn((1, 2, 64, 64), device=device) - - lhs = torch.sum(distortion.A(x) * y) - rhs = torch.sum(x * distortion.A_adjoint(y)) - - assert torch.allclose(lhs, rhs, atol=1e-4, rtol=1e-4) - - -def test_phase_encode_ghosting_zero_phase_and_unit_scale_is_identity(device): - distortion = PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=0.0, - corrupted_line_scale=1.0, - ) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_phase_encode_ghosting_modulates_selected_lines(device): - distortion = PhaseEncodeGhostingDistortion( - line_period=3, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=0.75, - ghost_axis=-2, - ) - y = torch.randn((2, 2, 3, 18, 20), device=device) - - y_distorted = distortion.A(y) - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) - - line_factor = torch.polar( - torch.tensor(0.75, device=device), - torch.tensor(torch.pi / 2, device=device), - ) - for line_index in range(y.shape[-2]): - if (line_index - distortion.line_offset) % distortion.line_period == 0: - expected = y_complex[:, :, line_index, :] * line_factor - else: - expected = y_complex[:, :, line_index, :] - actual = y_distorted_complex[:, :, line_index, :] - assert torch.allclose(actual, expected) - - -def test_phase_encode_ghosting_creates_half_fov_replica_for_partial_alternating_phase(device): - distortion = PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - image = torch.zeros((32, 32), dtype=torch.complex64, device=device) - image[5, 7] = 1.0 - - kspace = torch.fft.fftshift(torch.fft.fft2(image)) - y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() - - y_distorted = distortion.A(y) - ghosted_kspace = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) - ghosted_image = torch.fft.ifft2(torch.fft.ifftshift(ghosted_kspace)) - - peak_locations = torch.nonzero(torch.abs(ghosted_image) > 0.3) - - assert peak_locations.shape[0] == 2 - assert [5, 7] in peak_locations.tolist() - assert [21, 7] in peak_locations.tolist() - - -def test_segmented_translation_motion_zero_shift_is_identity(device): - distortion = SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 0.0, 0.0), - shift_y_pixels=(0.0, 0.0, 0.0), - ) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_segmented_translation_motion_matches_phase_ramp_per_phase_encode_segment(device): - distortion = SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 1.5, -2.0), - shift_y_pixels=(0.0, 3.0, -1.0), - segment_axis=-2, - ) - y = torch.randn((2, 2, 3, 18, 20), device=device) - - y_distorted = distortion.A(y) - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) - - for segment_slice, shift_x, shift_y in zip( - distortion._segment_slices(y.shape), - distortion.shift_x_pixels, - distortion.shift_y_pixels, - strict=True, - ): - ramp = distortion._phase_ramp(y.shape, y.device, shift_x, shift_y) - expected = y_complex[:, :, segment_slice, :] * ramp[segment_slice, :] - actual = y_distorted_complex[:, :, segment_slice, :] - assert torch.allclose(actual, expected) - - -def test_segmented_translation_motion_matches_phase_ramp_per_readout_segment(device): - distortion = SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 2.0), - shift_y_pixels=(0.0, 1.0), - segment_axis=-1, - ) - y = torch.randn((1, 2, 16, 14), device=device) - - y_distorted = distortion.A(y) - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) - - for segment_slice, shift_x, shift_y in zip( - distortion._segment_slices(y.shape), - distortion.shift_x_pixels, - distortion.shift_y_pixels, - strict=True, - ): - ramp = distortion._phase_ramp(y.shape, y.device, shift_x, shift_y) - expected = y_complex[:, :, segment_slice] * ramp[:, segment_slice] - actual = y_distorted_complex[:, :, segment_slice] - assert torch.allclose(actual, expected) - - -def test_segmented_translation_motion_rejects_too_many_segments(device): - distortion = SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 1.0, 2.0, 3.0, 4.0), - shift_y_pixels=(0.0, 1.0, 2.0, 3.0, 4.0), - ) - y = torch.randn((1, 2, 4, 64), device=device) - - with pytest.raises(ValueError, match="non-empty segments"): - distortion.A(y) - - -def test_segmented_translation_motion_keeps_zero_motion_segment_and_modulates_shifted_segment( - device, -): - distortion = SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 4.0), - shift_y_pixels=(0.0, 2.0), - ) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) - first_segment, second_segment = distortion._segment_slices(y.shape) - - first_segment_expected = y_complex[:, first_segment, :] - first_segment_actual = y_distorted_complex[:, first_segment, :] - - ramp = distortion._phase_ramp( - y.shape, - y.device, - shift_x_pixels=distortion.shift_x_pixels[1], - shift_y_pixels=distortion.shift_y_pixels[1], - ) - second_segment_expected = y_complex[:, second_segment, :] * ramp[second_segment, :] - second_segment_actual = y_distorted_complex[:, second_segment, :] - - assert torch.allclose(first_segment_actual, first_segment_expected) - assert torch.allclose(second_segment_actual, second_segment_expected) - - -@pytest.mark.parametrize( - "distortion_cls", - [ - IsotropicResolutionReduction, - AnisotropicResolutionReduction, - HannTaperResolutionReduction, - KaiserTaperResolutionReduction, - RadialHighPassEmphasisDistortion, - ], -) -def test_resolution_reduction_classes_inherit_from_self_adjoint_multiplicative_mask( - distortion_cls, -): - """Verify that all resolution-reduction classes are subclasses of the shared super class.""" - assert issubclass(distortion_cls, SelfAdjointMultiplicativeMaskDistortion) - assert issubclass(distortion_cls, BaseDistortion) - - -def test_self_adjoint_multiplicative_mask_distortion_requires_mask_implementation(device): - """Verify that the base super class raises NotImplementedError when _mask is not overridden.""" - - class IncompleteDistortion(SelfAdjointMultiplicativeMaskDistortion): - pass - - distortion = IncompleteDistortion() - y = torch.randn((1, 2, 8, 8), device=device) - - with pytest.raises(NotImplementedError): - distortion.A(y) - - -def test_self_adjoint_multiplicative_mask_distortion_a_adjoint_equals_a(device): - """Verify that A_adjoint equals A for a concrete SelfAdjointMultiplicativeMaskDistortion.""" - distortion = IsotropicResolutionReduction(radius_fraction=0.7) - y = torch.randn((1, 2, 32, 32), device=device) - - assert torch.equal(distortion.A(y), distortion.A_adjoint(y)) - - -def test_cartesian_undersampling_rejects_invalid_keep_fraction(device): - """Verify that keep_fraction must be in (0, 1].""" - with pytest.raises(ValueError, match="keep_fraction must be in"): - CartesianUndersampling(keep_fraction=0.0) - - with pytest.raises(ValueError, match="keep_fraction must be in"): - CartesianUndersampling(keep_fraction=1.5) - - -def test_cartesian_undersampling_rejects_invalid_center_fraction(device): - """Verify that center_fraction must be in [0, 1].""" - with pytest.raises(ValueError, match="center_fraction must be in"): - CartesianUndersampling(keep_fraction=0.5, center_fraction=-0.1) - - with pytest.raises(ValueError, match="center_fraction must be in"): - CartesianUndersampling(keep_fraction=0.5, center_fraction=1.5) - - -def test_cartesian_undersampling_rejects_center_fraction_exceeding_keep_fraction(device): - """Verify that center_fraction cannot exceed keep_fraction.""" - with pytest.raises(ValueError, match="center_fraction.*must not exceed"): - CartesianUndersampling(keep_fraction=0.2, center_fraction=0.5) - - -def test_cartesian_undersampling_rejects_invalid_axis(device): - """Verify that axis must be -2 or -3.""" - with pytest.raises(ValueError, match="axis must be -2 or -3"): - CartesianUndersampling(axis=-1) - - with pytest.raises(ValueError, match="axis must be -2 or -3"): - CartesianUndersampling(axis=0) - - -def test_cartesian_undersampling_rejects_invalid_pattern(device): - """Verify that pattern must be one of the supported sampling strategies.""" - with pytest.raises(ValueError, match="pattern must be one of"): - CartesianUndersampling(pattern="spiral") - - -def test_cartesian_undersampling_identity_at_full_keep(device): - """Verify that keep_fraction=1.0 is the identity.""" - distortion = CartesianUndersampling(keep_fraction=1.0) - y = torch.randn((1, 2, 64, 64), device=device) - - y_distorted = distortion.A(y) - - assert torch.equal(y_distorted, y) - - -def test_cartesian_undersampling_default_center_fraction_leaves_peripheral_budget(device): - """Verify that the default ACS fraction is strictly smaller than keep_fraction.""" - distortion = CartesianUndersampling(keep_fraction=0.3) - - assert distortion.center_fraction == pytest.approx(0.15) - assert distortion.center_fraction < distortion.keep_fraction - assert distortion.pattern == "variable_density_random" - - -def test_cartesian_undersampling_achieves_target_keep_rate(device): - """Verify that the mask achieves approximately the target keep_fraction.""" - distortion = CartesianUndersampling(keep_fraction=0.25, center_fraction=0.15, seed=42) - shape = (1, 2, 64, 100) # Use different sizes to test phase-encode and readout - mask = distortion._mask(shape, torch.device(device)) - - # Count the number of non-zero elements along the phase-encode axis (-2) - num_ones = torch.sum(mask, dim=-2) # Sum along phase-encode axis - total_lines = shape[-2] - actual_keep_fraction = num_ones[0, 0, 0].item() / total_lines - - # Allow 2-3% tolerance for rounding - assert abs(actual_keep_fraction - 0.25) < 0.03 - - -def test_cartesian_undersampling_preserves_center_acs_region(device): - """Verify that the center ACS region is fully sampled.""" - distortion = CartesianUndersampling(keep_fraction=0.4, center_fraction=0.3, seed=42) - shape = (1, 2, 100, 100) - mask = distortion._mask(shape, torch.device(device)) - - # Extract the 1D mask along phase-encode axis - mask_1d = mask[0, 0, :, 0] - - # Center region should be fully sampled (all ones) - total_lines = shape[-2] - center_lines = int(round(total_lines * 0.3)) - center_start = (total_lines - center_lines) // 2 - center_end = center_start + center_lines - - assert torch.all(mask_1d[center_start:center_end] == 1.0) - - -def test_cartesian_undersampling_zero_center_fraction_has_no_forced_acs(device): - """Verify that center_fraction=0 does not force a contiguous ACS block.""" - distortion = CartesianUndersampling( - keep_fraction=0.5, - center_fraction=0.0, - pattern="equispaced", - ) - shape = (1, 2, 64, 64) - mask_1d = distortion._mask(shape, torch.device(device))[0, 0, :, 0] - - center_slice = mask_1d[30:34] - - assert not torch.all(center_slice == 1.0) - assert torch.sum(mask_1d).item() == 32 - - -def test_cartesian_undersampling_mask_is_deterministic_with_seed(device): - """Verify that the mask is reproducible with the same seed.""" - distortion1 = CartesianUndersampling(keep_fraction=0.3, seed=123) - distortion2 = CartesianUndersampling(keep_fraction=0.3, seed=123) - - shape = (1, 2, 32, 32) - mask1 = distortion1._mask(shape, torch.device(device)) - mask2 = distortion2._mask(shape, torch.device(device)) - - assert torch.equal(mask1, mask2) - - -def test_cartesian_undersampling_mask_differs_with_different_seed(device): - """Verify that different seeds produce different masks.""" - # Use a larger keep_fraction so peripheral region has samples to randomize - distortion1 = CartesianUndersampling(keep_fraction=0.5, center_fraction=0.2, seed=123) - distortion2 = CartesianUndersampling(keep_fraction=0.5, center_fraction=0.2, seed=456) - - shape = (1, 2, 64, 64) - mask1 = distortion1._mask(shape, torch.device(device)) - mask2 = distortion2._mask(shape, torch.device(device)) - - # Masks should be different (with very high probability) - assert not torch.equal(mask1, mask2) - - -def test_cartesian_undersampling_variable_density_biases_toward_center(device): - """Verify that variable-density sampling favors lines closer to k-space center.""" - shape = (1, 2, 128, 64) - uniform = CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, - pattern="uniform_random", - seed=7, - ) - variable_density = CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, - pattern="variable_density_random", - seed=7, - ) - - uniform_mask = uniform._mask(shape, torch.device(device))[0, 0, :, 0] - variable_density_mask = variable_density._mask(shape, torch.device(device))[0, 0, :, 0] - center = 0.5 * (shape[-2] - 1) - distances = torch.abs(torch.arange(shape[-2], dtype=torch.float32) - center) - - uniform_peripheral_mean = distances[(uniform_mask == 1.0) & (distances > 0)].mean() - variable_density_peripheral_mean = distances[ - (variable_density_mask == 1.0) & (distances > 0) - ].mean() - - assert variable_density_peripheral_mean < uniform_peripheral_mean - - -def test_cartesian_undersampling_equispaced_pattern_is_seed_independent(device): - """Verify that the equispaced pattern is deterministic and does not depend on seed.""" - shape = (1, 2, 64, 64) - distortion1 = CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, - pattern="equispaced", - seed=123, - ) - distortion2 = CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, - pattern="equispaced", - seed=456, - ) - - mask1 = distortion1._mask(shape, torch.device(device)) - mask2 = distortion2._mask(shape, torch.device(device)) - - assert torch.equal(mask1, mask2) - - -def test_cartesian_undersampling_zero_acs_equispaced_half_keep_samples_every_second_line(device): - """Verify equispaced half sampling with no ACS selects every second line.""" - distortion = CartesianUndersampling( - keep_fraction=0.5, - center_fraction=0.0, - pattern="equispaced", - ) - - mask_1d = distortion._mask((1, 2, 64, 64), torch.device(device))[0, 0, :, 0] - sampled_indices = torch.where(mask_1d == 1.0)[0] - - assert torch.equal(sampled_indices, torch.arange(1, 64, 2, device=sampled_indices.device)) - - -def test_cartesian_undersampling_mask_caching(device): - """Verify that the mask is cached and reused for the same shape.""" - distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) - shape = (1, 2, 32, 32) - - mask1 = distortion._mask(shape, torch.device(device)) - mask2 = distortion._mask(shape, torch.device(device)) - - # Should be the same object (cached) - assert mask1.data_ptr() == mask2.data_ptr() - - -def test_cartesian_undersampling_is_self_adjoint(device): - """Verify adjointness property for CartesianUndersampling.""" - distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) - x = torch.randn((1, 2, 32, 32), device=device) - y = torch.randn((1, 2, 32, 32), device=device) - - # Test adjointness: = - lhs = torch.sum(distortion.A(x) * y) - rhs = torch.sum(x * distortion.A_adjoint(y)) - - assert torch.allclose(lhs, rhs, atol=1e-5) - - -def test_cartesian_undersampling_zero_center_fraction_is_self_adjoint(device): - """Verify adjointness still holds when no ACS block is forced.""" - distortion = CartesianUndersampling(keep_fraction=0.3, center_fraction=0.0, seed=42) - x = torch.randn((1, 2, 32, 32), device=device) - y = torch.randn((1, 2, 32, 32), device=device) - - lhs = torch.sum(distortion.A(x) * y) - rhs = torch.sum(x * distortion.A_adjoint(y)) - - assert torch.allclose(lhs, rhs, atol=1e-5) - - -def test_cartesian_undersampling_zeros_undersampled_lines(device): - """Verify that undersampled lines (mask=0) result in zero k-space.""" - distortion = CartesianUndersampling(keep_fraction=0.25, center_fraction=0.2, seed=42) - shape = (1, 2, 64, 64) - y = torch.ones(shape, device=device) # All ones - - y_distorted = distortion.A(y) - mask = distortion._mask(shape, torch.device(device)) - - # Expand mask to match y shape for comparison - expanded_mask = mask.expand(shape) - - # Undersampled lines should be zero - zero_mask = expanded_mask == 0 - assert torch.all(y_distorted[zero_mask] == 0.0) - - # Sampled lines should be unchanged - sampled_mask = expanded_mask == 1 - assert torch.all(y_distorted[sampled_mask] == y[sampled_mask]) - - -def test_cartesian_undersampling_works_with_3d_tensor(device): - """Verify that CartesianUndersampling works with 5D tensors (3D MRI).""" - distortion = CartesianUndersampling(keep_fraction=0.3, center_fraction=0.25, seed=42, axis=-3) - y = torch.randn((1, 2, 8, 32, 32), device=device) # 3D k-space - - y_distorted = distortion.A(y) - - assert y_distorted.shape == y.shape - - -def test_cartesian_undersampling_with_distorted_kspace_physics(device): - """Verify that CartesianUndersampling works with DistortedKspaceMultiCoilMRI.""" - distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, - img_size=(1, 2, 32, 32), - device=device, - ) - x = torch.randn((1, 2, 32, 32), device=device) - - y = physics.A(x) - - assert y.shape == (1, 2, 32, 32) +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from math import sqrt +import pytest +import torch + +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + CartesianUndersampling, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + SelfAdjointMultiplicativeMaskDistortion, + TranslationMotionDistortion, +) + +DISTORTIONS = [ + "None", + "Isotropic LP", + "Anisotropic LP", + "Hann taper LP", + "Kaiser taper LP", + "Cartesian undersampling", + "Radial high-pass emphasis", + "Gaussian bias field", + "Off-center anisotropic Gaussian bias field", + "Phase-encode ghosting", + "Translation motion", + "Rotational motion", + "Segmented translation motion", +] + +EXACT_OPERATOR_DISTORTIONS = { + "None", + "Isotropic LP", + "Anisotropic LP", + "Cartesian undersampling", + "Phase-encode ghosting", + "Translation motion", + "Segmented translation motion", + "Rotational motion", +} +NON_EXPANSIVE_DISTORTIONS = { + "Off-center anisotropic Gaussian bias field", + "Gaussian bias field", +} + + +def choose_distortion(name): + match name: + case "None": + return BaseDistortion() + case "Isotropic LP": + return IsotropicResolutionReduction(radius_fraction=0.6) + case "Anisotropic LP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.35, + ) + case "Hann taper LP": + return HannTaperResolutionReduction( + radius_fraction=0.6, + transition_fraction=0.25, + ) + case "Kaiser taper LP": + return KaiserTaperResolutionReduction( + radius_fraction=0.6, + transition_fraction=0.25, + beta=8.6, + ) + case "Cartesian undersampling": + return CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.2, + seed=42, + ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "Gaussian bias field": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "Off-center anisotropic Gaussian bias field": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "Phase-encode ghosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "Translation motion": + return TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) + case "Rotational motion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "Segmented translation motion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 2.0, 5.0, 5.0), + shift_y_pixels=(0.0, 1.0, 2.0, 2.0), + ) + case _: + raise ValueError(f"Unknown distortion {name!r}") + + +@pytest.fixture +def device(): + return "cpu" + + +@pytest.mark.parametrize("name", DISTORTIONS) +@pytest.mark.parametrize( + "img_size", [(1, 2, 256, 256), (1, 2, 4, 256, 256)] +) # singlecoil and multicoil +def test_distortion_properties(name, img_size, device): + """ + Test exact operators for adjointness and norm preservation, and verify that + approximate or attenuation operators remain shape-preserving and non-expansive. + """ + distortion = choose_distortion(name) + y = torch.randn(img_size, device=device) + x_dummy = torch.randn(1, 2, *img_size[-2:], device=device) + + if name in EXACT_OPERATOR_DISTORTIONS: + assert distortion.adjointness_test(x_dummy) < 0.01 + assert abs(distortion.compute_norm(x_dummy, squared=False) - 1) < 0.01 + elif name in NON_EXPANSIVE_DISTORTIONS: + y_distorted = distortion.A(x_dummy) + assert y_distorted.shape == x_dummy.shape + assert torch.max(torch.abs(y_distorted)) <= torch.max(torch.abs(x_dummy)) + 1e-6 + + if len(img_size) == 4: # singlecoil + coil_maps = None + elif len(img_size) == 5: # multicoil + coil_maps = torch.ones(1, *img_size[-3:], device=device, dtype=torch.complex64) / sqrt( + img_size[-3] + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, img_size=(1, 2, *y.shape[-2:]), coil_maps=coil_maps, device=device + ) + + if name in EXACT_OPERATOR_DISTORTIONS: + assert physics.adjointness_test(x_dummy) < 0.01 + assert abs(physics.compute_norm(x_dummy, squared=False) - 1) < 0.01 + elif name in NON_EXPANSIVE_DISTORTIONS: + y_physics = physics.A(x_dummy) + y_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *y.shape[-2:]), + coil_maps=coil_maps, + device=device, + ).A(x_dummy) + assert torch.max(torch.abs(y_physics)) <= torch.max(torch.abs(y_clean)) + 1e-6 + + +def test_gaussian_noise_distortion_preserves_shape_and_changes_values(device): + distortion = GaussianNoiseDistortion(sigma=0.1) + y = torch.zeros((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert y_distorted.shape == y.shape + assert y_distorted.dtype == y.dtype + assert not torch.equal(y_distorted, y) + + +def test_gaussian_noise_distortion_zero_sigma_is_identity(device): + distortion = GaussianNoiseDistortion(sigma=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_anisotropic_resolution_reduction_zeroes_only_filtered_axis(device): + distortion = AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.3, + ) + y = torch.ones((1, 2, 9, 11), device=device) + + y_distorted = distortion.A(y) + mask = distortion._mask(y.shape, y.device) + + assert torch.all(y_distorted == y * mask) + assert torch.all(mask[y.shape[-2] // 2, :] == 1) + assert torch.all(mask[0, :] == 0) + + +def test_anisotropic_resolution_reduction_identity_at_full_cutoffs(device): + distortion = AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=1.0, + ) + y = torch.randn((1, 2, 32, 32), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_hann_taper_resolution_reduction_has_smooth_transition(device): + distortion = HannTaperResolutionReduction( + radius_fraction=0.8, + transition_fraction=0.5, + ) + + mask = distortion._mask((1, 2, 33, 33), torch.device(device)) + + assert mask[16, 16] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(0.0) + assert torch.any((mask > 0.0) & (mask < 1.0)) + + +def test_hann_taper_resolution_reduction_zero_transition_matches_hard_cutoff(device): + hard = IsotropicResolutionReduction(radius_fraction=0.6) + smooth = HannTaperResolutionReduction(radius_fraction=0.6, transition_fraction=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + assert torch.equal(smooth.A(y), hard.A(y)) + + +def test_kaiser_taper_resolution_reduction_has_smooth_transition(device): + distortion = KaiserTaperResolutionReduction( + radius_fraction=0.8, + transition_fraction=0.5, + beta=8.6, + ) + + mask = distortion._mask((1, 2, 33, 33), torch.device(device)) + + assert mask[16, 16] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(0.0) + assert torch.any((mask > 0.0) & (mask < 1.0)) + + +def test_kaiser_taper_resolution_reduction_zero_transition_matches_hard_cutoff(device): + hard = IsotropicResolutionReduction(radius_fraction=0.6) + smooth = KaiserTaperResolutionReduction( + radius_fraction=0.6, + transition_fraction=0.0, + beta=8.6, + ) + y = torch.randn((1, 2, 64, 64), device=device) + + assert torch.equal(smooth.A(y), hard.A(y)) + + +def test_radial_high_pass_emphasis_distortion_boosts_edges_more_than_center(device): + distortion = RadialHighPassEmphasisDistortion(alpha=0.4) + shape = (1, 2, 33, 33) + center_y = shape[-2] // 2 + center_x = shape[-1] // 2 + + mask = distortion._mask(shape, torch.device(device)) + + assert mask[center_y, center_x] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) + assert torch.all(mask >= 1.0) + assert mask[center_y, center_x + 4] == pytest.approx(1.0) + assert torch.any((mask > 1.0) & (mask < 1.0 + distortion.alpha)) + + +def test_radial_high_pass_emphasis_distortion_zero_alpha_is_identity(device): + distortion = RadialHighPassEmphasisDistortion(alpha=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + assert torch.equal(distortion.A(y), y) + + +def test_radial_high_pass_emphasis_distortion_respects_custom_band(device): + distortion = RadialHighPassEmphasisDistortion( + alpha=0.4, + boost_start_radius=0.7, + boost_end_radius=0.95, + ) + shape = (1, 2, 65, 65) + center_y = shape[-2] // 2 + center_x = shape[-1] // 2 + + mask = distortion._mask(shape, torch.device(device)) + + assert mask[center_y, center_x + 12] == pytest.approx(1.0) + assert mask[0, 0] == pytest.approx(1.0 + distortion.alpha) + + +def test_centered_isotropic_bias_matches_anisotropic_special_case(device): + centered = GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + anisotropic = OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.35, + width_y_fraction=0.35, + center_x_fraction=0.0, + center_y_fraction=0.0, + edge_gain=0.4, + ) + y = torch.randn((1, 2, 64, 64), device=device) + + y_centered = centered.A(y) + y_anisotropic = anisotropic.A(y) + + assert torch.allclose(y_centered, y_anisotropic) + + +def test_translation_motion_zero_shift_is_identity(device): + distortion = TranslationMotionDistortion(shift_x_pixels=0.0, shift_y_pixels=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_translation_motion_preserves_kspace_magnitude(device): + distortion = TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) + + assert torch.allclose(torch.abs(y_distorted_complex), torch.abs(y_complex)) + + +def test_translation_motion_produces_requested_image_shift(device): + distortion = TranslationMotionDistortion(shift_x_pixels=8.0, shift_y_pixels=4.0) + image = torch.zeros((64, 64), dtype=torch.complex64, device=device) + image[20, 18] = 1.0 + + kspace = torch.fft.fftshift(torch.fft.fft2(image)) + y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() + + y_distorted = distortion.A(y) + kspace_distorted = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) + image_distorted = torch.fft.ifft2(torch.fft.ifftshift(kspace_distorted)) + max_position = torch.nonzero(torch.abs(image_distorted) == torch.abs(image_distorted).max())[ + 0 + ].tolist() + + assert max_position == [24, 26] + + +def test_translation_motion_rejects_invalid_kspace_shape(device): + distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) + y = torch.randn((1, 64, 64), device=device) + + with pytest.raises(ValueError, match="Expected k-space with shape"): + distortion.A(y) + + +def test_translation_motion_rejects_invalid_channel_dimension(device): + distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) + y = torch.randn((1, 3, 64, 64), device=device) + + with pytest.raises(ValueError, match="channel dimension of size 2"): + distortion.A(y) + + +def test_translation_motion_rejects_non_floating_tensor(device): + distortion = TranslationMotionDistortion(shift_x_pixels=1.0, shift_y_pixels=1.0) + y = torch.zeros((1, 2, 64, 64), device=device, dtype=torch.int64) + + with pytest.raises(TypeError, match="floating-point"): + distortion.A(y) + + +def test_rotational_motion_zero_angle_is_identity(device): + distortion = RotationalMotionDistortion(angle_radians=0.0) + y = torch.randn((1, 2, 64, 64), device=device) + + assert torch.equal(distortion.A(y), y) + assert torch.equal(distortion.A_adjoint(y), y) + + +def test_rotational_motion_rejects_invalid_kspace_shape(device): + distortion = RotationalMotionDistortion(angle_radians=torch.pi / 8) + y = torch.randn((1, 64, 64), device=device) + + with pytest.raises(ValueError, match="Expected k-space with shape"): + distortion.A(y) + + +def test_rotational_motion_rejects_non_floating_tensor(device): + distortion = RotationalMotionDistortion(angle_radians=torch.pi / 8) + y = torch.zeros((1, 2, 64, 64), device=device, dtype=torch.int64) + + with pytest.raises(TypeError, match="floating-point"): + distortion.A(y) + + +def test_rotational_motion_rotates_image_content(device): + angle_radians = -0.5 * torch.pi + distortion = RotationalMotionDistortion(angle_radians=angle_radians) + image = torch.zeros((63, 63), dtype=torch.complex64, device=device) + image[31, 40] = 1.0 + + kspace = torch.fft.fftshift(torch.fft.fft2(image)) + y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() + + y_distorted = distortion.A(y) + kspace_distorted = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) + image_distorted = torch.fft.ifft2(torch.fft.ifftshift(kspace_distorted)) + magnitude = torch.abs(image_distorted) + max_index = magnitude.reshape(-1).argmax() + max_position = torch.tensor(torch.unravel_index(max_index, magnitude.shape), device=device) + + assert torch.equal(max_position, torch.tensor([40, 32], device=device)) + + +def test_rotational_motion_uses_matched_adjoint(device): + distortion = RotationalMotionDistortion(angle_radians=torch.pi / 6) + x = torch.randn((1, 2, 64, 64), device=device) + y = torch.randn((1, 2, 64, 64), device=device) + + lhs = torch.sum(distortion.A(x) * y) + rhs = torch.sum(x * distortion.A_adjoint(y)) + + assert torch.allclose(lhs, rhs, atol=1e-4, rtol=1e-4) + + +def test_phase_encode_ghosting_zero_phase_and_unit_scale_is_identity(device): + distortion = PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=0.0, + corrupted_line_scale=1.0, + ) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_phase_encode_ghosting_modulates_selected_lines(device): + distortion = PhaseEncodeGhostingDistortion( + line_period=3, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=0.75, + ghost_axis=-2, + ) + y = torch.randn((2, 2, 3, 18, 20), device=device) + + y_distorted = distortion.A(y) + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) + + line_factor = torch.polar( + torch.tensor(0.75, device=device), + torch.tensor(torch.pi / 2, device=device), + ) + for line_index in range(y.shape[-2]): + if (line_index - distortion.line_offset) % distortion.line_period == 0: + expected = y_complex[:, :, line_index, :] * line_factor + else: + expected = y_complex[:, :, line_index, :] + actual = y_distorted_complex[:, :, line_index, :] + assert torch.allclose(actual, expected) + + +def test_phase_encode_ghosting_creates_half_fov_replica_for_partial_alternating_phase(device): + distortion = PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + image = torch.zeros((32, 32), dtype=torch.complex64, device=device) + image[5, 7] = 1.0 + + kspace = torch.fft.fftshift(torch.fft.fft2(image)) + y = torch.view_as_real(kspace).movedim(-1, 0).unsqueeze(0).contiguous() + + y_distorted = distortion.A(y) + ghosted_kspace = torch.view_as_complex(y_distorted[0].movedim(0, -1).contiguous()) + ghosted_image = torch.fft.ifft2(torch.fft.ifftshift(ghosted_kspace)) + + peak_locations = torch.nonzero(torch.abs(ghosted_image) > 0.3) + + assert peak_locations.shape[0] == 2 + assert [5, 7] in peak_locations.tolist() + assert [21, 7] in peak_locations.tolist() + + +def test_segmented_translation_motion_zero_shift_is_identity(device): + distortion = SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 0.0, 0.0), + shift_y_pixels=(0.0, 0.0, 0.0), + ) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_segmented_translation_motion_matches_phase_ramp_per_phase_encode_segment(device): + distortion = SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 1.5, -2.0), + shift_y_pixels=(0.0, 3.0, -1.0), + segment_axis=-2, + ) + y = torch.randn((2, 2, 3, 18, 20), device=device) + + y_distorted = distortion.A(y) + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) + + for segment_slice, shift_x, shift_y in zip( + distortion._segment_slices(y.shape), + distortion.shift_x_pixels, + distortion.shift_y_pixels, + strict=True, + ): + ramp = distortion._phase_ramp(y.shape, y.device, shift_x, shift_y) + expected = y_complex[:, :, segment_slice, :] * ramp[segment_slice, :] + actual = y_distorted_complex[:, :, segment_slice, :] + assert torch.allclose(actual, expected) + + +def test_segmented_translation_motion_matches_phase_ramp_per_readout_segment(device): + distortion = SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 2.0), + shift_y_pixels=(0.0, 1.0), + segment_axis=-1, + ) + y = torch.randn((1, 2, 16, 14), device=device) + + y_distorted = distortion.A(y) + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) + + for segment_slice, shift_x, shift_y in zip( + distortion._segment_slices(y.shape), + distortion.shift_x_pixels, + distortion.shift_y_pixels, + strict=True, + ): + ramp = distortion._phase_ramp(y.shape, y.device, shift_x, shift_y) + expected = y_complex[:, :, segment_slice] * ramp[:, segment_slice] + actual = y_distorted_complex[:, :, segment_slice] + assert torch.allclose(actual, expected) + + +def test_segmented_translation_motion_rejects_too_many_segments(device): + distortion = SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 1.0, 2.0, 3.0, 4.0), + shift_y_pixels=(0.0, 1.0, 2.0, 3.0, 4.0), + ) + y = torch.randn((1, 2, 4, 64), device=device) + + with pytest.raises(ValueError, match="non-empty segments"): + distortion.A(y) + + +def test_segmented_translation_motion_keeps_zero_motion_segment_and_modulates_shifted_segment( + device, +): + distortion = SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 4.0), + shift_y_pixels=(0.0, 2.0), + ) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + y_distorted_complex = torch.view_as_complex(y_distorted.movedim(1, -1).contiguous()) + first_segment, second_segment = distortion._segment_slices(y.shape) + + first_segment_expected = y_complex[:, first_segment, :] + first_segment_actual = y_distorted_complex[:, first_segment, :] + + ramp = distortion._phase_ramp( + y.shape, + y.device, + shift_x_pixels=distortion.shift_x_pixels[1], + shift_y_pixels=distortion.shift_y_pixels[1], + ) + second_segment_expected = y_complex[:, second_segment, :] * ramp[second_segment, :] + second_segment_actual = y_distorted_complex[:, second_segment, :] + + assert torch.allclose(first_segment_actual, first_segment_expected) + assert torch.allclose(second_segment_actual, second_segment_expected) + + +@pytest.mark.parametrize( + "distortion_cls", + [ + IsotropicResolutionReduction, + AnisotropicResolutionReduction, + HannTaperResolutionReduction, + KaiserTaperResolutionReduction, + RadialHighPassEmphasisDistortion, + ], +) +def test_resolution_reduction_classes_inherit_from_self_adjoint_multiplicative_mask( + distortion_cls, +): + """Verify that all resolution-reduction classes are subclasses of the shared super class.""" + assert issubclass(distortion_cls, SelfAdjointMultiplicativeMaskDistortion) + assert issubclass(distortion_cls, BaseDistortion) + + +def test_self_adjoint_multiplicative_mask_distortion_requires_mask_implementation(device): + """Verify that the base super class raises NotImplementedError when _mask is not overridden.""" + + class IncompleteDistortion(SelfAdjointMultiplicativeMaskDistortion): + pass + + distortion = IncompleteDistortion() + y = torch.randn((1, 2, 8, 8), device=device) + + with pytest.raises(NotImplementedError): + distortion.A(y) + + +def test_self_adjoint_multiplicative_mask_distortion_a_adjoint_equals_a(device): + """Verify that A_adjoint equals A for a concrete SelfAdjointMultiplicativeMaskDistortion.""" + distortion = IsotropicResolutionReduction(radius_fraction=0.7) + y = torch.randn((1, 2, 32, 32), device=device) + + assert torch.equal(distortion.A(y), distortion.A_adjoint(y)) + + +def test_cartesian_undersampling_rejects_invalid_keep_fraction(device): + """Verify that keep_fraction must be in (0, 1].""" + with pytest.raises(ValueError, match="keep_fraction must be in"): + CartesianUndersampling(keep_fraction=0.0) + + with pytest.raises(ValueError, match="keep_fraction must be in"): + CartesianUndersampling(keep_fraction=1.5) + + +def test_cartesian_undersampling_rejects_invalid_center_fraction(device): + """Verify that center_fraction must be in [0, 1].""" + with pytest.raises(ValueError, match="center_fraction must be in"): + CartesianUndersampling(keep_fraction=0.5, center_fraction=-0.1) + + with pytest.raises(ValueError, match="center_fraction must be in"): + CartesianUndersampling(keep_fraction=0.5, center_fraction=1.5) + + +def test_cartesian_undersampling_rejects_center_fraction_exceeding_keep_fraction(device): + """Verify that center_fraction cannot exceed keep_fraction.""" + with pytest.raises(ValueError, match="center_fraction.*must not exceed"): + CartesianUndersampling(keep_fraction=0.2, center_fraction=0.5) + + +def test_cartesian_undersampling_rejects_invalid_axis(device): + """Verify that axis must be one of the supported trailing k-space axes.""" + with pytest.raises(ValueError, match="axis must be -1, -2, or -3"): + CartesianUndersampling(axis=0) + + with pytest.raises(ValueError, match="axis must be -1, -2, or -3"): + CartesianUndersampling(axis=-4) + + +def test_cartesian_undersampling_rejects_invalid_pattern(device): + """Verify that pattern must be one of the supported sampling strategies.""" + with pytest.raises(ValueError, match="pattern must be one of"): + CartesianUndersampling(pattern="spiral") + + +def test_cartesian_undersampling_identity_at_full_keep(device): + """Verify that keep_fraction=1.0 is the identity.""" + distortion = CartesianUndersampling(keep_fraction=1.0) + y = torch.randn((1, 2, 64, 64), device=device) + + y_distorted = distortion.A(y) + + assert torch.equal(y_distorted, y) + + +def test_cartesian_undersampling_default_center_fraction_leaves_peripheral_budget(device): + """Verify that the default ACS fraction is strictly smaller than keep_fraction.""" + distortion = CartesianUndersampling(keep_fraction=0.3) + + assert distortion.center_fraction == pytest.approx(0.15) + assert distortion.center_fraction < distortion.keep_fraction + assert distortion.pattern == "variable_density_random" + + +def test_cartesian_undersampling_achieves_target_keep_rate(device): + """Verify that the mask achieves approximately the target keep_fraction.""" + distortion = CartesianUndersampling(keep_fraction=0.25, center_fraction=0.15, seed=42) + shape = (1, 2, 64, 100) # Use different sizes to test phase-encode and readout + mask = distortion._mask(shape, torch.device(device)) + + # Count the number of non-zero elements along the phase-encode axis (-2) + num_ones = torch.sum(mask, dim=-2) # Sum along phase-encode axis + total_lines = shape[-2] + actual_keep_fraction = num_ones[0, 0, 0].item() / total_lines + + # Allow 2-3% tolerance for rounding + assert abs(actual_keep_fraction - 0.25) < 0.03 + + +def test_cartesian_undersampling_preserves_center_acs_region(device): + """Verify that the center ACS region is fully sampled.""" + distortion = CartesianUndersampling(keep_fraction=0.4, center_fraction=0.3, seed=42) + shape = (1, 2, 100, 100) + mask = distortion._mask(shape, torch.device(device)) + + # Extract the 1D mask along phase-encode axis + mask_1d = mask[0, 0, :, 0] + + # Center region should be fully sampled (all ones) + total_lines = shape[-2] + center_lines = int(round(total_lines * 0.3)) + center_start = (total_lines - center_lines) // 2 + center_end = center_start + center_lines + + assert torch.all(mask_1d[center_start:center_end] == 1.0) + + +def test_cartesian_undersampling_can_mask_readout_axis(device): + """Verify that axis=-1 masks columns rather than phase-encode rows.""" + distortion = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="equispaced", + axis=-1, + ) + shape = (1, 2, 32, 64) + mask = distortion._mask(shape, torch.device(device)) + + assert mask.shape == (1, 1, 1, 64) + assert torch.sum(mask[0, 0, 0]).item() == 16 + assert torch.all(mask.expand(shape)[:, :, 0, :] == mask.expand(shape)[:, :, -1, :]) + + +def test_cartesian_undersampling_zero_center_fraction_has_no_forced_acs(device): + """Verify that center_fraction=0 does not force a contiguous ACS block.""" + distortion = CartesianUndersampling( + keep_fraction=0.5, + center_fraction=0.0, + pattern="equispaced", + ) + shape = (1, 2, 64, 64) + mask_1d = distortion._mask(shape, torch.device(device))[0, 0, :, 0] + + center_slice = mask_1d[30:34] + + assert not torch.all(center_slice == 1.0) + assert torch.sum(mask_1d).item() == 32 + + +def test_cartesian_undersampling_mask_is_deterministic_with_seed(device): + """Verify that the mask is reproducible with the same seed.""" + distortion1 = CartesianUndersampling(keep_fraction=0.3, seed=123) + distortion2 = CartesianUndersampling(keep_fraction=0.3, seed=123) + + shape = (1, 2, 32, 32) + mask1 = distortion1._mask(shape, torch.device(device)) + mask2 = distortion2._mask(shape, torch.device(device)) + + assert torch.equal(mask1, mask2) + + +def test_cartesian_undersampling_mask_differs_with_different_seed(device): + """Verify that different seeds produce different masks.""" + # Use a larger keep_fraction so peripheral region has samples to randomize + distortion1 = CartesianUndersampling(keep_fraction=0.5, center_fraction=0.2, seed=123) + distortion2 = CartesianUndersampling(keep_fraction=0.5, center_fraction=0.2, seed=456) + + shape = (1, 2, 64, 64) + mask1 = distortion1._mask(shape, torch.device(device)) + mask2 = distortion2._mask(shape, torch.device(device)) + + # Masks should be different (with very high probability) + assert not torch.equal(mask1, mask2) + + +def test_cartesian_undersampling_variable_density_biases_toward_center(device): + """Verify that variable-density sampling favors lines closer to k-space center.""" + shape = (1, 2, 128, 64) + uniform = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="uniform_random", + seed=7, + ) + variable_density = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="variable_density_random", + seed=7, + ) + + uniform_mask = uniform._mask(shape, torch.device(device))[0, 0, :, 0] + variable_density_mask = variable_density._mask(shape, torch.device(device))[0, 0, :, 0] + center = 0.5 * (shape[-2] - 1) + distances = torch.abs(torch.arange(shape[-2], dtype=torch.float32) - center) + + uniform_peripheral_mean = distances[(uniform_mask == 1.0) & (distances > 0)].mean() + variable_density_peripheral_mean = distances[ + (variable_density_mask == 1.0) & (distances > 0) + ].mean() + + assert variable_density_peripheral_mean < uniform_peripheral_mean + + +def test_cartesian_undersampling_equispaced_pattern_is_seed_independent(device): + """Verify that the equispaced pattern is deterministic and does not depend on seed.""" + shape = (1, 2, 64, 64) + distortion1 = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="equispaced", + seed=123, + ) + distortion2 = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="equispaced", + seed=456, + ) + + mask1 = distortion1._mask(shape, torch.device(device)) + mask2 = distortion2._mask(shape, torch.device(device)) + + assert torch.equal(mask1, mask2) + + +def test_cartesian_undersampling_zero_acs_equispaced_half_keep_samples_every_second_line(device): + """Verify equispaced half sampling with no ACS selects every second line.""" + distortion = CartesianUndersampling( + keep_fraction=0.5, + center_fraction=0.0, + pattern="equispaced", + ) + + mask_1d = distortion._mask((1, 2, 64, 64), torch.device(device))[0, 0, :, 0] + sampled_indices = torch.where(mask_1d == 1.0)[0] + + assert torch.equal(sampled_indices, torch.arange(1, 64, 2, device=sampled_indices.device)) + + +def test_cartesian_undersampling_mask_caching(device): + """Verify that the mask is cached and reused for the same shape.""" + distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) + shape = (1, 2, 32, 32) + + mask1 = distortion._mask(shape, torch.device(device)) + mask2 = distortion._mask(shape, torch.device(device)) + + # Should be the same object (cached) + assert mask1.data_ptr() == mask2.data_ptr() + + +def test_cartesian_undersampling_is_self_adjoint(device): + """Verify adjointness property for CartesianUndersampling.""" + distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) + x = torch.randn((1, 2, 32, 32), device=device) + y = torch.randn((1, 2, 32, 32), device=device) + + # Test adjointness: = + lhs = torch.sum(distortion.A(x) * y) + rhs = torch.sum(x * distortion.A_adjoint(y)) + + assert torch.allclose(lhs, rhs, atol=1e-5) + + +def test_cartesian_undersampling_zero_center_fraction_is_self_adjoint(device): + """Verify adjointness still holds when no ACS block is forced.""" + distortion = CartesianUndersampling(keep_fraction=0.3, center_fraction=0.0, seed=42) + x = torch.randn((1, 2, 32, 32), device=device) + y = torch.randn((1, 2, 32, 32), device=device) + + lhs = torch.sum(distortion.A(x) * y) + rhs = torch.sum(x * distortion.A_adjoint(y)) + + assert torch.allclose(lhs, rhs, atol=1e-5) + + +def test_cartesian_undersampling_zeros_undersampled_lines(device): + """Verify that undersampled lines (mask=0) result in zero k-space.""" + distortion = CartesianUndersampling(keep_fraction=0.25, center_fraction=0.2, seed=42) + shape = (1, 2, 64, 64) + y = torch.ones(shape, device=device) # All ones + + y_distorted = distortion.A(y) + mask = distortion._mask(shape, torch.device(device)) + + # Expand mask to match y shape for comparison + expanded_mask = mask.expand(shape) + + # Undersampled lines should be zero + zero_mask = expanded_mask == 0 + assert torch.all(y_distorted[zero_mask] == 0.0) + + # Sampled lines should be unchanged + sampled_mask = expanded_mask == 1 + assert torch.all(y_distorted[sampled_mask] == y[sampled_mask]) + + +def test_cartesian_undersampling_works_with_3d_tensor(device): + """Verify that CartesianUndersampling works with 5D tensors (3D MRI).""" + distortion = CartesianUndersampling(keep_fraction=0.3, center_fraction=0.25, seed=42, axis=-3) + y = torch.randn((1, 2, 8, 32, 32), device=device) # 3D k-space + + y_distorted = distortion.A(y) + + assert y_distorted.shape == y.shape + + +def test_cartesian_undersampling_with_distorted_kspace_physics(device): + """Verify that CartesianUndersampling works with DistortedKspaceMultiCoilMRI.""" + distortion = CartesianUndersampling(keep_fraction=0.3, seed=42) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, + img_size=(1, 2, 32, 32), + device=device, + ) + x = torch.randn((1, 2, 32, 32), device=device) + + y = physics.A(x) + + assert y.shape == (1, 2, 32, 32) From 06baac4db73dea71c6b419438e58da9e2488e533 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Fri, 8 May 2026 13:29:05 +0200 Subject: [PATCH 03/17] Automated SHA256 certified model download for OASIS --- README.md | 14 ++-- examples/OASIS_inference_plot.py | 20 ++---- mri_recon/reconstruction/deep.py | 112 ++++++++++++++++++++++++++++--- mri_recon/utils/__init__.py | 8 ++- mri_recon/utils/io.py | 85 +++++++++++++++++++++-- tests/test_reconstructions.py | 102 ++++++++++++++++++++++++++++ tests/test_utils_io.py | 29 +++++++- uv.lock | 16 +++++ 8 files changed, 349 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 52e7456..4f8adef 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ uv sync uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" ``` -## OASIS Inference Example - -Run the OASIS plotting example with a local OASIS root folder. By default, it uses the packaged split and checkpoint manifest under `reconstruction_only/`. -Download the `reconstruction_only/` folder, including data splits and checkpoints, from [Google Drive](https://drive.google.com/drive/folders/1YPmjiQxy3odiUq8gwYqwOGhoXbGRcSXp?usp=drive_link). - -```bash -python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 +## OASIS Inference Example + +Run the OASIS plotting example with a local OASIS root folder. On first use, the packaged OASIS checkpoint manifest and the requested checkpoint are downloaded automatically into `reconstruction_only/checkpoints/`. +If you also need the packaged split CSVs, download the `reconstruction_only/` folder from [Google Drive](https://drive.google.com/drive/folders/1YPmjiQxy3odiUq8gwYqwOGhoXbGRcSXp?usp=drive_link). + +```bash +python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 ``` Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index 1523d10..ef1424a 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -7,7 +7,6 @@ from __future__ import annotations import argparse -import json import os import sys from collections import OrderedDict @@ -366,21 +365,10 @@ def resolve_oasis_checkpoint( if checkpoint is not None: return checkpoint.expanduser().resolve() - with manifest_path.open("r", encoding="utf-8") as handle: - manifest = json.load(handle) - - checkpoints = manifest.get("checkpoints", {}) - key = str(acceleration) - if key not in checkpoints: - available = ", ".join(sorted(checkpoints)) - raise ValueError( - f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." - ) - - filename = Path(checkpoints[key]["filename"]) - if filename.is_absolute(): - return filename - return (manifest_path.parent.parent / filename).resolve() + return OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=manifest_path, + ) def choose_algorithm( diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 6da8ac1..52823cb 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,10 +1,16 @@ +import json from pathlib import Path +from typing import Optional import deepinv as dinv import torch from ._fastmri_unet import Unet -from ..utils import download_file_with_sha256, matches_sha256 +from ..utils import ( + download_file_with_sha256, + download_google_drive_file_with_sha256, + matches_sha256, +) class RAMReconstructor(dinv.models.Reconstructor): @@ -194,15 +200,21 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): """Wrapper for a trained OASIS single-coil U-Net model. The model reuses the repository's fastMRI-derived :class:`Unet` module, but - loads an OASIS checkpoint supplied by the caller. The forward pass converts + can also download the packaged checkpoint manifest and checkpoint on demand + when no explicit checkpoint path is supplied. The forward pass converts k-space to a zero-filled magnitude image, applies per-slice instance normalization, runs the U-Net, then rescales the prediction back to the adjoint-image intensity range. Parameters ---------- - checkpoint_file : str - Path to the trained OASIS U-Net checkpoint. + checkpoint_file : str, optional + Path to the trained OASIS U-Net checkpoint. If omitted, the reconstructor + downloads the packaged checkpoint for ``acceleration``. + acceleration : int, optional + Packaged checkpoint acceleration factor used when ``checkpoint_file`` is omitted. + manifest_path : str, optional + Override path for the downloaded or cached packaged checkpoint manifest. device : torch.device, optional Device on which to run inference. """ @@ -214,10 +226,85 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): "num_pool_layers": 4, "drop_prob": 0.0, } + ASSET_ROOT = Path(__file__).resolve().parents[2] / "reconstruction_only" + CHECKPOINTS_DIR = ASSET_ROOT / "checkpoints" + MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" + MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" + MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" + CHECKPOINT_FILE_IDS = { + "4": "11s6YeM6_YJeD4wcrn24jyMjyj_vX2ANU", + "8": "1w8PDiYpr2xBPXahzRllhZjQT1yoMGXg-", + "10": "1djJ2i0uYP4PT070CS0xx9nNJ41JmSFhh", + } + CHECKPOINT_SHA256 = { + "4": "4fcefa9860cb7895e581a0de8f90bd7f188ae1c0b5e428a4a07519dd2561ac29", + "8": "2cd4c44e3c7a3870adbe5090b2bfaae044f5e3f0b4bcaf2b1fc29969e5e6b9ca", + "10": "90e3d9b17aa0f9fd43aaf090c152edcbdead1b9be41a076594e41098db7befa8", + } + + @classmethod + def ensure_manifest(cls, manifest_path: Optional[Path] = None) -> Path: + """Ensure the packaged OASIS checkpoint manifest exists locally and is verified.""" + + resolved_manifest_path = ( + manifest_path.expanduser().resolve() if manifest_path is not None else cls.MANIFEST_PATH + ) + if not matches_sha256(resolved_manifest_path, cls.MANIFEST_SHA256): + download_google_drive_file_with_sha256( + cls.MANIFEST_FILE_ID, + resolved_manifest_path, + cls.MANIFEST_SHA256, + label="OASIS checkpoint manifest", + ) + return resolved_manifest_path + + @classmethod + def resolve_default_checkpoint( + cls, + acceleration: int, + manifest_path: Optional[Path] = None, + ) -> Path: + """Resolve and download the packaged OASIS checkpoint for a given acceleration.""" + + resolved_manifest_path = cls.ensure_manifest(manifest_path) + with resolved_manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + key = str(acceleration) + checkpoints = manifest.get("checkpoints", {}) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + if key not in cls.CHECKPOINT_FILE_IDS or key not in cls.CHECKPOINT_SHA256: + raise ValueError( + f"No automated download metadata is configured for acceleration {acceleration}." + ) + + filename = Path(checkpoints[key]["filename"]) + checkpoint_path = ( + filename + if filename.is_absolute() + else (resolved_manifest_path.parent.parent / filename) + ).resolve() + + if not matches_sha256(checkpoint_path, cls.CHECKPOINT_SHA256[key]): + download_google_drive_file_with_sha256( + cls.CHECKPOINT_FILE_IDS[key], + checkpoint_path, + cls.CHECKPOINT_SHA256[key], + label=f"OASIS checkpoint x{acceleration}", + ) + + return checkpoint_path def __init__( self, - checkpoint_file: str, + checkpoint_file: str | None = None, + acceleration: int = 4, + manifest_path: str | None = None, device: torch.device = None, ) -> None: super().__init__() @@ -226,9 +313,18 @@ def __init__( device = torch.device("cpu") self.device = device - checkpoint_path = Path(checkpoint_file).expanduser() - if not checkpoint_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + if checkpoint_file is None: + resolved_manifest_path = ( + Path(manifest_path).expanduser() if manifest_path is not None else None + ) + checkpoint_path = self.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=resolved_manifest_path, + ) + else: + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") # Use the helper to obtain a normalized state_dict (handles plain or Lightning) state_dict = _load_unet_checkpoint_state(checkpoint_path, device) diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 61ae0e6..3d17127 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -1,5 +1,11 @@ from .io import download_file_with_sha256 as download_file_with_sha256 +from .io import download_google_drive_file_with_sha256 as download_google_drive_file_with_sha256 from .io import format_megabytes as format_megabytes from .io import matches_sha256 as matches_sha256 -__all__ = ["download_file_with_sha256", "format_megabytes", "matches_sha256"] +__all__ = [ + "download_file_with_sha256", + "download_google_drive_file_with_sha256", + "format_megabytes", + "matches_sha256", +] diff --git a/mri_recon/utils/io.py b/mri_recon/utils/io.py index 5aaa962..bde5d0e 100644 --- a/mri_recon/utils/io.py +++ b/mri_recon/utils/io.py @@ -1,6 +1,9 @@ import hashlib +import re import tempfile +from io import BytesIO from pathlib import Path +from urllib.parse import urlencode from urllib.request import urlopen from tqdm.auto import tqdm @@ -21,8 +24,24 @@ def format_megabytes(num_bytes: int) -> str: return f"{num_bytes / (1024 * 1024):.1f} MB" -def download_file_with_sha256( - url: str, +class _BytesResponse: + def __init__(self, payload: bytes): + self._buffer = BytesIO(payload) + self.headers = {"Content-Length": str(len(payload))} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self, size: int = -1) -> bytes: + return self._buffer.read(size) + + +def _download_response_with_sha256( + response_factory, + source: str, destination: Path, expected_sha256: str, *, @@ -30,7 +49,7 @@ def download_file_with_sha256( report_interval_mb: int = 25, ) -> None: destination.parent.mkdir(parents=True, exist_ok=True) - print(f"Downloading {label} from {url} to {destination}. This may take a moment.") + print(f"Downloading {label} from {source} to {destination}. This may take a moment.") chunk_size = 1024 * 1024 report_interval = report_interval_mb * 1024 * 1024 @@ -42,7 +61,7 @@ def download_file_with_sha256( ) as handle: tmp_path = Path(handle.name) - with urlopen(url, timeout=30) as response: + with response_factory() as response: total_size = response.headers.get("Content-Length") total_size = int(total_size) if total_size is not None else None with tqdm( @@ -67,3 +86,61 @@ def download_file_with_sha256( if tmp_path is not None: tmp_path.unlink(missing_ok=True) raise + + +def download_file_with_sha256( + url: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + _download_response_with_sha256( + lambda: urlopen(url, timeout=30), + url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + + +def download_google_drive_file_with_sha256( + file_id: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + base_url = "https://drive.usercontent.google.com/download" + request_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download'})}" + + with urlopen(request_url, timeout=30) as response: + payload = response.read() + + if b"Google Drive - Virus scan warning" not in payload: + _download_response_with_sha256( + lambda: _BytesResponse(payload), + request_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + return + + html = payload.decode("utf-8", errors="replace") + uuid_match = re.search(r'name="uuid" value="([^"]+)"', html) + if uuid_match is None: + raise ValueError(f"Could not parse Google Drive confirmation token for file {file_id}.") + + confirmed_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download', 'confirm': 't', 'uuid': uuid_match.group(1)})}" + download_file_with_sha256( + confirmed_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index c2689ed..9e55c41 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -150,3 +150,105 @@ def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, assert x_hat.shape == x.shape assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_path, monkeypatch): + asset_root = tmp_path / "reconstruction_only" + manifest_path = asset_root / "checkpoints" / "manifest.json" + checkpoint_path = asset_root / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" + manifest_bytes = ( + b'{"checkpoints": {"4": {"filename": "checkpoints/oasis_balanced_seed24_accel4.ckpt"}}}' + ) + checkpoint_bytes = b"fake-oasis-checkpoint" + downloads = [] + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "MANIFEST_SHA256", + __import__("hashlib").sha256(manifest_bytes).hexdigest(), + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_SHA256", + {"4": __import__("hashlib").sha256(checkpoint_bytes).hexdigest()}, + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_FILE_IDS", + {"4": "checkpoint-file-id"}, + ) + + def fake_download(file_id, destination, expected_sha256, **kwargs): + downloads.append((file_id, destination)) + destination.parent.mkdir(parents=True, exist_ok=True) + if file_id == "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I": + destination.write_bytes(manifest_bytes) + elif file_id == "checkpoint-file-id": + destination.write_bytes(checkpoint_bytes) + else: + raise AssertionError(f"Unexpected file_id: {file_id}") + + monkeypatch.setattr( + "mri_recon.reconstruction.deep.download_google_drive_file_with_sha256", + fake_download, + ) + + resolved = OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( + 4, manifest_path=manifest_path + ) + + assert resolved == checkpoint_path.resolve() + assert manifest_path.read_bytes() == manifest_bytes + assert checkpoint_path.read_bytes() == checkpoint_bytes + assert downloads == [ + ("1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I", manifest_path.resolve()), + ("checkpoint-file-id", checkpoint_path.resolve()), + ] + + +def test_oasis_singlecoil_unet_reconstructor_uses_packaged_checkpoint_defaults( + device, tmp_path, monkeypatch +): + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "UNET_KWARGS", + { + "in_chans": 1, + "out_chans": 1, + "chans": 8, + "num_pool_layers": 2, + "drop_prob": 0.0, + }, + ) + + weights_path = tmp_path / "oasis_unet_packaged_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + captured = {} + + def fake_resolve(cls, acceleration, manifest_path=None): + captured["acceleration"] = acceleration + captured["manifest_path"] = manifest_path + return weights_path + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "resolve_default_checkpoint", + classmethod(fake_resolve), + ) + + model = OASISSinglecoilUnetReconstructor( + acceleration=8, + manifest_path=str(tmp_path / "manifest.json"), + device=device, + ) + + assert captured == { + "acceleration": 8, + "manifest_path": tmp_path / "manifest.json", + } + assert isinstance(model.model, Unet) diff --git a/tests/test_utils_io.py b/tests/test_utils_io.py index b984872..3d2c434 100644 --- a/tests/test_utils_io.py +++ b/tests/test_utils_io.py @@ -1,7 +1,7 @@ import hashlib from io import BytesIO -from mri_recon.utils.io import download_file_with_sha256 +from mri_recon.utils.io import download_file_with_sha256, download_google_drive_file_with_sha256 class FakeResponse: @@ -39,3 +39,30 @@ def test_download_file_with_sha256_moves_temp_file_after_close(tmp_path, monkeyp assert destination.read_bytes() == payload assert list(tmp_path.glob("*.tmp")) == [] + + +def test_download_google_drive_file_with_sha256_confirms_large_download(tmp_path, monkeypatch): + warning_html = b"""
Google Drive - Virus scan warning""" + payload = b"oasis-checkpoint" + expected_sha256 = hashlib.sha256(payload).hexdigest() + destination = tmp_path / "oasis.ckpt" + requested_urls = [] + + def fake_urlopen(url, timeout=30): + requested_urls.append(url) + if "confirm=t" in url: + return FakeResponse(payload) + return FakeResponse(warning_html) + + monkeypatch.setattr("mri_recon.utils.io.urlopen", fake_urlopen) + + download_google_drive_file_with_sha256( + "file-123", + destination, + expected_sha256, + label="OASIS checkpoint", + report_interval_mb=1, + ) + + assert destination.read_bytes() == payload + assert any("confirm=t" in url and "uuid=uuid-456" in url for url in requested_urls) diff --git a/uv.lock b/uv.lock index 7034d30..d25083a 100644 --- a/uv.lock +++ b/uv.lock @@ -947,6 +947,7 @@ dependencies = [ { name = "fastmri" }, { name = "h5py" }, { name = "matplotlib" }, + { name = "nibabel" }, { name = "numpy" }, { name = "ptwt" }, { name = "pydicom" }, @@ -976,6 +977,7 @@ requires-dist = [ { name = "fastmri", specifier = ">=0.3.0" }, { name = "h5py", specifier = ">=3.16.0" }, { name = "matplotlib", specifier = ">=3.9.0" }, + { name = "nibabel", specifier = ">=5.3.2" }, { name = "numpy", specifier = ">=2.4.3" }, { name = "ptwt", specifier = ">=1.0.1" }, { name = "pydicom", specifier = ">=3.0.1" }, @@ -1117,6 +1119,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] +[[package]] +name = "nibabel" +version = "5.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/01/3d2cc510c616bc8e27be17a063070d9126f69407961594a9ae734ea51121/nibabel-5.4.2.tar.gz", hash = "sha256:d5f4b9076a13178ae7f7acf18c8dbd503ee1c4d5c0c23b85df7be87efcbb49da", size = 4663132, upload-time = "2026-03-11T13:31:52.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/d7/601b6396b33536811668935faa790112266c70661be94555999be431f86f/nibabel-5.4.2-py3-none-any.whl", hash = "sha256:553482c5f1e1034fc312edf6fb7f32236c0056439845d1c29293b7e8c98d4854", size = 3300985, upload-time = "2026-03-11T13:31:50.028Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0" From 06de9faa8ced32988e4ce3321916710b1c44c389 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Fri, 8 May 2026 14:24:25 +0200 Subject: [PATCH 04/17] added common download storage path for all models --- README.md | 2 +- examples/OASIS_inference_plot.py | 4 +++- mri_recon/reconstruction/deep.py | 7 ++++--- tests/test_reconstructions.py | 6 +++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 123ab10..95d91f4 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available ## OASIS Inference Example -Run the OASIS plotting example with a local OASIS root folder. On first use, the packaged OASIS checkpoint manifest and the requested checkpoint are downloaded automatically into `reconstruction_only/checkpoints/`. +Run the OASIS plotting example with a local OASIS root folder. On first use, the packaged OASIS checkpoint manifest and the requested checkpoint are downloaded automatically into `downloads/oasis_singlecoil_unet/checkpoints/`. If you also need the packaged split CSVs, download the `reconstruction_only/` folder from [Google Drive](https://drive.google.com/drive/folders/1YPmjiQxy3odiUq8gwYqwOGhoXbGRcSXp?usp=drive_link). ```bash diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index ef1424a..ac806ba 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -61,7 +61,9 @@ REPO_ROOT = Path(__file__).resolve().parents[1] REPORT_DIR = Path("reports") / "oasis_inference_plot" DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" +DEFAULT_MANIFEST_PATH = ( + REPO_ROOT / "downloads" / "oasis_singlecoil_unet" / "checkpoints" / "manifest.json" +) REPORT_DIR.mkdir(parents=True, exist_ok=True) diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 52823cb..cb286c3 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -114,6 +114,7 @@ class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): ) MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "fastmri_singlecoil_unet" UNET_KWARGS = { "in_chans": 1, "out_chans": 1, @@ -134,7 +135,7 @@ def __init__(self, device: torch.device = None, state_dict_file: str = None) -> state_dict_path = ( Path(state_dict_file).expanduser() if state_dict_file is not None - else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME + else self.MODEL_DIR / self.MODEL_FILENAME ) if state_dict_file is None: @@ -226,8 +227,8 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): "num_pool_layers": 4, "drop_prob": 0.0, } - ASSET_ROOT = Path(__file__).resolve().parents[2] / "reconstruction_only" - CHECKPOINTS_DIR = ASSET_ROOT / "checkpoints" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "oasis_singlecoil_unet" + CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index 9e55c41..36117ff 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -153,9 +153,9 @@ def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_path, monkeypatch): - asset_root = tmp_path / "reconstruction_only" - manifest_path = asset_root / "checkpoints" / "manifest.json" - checkpoint_path = asset_root / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" + model_dir = tmp_path / "downloads" / "oasis_singlecoil_unet" + manifest_path = model_dir / "checkpoints" / "manifest.json" + checkpoint_path = model_dir / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" manifest_bytes = ( b'{"checkpoints": {"4": {"filename": "checkpoints/oasis_balanced_seed24_accel4.ckpt"}}}' ) From 4ea5aca38345d660739577d3a6dcf116fa9204e7 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 08:48:17 +0200 Subject: [PATCH 05/17] added detailed explanations related to resolution reduction --- README.md | 10 +- mri_recon/distortions/resolution.py | 165 +++++++++++++++++++------ mri_recon/distortions/undersampling.py | 62 ++++++---- 3 files changed, 170 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 95d91f4..94da93d 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,11 @@ MRI reconstruction playground for the MRI Metrics project. | --- | --- | --- | --- | | `BaseDistortion` | `None` | Identity | Leaves the k-space unchanged and serves as the no-distortion baseline. | | `SelfAdjointMultiplicativeMaskDistortion` | `None` | Abstract base | Super class for self-adjoint distortions that apply a real-valued elementwise multiplicative mask; subclasses implement `_mask`. | -| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Applies a circular low-pass mask in k-space to remove high frequencies isotropically. | -| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Applies an axis-aligned rectangular low-pass mask with separate cutoffs along `kx` and `ky`. | -| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Applies a circular low-pass mask with a raised-cosine transition band to soften the cutoff. | -| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Applies a circular low-pass mask with a Kaiser transition band for adjustable cutoff smoothness. | -| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates Cartesian acquisition undersampling with optional contiguous ACS center retention plus uniform-random, variable-density-random, or equispaced peripheral sampling. | +| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Simulates isotropic in-plane MRI resolution reduction at fixed field of view by retaining only a centered circular k-space core; smaller `radius_fraction` removes more high frequencies and produces stronger blur. | +| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Simulates reduced in-plane Cartesian MRI resolution at fixed field of view by keeping only a centered rectangular k-space region, with separate control of retained readout (`kx_radius_fraction`) and phase-encode (`ky_radius_fraction`) support. | +| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Hann-tapered circular k-space cutoff; `radius_fraction` sets the retained support and `transition_fraction` softens the edge to reduce ringing relative to a hard cutoff. | +| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Kaiser-tapered circular k-space cutoff; `radius_fraction` sets the retained support, `transition_fraction` sets the taper width, and `beta` controls taper steepness. | +| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates sub-Nyquist Cartesian MRI by skipping lines along one encoding direction while optionally preserving a fully sampled ACS center; unlike resolution reduction, it preserves the original k-space extent and mainly introduces aliasing or incoherent undersampling artifacts rather than simple blur. | | `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | | `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | | `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | diff --git a/mri_recon/distortions/resolution.py b/mri_recon/distortions/resolution.py index 1011abc..399a037 100644 --- a/mri_recon/distortions/resolution.py +++ b/mri_recon/distortions/resolution.py @@ -73,14 +73,33 @@ def _smooth_radial_low_pass_mask( class IsotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Low-pass truncation with a circular mask. - - This applies - ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]`` - where ``K`` is ``radius_fraction`` on the normalized radial grid. - - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are set to zero. + """Isotropic in-plane resolution reduction with a circular hard cutoff. + + This distortion keeps only a centered circular region of k-space: + ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]``, where ``K`` is the + retained radial support on the normalized frequency grid. + + In MRI terms, this models reduced in-plane spatial resolution at fixed + field of view by removing high-frequency content equally in all in-plane + directions. The reconstructed image keeps the same matrix size, but fine + detail is lost because the maximum sampled k-space extent is smaller. The + resulting effect is isotropic blur from limited k-space support. + + The parameter ``radius_fraction`` controls how much of the centered + low-frequency region is retained. Smaller values preserve only the k-space + core and therefore produce stronger blur and a broader point-spread + function. A value of ``1.0`` keeps the full sampled support and recovers + the identity operator. + + Compared with :class:`AnisotropicResolutionReduction`, this class applies + the same reduction in all directions rather than separately along readout + and phase encode. Compared with :class:`CartesianUndersampling`, it models + resolution loss by shrinking the sampled support, not by skipping lines + within the original support. + + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Smaller values correspond to + stronger isotropic resolution reduction. """ def __init__(self, radius_fraction: float = 0.6) -> None: @@ -96,17 +115,40 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class AnisotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Axis-aligned low-pass truncation with independent cutoffs. - - This applies a rectangular mask - ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]`` - where ``Kx`` and ``Ky`` are the normalized cutoffs along the readout and - phase-encode frequency axes respectively. - - :param float kx_radius_fraction: Normalized cutoff along the horizontal - frequency axis in ``(0, 1]``. - :param float ky_radius_fraction: Normalized cutoff along the vertical - frequency axis in ``(0, 1]``. + """Axis-aligned Cartesian resolution reduction with independent cutoffs. + + This distortion keeps only a centered rectangular region of Cartesian + k-space and zeros the remaining outer frequencies: + ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]``. + + In MRI terms, this models reduced in-plane acquisition resolution at fixed + field of view. The reconstructed image keeps the same matrix size, but its + effective spatial resolution decreases because less high-frequency k-space + support is retained. The resulting artifact is blur from limited k-space + extent, not aliasing from undersampling. + + The two parameters control the retained centered k-space support along the + Cartesian encoding axes. ``kx_radius_fraction`` corresponds to the retained + readout-direction frequency extent, while ``ky_radius_fraction`` + corresponds to the retained phase-encode-direction frequency extent. + Reducing either parameter broadens the point-spread function along the + corresponding image direction. A typical MRI-like setting keeps + ``kx_radius_fraction`` close to ``1.0`` and reduces + ``ky_radius_fraction``, reflecting that protocols often sacrifice more + phase-encode resolution than readout resolution. + + This is a hard rectangular cutoff. If a softer edge is desired, use one of + the taper-based resolution distortions instead. + + :param float kx_radius_fraction: Fraction of the original centered + k-space extent retained along the horizontal frequency axis in + ``(0, 1]``. This corresponds to the retained readout-direction + resolution support. ``1.0`` keeps the full sampled readout extent. + :param float ky_radius_fraction: Fraction of the original centered + k-space extent retained along the vertical frequency axis in + ``(0, 1]``. This corresponds to the retained phase-encode-direction + resolution support. Smaller values produce stronger blur along the + corresponding image direction. """ def __init__( @@ -126,8 +168,8 @@ def __init__( def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: normalized_kx, normalized_ky = _normalized_axis_frequencies(shape) - # A rectangular passband models direction-dependent resolution loss, - # such as stronger truncation along phase encode than along readout. + # Centered rectangular support models reduced acquired Cartesian + # resolution, often stronger along phase encode than along readout. mask = (normalized_kx <= self.kx_radius_fraction) & ( normalized_ky <= self.ky_radius_fraction ) @@ -135,18 +177,41 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class HannTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Hann transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a raised-cosine profile near the cutoff, and is exactly ``0`` - beyond ``radius_fraction``. + """Isotropic resolution reduction with a Hann-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a raised-cosine Hann + profile. Inside the passband the mask equals ``1``; inside the transition + band it decreases smoothly from ``1`` to ``0``; beyond the cutoff it is + exactly ``0``. + + In MRI terms, this is still a resolution-reduction operator: it suppresses + high spatial frequencies and therefore lowers effective in-plane spatial + resolution at fixed field of view. Relative to + :class:`IsotropicResolutionReduction`, the main difference is not the type + of resolution loss but the edge behavior of the k-space support. The smooth + transition reduces ringing that a hard truncation can introduce, at the + cost of making the support edge less sharp. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` sets how much of that outer region is devoted to + the smooth taper. Setting ``transition_fraction=0`` recovers the hard + circular cutoff. Larger transition fractions make the cutoff gentler and + behave more like k-space apodization. + + Compared with :class:`CartesianUndersampling`, this class does not skip + phase-encode lines within the original support. It reduces resolution by + attenuating and removing high frequencies, leading primarily to blur rather + than undersampling aliasing. See https://en.wikipedia.org/wiki/Hann_function for details on the Hann window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. + the smooth Hann transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. """ def __init__( @@ -173,20 +238,42 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class KaiserTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Kaiser transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a Kaiser-profile transition near the cutoff, and is exactly - ``0`` beyond ``radius_fraction``. + """Isotropic resolution reduction with a Kaiser-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a Kaiser-window profile. + Inside the passband the mask equals ``1``; inside the transition band it + decreases smoothly from ``1`` to ``0``; beyond the cutoff it is exactly + ``0``. + + In MRI terms, this lowers effective in-plane spatial resolution at fixed + field of view by reducing the retained high-frequency k-space extent. As + with :class:`HannTaperResolutionReduction`, the purpose of the taper is to + soften the hard support edge and reduce ringing, while still producing blur + from lost high-frequency information. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` controls the width of the taper band. ``beta`` + controls the shape of the Kaiser taper within that band: larger values + produce a steeper transition, while smaller positive values produce a more + gradual roll-off. Setting ``transition_fraction=0`` recovers the hard + circular cutoff regardless of ``beta``. + + Compared with :class:`CartesianUndersampling`, this class reduces + resolution by shrinking and tapering the effective k-space support, rather + than by skipping lines from the original grid. The dominant image effect is + blur and apodization, not aliasing from sub-Nyquist sampling. See https://en.wikipedia.org/wiki/Kaiser_window for details on the Kaiser window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. - :param float beta: Positive Kaiser shape parameter. Larger values create a - steeper transition inside the taper band. + the smooth Kaiser transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. + :param float beta: Positive Kaiser shape parameter controlling how steeply + the taper falls inside the transition band. """ def __init__( diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index afa6c39..e92316e 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -17,32 +17,48 @@ class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): - """Cartesian k-space undersampling along phase-encode direction. - - This distortion simulates true MRI acquisition undersampling by applying a - binary sampling mask along the phase-encode direction (default). The mask - keeps a contiguous center region (ACS - Auto-Calibration Signal) fully - sampled and randomly undersamples the periphery. - - The mask is deterministic given a shape and seed, ensuring reproducibility. - - :param float keep_fraction: Fraction of phase-encode lines to keep in - ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. - :param float center_fraction: Fraction of phase-encode lines reserved for - the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to - ``0.5 * keep_fraction``, leaving part of the acquisition budget for - randomized peripheral line sampling. Set to ``0`` to sample without a - guaranteed ACS block. + """Cartesian k-space undersampling along one encoding direction. + + This distortion simulates sub-Nyquist MRI acquisition by keeping only a + subset of Cartesian k-space lines along a chosen axis, phase encode by + default. A contiguous low-frequency center region (ACS) may be retained + fully, while the peripheral lines are sampled using a configurable random + or equispaced pattern. + + This is fundamentally different from resolution reduction. In + resolution-reduction distortions, the maximum retained k-space extent is + reduced, which primarily causes blur because high spatial frequencies are + absent. In Cartesian undersampling, the original k-space extent is still + targeted, but many lines inside that extent are skipped. The dominant + consequence is aliasing or incoherent undersampling artifact, not a simple + broader point-spread function. + + ``keep_fraction`` controls the total sampling budget along the chosen axis. + ``center_fraction`` reserves part of that budget for a fully sampled ACS + block near k-space center, which is important for many parallel imaging and + learned reconstruction pipelines. ``pattern`` controls how the remaining + peripheral lines are selected. ``axis`` determines which Cartesian encoding + direction is undersampled. ``seed`` makes the random patterns reproducible. + + The mask is deterministic for a given shape, device, and seed. + + :param float keep_fraction: Fraction of lines kept along the undersampled + axis in ``(0, 1]``. For example, ``0.25`` keeps 25% of the lines and + corresponds to approximately 4x acceleration when applied to a single + encoding direction. + :param float center_fraction: Fraction of lines along the undersampled axis + reserved for a contiguous, fully sampled ACS region in ``[0, 1]``. + Defaults to ``0.5 * keep_fraction`` so that part of the sampling budget + remains available for peripheral sampling. ``0`` disables the ACS block. :param str pattern: Peripheral sampling pattern. Supported values are ``"uniform_random"``, ``"variable_density_random"``, and - ``"equispaced"``. Defaults to ``"variable_density_random"``. - :param int axis: Axis along which to apply undersampling. Default is -2 - (phase-encode for 4D tensors), can also be -1 for readout/column - masking or -3 for the slice/depth axis in 5D tensors. + ``"equispaced"``. Variable-density sampling favors low-frequency lines + near the ACS boundary; equispaced sampling is deterministic. + :param int axis: Axis along which to undersample. The default ``-2`` + corresponds to phase encode for the repository's standard 2D k-space + convention. Other values allow undersampling along readout or depth. :param int | None seed: Random seed for reproducible mask generation. - If None, uses unseeded randomness (not recommended for reproducibility). - Has no effect when ``pattern="equispaced"`` because that pattern is - fully deterministic and uses no randomness. + Ignored by the deterministic ``"equispaced"`` pattern. """ def __init__( From 997cccce7a1ab4db3722477ee92148fa1edfe056 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 09:01:43 +0200 Subject: [PATCH 06/17] added detailed explanations related to bias field --- README.md | 2 +- mri_recon/distortions/biasfield.py | 36 +++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 94da93d..225d1a2 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ MRI reconstruction playground for the MRI Metrics project. | `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates sub-Nyquist Cartesian MRI by skipping lines along one encoding direction while optionally preserving a fully sampled ACS center; unlike resolution reduction, it preserves the original k-space extent and mainly introduces aliasing or incoherent undersampling artifacts rather than simple blur. | | `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | | `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | -| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | +| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`; when centered at DC with `width_x_fraction < width_y_fraction`, it can also approximate a smooth readout-decay-like blur by attenuating high frequencies more strongly along one axis. | | `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | | `TranslationMotionDistortion` | `Translation motion` | Motion | Applies a rigid in-plane translation as a unit-modulus phase ramp in k-space. | | `RotationalMotionDistortion` | `Rotational motion` | Motion | Applies a rigid in-plane rotation about the image center by resampling centered Cartesian k-space. | diff --git a/mri_recon/distortions/biasfield.py b/mri_recon/distortions/biasfield.py index 26ca4f5..a9d7910 100644 --- a/mri_recon/distortions/biasfield.py +++ b/mri_recon/distortions/biasfield.py @@ -80,11 +80,37 @@ class OffCenterAnisotropicGaussianKspaceBiasField(BaseDistortion): The gain peaks at an offset location in k-space, decays with different widths along ``kx`` and ``ky``, and is normalized to unit maximum on the sampled grid. - :param float width_x_fraction: Gaussian width along the normalized ``kx`` direction. - :param float width_y_fraction: Gaussian width along the normalized ``ky`` direction. - :param float center_x_fraction: Center offset along normalized ``kx`` in ``[-1, 1]``. - :param float center_y_fraction: Center offset along normalized ``ky`` in ``[-1, 1]``. - :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie in ``(0, 1]``. + Note: This class can also approximate a readout-decay-like blur when used in a + centered anisotropic configuration. To mimic stronger attenuation along the + readout direction, keep the Gaussian centered at DC with + ``center_x_fraction=0.0`` and ``center_y_fraction=0.0``, choose a narrower + width along the readout axis than along the orthogonal axis, and reduce + ``edge_gain`` below ``1.0``. For example, a setting such as + ``width_x_fraction < width_y_fraction`` with a moderate ``edge_gain`` + produces a smooth directional loss of high-frequency content that can + resemble readout-decay blur. + + This remains a phenomenological approximation rather than an explicit + time-ordered readout-decay model. It applies a smooth multiplicative + k-space weighting, not a physically parameterized echo-train decay. + + :param float width_x_fraction: Gaussian width along the normalized ``kx`` + direction. Smaller values produce stronger attenuation away from the + center along ``kx``. When ``kx`` is the readout axis, choosing + ``width_x_fraction < width_y_fraction`` approximates readout-direction + blur. + :param float width_y_fraction: Gaussian width along the normalized ``ky`` + direction. Larger values preserve more support along ``ky`` relative + to ``kx``. + :param float center_x_fraction: Center offset along normalized ``kx`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float center_y_fraction: Center offset along normalized ``ky`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie + in ``(0, 1]``. Smaller values strengthen the peripheral attenuation + and therefore the resulting directional blur. """ def __init__( From 5e5b69ce1c7e34527ae4b67e2be3a1948c98a93f Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 13:25:36 +0200 Subject: [PATCH 07/17] partial fourier --- README.md | 1 + examples/OASIS_inference_plot.py | 9 ++ examples/fastmri_inference_plot.py | 8 ++ mri_recon/distortions/__init__.py | 2 +- mri_recon/distortions/undersampling.py | 124 +++++++++++++++++++++++++ tests/test_distortions.py | 56 +++++++++++ 6 files changed, 199 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 225d1a2..c7aa9a6 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ MRI reconstruction playground for the MRI Metrics project. | `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Hann-tapered circular k-space cutoff; `radius_fraction` sets the retained support and `transition_fraction` softens the edge to reduce ringing relative to a hard cutoff. | | `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Kaiser-tapered circular k-space cutoff; `radius_fraction` sets the retained support, `transition_fraction` sets the taper width, and `beta` controls taper steepness. | | `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates sub-Nyquist Cartesian MRI by skipping lines along one encoding direction while optionally preserving a fully sampled ACS center; unlike resolution reduction, it preserves the original k-space extent and mainly introduces aliasing or incoherent undersampling artifacts rather than simple blur. | +| `PartialFourierDistortion` | `Partial Fourier` | Acquisition asymmetry | Simulates partial Fourier MRI acquisition by retaining a contiguous asymmetric region of k-space along one encoding axis; unlike symmetric resolution reduction it preserves one side more fully than the other, and unlike sparse undersampling it keeps a contiguous support rather than skipping lines throughout the original extent. | | `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | | `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | | `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`; when centered at DC with `width_x_fraction < width_y_fraction`, it can also approximate a smooth readout-decay-like blur by attenuating high frequencies more strongly along one axis. | diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index ac806ba..512dc33 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -40,6 +40,7 @@ IsotropicResolutionReduction, KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, PhaseEncodeGhostingDistortion, RadialHighPassEmphasisDistortion, RotationalMotionDistortion, @@ -86,6 +87,7 @@ "Cartesian undersampling (uniform random, zero ACS)", "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", + "Partial Fourier", "Phase-encode ghosting", "Segmented translation motion", "Translation motion", @@ -458,6 +460,13 @@ def choose_distortion( axis=-1, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-1, + side="high", + ) case "Phase-encode ghosting": return PhaseEncodeGhostingDistortion( line_period=2, diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index e8f23ef..fe52c07 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -37,6 +37,7 @@ "Cartesian undersampling (uniform random, zero ACS)", "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", + "Partial Fourier", "Phase-encode ghosting", "Segmented translation motion", "Segmented rotational motion", @@ -181,6 +182,13 @@ def choose_distortion(name: str) -> BaseDistortion: pattern="equispaced", seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) case "Anisotropic LP": return AnisotropicResolutionReduction( kx_radius_fraction=1.0, diff --git a/mri_recon/distortions/__init__.py b/mri_recon/distortions/__init__.py index 80bbad4..af22fd1 100644 --- a/mri_recon/distortions/__init__.py +++ b/mri_recon/distortions/__init__.py @@ -19,4 +19,4 @@ KaiserTaperResolutionReduction, RadialHighPassEmphasisDistortion, ) -from .undersampling import CartesianUndersampling +from .undersampling import CartesianUndersampling, PartialFourierDistortion diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index e92316e..6c0be94 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -8,6 +8,7 @@ PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} +PARTIAL_FOURIER_SIDES = {"low", "high"} # Lorentzian offset that controls how steeply the variable-density weights # fall off with normalized distance from k-space center. @@ -289,3 +290,126 @@ def _expand_mask_to_shape( mask = mask_1d.reshape(expand_shape) return mask + + +class PartialFourierDistortion(SelfAdjointMultiplicativeMaskDistortion): + """Asymmetric contiguous Cartesian mask for partial Fourier acquisition. + + This distortion simulates partial Fourier MRI acquisition by keeping a + contiguous asymmetric region of k-space along one encoding axis while + preserving a centered low-frequency block. Unlike symmetric resolution + reduction, the retained support extends farther on one side of k-space than + the other. Unlike Cartesian undersampling, the retained region is + contiguous rather than sparse throughout the original support. + + The distortion models the acquired k-space mask only. It does not attempt + to reconstruct or infer the missing region with homodyne, POCS, or any + other partial-Fourier-specific reconstruction method. + + :param float partial_fraction: Fraction of lines retained along the chosen + axis in ``[0.5, 1]``. ``1.0`` recovers the identity operator. + :param float center_fraction: Fraction of lines reserved for a centered, + fully retained low-frequency block in ``[0, 1]``. This block must not + exceed ``partial_fraction``. + :param int axis: Axis along which to apply the asymmetric truncation. The + default ``-2`` matches the repository's standard phase-encode axis. + :param str side: Side that retains more support outside the centered block. + Supported values are ``"low"`` and ``"high"``. + """ + + def __init__( + self, + partial_fraction: float = 0.7, + center_fraction: float = 0.1, + axis: int = -2, + side: str = "high", + ) -> None: + super().__init__() + + if not 0.5 <= partial_fraction <= 1.0: + raise ValueError(f"partial_fraction must be in [0.5, 1], got {partial_fraction}") + if not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + if center_fraction > partial_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"partial_fraction ({partial_fraction})" + ) + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + if side not in PARTIAL_FOURIER_SIDES: + raise ValueError(f"side must be one of {sorted(PARTIAL_FOURIER_SIDES)}, got {side!r}") + + self.partial_fraction = partial_fraction + self.center_fraction = center_fraction + self.axis = axis + self.side = side + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a deterministic partial Fourier mask. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to ``shape``. + :rtype: torch.Tensor + """ + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + axis_size = shape[self.axis] + mask_1d = self._generate_1d_mask(axis_size) + mask = self._expand_mask_to_shape(mask_1d, shape).to(device) + + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D contiguous asymmetric partial Fourier mask.""" + num_keep = max(1, int(round(axis_size * self.partial_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + num_center = min(num_center, num_keep) + + mask = torch.zeros(axis_size, dtype=torch.float32) + + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + remaining = num_keep - num_center + + low_available = center_start + high_available = axis_size - center_end + + if self.side == "high": + extra_high = min(remaining, high_available) + extra_low = remaining - extra_high + else: + extra_low = min(remaining, low_available) + extra_high = remaining - extra_low + + start = center_start - extra_low + end = center_end + extra_high + + mask[start:end] = 1.0 + return mask + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape.""" + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + return mask_1d.reshape(expand_shape) diff --git a/tests/test_distortions.py b/tests/test_distortions.py index 846b5ba..7e884f6 100644 --- a/tests/test_distortions.py +++ b/tests/test_distortions.py @@ -18,6 +18,7 @@ IsotropicResolutionReduction, KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, PhaseEncodeGhostingDistortion, RadialHighPassEmphasisDistortion, RotationalMotionDistortion, @@ -34,6 +35,7 @@ "Hann taper LP", "Kaiser taper LP", "Cartesian undersampling", + "Partial Fourier", "Radial high-pass emphasis", "Gaussian bias field", "Off-center anisotropic Gaussian bias field", @@ -49,6 +51,7 @@ "Isotropic LP", "Anisotropic LP", "Cartesian undersampling", + "Partial Fourier", "Phase-encode ghosting", "Translation motion", "Segmented translation motion", @@ -91,6 +94,13 @@ def choose_distortion(name): center_fraction=0.2, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) case "Radial high-pass emphasis": return RadialHighPassEmphasisDistortion(alpha=0.4) case "Gaussian bias field": @@ -229,6 +239,52 @@ def test_anisotropic_resolution_reduction_identity_at_full_cutoffs(device): assert torch.equal(y_distorted, y) +def test_partial_fourier_distortion_identity_at_full_fraction(device): + distortion = PartialFourierDistortion( + partial_fraction=1.0, + center_fraction=0.1, + axis=-2, + side="high", + ) + y = torch.randn((1, 2, 32, 32), device=device) + + assert torch.equal(distortion.A(y), y) + + +def test_partial_fourier_distortion_retains_contiguous_asymmetric_region(device): + distortion = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + + retained_indices = distortion._generate_1d_mask(16).nonzero().flatten().tolist() + + assert retained_indices == list(range(5, 16)) + + +def test_partial_fourier_distortion_side_changes_retained_half(device): + high = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + low = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="low", + ) + + high_indices = high._generate_1d_mask(16).nonzero().flatten().tolist() + low_indices = low._generate_1d_mask(16).nonzero().flatten().tolist() + + assert high_indices == list(range(5, 16)) + assert low_indices == list(range(0, 11)) + + def test_hann_taper_resolution_reduction_has_smooth_transition(device): distortion = HannTaperResolutionReduction( radius_fraction=0.8, From a753ce3e53de5be65f75655c5f2fccc493ae7210 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 14:05:21 +0200 Subject: [PATCH 08/17] updaed readme --- README.md | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index c7aa9a6..489ef1e 100644 --- a/README.md +++ b/README.md @@ -21,25 +21,33 @@ MRI reconstruction playground for the MRI Metrics project. ## Implemented Distortions -| Class | Selector name | Family | Summary | -| --- | --- | --- | --- | -| `BaseDistortion` | `None` | Identity | Leaves the k-space unchanged and serves as the no-distortion baseline. | -| `SelfAdjointMultiplicativeMaskDistortion` | `None` | Abstract base | Super class for self-adjoint distortions that apply a real-valued elementwise multiplicative mask; subclasses implement `_mask`. | -| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Simulates isotropic in-plane MRI resolution reduction at fixed field of view by retaining only a centered circular k-space core; smaller `radius_fraction` removes more high frequencies and produces stronger blur. | -| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Simulates reduced in-plane Cartesian MRI resolution at fixed field of view by keeping only a centered rectangular k-space region, with separate control of retained readout (`kx_radius_fraction`) and phase-encode (`ky_radius_fraction`) support. | -| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Hann-tapered circular k-space cutoff; `radius_fraction` sets the retained support and `transition_fraction` softens the edge to reduce ringing relative to a hard cutoff. | -| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Simulates isotropic resolution reduction with a Kaiser-tapered circular k-space cutoff; `radius_fraction` sets the retained support, `transition_fraction` sets the taper width, and `beta` controls taper steepness. | -| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Simulates sub-Nyquist Cartesian MRI by skipping lines along one encoding direction while optionally preserving a fully sampled ACS center; unlike resolution reduction, it preserves the original k-space extent and mainly introduces aliasing or incoherent undersampling artifacts rather than simple blur. | -| `PartialFourierDistortion` | `Partial Fourier` | Acquisition asymmetry | Simulates partial Fourier MRI acquisition by retaining a contiguous asymmetric region of k-space along one encoding axis; unlike symmetric resolution reduction it preserves one side more fully than the other, and unlike sparse undersampling it keeps a contiguous support rather than skipping lines throughout the original extent. | -| `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | -| `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Applies a centered smooth multiplicative Gaussian gain field in k-space. | -| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Applies an anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`; when centered at DC with `width_x_fraction < width_y_fraction`, it can also approximate a smooth readout-decay-like blur by attenuating high frequencies more strongly along one axis. | -| `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | -| `TranslationMotionDistortion` | `Translation motion` | Motion | Applies a rigid in-plane translation as a unit-modulus phase ramp in k-space. | -| `RotationalMotionDistortion` | `Rotational motion` | Motion | Applies a rigid in-plane rotation about the image center by resampling centered Cartesian k-space. | -| `SegmentedRotationalMotionDistortion` | `Segmented rotational motion` | Motion | Splits Cartesian k-space into acquisition segments and stitches segment-specific centered k-space rotations into one inconsistent scan. | -| `SegmentedTranslationMotionDistortion` | `Segmented translation motion` | Motion | Splits Cartesian k-space into acquisition segments and applies a different translation phase ramp to each segment. | -| `PhaseEncodeGhostingDistortion` | `Phase-encode ghosting` | Ghosting | Applies periodic line-wise phase and magnitude inconsistency to create phase-encode ghost replicas. | + +## Implemented Distortions + +| Class | Selector name | Family | Targeted Image Property | Summary | +| --- | --- | --- | --- | --- | +| `BaseDistortion` | `None` | Identity | | Leaves the k-space unchanged and serves as the no-distortion baseline. | +| `SelfAdjointMultiplicativeMaskDistortion` | `None` | Abstract base | | Super class for self-adjoint distortions that apply a real-valued elementwise multiplicative mask; subclasses implement `_mask`. | +| `IsotropicResolutionReduction` | `Isotropic LP` | Resolution loss | Sharpness (Glancing), Edges (Scanning) | Applies a circular low-pass mask in k-space to remove high frequencies isotropically. | +| `AnisotropicResolutionReduction` | `Anisotropic LP` | Resolution loss | Sharpness (Glancing), Edges (Scanning) | Applies an axis-aligned rectangular low-pass mask with separate cutoffs along `kx` and `ky`. | +| `HannTaperResolutionReduction` | `Hann taper LP` | Resolution loss | Sharpness (Glancing), Edges (Scanning), RoI Homogeneity (Glancing) | Applies a circular low-pass mask with a raised-cosine transition band to soften the cutoff. | +| `KaiserTaperResolutionReduction` | `Kaiser taper LP` | Resolution loss | Sharpness (Glancing), Edges (Scanning), RoI Homogeneity (Glancing) | Applies a circular low-pass mask with a Kaiser transition band for adjustable cutoff smoothness. | +| `CartesianUndersampling` | `Cartesian undersampling` | Acquisition undersampling | Local Signal Preservation (Scanning) | Simulates Cartesian acquisition undersampling with optional contiguous ACS center retention plus uniform-random, variable-density-random, or equispaced peripheral sampling. | +| `RadialHighPassEmphasisDistortion` | `Radial high-pass emphasis` | Sharpening | Sharpness (Glancing), Edges (Scanning), Noise Level (Glancing) | Applies a radial gain mask that increasingly boosts high-frequency k-space content toward the sampled edge. | +| `PartialFourierDistortion` | `Partial Fourier` | Acquisition asymmetry | Pixel Resolution | Simulates partial Fourier MRI acquisition by retaining a contiguous asymmetric region of k-space along one encoding axis; unlike symmetric resolution reduction it preserves one side more fully than the other, and unlike sparse undersampling it keeps a contiguous support rather than skipping lines throughout the original extent. | +| `GaussianKspaceBiasField` | `Gaussian bias field` | Intensity non-uniformity | Intensity Uniformity (Glancing/Full Image) | Applies a centered smooth multiplicative Gaussian gain field in k-space. | +| `OffCenterAnisotropicGaussianKspaceBiasField` | `Off-center anisotropic Gaussian bias field` | Intensity non-uniformity | Intensity Uniformity (Glancing/Full Image) | Applies an off-center anisotropic Gaussian gain field in k-space with separate widths along `kx` and `ky`. | +| `GaussianNoiseDistortion` | `Gaussian noise` | Noise | Noise Level (Glancing) | Adds independent zero-mean Gaussian noise to the stored real and imaginary k-space channels. | +| `TranslationMotionDistortion` | `Translation motion` | Motion | Local Signal Preservation (Scanning) |Applies a rigid in-plane translation as a unit-modulus phase ramp in k-space. | +| `RotationalMotionDistortion` | `Rotational motion` | Motion | Local Signal Preservation (Scanning) | Applies a rigid in-plane rotation about the image center by resampling centered Cartesian k-space. | +| `SegmentedRotationalMotionDistortion` | `Segmented rotational motion` | Motion | Local Signal Preservation (Scanning) | Splits Cartesian k-space into acquisition segments and stitches segment-specific centered k-space rotations into one inconsistent scan. | +| `SegmentedTranslationMotionDistortion` | `Segmented translation motion` | Motion | Local Signal Preservation (Scanning) | Splits Cartesian k-space into acquisition segments and applies a different translation phase ramp to each segment. | +| `PhaseEncodeGhostingDistortion` | `Phase-encode ghosting` | Ghosting | Local Signal Preservation (Scanning) | Applies periodic line-wise phase and magnitude inconsistency to create phase-encode ghost replicas. | + + +## Distortions Possibly Planned +- Rician Noise (Noise Level (Glancing)) +- Modification of sensitivity map (RoI Homogeneity (Glancing)) ## uv Environment Notes From c56665b105bb3d39de3f592c6e9f2dce1b2c66e4 Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Wed, 6 May 2026 16:48:40 +0100 Subject: [PATCH 09/17] Add OASIS U-Net inference example --- README.md | 11 + examples/OASIS_inference_plot.py | 623 +++++++++++++++++++++++++++ mri_recon/reconstruction/__init__.py | 7 +- mri_recon/reconstruction/deep.py | 103 +++++ pyproject.toml | 1 + tests/test_reconstructions.py | 32 ++ 6 files changed, 776 insertions(+), 1 deletion(-) create mode 100644 examples/OASIS_inference_plot.py diff --git a/README.md b/README.md index 091d989..1590e9a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Reconstruction playground for the MRI Recon Metrics Reloaded workgroup. | `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. | | `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. | | `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. | +| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. | ## Implemented Distortions @@ -56,6 +57,16 @@ uv sync uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" ``` +## OASIS Inference Example + +Run the OASIS plotting example with a local OASIS root folder. By default, it uses the packaged split and checkpoint manifest under `reconstruction_only/`. + +```bash +python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 +``` + +Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. + ## Pre-commit Install the local tooling and register the git hook: diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py new file mode 100644 index 0000000..7d6643e --- /dev/null +++ b/examples/OASIS_inference_plot.py @@ -0,0 +1,623 @@ +"""Inference OASIS reconstructors for k-space distortion operators. + +Usage: + python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data +""" + +from __future__ import annotations + +import argparse +import contextlib +import json +import os +import sys +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Sequence, Union + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +import numpy as np +import torch +import deepinv as dinv +from torch.utils.data import DataLoader, Dataset + +try: + import nibabel as nib +except ImportError as exc: + raise ImportError( + "The OASIS example requires nibabel. Install the project dependencies " + "or add nibabel to your environment before running this script." + ) from exc + +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + TranslationMotionDistortion, +) +from mri_recon.reconstruction import ( + ConjugateGradientReconstructor, + DeepImagePriorReconstructor, + OASISSinglecoilUnetReconstructor, + RAMReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + ZeroFilledReconstructor, +) + +REPO_ROOT = Path(__file__).resolve().parents[1] +REPORT_DIR = Path("reports") / "oasis_inference_plot" +DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" +DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" + +REPORT_DIR.mkdir(parents=True, exist_ok=True) + +ALGORITHMS = [ + # "zero-filled", + # "conjugate-gradient", + # "ram", + # "dip", + "tv-pgd", + # "wavelet-fista", + # "tv-fista", + # "tv-pdhg", + "oasis-unet", +] +DISTORTIONS = [ + # "Phase-encode ghosting", + # "Segmented translation motion", + # "Translation motion", + # "Rotational motion", + # "Off-center anisotropic Gaussian bias field", + # "Gaussian bias field", + # "Anisotropic LP", + # "Hann taper LP", + # "Kaiser taper LP", + "Radial high-pass emphasis", + # "Gaussian noise", + # "Isotropic LP", +] +METRICS = [ + "PSNR", + "NMSE", + "SSIM", + "HaarPSI", + "SharpnessIndex", + "BlurStrength", +] + + +@contextlib.contextmanager +def temp_seed( + rng: np.random.RandomState, + seed: Optional[Union[int, tuple[int, ...]]] = None, +): + """Temporarily set a NumPy random seed.""" + + if seed is None: + yield + return + + state = rng.get_state() + rng.seed(seed) + try: + yield + finally: + rng.set_state(state) + + +class MaskFunc: + """Random Cartesian undersampling mask matching the packaged OASIS checkpoints.""" + + def __init__( + self, + center_fractions: Sequence[float], + accelerations: Sequence[int], + seed: Optional[int] = None, + ) -> None: + if len(center_fractions) != len(accelerations): + raise ValueError("center_fractions and accelerations must have the same length.") + + self.center_fractions = list(center_fractions) + self.accelerations = list(accelerations) + self.rng = np.random.RandomState(seed) + + def __call__( + self, + shape: Sequence[int], + seed: Optional[Union[int, tuple[int, ...]]] = None, + ) -> torch.Tensor: + """Create a broadcastable mask for k-space shaped ``(..., H, W)``.""" + + if len(shape) < 2: + raise ValueError("Mask shape must have at least two dimensions.") + + with temp_seed(self.rng, seed): + center_fraction = self.rng.choice(self.center_fractions) + acceleration = self.rng.choice(self.accelerations) + num_cols = shape[-1] + num_low_freqs = round(num_cols * center_fraction) + + center_mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs) // 2 + center_mask[pad : pad + num_low_freqs] = 1 + + accel_prob = (num_cols / acceleration - num_low_freqs) / ( + num_cols - num_low_freqs + ) + accel_mask = self.rng.uniform(size=num_cols) < accel_prob + + mask = np.maximum(center_mask, accel_mask.astype(np.float32)) + mask_shape = [1 for _ in shape] + mask_shape[-1] = num_cols + return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes listed in a split CSV.""" + + def __init__( + self, + split_csv: Path, + data_path: Path, + sample_rate: float = 1.0, + cache_size: int = 2, + ) -> None: + self.split_csv = Path(split_csv) + self.data_path = Path(data_path) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.cache_size = max(0, cache_size) + self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + target_np = self._read_raw_slice(subject_id, slice_num) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples: list[tuple[str, int]] = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + total_slices = int(row[-1]) + except ValueError: + continue + + subject_id = row[0] + if self.sample_rate >= 1.0: + start = 0 + stop = total_slices + else: + mid = round(total_slices / 2) + half_span = round(total_slices * self.sample_rate / 2) + start = max(0, mid - half_span) + stop = min(total_slices, mid + half_span) + + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: + volume = self._get_volume(subject_id) + return np.ascontiguousarray(volume[slice_num], dtype=np.float32) + + def _get_volume(self, subject_id: str) -> np.ndarray: + if self.cache_size > 0 and subject_id in self._volume_cache: + self._volume_cache.move_to_end(subject_id) + return self._volume_cache[subject_id] + + image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if not matches: + raise FileNotFoundError( + f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." + ) + + image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + + if self.cache_size > 0: + self._volume_cache[subject_id] = volume + if len(self._volume_cache) > self.cache_size: + self._volume_cache.popitem(last=False) + + return volume + + +def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: + """Convert k-space tensor to log-magnitude image for visualization.""" + + if kspace.ndim == 4: + kspace = kspace[0] + if kspace.ndim != 3 or kspace.shape[0] != 2: + raise ValueError( + f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" + ) + + kspace = kspace.detach().cpu() + kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) + magnitude = torch.log1p(torch.abs(kspace_complex)) + + lower = torch.quantile(magnitude, 0.05) + upper = torch.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = magnitude.clamp(lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return torch.sqrt(magnitude) + + +def save_kspace_plot( + y_clean: torch.Tensor, + y_distorted: torch.Tensor, + save_fn: Path, + distortion_name: str, +) -> None: + """Save clean and distorted k-space magnitude plots.""" + + images = [ + ("Original k-space", _kspace_to_log_magnitude(y_clean)), + ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), + ] + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + fig.suptitle(f"Distortion: {distortion_name}") + for ax, (title, image) in zip(axes, images, strict=True): + ax.imshow(image.numpy(), cmap="magma") + ax.set_title(title) + ax.axis("off") + fig.savefig(save_fn, dpi=200, bbox_inches="tight") + plt.close(fig) + + +def resolve_oasis_checkpoint( + checkpoint: Optional[Path], + acceleration: int, + manifest_path: Path, +) -> Path: + """Resolve an explicit or packaged OASIS checkpoint path.""" + + if checkpoint is not None: + return checkpoint.expanduser().resolve() + + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + checkpoints = manifest.get("checkpoints", {}) + key = str(acceleration) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + filename = Path(checkpoints[key]["filename"]) + if filename.is_absolute(): + return filename + return (manifest_path.parent.parent / filename).resolve() + + +def choose_algorithm( + name: str, + checkpoint_file: Path, + img_size: tuple = (640, 368), + device: torch.device = "cpu", + verbose: bool = False, +) -> dinv.models.Reconstructor: + """Construct a reconstructor by selector name.""" + + match name: + case "zero-filled": + return ZeroFilledReconstructor() + case "conjugate-gradient": + return ConjugateGradientReconstructor(max_iter=20) + case "ram": + return RAMReconstructor(default_sigma=0.05, device=device) + case "dip": + return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) + case "tv-pgd": + return TVPGDReconstructor(n_iter=100, verbose=verbose) + case "tv-fista": + return TVFISTAReconstructor(n_iter=200, verbose=verbose) + case "tv-pdhg": + return TVPDHGReconstructor(n_iter=100, verbose=verbose) + case "wavelet-fista": + return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) + case "oasis-unet" | "unet": + return OASISSinglecoilUnetReconstructor( + checkpoint_file=str(checkpoint_file), + device=device, + ) + case _: + raise ValueError(f"Unknown algorithm {name!r}") + + +def choose_distortion(name: str) -> BaseDistortion: + """Construct a k-space distortion by display name.""" + + match name: + case "Phase-encode ghosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "Anisotropic LP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.25, + ) + case "Hann taper LP": + return HannTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + ) + case "Kaiser taper LP": + return KaiserTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + beta=8.6, + ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "Isotropic LP": + return IsotropicResolutionReduction(radius_fraction=0.1) + case "Off-center anisotropic Gaussian bias field": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "Translation motion": + return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) + case "Rotational motion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "Segmented translation motion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 20.0, 50.0, -50.0), + shift_y_pixels=(0.0, 10.0, -20.0, 20.0), + ) + case "Gaussian bias field": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "Gaussian noise": + return GaussianNoiseDistortion(sigma=0.00001) + case _: + raise ValueError(f"Unknown distortion {name!r}") + + +def choose_metric(name: str) -> dinv.metric.Metric: + """Construct a DeepInverse metric by selector name.""" + + match name: + case "PSNR": + return dinv.metric.PSNR(max_pixel=None, complex_abs=True) + case "NMSE": + return dinv.metric.NMSE(complex_abs=True) + case "SSIM": + return dinv.metric.SSIM(max_pixel=None, complex_abs=True) + case "HaarPSI": + return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) + case "BlurStrength": + return dinv.metric.BlurStrength(complex_abs=True) + case "SharpnessIndex": + return dinv.metric.SharpnessIndex(complex_abs=True) + case _: + raise ValueError(f"Unknown metric {name!r}") + + +def build_parser() -> argparse.ArgumentParser: + """Build the command-line parser.""" + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source", + type=Path, + required=True, + help="OASIS root directory containing subject folders.", + ) + parser.add_argument( + "--split_csv", + type=Path, + default=DEFAULT_SPLIT_CSV, + help="CSV listing OASIS subjects and slice counts.", + ) + parser.add_argument( + "--manifest", + type=Path, + default=DEFAULT_MANIFEST_PATH, + help="Checkpoint manifest JSON.", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", + ) + parser.add_argument( + "--acceleration", + type=int, + default=4, + help="Packaged OASIS checkpoint acceleration factor.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.08, + help="Center fraction used by the random Cartesian sampling mask.", + ) + parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) + parser.add_argument( + "--algorithm", + type=str, + default="", + choices=ALGORITHMS, + help="Reconstruction algorithm applied to distorted OASIS k-space.", + ) + parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") + parser.add_argument( + "--sample_rate", + type=float, + default=1.0, + help="Fraction of slices per volume to include from the split CSV.", + ) + parser.add_argument("--volume_cache_size", type=int, default=2) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose output for reconstructors that support it.", + ) + return parser + + +def main() -> None: + """Run OASIS inference plots.""" + + args = build_parser().parse_args() + args.source = args.source.expanduser().resolve() + args.split_csv = args.split_csv.expanduser().resolve() + args.manifest = args.manifest.expanduser().resolve() + checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) + + device = dinv.utils.get_device() + dataset = OasisSliceDataset( + split_csv=args.split_csv, + data_path=args.source, + sample_rate=args.sample_rate, + cache_size=args.volume_cache_size, + ) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + mask_func = MaskFunc( + center_fractions=[args.center_fraction], + accelerations=[args.acceleration], + ) + metrics = [choose_metric(m) for m in METRICS] + + for i, batch in enumerate(iter(dataloader)): + if i >= args.num_samples: + break + + x = batch["x"].to(device) + subject_id = batch["subject_id"][0] + slice_num = int(batch["slice_num"][0]) + mask = mask_func(x.shape, seed=tuple(map(ord, subject_id))).to(device) + mask_2d = mask.reshape(-1).view(1, -1).expand(x.shape[-2], x.shape[-1]) + + for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: + distortion = choose_distortion(distortion_name) + + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + mask=mask_2d, + img_size=(1, 2, *x.shape[-2:]), + coil_maps=1, + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, + mask=mask_2d, + img_size=(1, 2, *x.shape[-2:]), + coil_maps=1, + device=device, + ) + + y = physics_clean(x) + y_distorted = physics(x) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}..." + ) + + algo = choose_algorithm( + algo_name, + checkpoint_file=checkpoint_file, + img_size=x.shape[-2:], + device=device, + verbose=args.verbose, + ).to(device) + + x_uncorrected = algo(y_distorted, physics_clean) + x_corrected = algo(y_distorted, physics) + + dinv.utils.plot( + { + "Ground truth OASIS slice": x, + "Distorted ksp, CG recon": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join( + f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" + for m in metrics + ), + "\n".join( + f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" + for m in metrics + ), + ], + show=False, + close=True, + suptitle=( + f"Algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}" + ), + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + + print("done!") + + +if __name__ == "__main__": + main() diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index 07dec7d..a11a909 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -1,4 +1,9 @@ -from .deep import RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor +from .deep import ( + RAMReconstructor, + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, +) from .classic import ( ZeroFilledReconstructor, ConjugateGradientReconstructor, diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 5d507c4..ba520f9 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -163,3 +163,106 @@ def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tenso out = self.model(x_in) * std + mu # (B, 1, H, W) return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + elif isinstance(checkpoint, dict): + state_dict = checkpoint + else: + raise ValueError(f"Unsupported checkpoint format in {checkpoint_path}.") + + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + if any(key.startswith("model.unet.") for key in state_dict): + return { + key[len("model.unet.") :]: value + for key, value in state_dict.items() + if key.startswith("model.unet.") + } + + if any(key.startswith("module.") for key in state_dict): + return { + key[len("module.") :]: value + for key, value in state_dict.items() + if key.startswith("module.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """ + Wrapper for a trained OASIS single-coil U-Net reconstruction model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + loads an OASIS checkpoint supplied by the caller. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + :param str checkpoint_file: Path to the trained OASIS U-Net checkpoint. + :param torch.device device: Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__( + self, + checkpoint_file: str, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + _load_unet_checkpoint_state(checkpoint_path, device), + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space.""" + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) diff --git a/pyproject.toml b/pyproject.toml index b12f61c..611d63c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "fastmri>=0.3.0", "h5py>=3.16.0", "matplotlib>=3.9.0", + "nibabel>=5.3.2", "numpy>=2.4.3", "ptwt>=1.0.1", "pydicom>=3.0.1", diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index b331a6b..c2689ed 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -12,6 +12,7 @@ RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, ) from mri_recon.reconstruction.classic import ( ZeroFilledReconstructor, @@ -118,3 +119,34 @@ def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): assert x_hat.shape == x.shape assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, tmp_path): + """ + Test the OASIS UNet reconstructor with a Lightning-style checkpoint. + """ + weights_path = tmp_path / "oasis_unet_test_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + x = dinv.utils.load_example( + "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + ) + x = torch.cat([x, torch.zeros_like(x)], dim=1) + + model = OASISSinglecoilUnetReconstructor( + checkpoint_file=str(weights_path), + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + ) + + y = physics(x) + x_hat = model(y, physics) + + assert x_hat.shape == x.shape + assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) From 913a162213baac3d5d8a5fe8f6264258632b2ecb Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Thu, 7 May 2026 19:15:25 +0100 Subject: [PATCH 10/17] Add OASIS inference example --- examples/OASIS_inference_plot.py | 1355 +++++++++++++----------- mri_recon/distortions/undersampling.py | 549 +++++----- mri_recon/reconstruction/__init__.py | 28 +- mri_recon/reconstruction/deep.py | 536 +++++----- 4 files changed, 1289 insertions(+), 1179 deletions(-) diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index 7d6643e..1523d10 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -1,623 +1,732 @@ -"""Inference OASIS reconstructors for k-space distortion operators. - -Usage: - python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data -""" - -from __future__ import annotations - -import argparse -import contextlib -import json -import os -import sys -from collections import OrderedDict -from pathlib import Path -from typing import Optional, Sequence, Union - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import matplotlib.pyplot as plt -import numpy as np -import torch -import deepinv as dinv -from torch.utils.data import DataLoader, Dataset - -try: - import nibabel as nib -except ImportError as exc: - raise ImportError( - "The OASIS example requires nibabel. Install the project dependencies " - "or add nibabel to your environment before running this script." - ) from exc - -from mri_recon.distortions import ( - AnisotropicResolutionReduction, - BaseDistortion, - DistortedKspaceMultiCoilMRI, - GaussianKspaceBiasField, - GaussianNoiseDistortion, - HannTaperResolutionReduction, - IsotropicResolutionReduction, - KaiserTaperResolutionReduction, - OffCenterAnisotropicGaussianKspaceBiasField, - PhaseEncodeGhostingDistortion, - RadialHighPassEmphasisDistortion, - RotationalMotionDistortion, - SegmentedTranslationMotionDistortion, - TranslationMotionDistortion, -) -from mri_recon.reconstruction import ( - ConjugateGradientReconstructor, - DeepImagePriorReconstructor, - OASISSinglecoilUnetReconstructor, - RAMReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - ZeroFilledReconstructor, -) - -REPO_ROOT = Path(__file__).resolve().parents[1] -REPORT_DIR = Path("reports") / "oasis_inference_plot" -DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" - -REPORT_DIR.mkdir(parents=True, exist_ok=True) - -ALGORITHMS = [ - # "zero-filled", - # "conjugate-gradient", - # "ram", - # "dip", - "tv-pgd", - # "wavelet-fista", - # "tv-fista", - # "tv-pdhg", - "oasis-unet", -] -DISTORTIONS = [ - # "Phase-encode ghosting", - # "Segmented translation motion", - # "Translation motion", - # "Rotational motion", - # "Off-center anisotropic Gaussian bias field", - # "Gaussian bias field", - # "Anisotropic LP", - # "Hann taper LP", - # "Kaiser taper LP", - "Radial high-pass emphasis", - # "Gaussian noise", - # "Isotropic LP", -] -METRICS = [ - "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", -] - - -@contextlib.contextmanager -def temp_seed( - rng: np.random.RandomState, - seed: Optional[Union[int, tuple[int, ...]]] = None, -): - """Temporarily set a NumPy random seed.""" - - if seed is None: - yield - return - - state = rng.get_state() - rng.seed(seed) - try: - yield - finally: - rng.set_state(state) - - -class MaskFunc: - """Random Cartesian undersampling mask matching the packaged OASIS checkpoints.""" - - def __init__( - self, - center_fractions: Sequence[float], - accelerations: Sequence[int], - seed: Optional[int] = None, - ) -> None: - if len(center_fractions) != len(accelerations): - raise ValueError("center_fractions and accelerations must have the same length.") - - self.center_fractions = list(center_fractions) - self.accelerations = list(accelerations) - self.rng = np.random.RandomState(seed) - - def __call__( - self, - shape: Sequence[int], - seed: Optional[Union[int, tuple[int, ...]]] = None, - ) -> torch.Tensor: - """Create a broadcastable mask for k-space shaped ``(..., H, W)``.""" - - if len(shape) < 2: - raise ValueError("Mask shape must have at least two dimensions.") - - with temp_seed(self.rng, seed): - center_fraction = self.rng.choice(self.center_fractions) - acceleration = self.rng.choice(self.accelerations) - num_cols = shape[-1] - num_low_freqs = round(num_cols * center_fraction) - - center_mask = np.zeros(num_cols, dtype=np.float32) - pad = (num_cols - num_low_freqs) // 2 - center_mask[pad : pad + num_low_freqs] = 1 - - accel_prob = (num_cols / acceleration - num_low_freqs) / ( - num_cols - num_low_freqs - ) - accel_mask = self.rng.uniform(size=num_cols) < accel_prob - - mask = np.maximum(center_mask, accel_mask.astype(np.float32)) - mask_shape = [1 for _ in shape] - mask_shape[-1] = num_cols - return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) - - -class OasisSliceDataset(Dataset): - """Load 2D OASIS slices from Analyze/NIfTI volumes listed in a split CSV.""" - - def __init__( - self, - split_csv: Path, - data_path: Path, - sample_rate: float = 1.0, - cache_size: int = 2, - ) -> None: - self.split_csv = Path(split_csv) - self.data_path = Path(data_path) - if not 0 < sample_rate <= 1.0: - raise ValueError("sample_rate must be in the range (0, 1].") - self.sample_rate = sample_rate - self.cache_size = max(0, cache_size) - self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() - self.raw_samples = self._create_sample_list() - - def __len__(self) -> int: - """Return the number of available slices.""" - - return len(self.raw_samples) - - def __getitem__(self, index: int) -> dict[str, object]: - """Return one complex-valued OASIS slice in repo tensor convention.""" - - subject_id, slice_num = self.raw_samples[index] - target_np = self._read_raw_slice(subject_id, slice_num) - real = torch.from_numpy(target_np) - x = torch.stack([real, torch.zeros_like(real)], dim=0) - return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} - - def _create_sample_list(self) -> list[tuple[str, int]]: - samples: list[tuple[str, int]] = [] - with self.split_csv.open("r", encoding="utf-8") as handle: - for line in handle: - row = [item.strip() for item in line.split(",")] - if not row or not row[0]: - continue - try: - total_slices = int(row[-1]) - except ValueError: - continue - - subject_id = row[0] - if self.sample_rate >= 1.0: - start = 0 - stop = total_slices - else: - mid = round(total_slices / 2) - half_span = round(total_slices * self.sample_rate / 2) - start = max(0, mid - half_span) - stop = min(total_slices, mid + half_span) - - for slice_num in range(start, stop): - samples.append((subject_id, slice_num)) - return samples - - def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: - volume = self._get_volume(subject_id) - return np.ascontiguousarray(volume[slice_num], dtype=np.float32) - - def _get_volume(self, subject_id: str) -> np.ndarray: - if self.cache_size > 0 and subject_id in self._volume_cache: - self._volume_cache.move_to_end(subject_id) - return self._volume_cache[subject_id] - - image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" - matches = sorted(image_glob.glob("*t88_gfc.img")) - if not matches: - raise FileNotFoundError( - f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." - ) - - image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) - volume = np.ascontiguousarray( - np.transpose(np.squeeze(image_data), (1, 0, 2)), - dtype=np.float32, - ) - - if self.cache_size > 0: - self._volume_cache[subject_id] = volume - if len(self._volume_cache) > self.cache_size: - self._volume_cache.popitem(last=False) - - return volume - - -def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to log-magnitude image for visualization.""" - - if kspace.ndim == 4: - kspace = kspace[0] - if kspace.ndim != 3 or kspace.shape[0] != 2: - raise ValueError( - f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" - ) - - kspace = kspace.detach().cpu() - kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) - magnitude = torch.log1p(torch.abs(kspace_complex)) - - lower = torch.quantile(magnitude, 0.05) - upper = torch.quantile(magnitude, 0.995) - if float(upper) > float(lower): - magnitude = magnitude.clamp(lower, upper) - magnitude = (magnitude - lower) / (upper - lower) - else: - mag_max = float(magnitude.max()) - if mag_max > 0.0: - magnitude = magnitude / mag_max - - return torch.sqrt(magnitude) - - -def save_kspace_plot( - y_clean: torch.Tensor, - y_distorted: torch.Tensor, - save_fn: Path, - distortion_name: str, -) -> None: - """Save clean and distorted k-space magnitude plots.""" - - images = [ - ("Original k-space", _kspace_to_log_magnitude(y_clean)), - ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), - ] - - fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) - fig.suptitle(f"Distortion: {distortion_name}") - for ax, (title, image) in zip(axes, images, strict=True): - ax.imshow(image.numpy(), cmap="magma") - ax.set_title(title) - ax.axis("off") - fig.savefig(save_fn, dpi=200, bbox_inches="tight") - plt.close(fig) - - -def resolve_oasis_checkpoint( - checkpoint: Optional[Path], - acceleration: int, - manifest_path: Path, -) -> Path: - """Resolve an explicit or packaged OASIS checkpoint path.""" - - if checkpoint is not None: - return checkpoint.expanduser().resolve() - - with manifest_path.open("r", encoding="utf-8") as handle: - manifest = json.load(handle) - - checkpoints = manifest.get("checkpoints", {}) - key = str(acceleration) - if key not in checkpoints: - available = ", ".join(sorted(checkpoints)) - raise ValueError( - f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." - ) - - filename = Path(checkpoints[key]["filename"]) - if filename.is_absolute(): - return filename - return (manifest_path.parent.parent / filename).resolve() - - -def choose_algorithm( - name: str, - checkpoint_file: Path, - img_size: tuple = (640, 368), - device: torch.device = "cpu", - verbose: bool = False, -) -> dinv.models.Reconstructor: - """Construct a reconstructor by selector name.""" - - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=verbose) - case "tv-fista": - return TVFISTAReconstructor(n_iter=200, verbose=verbose) - case "tv-pdhg": - return TVPDHGReconstructor(n_iter=100, verbose=verbose) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) - case "oasis-unet" | "unet": - return OASISSinglecoilUnetReconstructor( - checkpoint_file=str(checkpoint_file), - device=device, - ) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - -def choose_distortion(name: str) -> BaseDistortion: - """Construct a k-space distortion by display name.""" - - match name: - case "Phase-encode ghosting": - return PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - case "Anisotropic LP": - return AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.25, - ) - case "Hann taper LP": - return HannTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - ) - case "Kaiser taper LP": - return KaiserTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - beta=8.6, - ) - case "Radial high-pass emphasis": - return RadialHighPassEmphasisDistortion(alpha=0.4) - case "Isotropic LP": - return IsotropicResolutionReduction(radius_fraction=0.1) - case "Off-center anisotropic Gaussian bias field": - return OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.2, - width_y_fraction=0.35, - center_x_fraction=0.15, - center_y_fraction=-0.1, - edge_gain=0.3, - ) - case "Translation motion": - return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) - case "Rotational motion": - return RotationalMotionDistortion(angle_radians=torch.pi / 6) - case "Segmented translation motion": - return SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 20.0, 50.0, -50.0), - shift_y_pixels=(0.0, 10.0, -20.0, 20.0), - ) - case "Gaussian bias field": - return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - case "Gaussian noise": - return GaussianNoiseDistortion(sigma=0.00001) - case _: - raise ValueError(f"Unknown distortion {name!r}") - - -def choose_metric(name: str) -> dinv.metric.Metric: - """Construct a DeepInverse metric by selector name.""" - - match name: - case "PSNR": - return dinv.metric.PSNR(max_pixel=None, complex_abs=True) - case "NMSE": - return dinv.metric.NMSE(complex_abs=True) - case "SSIM": - return dinv.metric.SSIM(max_pixel=None, complex_abs=True) - case "HaarPSI": - return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) - case "BlurStrength": - return dinv.metric.BlurStrength(complex_abs=True) - case "SharpnessIndex": - return dinv.metric.SharpnessIndex(complex_abs=True) - case _: - raise ValueError(f"Unknown metric {name!r}") - - -def build_parser() -> argparse.ArgumentParser: - """Build the command-line parser.""" - - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--source", - type=Path, - required=True, - help="OASIS root directory containing subject folders.", - ) - parser.add_argument( - "--split_csv", - type=Path, - default=DEFAULT_SPLIT_CSV, - help="CSV listing OASIS subjects and slice counts.", - ) - parser.add_argument( - "--manifest", - type=Path, - default=DEFAULT_MANIFEST_PATH, - help="Checkpoint manifest JSON.", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", - ) - parser.add_argument( - "--acceleration", - type=int, - default=4, - help="Packaged OASIS checkpoint acceleration factor.", - ) - parser.add_argument( - "--center_fraction", - type=float, - default=0.08, - help="Center fraction used by the random Cartesian sampling mask.", - ) - parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) - parser.add_argument( - "--algorithm", - type=str, - default="", - choices=ALGORITHMS, - help="Reconstruction algorithm applied to distorted OASIS k-space.", - ) - parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") - parser.add_argument( - "--sample_rate", - type=float, - default=1.0, - help="Fraction of slices per volume to include from the split CSV.", - ) - parser.add_argument("--volume_cache_size", type=int, default=2) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose output for reconstructors that support it.", - ) - return parser - - -def main() -> None: - """Run OASIS inference plots.""" - - args = build_parser().parse_args() - args.source = args.source.expanduser().resolve() - args.split_csv = args.split_csv.expanduser().resolve() - args.manifest = args.manifest.expanduser().resolve() - checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) - - device = dinv.utils.get_device() - dataset = OasisSliceDataset( - split_csv=args.split_csv, - data_path=args.source, - sample_rate=args.sample_rate, - cache_size=args.volume_cache_size, - ) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False) - mask_func = MaskFunc( - center_fractions=[args.center_fraction], - accelerations=[args.acceleration], - ) - metrics = [choose_metric(m) for m in METRICS] - - for i, batch in enumerate(iter(dataloader)): - if i >= args.num_samples: - break - - x = batch["x"].to(device) - subject_id = batch["subject_id"][0] - slice_num = int(batch["slice_num"][0]) - mask = mask_func(x.shape, seed=tuple(map(ord, subject_id))).to(device) - mask_2d = mask.reshape(-1).view(1, -1).expand(x.shape[-2], x.shape[-1]) - - for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: - distortion = choose_distortion(distortion_name) - - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), - mask=mask_2d, - img_size=(1, 2, *x.shape[-2:]), - coil_maps=1, - device=device, - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, - mask=mask_2d, - img_size=(1, 2, *x.shape[-2:]), - coil_maps=1, - device=device, - ) - - y = physics_clean(x) - y_distorted = physics(x) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: - print( - f"Evaluating algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}..." - ) - - algo = choose_algorithm( - algo_name, - checkpoint_file=checkpoint_file, - img_size=x.shape[-2:], - device=device, - verbose=args.verbose, - ).to(device) - - x_uncorrected = algo(y_distorted, physics_clean) - x_corrected = algo(y_distorted, physics) - - dinv.utils.plot( - { - "Ground truth OASIS slice": x, - "Distorted ksp, CG recon": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join( - f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" - for m in metrics - ), - "\n".join( - f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" - for m in metrics - ), - ], - show=False, - close=True, - suptitle=( - f"Algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}" - ), - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, - ) - - print("done!") - - -if __name__ == "__main__": - main() +"""Inference OASIS reconstructors for k-space distortion operators. + +Usage: + python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from collections import OrderedDict +from pathlib import Path +from typing import Optional + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +import numpy as np +import torch +import deepinv as dinv +from torch.utils.data import DataLoader, Dataset + +try: + import nibabel as nib +except ImportError as exc: + raise ImportError( + "The OASIS example requires nibabel. Install the project dependencies " + "or add nibabel to your environment before running this script." + ) from exc + +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + CartesianUndersampling, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + TranslationMotionDistortion, +) +from mri_recon.reconstruction import ( + ConjugateGradientReconstructor, + DeepImagePriorReconstructor, + OASISSinglecoilUnetReconstructor, + RAMReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + ZeroFilledReconstructor, +) + +REPO_ROOT = Path(__file__).resolve().parents[1] +REPORT_DIR = Path("reports") / "oasis_inference_plot" +DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" +DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" + +REPORT_DIR.mkdir(parents=True, exist_ok=True) + +ALGORITHMS = [ + # "zero-filled", + # "conjugate-gradient", + # "ram", + # "dip", + "tv-pgd", + # "wavelet-fista", + # "tv-fista", + # "tv-pdhg", + "oasis-unet", + "unet", +] +DISTORTIONS = [ + "None", + "Cartesian undersampling (variable density)", + "Cartesian undersampling (uniform random)", + "Cartesian undersampling (uniform random, zero ACS)", + "Cartesian undersampling (equispaced)", + "Cartesian undersampling (equispaced, zero ACS)", + "Phase-encode ghosting", + "Segmented translation motion", + "Translation motion", + "Rotational motion", + "Off-center anisotropic Gaussian bias field", + "Gaussian bias field", + "Anisotropic LP", + "Hann taper LP", + "Kaiser taper LP", + "Radial high-pass emphasis", + "Gaussian noise", + "Isotropic LP", +] +METRICS = [ + "PSNR", + "NMSE", + "SSIM", + "HaarPSI", + "SharpnessIndex", + "BlurStrength", +] + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes. + + Parameters + ---------- + split_csv : Path + CSV file listing OASIS subjects and slice counts. + data_path : Path + Root directory containing OASIS subject folders. + sample_rate : float, optional + Fraction of slices to include from each volume. + cache_size : int, optional + Number of loaded volumes to keep in memory. + """ + + def __init__( + self, + split_csv: Path, + data_path: Path, + sample_rate: float = 1.0, + cache_size: int = 2, + ) -> None: + self.split_csv = Path(split_csv) + self.data_path = Path(data_path) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.cache_size = max(0, cache_size) + self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + target_np = self._read_raw_slice(subject_id, slice_num) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples: list[tuple[str, int]] = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + total_slices = int(row[-1]) + except ValueError: + continue + + subject_id = row[0] + if self.sample_rate >= 1.0: + start = 0 + stop = total_slices + else: + mid = round(total_slices / 2) + half_span = round(total_slices * self.sample_rate / 2) + start = max(0, mid - half_span) + stop = min(total_slices, mid + half_span) + + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: + volume = self._get_volume(subject_id) + return np.ascontiguousarray(volume[slice_num], dtype=np.float32) + + def _get_volume(self, subject_id: str) -> np.ndarray: + if self.cache_size > 0 and subject_id in self._volume_cache: + self._volume_cache.move_to_end(subject_id) + return self._volume_cache[subject_id] + + image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if not matches: + raise FileNotFoundError( + f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." + ) + + image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + + if self.cache_size > 0: + self._volume_cache[subject_id] = volume + if len(self._volume_cache) > self.cache_size: + self._volume_cache.popitem(last=False) + + return volume + + +def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: + """Convert k-space tensor to log-magnitude image for visualization.""" + + kspace = kspace.detach().cpu() + if kspace.ndim == 5: + kspace = kspace[0] + if kspace.ndim == 4: + if kspace.shape[0] == 1 and kspace.shape[1] == 2: + kspace = kspace[0] + elif kspace.shape[0] != 2: + raise ValueError( + "Expected k-space with shape (2, H, W), (1, 2, H, W), " + f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" + ) + if kspace.ndim != 3 and kspace.ndim != 4: + raise ValueError( + "Expected k-space with shape (2, H, W), (1, 2, H, W), " + f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" + ) + if kspace.shape[0] != 2: + raise ValueError(f"Expected real/imaginary channel first, got {tuple(kspace.shape)}") + + kspace_complex = torch.view_as_complex(torch.movedim(kspace, 0, -1).contiguous()) + magnitude = torch.abs(kspace_complex) + if magnitude.ndim == 3: + magnitude = torch.sqrt(torch.sum(magnitude.square(), dim=0)) + magnitude = torch.log1p(magnitude) + + lower = torch.quantile(magnitude, 0.05) + upper = torch.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = magnitude.clamp(lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return torch.sqrt(magnitude) + + +def save_kspace_plot( + y_clean: torch.Tensor, + y_distorted: torch.Tensor, + save_fn: Path, + distortion_name: str, +) -> None: + """Save clean and distorted k-space magnitude plots.""" + + images = [ + ("Original k-space", _kspace_to_log_magnitude(y_clean)), + ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), + ] + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + fig.suptitle(f"Distortion: {distortion_name}") + for ax, (title, image) in zip(axes, images, strict=True): + ax.imshow(image.numpy(), cmap="magma") + ax.set_title(title) + ax.axis("off") + fig.savefig(save_fn, dpi=200, bbox_inches="tight") + plt.close(fig) + + +def image_to_kspace(x: torch.Tensor) -> torch.Tensor: + """Convert channel-first complex images to centered k-space.""" + + x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) + y_complex = torch.fft.fftshift( + torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), + dim=(-2, -1), + ) + return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() + + +def kspace_to_image(y: torch.Tensor) -> torch.Tensor: + """Convert centered channel-first k-space to complex images.""" + + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + x_complex = torch.fft.ifft2( + torch.fft.ifftshift(y_complex, dim=(-2, -1)), + dim=(-2, -1), + norm="ortho", + ) + return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() + + +class OasisCenteredFFTPhysics: + """Physics adapter matching the OASIS U-Net FFT convention. + + Parameters + ---------- + distortion : BaseDistortion + K-space distortion applied after the centered FFT. + """ + + def __init__(self, distortion: BaseDistortion) -> None: + self.distortion = distortion + + def A(self, x: torch.Tensor) -> torch.Tensor: + """Apply centered FFT and k-space distortion. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Distorted centered k-space tensor. + """ + + return self.distortion.A(image_to_kspace(x)) + + def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + """Apply adjoint distortion and centered inverse FFT. + + Parameters + ---------- + y : torch.Tensor + Distorted centered k-space tensor. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + return kspace_to_image(self.distortion.A_adjoint(y)) + + +def resolve_oasis_checkpoint( + checkpoint: Optional[Path], + acceleration: int, + manifest_path: Path, +) -> Path: + """Resolve an explicit or packaged OASIS checkpoint path. + + Parameters + ---------- + checkpoint : Path or None + User-provided checkpoint path. + acceleration : int + Acceleration key used when loading from the manifest. + manifest_path : Path + JSON manifest with packaged checkpoint metadata. + + Returns + ------- + Path + Resolved checkpoint path. + """ + + if checkpoint is not None: + return checkpoint.expanduser().resolve() + + with manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + checkpoints = manifest.get("checkpoints", {}) + key = str(acceleration) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + filename = Path(checkpoints[key]["filename"]) + if filename.is_absolute(): + return filename + return (manifest_path.parent.parent / filename).resolve() + + +def choose_algorithm( + name: str, + checkpoint_file: Path, + img_size: tuple = (640, 368), + device: torch.device = "cpu", + verbose: bool = False, +) -> dinv.models.Reconstructor: + """Construct a reconstructor by selector name.""" + + match name: + case "zero-filled": + return ZeroFilledReconstructor() + case "conjugate-gradient": + return ConjugateGradientReconstructor(max_iter=20) + case "ram": + return RAMReconstructor(default_sigma=0.05, device=device) + case "dip": + return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) + case "tv-pgd": + return TVPGDReconstructor(n_iter=100, verbose=verbose) + case "tv-fista": + return TVFISTAReconstructor(n_iter=200, verbose=verbose) + case "tv-pdhg": + return TVPDHGReconstructor(n_iter=100, verbose=verbose) + case "wavelet-fista": + return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) + case "oasis-unet" | "unet": + return OASISSinglecoilUnetReconstructor( + checkpoint_file=str(checkpoint_file), + device=device, + ) + case _: + raise ValueError(f"Unknown algorithm {name!r}") + + +def choose_distortion( + name: str, + acceleration: int, + center_fraction: float, +) -> BaseDistortion: + """Construct a k-space distortion by display name.""" + + match name: + case "None": + return BaseDistortion() + case "Cartesian undersampling (variable density)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="variable_density_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (uniform random)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="uniform_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (uniform random, zero ACS)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=0.0, + pattern="uniform_random", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (equispaced)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=center_fraction, + pattern="equispaced", + axis=-1, + seed=42, + ) + case "Cartesian undersampling (equispaced, zero ACS)": + return CartesianUndersampling( + keep_fraction=1.0 / acceleration, + center_fraction=0.0, + pattern="equispaced", + axis=-1, + seed=42, + ) + case "Phase-encode ghosting": + return PhaseEncodeGhostingDistortion( + line_period=2, + line_offset=1, + phase_error_radians=torch.pi / 2, + corrupted_line_scale=1.0, + ) + case "Anisotropic LP": + return AnisotropicResolutionReduction( + kx_radius_fraction=1.0, + ky_radius_fraction=0.25, + ) + case "Hann taper LP": + return HannTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + ) + case "Kaiser taper LP": + return KaiserTaperResolutionReduction( + radius_fraction=0.35, + transition_fraction=0.4, + beta=8.6, + ) + case "Radial high-pass emphasis": + return RadialHighPassEmphasisDistortion(alpha=0.4) + case "Isotropic LP": + return IsotropicResolutionReduction(radius_fraction=0.1) + case "Off-center anisotropic Gaussian bias field": + return OffCenterAnisotropicGaussianKspaceBiasField( + width_x_fraction=0.2, + width_y_fraction=0.35, + center_x_fraction=0.15, + center_y_fraction=-0.1, + edge_gain=0.3, + ) + case "Translation motion": + return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) + case "Rotational motion": + return RotationalMotionDistortion(angle_radians=torch.pi / 6) + case "Segmented translation motion": + return SegmentedTranslationMotionDistortion( + shift_x_pixels=(0.0, 20.0, 50.0, -50.0), + shift_y_pixels=(0.0, 10.0, -20.0, 20.0), + ) + case "Gaussian bias field": + return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) + case "Gaussian noise": + return GaussianNoiseDistortion(sigma=0.00001) + case _: + raise ValueError(f"Unknown distortion {name!r}") + + +def choose_metric(name: str) -> dinv.metric.Metric: + """Construct a DeepInverse metric by selector name.""" + + match name: + case "PSNR": + return dinv.metric.PSNR(max_pixel=None, complex_abs=True) + case "NMSE": + return dinv.metric.NMSE(complex_abs=True) + case "SSIM": + return dinv.metric.SSIM(max_pixel=None, complex_abs=True) + case "HaarPSI": + return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) + case "BlurStrength": + return dinv.metric.BlurStrength(complex_abs=True) + case "SharpnessIndex": + return dinv.metric.SharpnessIndex(complex_abs=True) + case _: + raise ValueError(f"Unknown metric {name!r}") + + +def build_parser() -> argparse.ArgumentParser: + """Build the command-line parser.""" + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--source", + type=Path, + required=True, + help="OASIS root directory containing subject folders.", + ) + parser.add_argument( + "--split_csv", + type=Path, + default=DEFAULT_SPLIT_CSV, + help="CSV listing OASIS subjects and slice counts.", + ) + parser.add_argument( + "--manifest", + type=Path, + default=DEFAULT_MANIFEST_PATH, + help="Checkpoint manifest JSON.", + ) + parser.add_argument( + "--checkpoint", + type=Path, + default=None, + help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", + ) + parser.add_argument( + "--acceleration", + type=int, + default=4, + help="Packaged OASIS checkpoint acceleration factor.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.08, + help="Center fraction used by the Cartesian undersampling distortion.", + ) + parser.add_argument( + "--distortion", + type=str, + default="Cartesian undersampling (uniform random)", + choices=DISTORTIONS, + ) + parser.add_argument( + "--algorithm", + type=str, + default="unet", + choices=ALGORITHMS, + help="Reconstruction algorithm applied to distorted OASIS k-space.", + ) + parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") + parser.add_argument( + "--sample_rate", + type=float, + default=0.6, + help="Fraction of slices per volume to include from the split CSV.", + ) + parser.add_argument("--volume_cache_size", type=int, default=2) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose output for reconstructors that support it.", + ) + return parser + + +def main() -> None: + """Run OASIS inference plots.""" + + args = build_parser().parse_args() + args.source = args.source.expanduser().resolve() + args.split_csv = args.split_csv.expanduser().resolve() + args.manifest = args.manifest.expanduser().resolve() + checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) + + device = dinv.utils.get_device() + dataset = OasisSliceDataset( + split_csv=args.split_csv, + data_path=args.source, + sample_rate=args.sample_rate, + cache_size=args.volume_cache_size, + ) + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + metrics = [choose_metric(m) for m in METRICS] + + for i, batch in enumerate(iter(dataloader)): + if i >= args.num_samples: + break + + x = batch["x"].to(device) + subject_id = batch["subject_id"][0] + slice_num = int(batch["slice_num"][0]) + + for distortion_name in [args.distortion]: + distortion = choose_distortion( + distortion_name, + acceleration=args.acceleration, + center_fraction=args.center_fraction, + ) + + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *x.shape[-2:]), + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, + img_size=(1, 2, *x.shape[-2:]), + device=device, + ) + oasis_physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) + oasis_physics = OasisCenteredFFTPhysics(distortion) + + y = image_to_kspace(x) + y_distorted = distortion.A(y) + y_physics_distorted = physics(x) + x_distorted = kspace_to_image(y_distorted) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", + distortion_name, + ) + + for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: + print( + f"Evaluating algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}..." + ) + + algo = choose_algorithm( + algo_name, + checkpoint_file=checkpoint_file, + img_size=x.shape[-2:], + device=device, + verbose=args.verbose, + ).to(device) + + if algo_name in {"oasis-unet", "unet"}: + y_eval = y_distorted + eval_physics_clean = oasis_physics_clean + eval_physics = oasis_physics + else: + y_eval = y_physics_distorted + eval_physics_clean = physics_clean + eval_physics = physics + + x_uncorrected = algo(y_eval, eval_physics_clean) + x_corrected = algo(y_eval, eval_physics) + uncorrected_scores = [ + f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" for m in metrics + ] + corrected_scores = [ + f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" for m in metrics + ] + print(f" uncorrected: {', '.join(uncorrected_scores)}") + print(f" corrected: {', '.join(corrected_scores)}") + + dinv.utils.plot( + { + "Ground truth OASIS slice": x, + "Distorted ksp, zero-filled": x_distorted, + f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, + f"Distorted ksp, {algo_name} recon, corrected": x_corrected, + }, + subtitles=[ + "", + "", + "\n".join(uncorrected_scores), + "\n".join(corrected_scores), + ], + show=False, + close=True, + suptitle=( + f"Algo {algo_name}, distortion {distortion_name}, " + f"subject {subject_id}, slice {slice_num}" + ), + save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", + fontsize=3, + ) + + print("done!") + + +if __name__ == "__main__": + main() diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index 4a41674..afa6c39 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -1,274 +1,275 @@ -"""Cartesian k-space undersampling distortions.""" - -from __future__ import annotations - -import torch - -from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion - - -PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} - -# Lorentzian offset that controls how steeply the variable-density weights -# fall off with normalized distance from k-space center. -# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio -# between the ACS boundary and the k-space edge. -_VD_WEIGHT_OFFSET: float = 0.05 - - -class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): - """Cartesian k-space undersampling along phase-encode direction. - - This distortion simulates true MRI acquisition undersampling by applying a - binary sampling mask along the phase-encode direction (default). The mask - keeps a contiguous center region (ACS - Auto-Calibration Signal) fully - sampled and randomly undersamples the periphery. - - The mask is deterministic given a shape and seed, ensuring reproducibility. - - :param float keep_fraction: Fraction of phase-encode lines to keep in - ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. - :param float center_fraction: Fraction of phase-encode lines reserved for - the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to - ``0.5 * keep_fraction``, leaving part of the acquisition budget for - randomized peripheral line sampling. Set to ``0`` to sample without a - guaranteed ACS block. - :param str pattern: Peripheral sampling pattern. Supported values are - ``"uniform_random"``, ``"variable_density_random"``, and - ``"equispaced"``. Defaults to ``"variable_density_random"``. - :param int axis: Axis along which to apply undersampling. Default is -2 - (phase-encode for 4D tensors), can also be -3 for 5D tensors. - :param int | None seed: Random seed for reproducible mask generation. - If None, uses unseeded randomness (not recommended for reproducibility). - Has no effect when ``pattern="equispaced"`` because that pattern is - fully deterministic and uses no randomness. - """ - - def __init__( - self, - keep_fraction: float = 0.25, - center_fraction: float | None = None, - pattern: str = "variable_density_random", - axis: int = -2, - seed: int | None = None, - ) -> None: - super().__init__() - - if not 0.0 < keep_fraction <= 1.0: - raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") - - if axis not in (-2, -3): - raise ValueError(f"axis must be -2 or -3, got {axis}") - - if pattern not in PATTERNS: - raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") - - if center_fraction is None: - # Reserve half of the kept lines for a contiguous ACS block and - # leave the remainder for peripheral sampling. - center_fraction = 0.5 * keep_fraction - elif not 0.0 <= center_fraction <= 1.0: - raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") - - if center_fraction > keep_fraction: - raise ValueError( - f"center_fraction ({center_fraction}) must not exceed " - f"keep_fraction ({keep_fraction})" - ) - - self.keep_fraction = keep_fraction - self.center_fraction = center_fraction - self.pattern = pattern - self.axis = axis - self.seed = seed - self._cached_mask = None - self._cached_shape = None - self._cached_device: torch.device | None = None - - def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: - """Generate a binary Cartesian undersampling mask. - - The mask is applied along the specified axis (phase-encode by default). - It keeps a contiguous center region fully sampled when requested and - randomly undersamples the periphery to achieve the desired keep_fraction. - - :param tuple[int, ...] shape: k-space tensor shape. - :param torch.device device: Device for the mask. - :returns: Binary mask broadcastable to shape. - :rtype: torch.Tensor - """ - # Cache the mask keyed on both shape and device so that repeated GPU - # forward passes do not perform a CPU→GPU copy on every call. - if ( - self._cached_mask is not None - and self._cached_shape == shape - and self._cached_device == device - ): - return self._cached_mask - - # Get the size along the undersampling axis - axis_size = shape[self.axis] - - # Set random seed for reproducibility - rng_state = None - if self.seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.seed) - - try: - # Generate the 1D mask along the specified axis - mask_1d = self._generate_1d_mask(axis_size) - - # Expand the 1D mask to match the full shape - mask = self._expand_mask_to_shape(mask_1d, shape) - finally: - # Restore random state if we set a seed - if self.seed is not None and rng_state is not None: - torch.set_rng_state(rng_state) - - # Move to the target device and cache, including the device. - mask = mask.to(device) - self._cached_mask = mask - self._cached_shape = shape - self._cached_device = device - - return mask - - def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: - """Generate a 1D binary mask along the undersampling axis. - - :param int axis_size: Size of the axis along which to undersample. - :returns: 1D binary mask of shape (axis_size,). - :rtype: torch.Tensor - """ - # Calculate number of lines to keep and center region size. - num_keep = max(1, int(round(axis_size * self.keep_fraction))) - num_center = int(round(axis_size * self.center_fraction)) - - # Ensure center is not larger than total kept lines - num_center = min(num_center, num_keep) - - # Initialize mask (all zeros) - mask = torch.zeros(axis_size, dtype=torch.float32) - - # Keep the contiguous center region (ACS) when requested. - center_start = (axis_size - num_center) // 2 - center_end = center_start + num_center - if num_center > 0: - mask[center_start:center_end] = 1.0 - - peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) - - # Sample from the peripheral region with the requested pattern. - if num_keep > num_center: - num_peripheral = num_keep - num_center - selected_indices = self._select_peripheral_indices( - peripheral_indices=peripheral_indices, - num_peripheral=num_peripheral, - axis_size=axis_size, - ) - mask[selected_indices] = 1.0 - - return mask - - def _peripheral_indices( - self, - center_start: int, - center_end: int, - axis_size: int, - ) -> torch.Tensor: - """Return line indices outside the contiguous ACS region.""" - - return torch.cat( - [ - torch.arange(0, center_start, dtype=torch.long), - torch.arange(center_end, axis_size, dtype=torch.long), - ] - ) - - def _select_peripheral_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - axis_size: int, - ) -> torch.Tensor: - """Select peripheral lines according to the configured sampling pattern.""" - - if self.pattern == "uniform_random": - permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] - return peripheral_indices[permutation] - - if self.pattern == "variable_density_random": - center = 0.5 * (axis_size - 1) - distances = torch.abs(peripheral_indices.to(torch.float32) - center) - normalized_distances = distances / max(center, 1.0) - - # Favor low-frequency lines near the ACS boundary while preserving - # non-zero probability across the periphery. - weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) - selected_positions = torch.multinomial( - weights, - num_samples=num_peripheral, - replacement=False, - ) - return peripheral_indices[selected_positions] - - if self.pattern == "equispaced": - return self._select_equispaced_indices(peripheral_indices, num_peripheral) - - raise RuntimeError(f"Unsupported pattern {self.pattern!r}") - - def _select_equispaced_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - ) -> torch.Tensor: - """Select evenly spaced peripheral indices. - - Lines are spaced uniformly within the peripheral index array. - Because the peripheral array has a gap at the ACS region, the spacing - between selected k-space lines will be approximately doubled near the - ACS boundary compared to the spacing in the outer periphery. - This pattern is fully deterministic and unaffected by ``seed``. - """ - - if num_peripheral >= len(peripheral_indices): - return peripheral_indices - - # The early-return above guarantees step > 1.0, which in turn means - # floor((idx + 0.5) * step) is strictly increasing, so no collision - # resolution is needed. - step = len(peripheral_indices) / num_peripheral - positions = ( - (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() - ) - - return peripheral_indices[positions] - - def _expand_mask_to_shape( - self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] - ) -> torch.Tensor: - """Expand a 1D mask to the full k-space shape. - - The 1D mask is applied along the specified axis and broadcast to all - other dimensions. - - :param torch.Tensor mask_1d: 1D binary mask. - :param tuple[int, ...] target_shape: Target shape for expansion. - :returns: Mask broadcastable to target_shape. - :rtype: torch.Tensor - """ - # Convert negative axis to positive - ndim = len(target_shape) - axis = self.axis if self.axis >= 0 else ndim + self.axis - - # Create the full shape with 1s for broadcast dimensions - expand_shape = list(target_shape) - for i in range(len(expand_shape)): - if i != axis: - expand_shape[i] = 1 - - # Reshape the 1D mask to the expand shape - mask = mask_1d.reshape(expand_shape) - - return mask +"""Cartesian k-space undersampling distortions.""" + +from __future__ import annotations + +import torch + +from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion + + +PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} + +# Lorentzian offset that controls how steeply the variable-density weights +# fall off with normalized distance from k-space center. +# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio +# between the ACS boundary and the k-space edge. +_VD_WEIGHT_OFFSET: float = 0.05 + + +class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): + """Cartesian k-space undersampling along phase-encode direction. + + This distortion simulates true MRI acquisition undersampling by applying a + binary sampling mask along the phase-encode direction (default). The mask + keeps a contiguous center region (ACS - Auto-Calibration Signal) fully + sampled and randomly undersamples the periphery. + + The mask is deterministic given a shape and seed, ensuring reproducibility. + + :param float keep_fraction: Fraction of phase-encode lines to keep in + ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. + :param float center_fraction: Fraction of phase-encode lines reserved for + the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to + ``0.5 * keep_fraction``, leaving part of the acquisition budget for + randomized peripheral line sampling. Set to ``0`` to sample without a + guaranteed ACS block. + :param str pattern: Peripheral sampling pattern. Supported values are + ``"uniform_random"``, ``"variable_density_random"``, and + ``"equispaced"``. Defaults to ``"variable_density_random"``. + :param int axis: Axis along which to apply undersampling. Default is -2 + (phase-encode for 4D tensors), can also be -1 for readout/column + masking or -3 for the slice/depth axis in 5D tensors. + :param int | None seed: Random seed for reproducible mask generation. + If None, uses unseeded randomness (not recommended for reproducibility). + Has no effect when ``pattern="equispaced"`` because that pattern is + fully deterministic and uses no randomness. + """ + + def __init__( + self, + keep_fraction: float = 0.25, + center_fraction: float | None = None, + pattern: str = "variable_density_random", + axis: int = -2, + seed: int | None = None, + ) -> None: + super().__init__() + + if not 0.0 < keep_fraction <= 1.0: + raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") + + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + + if pattern not in PATTERNS: + raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") + + if center_fraction is None: + # Reserve half of the kept lines for a contiguous ACS block and + # leave the remainder for peripheral sampling. + center_fraction = 0.5 * keep_fraction + elif not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + + if center_fraction > keep_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"keep_fraction ({keep_fraction})" + ) + + self.keep_fraction = keep_fraction + self.center_fraction = center_fraction + self.pattern = pattern + self.axis = axis + self.seed = seed + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a binary Cartesian undersampling mask. + + The mask is applied along the specified axis (phase-encode by default). + It keeps a contiguous center region fully sampled when requested and + randomly undersamples the periphery to achieve the desired keep_fraction. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to shape. + :rtype: torch.Tensor + """ + # Cache the mask keyed on both shape and device so that repeated GPU + # forward passes do not perform a CPU→GPU copy on every call. + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + # Get the size along the undersampling axis + axis_size = shape[self.axis] + + # Set random seed for reproducibility + rng_state = None + if self.seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.seed) + + try: + # Generate the 1D mask along the specified axis + mask_1d = self._generate_1d_mask(axis_size) + + # Expand the 1D mask to match the full shape + mask = self._expand_mask_to_shape(mask_1d, shape) + finally: + # Restore random state if we set a seed + if self.seed is not None and rng_state is not None: + torch.set_rng_state(rng_state) + + # Move to the target device and cache, including the device. + mask = mask.to(device) + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D binary mask along the undersampling axis. + + :param int axis_size: Size of the axis along which to undersample. + :returns: 1D binary mask of shape (axis_size,). + :rtype: torch.Tensor + """ + # Calculate number of lines to keep and center region size. + num_keep = max(1, int(round(axis_size * self.keep_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + + # Ensure center is not larger than total kept lines + num_center = min(num_center, num_keep) + + # Initialize mask (all zeros) + mask = torch.zeros(axis_size, dtype=torch.float32) + + # Keep the contiguous center region (ACS) when requested. + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + if num_center > 0: + mask[center_start:center_end] = 1.0 + + peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) + + # Sample from the peripheral region with the requested pattern. + if num_keep > num_center: + num_peripheral = num_keep - num_center + selected_indices = self._select_peripheral_indices( + peripheral_indices=peripheral_indices, + num_peripheral=num_peripheral, + axis_size=axis_size, + ) + mask[selected_indices] = 1.0 + + return mask + + def _peripheral_indices( + self, + center_start: int, + center_end: int, + axis_size: int, + ) -> torch.Tensor: + """Return line indices outside the contiguous ACS region.""" + + return torch.cat( + [ + torch.arange(0, center_start, dtype=torch.long), + torch.arange(center_end, axis_size, dtype=torch.long), + ] + ) + + def _select_peripheral_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + axis_size: int, + ) -> torch.Tensor: + """Select peripheral lines according to the configured sampling pattern.""" + + if self.pattern == "uniform_random": + permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] + return peripheral_indices[permutation] + + if self.pattern == "variable_density_random": + center = 0.5 * (axis_size - 1) + distances = torch.abs(peripheral_indices.to(torch.float32) - center) + normalized_distances = distances / max(center, 1.0) + + # Favor low-frequency lines near the ACS boundary while preserving + # non-zero probability across the periphery. + weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) + selected_positions = torch.multinomial( + weights, + num_samples=num_peripheral, + replacement=False, + ) + return peripheral_indices[selected_positions] + + if self.pattern == "equispaced": + return self._select_equispaced_indices(peripheral_indices, num_peripheral) + + raise RuntimeError(f"Unsupported pattern {self.pattern!r}") + + def _select_equispaced_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + ) -> torch.Tensor: + """Select evenly spaced peripheral indices. + + Lines are spaced uniformly within the peripheral index array. + Because the peripheral array has a gap at the ACS region, the spacing + between selected k-space lines will be approximately doubled near the + ACS boundary compared to the spacing in the outer periphery. + This pattern is fully deterministic and unaffected by ``seed``. + """ + + if num_peripheral >= len(peripheral_indices): + return peripheral_indices + + # The early-return above guarantees step > 1.0, which in turn means + # floor((idx + 0.5) * step) is strictly increasing, so no collision + # resolution is needed. + step = len(peripheral_indices) / num_peripheral + positions = ( + (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() + ) + + return peripheral_indices[positions] + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape. + + The 1D mask is applied along the specified axis and broadcast to all + other dimensions. + + :param torch.Tensor mask_1d: 1D binary mask. + :param tuple[int, ...] target_shape: Target shape for expansion. + :returns: Mask broadcastable to target_shape. + :rtype: torch.Tensor + """ + # Convert negative axis to positive + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + # Create the full shape with 1s for broadcast dimensions + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + # Reshape the 1D mask to the expand shape + mask = mask_1d.reshape(expand_shape) + + return mask diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index a11a909..eafddf3 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -1,14 +1,14 @@ -from .deep import ( - RAMReconstructor, - DeepImagePriorReconstructor, - FastMRISinglecoilUnetReconstructor, - OASISSinglecoilUnetReconstructor, -) -from .classic import ( - ZeroFilledReconstructor, - ConjugateGradientReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, -) +from .deep import ( + RAMReconstructor, + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, +) +from .classic import ( + ZeroFilledReconstructor, + ConjugateGradientReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, +) diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index ba520f9..6da8ac1 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,268 +1,268 @@ -from pathlib import Path - -import deepinv as dinv -import torch - -from ._fastmri_unet import Unet -from ..utils import download_file_with_sha256, matches_sha256 - - -class RAMReconstructor(dinv.models.Reconstructor): - """ - Wrapper for RAM from DeepInverse. - Normalises input by magnitude of adjoint. - - :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. - """ - - def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = dinv.models.RAM(device=device) - self.default_sigma = default_sigma - - def forward(self, y, physics): - _x_adj = physics.A_adjoint(y) - scale = torch.quantile(_x_adj, 0.99) - - physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() - physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() - - if physics_norm > 1.2 or physics_norm < 0.8: - raise ValueError( - f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" - ) - if physics_adjointness > 0.1 or physics_adjointness < -0.1: - raise ValueError( - f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" - ) - - sigma = ( - None - if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") - else self.default_sigma - ) - - with torch.no_grad(): - return self.model(y / scale, physics, sigma=sigma) * scale - - -class DeepImagePriorReconstructor(dinv.models.Reconstructor): - """ - Wrapper for Deep Image Prior from DeepInverse. - - :param tuple img_size: image size of the output. Defaults to (640, 368) - :param int n_iter: number of iterations to fit the DIP. Defaults to 100. - """ - - def __init__( - self, - img_size: tuple = (640, 368), - n_iter: int = 100, - verbose: bool = True, - ) -> None: - super().__init__() - - lr = 1e-2 # learning rate for the optimizer. - channels = 64 # number of channels per layer in the decoder. - in_size = [2, 2] # size of the input to the decoder. - - self.model = dinv.models.DeepImagePrior( - dinv.models.ConvDecoder( - img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels - ), - learning_rate=lr, - iterations=n_iter, - verbose=verbose, - input_size=[channels] + in_size, - ) - - def forward(self, y, physics): - return self.model(y, physics) - - -class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): - """ - Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. - - Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. - - Note: this model discards complex information and only returns the magnitude image. - - NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). - - The pretrained fastMRI model expects magnitude images that are normalized per slice, - so this wrapper matches that preprocessing and rescales the output back to the - original adjoint-image intensity range. - - See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. - """ - - MODEL_URL = ( - "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" - "knee_sc_leaderboard_state_dict.pt" - ) - MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" - MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 256, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - - def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = Unet(**self.UNET_KWARGS) - - state_dict_path = ( - Path(state_dict_file).expanduser() - if state_dict_file is not None - else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME - ) - - if state_dict_file is None: - if not matches_sha256(state_dict_path, self.MODEL_SHA256): - download_file_with_sha256( - self.MODEL_URL, - state_dict_path, - self.MODEL_SHA256, - label="FastMRI UNet checkpoint", - ) - elif not state_dict_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") - - self.model.load_state_dict( - torch.load(state_dict_path, map_location=device, weights_only=True) - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - x_in = physics.A_adjoint(y) - - x_in = dinv.utils.complex_abs(x_in, keepdim=True) - - # Match the fastMRI normalization used for training, then rescale the - # predicted magnitude image back to the original adjoint-image intensity range. - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu # (B, 1, H, W) - - return torch.cat([out, torch.zeros_like(out)], dim=1) - - -def _load_unet_checkpoint_state( - checkpoint_path: Path, - device: torch.device, -) -> dict[str, torch.Tensor]: - """Load U-Net weights from a plain or Lightning-style checkpoint.""" - - try: - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - except TypeError: - checkpoint = torch.load(checkpoint_path, map_location=device) - - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - elif isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: - state_dict = checkpoint["model_state_dict"] - elif isinstance(checkpoint, dict): - state_dict = checkpoint - else: - raise ValueError(f"Unsupported checkpoint format in {checkpoint_path}.") - - if any(key.startswith("unet.") for key in state_dict): - return { - key[len("unet.") :]: value - for key, value in state_dict.items() - if key.startswith("unet.") - } - - if any(key.startswith("model.unet.") for key in state_dict): - return { - key[len("model.unet.") :]: value - for key, value in state_dict.items() - if key.startswith("model.unet.") - } - - if any(key.startswith("module.") for key in state_dict): - return { - key[len("module.") :]: value - for key, value in state_dict.items() - if key.startswith("module.") - } - - return state_dict - - -class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): - """ - Wrapper for a trained OASIS single-coil U-Net reconstruction model. - - The model reuses the repository's fastMRI-derived :class:`Unet` module, but - loads an OASIS checkpoint supplied by the caller. The forward pass converts - k-space to a zero-filled magnitude image, applies per-slice instance - normalization, runs the U-Net, then rescales the prediction back to the - adjoint-image intensity range. - - :param str checkpoint_file: Path to the trained OASIS U-Net checkpoint. - :param torch.device device: Device on which to run inference. - """ - - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 32, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - - def __init__( - self, - checkpoint_file: str, - device: torch.device = None, - ) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - checkpoint_path = Path(checkpoint_file).expanduser() - if not checkpoint_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - - self.model = Unet(**self.UNET_KWARGS) - self.model.load_state_dict( - _load_unet_checkpoint_state(checkpoint_path, device), - strict=True, - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - """Reconstruct a magnitude image from measured k-space.""" - - x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu - - return torch.cat([out, torch.zeros_like(out)], dim=1) +from pathlib import Path + +import deepinv as dinv +import torch + +from ._fastmri_unet import Unet +from ..utils import download_file_with_sha256, matches_sha256 + + +class RAMReconstructor(dinv.models.Reconstructor): + """ + Wrapper for RAM from DeepInverse. + Normalises input by magnitude of adjoint. + + :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. + """ + + def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = dinv.models.RAM(device=device) + self.default_sigma = default_sigma + + def forward(self, y, physics): + _x_adj = physics.A_adjoint(y) + scale = torch.quantile(_x_adj, 0.99) + + physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() + physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() + + if physics_norm > 1.2 or physics_norm < 0.8: + raise ValueError( + f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" + ) + if physics_adjointness > 0.1 or physics_adjointness < -0.1: + raise ValueError( + f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" + ) + + sigma = ( + None + if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") + else self.default_sigma + ) + + with torch.no_grad(): + return self.model(y / scale, physics, sigma=sigma) * scale + + +class DeepImagePriorReconstructor(dinv.models.Reconstructor): + """ + Wrapper for Deep Image Prior from DeepInverse. + + :param tuple img_size: image size of the output. Defaults to (640, 368) + :param int n_iter: number of iterations to fit the DIP. Defaults to 100. + """ + + def __init__( + self, + img_size: tuple = (640, 368), + n_iter: int = 100, + verbose: bool = True, + ) -> None: + super().__init__() + + lr = 1e-2 # learning rate for the optimizer. + channels = 64 # number of channels per layer in the decoder. + in_size = [2, 2] # size of the input to the decoder. + + self.model = dinv.models.DeepImagePrior( + dinv.models.ConvDecoder( + img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels + ), + learning_rate=lr, + iterations=n_iter, + verbose=verbose, + input_size=[channels] + in_size, + ) + + def forward(self, y, physics): + return self.model(y, physics) + + +class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): + """ + Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. + + Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. + + Note: this model discards complex information and only returns the magnitude image. + + NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). + + The pretrained fastMRI model expects magnitude images that are normalized per slice, + so this wrapper matches that preprocessing and rescales the output back to the + original adjoint-image intensity range. + + See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. + """ + + MODEL_URL = ( + "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" + "knee_sc_leaderboard_state_dict.pt" + ) + MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" + MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 256, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = Unet(**self.UNET_KWARGS) + + state_dict_path = ( + Path(state_dict_file).expanduser() + if state_dict_file is not None + else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME + ) + + if state_dict_file is None: + if not matches_sha256(state_dict_path, self.MODEL_SHA256): + download_file_with_sha256( + self.MODEL_URL, + state_dict_path, + self.MODEL_SHA256, + label="FastMRI UNet checkpoint", + ) + elif not state_dict_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") + + self.model.load_state_dict( + torch.load(state_dict_path, map_location=device, weights_only=True) + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + x_in = physics.A_adjoint(y) + + x_in = dinv.utils.complex_abs(x_in, keepdim=True) + + # Match the fastMRI normalization used for training, then rescale the + # predicted magnitude image back to the original adjoint-image intensity range. + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu # (B, 1, H, W) + + return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + # Accept either a checkpoint with "state_dict" or a plain state_dict + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # If Lightning-style keys like 'unet.*' exist, strip the 'unet.' prefix + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """Wrapper for a trained OASIS single-coil U-Net model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + loads an OASIS checkpoint supplied by the caller. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + Parameters + ---------- + checkpoint_file : str + Path to the trained OASIS U-Net checkpoint. + device : torch.device, optional + Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__( + self, + checkpoint_file: str, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Use the helper to obtain a normalized state_dict (handles plain or Lightning) + state_dict = _load_unet_checkpoint_state(checkpoint_path, device) + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + state_dict, + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space. + + Parameters + ---------- + y : torch.Tensor + Measured k-space tensor. + physics : dinv.physics.Physics + Physics operator used to form the adjoint input image. + + Returns + ------- + torch.Tensor + Complex-valued reconstruction with zero imaginary channel. + """ + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) From 52fa4edb6bb1803dd19266f6d06f6ada5604c173 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Fri, 8 May 2026 13:29:05 +0200 Subject: [PATCH 11/17] Automated SHA256 certified model download for OASIS --- examples/OASIS_inference_plot.py | 20 ++---- mri_recon/reconstruction/deep.py | 112 ++++++++++++++++++++++++++++--- mri_recon/utils/__init__.py | 8 ++- mri_recon/utils/io.py | 85 +++++++++++++++++++++-- tests/test_reconstructions.py | 102 ++++++++++++++++++++++++++++ tests/test_utils_io.py | 29 +++++++- uv.lock | 16 +++++ 7 files changed, 342 insertions(+), 30 deletions(-) diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index 1523d10..ef1424a 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -7,7 +7,6 @@ from __future__ import annotations import argparse -import json import os import sys from collections import OrderedDict @@ -366,21 +365,10 @@ def resolve_oasis_checkpoint( if checkpoint is not None: return checkpoint.expanduser().resolve() - with manifest_path.open("r", encoding="utf-8") as handle: - manifest = json.load(handle) - - checkpoints = manifest.get("checkpoints", {}) - key = str(acceleration) - if key not in checkpoints: - available = ", ".join(sorted(checkpoints)) - raise ValueError( - f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." - ) - - filename = Path(checkpoints[key]["filename"]) - if filename.is_absolute(): - return filename - return (manifest_path.parent.parent / filename).resolve() + return OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=manifest_path, + ) def choose_algorithm( diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 6da8ac1..52823cb 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,10 +1,16 @@ +import json from pathlib import Path +from typing import Optional import deepinv as dinv import torch from ._fastmri_unet import Unet -from ..utils import download_file_with_sha256, matches_sha256 +from ..utils import ( + download_file_with_sha256, + download_google_drive_file_with_sha256, + matches_sha256, +) class RAMReconstructor(dinv.models.Reconstructor): @@ -194,15 +200,21 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): """Wrapper for a trained OASIS single-coil U-Net model. The model reuses the repository's fastMRI-derived :class:`Unet` module, but - loads an OASIS checkpoint supplied by the caller. The forward pass converts + can also download the packaged checkpoint manifest and checkpoint on demand + when no explicit checkpoint path is supplied. The forward pass converts k-space to a zero-filled magnitude image, applies per-slice instance normalization, runs the U-Net, then rescales the prediction back to the adjoint-image intensity range. Parameters ---------- - checkpoint_file : str - Path to the trained OASIS U-Net checkpoint. + checkpoint_file : str, optional + Path to the trained OASIS U-Net checkpoint. If omitted, the reconstructor + downloads the packaged checkpoint for ``acceleration``. + acceleration : int, optional + Packaged checkpoint acceleration factor used when ``checkpoint_file`` is omitted. + manifest_path : str, optional + Override path for the downloaded or cached packaged checkpoint manifest. device : torch.device, optional Device on which to run inference. """ @@ -214,10 +226,85 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): "num_pool_layers": 4, "drop_prob": 0.0, } + ASSET_ROOT = Path(__file__).resolve().parents[2] / "reconstruction_only" + CHECKPOINTS_DIR = ASSET_ROOT / "checkpoints" + MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" + MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" + MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" + CHECKPOINT_FILE_IDS = { + "4": "11s6YeM6_YJeD4wcrn24jyMjyj_vX2ANU", + "8": "1w8PDiYpr2xBPXahzRllhZjQT1yoMGXg-", + "10": "1djJ2i0uYP4PT070CS0xx9nNJ41JmSFhh", + } + CHECKPOINT_SHA256 = { + "4": "4fcefa9860cb7895e581a0de8f90bd7f188ae1c0b5e428a4a07519dd2561ac29", + "8": "2cd4c44e3c7a3870adbe5090b2bfaae044f5e3f0b4bcaf2b1fc29969e5e6b9ca", + "10": "90e3d9b17aa0f9fd43aaf090c152edcbdead1b9be41a076594e41098db7befa8", + } + + @classmethod + def ensure_manifest(cls, manifest_path: Optional[Path] = None) -> Path: + """Ensure the packaged OASIS checkpoint manifest exists locally and is verified.""" + + resolved_manifest_path = ( + manifest_path.expanduser().resolve() if manifest_path is not None else cls.MANIFEST_PATH + ) + if not matches_sha256(resolved_manifest_path, cls.MANIFEST_SHA256): + download_google_drive_file_with_sha256( + cls.MANIFEST_FILE_ID, + resolved_manifest_path, + cls.MANIFEST_SHA256, + label="OASIS checkpoint manifest", + ) + return resolved_manifest_path + + @classmethod + def resolve_default_checkpoint( + cls, + acceleration: int, + manifest_path: Optional[Path] = None, + ) -> Path: + """Resolve and download the packaged OASIS checkpoint for a given acceleration.""" + + resolved_manifest_path = cls.ensure_manifest(manifest_path) + with resolved_manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + key = str(acceleration) + checkpoints = manifest.get("checkpoints", {}) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." + ) + + if key not in cls.CHECKPOINT_FILE_IDS or key not in cls.CHECKPOINT_SHA256: + raise ValueError( + f"No automated download metadata is configured for acceleration {acceleration}." + ) + + filename = Path(checkpoints[key]["filename"]) + checkpoint_path = ( + filename + if filename.is_absolute() + else (resolved_manifest_path.parent.parent / filename) + ).resolve() + + if not matches_sha256(checkpoint_path, cls.CHECKPOINT_SHA256[key]): + download_google_drive_file_with_sha256( + cls.CHECKPOINT_FILE_IDS[key], + checkpoint_path, + cls.CHECKPOINT_SHA256[key], + label=f"OASIS checkpoint x{acceleration}", + ) + + return checkpoint_path def __init__( self, - checkpoint_file: str, + checkpoint_file: str | None = None, + acceleration: int = 4, + manifest_path: str | None = None, device: torch.device = None, ) -> None: super().__init__() @@ -226,9 +313,18 @@ def __init__( device = torch.device("cpu") self.device = device - checkpoint_path = Path(checkpoint_file).expanduser() - if not checkpoint_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + if checkpoint_file is None: + resolved_manifest_path = ( + Path(manifest_path).expanduser() if manifest_path is not None else None + ) + checkpoint_path = self.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=resolved_manifest_path, + ) + else: + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") # Use the helper to obtain a normalized state_dict (handles plain or Lightning) state_dict = _load_unet_checkpoint_state(checkpoint_path, device) diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 61ae0e6..3d17127 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -1,5 +1,11 @@ from .io import download_file_with_sha256 as download_file_with_sha256 +from .io import download_google_drive_file_with_sha256 as download_google_drive_file_with_sha256 from .io import format_megabytes as format_megabytes from .io import matches_sha256 as matches_sha256 -__all__ = ["download_file_with_sha256", "format_megabytes", "matches_sha256"] +__all__ = [ + "download_file_with_sha256", + "download_google_drive_file_with_sha256", + "format_megabytes", + "matches_sha256", +] diff --git a/mri_recon/utils/io.py b/mri_recon/utils/io.py index 5aaa962..bde5d0e 100644 --- a/mri_recon/utils/io.py +++ b/mri_recon/utils/io.py @@ -1,6 +1,9 @@ import hashlib +import re import tempfile +from io import BytesIO from pathlib import Path +from urllib.parse import urlencode from urllib.request import urlopen from tqdm.auto import tqdm @@ -21,8 +24,24 @@ def format_megabytes(num_bytes: int) -> str: return f"{num_bytes / (1024 * 1024):.1f} MB" -def download_file_with_sha256( - url: str, +class _BytesResponse: + def __init__(self, payload: bytes): + self._buffer = BytesIO(payload) + self.headers = {"Content-Length": str(len(payload))} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self, size: int = -1) -> bytes: + return self._buffer.read(size) + + +def _download_response_with_sha256( + response_factory, + source: str, destination: Path, expected_sha256: str, *, @@ -30,7 +49,7 @@ def download_file_with_sha256( report_interval_mb: int = 25, ) -> None: destination.parent.mkdir(parents=True, exist_ok=True) - print(f"Downloading {label} from {url} to {destination}. This may take a moment.") + print(f"Downloading {label} from {source} to {destination}. This may take a moment.") chunk_size = 1024 * 1024 report_interval = report_interval_mb * 1024 * 1024 @@ -42,7 +61,7 @@ def download_file_with_sha256( ) as handle: tmp_path = Path(handle.name) - with urlopen(url, timeout=30) as response: + with response_factory() as response: total_size = response.headers.get("Content-Length") total_size = int(total_size) if total_size is not None else None with tqdm( @@ -67,3 +86,61 @@ def download_file_with_sha256( if tmp_path is not None: tmp_path.unlink(missing_ok=True) raise + + +def download_file_with_sha256( + url: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + _download_response_with_sha256( + lambda: urlopen(url, timeout=30), + url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + + +def download_google_drive_file_with_sha256( + file_id: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + base_url = "https://drive.usercontent.google.com/download" + request_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download'})}" + + with urlopen(request_url, timeout=30) as response: + payload = response.read() + + if b"Google Drive - Virus scan warning" not in payload: + _download_response_with_sha256( + lambda: _BytesResponse(payload), + request_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + return + + html = payload.decode("utf-8", errors="replace") + uuid_match = re.search(r'name="uuid" value="([^"]+)"', html) + if uuid_match is None: + raise ValueError(f"Could not parse Google Drive confirmation token for file {file_id}.") + + confirmed_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download', 'confirm': 't', 'uuid': uuid_match.group(1)})}" + download_file_with_sha256( + confirmed_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index c2689ed..9e55c41 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -150,3 +150,105 @@ def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, assert x_hat.shape == x.shape assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_path, monkeypatch): + asset_root = tmp_path / "reconstruction_only" + manifest_path = asset_root / "checkpoints" / "manifest.json" + checkpoint_path = asset_root / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" + manifest_bytes = ( + b'{"checkpoints": {"4": {"filename": "checkpoints/oasis_balanced_seed24_accel4.ckpt"}}}' + ) + checkpoint_bytes = b"fake-oasis-checkpoint" + downloads = [] + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "MANIFEST_SHA256", + __import__("hashlib").sha256(manifest_bytes).hexdigest(), + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_SHA256", + {"4": __import__("hashlib").sha256(checkpoint_bytes).hexdigest()}, + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_FILE_IDS", + {"4": "checkpoint-file-id"}, + ) + + def fake_download(file_id, destination, expected_sha256, **kwargs): + downloads.append((file_id, destination)) + destination.parent.mkdir(parents=True, exist_ok=True) + if file_id == "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I": + destination.write_bytes(manifest_bytes) + elif file_id == "checkpoint-file-id": + destination.write_bytes(checkpoint_bytes) + else: + raise AssertionError(f"Unexpected file_id: {file_id}") + + monkeypatch.setattr( + "mri_recon.reconstruction.deep.download_google_drive_file_with_sha256", + fake_download, + ) + + resolved = OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( + 4, manifest_path=manifest_path + ) + + assert resolved == checkpoint_path.resolve() + assert manifest_path.read_bytes() == manifest_bytes + assert checkpoint_path.read_bytes() == checkpoint_bytes + assert downloads == [ + ("1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I", manifest_path.resolve()), + ("checkpoint-file-id", checkpoint_path.resolve()), + ] + + +def test_oasis_singlecoil_unet_reconstructor_uses_packaged_checkpoint_defaults( + device, tmp_path, monkeypatch +): + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "UNET_KWARGS", + { + "in_chans": 1, + "out_chans": 1, + "chans": 8, + "num_pool_layers": 2, + "drop_prob": 0.0, + }, + ) + + weights_path = tmp_path / "oasis_unet_packaged_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + captured = {} + + def fake_resolve(cls, acceleration, manifest_path=None): + captured["acceleration"] = acceleration + captured["manifest_path"] = manifest_path + return weights_path + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "resolve_default_checkpoint", + classmethod(fake_resolve), + ) + + model = OASISSinglecoilUnetReconstructor( + acceleration=8, + manifest_path=str(tmp_path / "manifest.json"), + device=device, + ) + + assert captured == { + "acceleration": 8, + "manifest_path": tmp_path / "manifest.json", + } + assert isinstance(model.model, Unet) diff --git a/tests/test_utils_io.py b/tests/test_utils_io.py index b984872..3d2c434 100644 --- a/tests/test_utils_io.py +++ b/tests/test_utils_io.py @@ -1,7 +1,7 @@ import hashlib from io import BytesIO -from mri_recon.utils.io import download_file_with_sha256 +from mri_recon.utils.io import download_file_with_sha256, download_google_drive_file_with_sha256 class FakeResponse: @@ -39,3 +39,30 @@ def test_download_file_with_sha256_moves_temp_file_after_close(tmp_path, monkeyp assert destination.read_bytes() == payload assert list(tmp_path.glob("*.tmp")) == [] + + +def test_download_google_drive_file_with_sha256_confirms_large_download(tmp_path, monkeypatch): + warning_html = b"""
Google Drive - Virus scan warning""" + payload = b"oasis-checkpoint" + expected_sha256 = hashlib.sha256(payload).hexdigest() + destination = tmp_path / "oasis.ckpt" + requested_urls = [] + + def fake_urlopen(url, timeout=30): + requested_urls.append(url) + if "confirm=t" in url: + return FakeResponse(payload) + return FakeResponse(warning_html) + + monkeypatch.setattr("mri_recon.utils.io.urlopen", fake_urlopen) + + download_google_drive_file_with_sha256( + "file-123", + destination, + expected_sha256, + label="OASIS checkpoint", + report_interval_mb=1, + ) + + assert destination.read_bytes() == payload + assert any("confirm=t" in url and "uuid=uuid-456" in url for url in requested_urls) diff --git a/uv.lock b/uv.lock index 7034d30..d25083a 100644 --- a/uv.lock +++ b/uv.lock @@ -947,6 +947,7 @@ dependencies = [ { name = "fastmri" }, { name = "h5py" }, { name = "matplotlib" }, + { name = "nibabel" }, { name = "numpy" }, { name = "ptwt" }, { name = "pydicom" }, @@ -976,6 +977,7 @@ requires-dist = [ { name = "fastmri", specifier = ">=0.3.0" }, { name = "h5py", specifier = ">=3.16.0" }, { name = "matplotlib", specifier = ">=3.9.0" }, + { name = "nibabel", specifier = ">=5.3.2" }, { name = "numpy", specifier = ">=2.4.3" }, { name = "ptwt", specifier = ">=1.0.1" }, { name = "pydicom", specifier = ">=3.0.1" }, @@ -1117,6 +1119,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] +[[package]] +name = "nibabel" +version = "5.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/01/3d2cc510c616bc8e27be17a063070d9126f69407961594a9ae734ea51121/nibabel-5.4.2.tar.gz", hash = "sha256:d5f4b9076a13178ae7f7acf18c8dbd503ee1c4d5c0c23b85df7be87efcbb49da", size = 4663132, upload-time = "2026-03-11T13:31:52.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/d7/601b6396b33536811668935faa790112266c70661be94555999be431f86f/nibabel-5.4.2-py3-none-any.whl", hash = "sha256:553482c5f1e1034fc312edf6fb7f32236c0056439845d1c29293b7e8c98d4854", size = 3300985, upload-time = "2026-03-11T13:31:50.028Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0" From 220c26e1705b2894c3a4fe6fec45369cbe4ed2d7 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Fri, 8 May 2026 14:24:25 +0200 Subject: [PATCH 12/17] added common download storage path for all models --- examples/OASIS_inference_plot.py | 4 +++- mri_recon/reconstruction/deep.py | 7 ++++--- tests/test_reconstructions.py | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index ef1424a..ac806ba 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -61,7 +61,9 @@ REPO_ROOT = Path(__file__).resolve().parents[1] REPORT_DIR = Path("reports") / "oasis_inference_plot" DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = REPO_ROOT / "reconstruction_only" / "checkpoints" / "manifest.json" +DEFAULT_MANIFEST_PATH = ( + REPO_ROOT / "downloads" / "oasis_singlecoil_unet" / "checkpoints" / "manifest.json" +) REPORT_DIR.mkdir(parents=True, exist_ok=True) diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 52823cb..cb286c3 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -114,6 +114,7 @@ class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): ) MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "fastmri_singlecoil_unet" UNET_KWARGS = { "in_chans": 1, "out_chans": 1, @@ -134,7 +135,7 @@ def __init__(self, device: torch.device = None, state_dict_file: str = None) -> state_dict_path = ( Path(state_dict_file).expanduser() if state_dict_file is not None - else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME + else self.MODEL_DIR / self.MODEL_FILENAME ) if state_dict_file is None: @@ -226,8 +227,8 @@ class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): "num_pool_layers": 4, "drop_prob": 0.0, } - ASSET_ROOT = Path(__file__).resolve().parents[2] / "reconstruction_only" - CHECKPOINTS_DIR = ASSET_ROOT / "checkpoints" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "oasis_singlecoil_unet" + CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index 9e55c41..36117ff 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -153,9 +153,9 @@ def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_path, monkeypatch): - asset_root = tmp_path / "reconstruction_only" - manifest_path = asset_root / "checkpoints" / "manifest.json" - checkpoint_path = asset_root / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" + model_dir = tmp_path / "downloads" / "oasis_singlecoil_unet" + manifest_path = model_dir / "checkpoints" / "manifest.json" + checkpoint_path = model_dir / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" manifest_bytes = ( b'{"checkpoints": {"4": {"filename": "checkpoints/oasis_balanced_seed24_accel4.ckpt"}}}' ) From b7d0063c418c3520eeb154120869f3299f118ff5 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 08:48:17 +0200 Subject: [PATCH 13/17] added detailed explanations related to resolution reduction --- mri_recon/distortions/resolution.py | 165 +++++++++++++++++++------ mri_recon/distortions/undersampling.py | 62 ++++++---- 2 files changed, 165 insertions(+), 62 deletions(-) diff --git a/mri_recon/distortions/resolution.py b/mri_recon/distortions/resolution.py index 1011abc..399a037 100644 --- a/mri_recon/distortions/resolution.py +++ b/mri_recon/distortions/resolution.py @@ -73,14 +73,33 @@ def _smooth_radial_low_pass_mask( class IsotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Low-pass truncation with a circular mask. - - This applies - ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]`` - where ``K`` is ``radius_fraction`` on the normalized radial grid. - - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are set to zero. + """Isotropic in-plane resolution reduction with a circular hard cutoff. + + This distortion keeps only a centered circular region of k-space: + ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]``, where ``K`` is the + retained radial support on the normalized frequency grid. + + In MRI terms, this models reduced in-plane spatial resolution at fixed + field of view by removing high-frequency content equally in all in-plane + directions. The reconstructed image keeps the same matrix size, but fine + detail is lost because the maximum sampled k-space extent is smaller. The + resulting effect is isotropic blur from limited k-space support. + + The parameter ``radius_fraction`` controls how much of the centered + low-frequency region is retained. Smaller values preserve only the k-space + core and therefore produce stronger blur and a broader point-spread + function. A value of ``1.0`` keeps the full sampled support and recovers + the identity operator. + + Compared with :class:`AnisotropicResolutionReduction`, this class applies + the same reduction in all directions rather than separately along readout + and phase encode. Compared with :class:`CartesianUndersampling`, it models + resolution loss by shrinking the sampled support, not by skipping lines + within the original support. + + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Smaller values correspond to + stronger isotropic resolution reduction. """ def __init__(self, radius_fraction: float = 0.6) -> None: @@ -96,17 +115,40 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class AnisotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Axis-aligned low-pass truncation with independent cutoffs. - - This applies a rectangular mask - ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]`` - where ``Kx`` and ``Ky`` are the normalized cutoffs along the readout and - phase-encode frequency axes respectively. - - :param float kx_radius_fraction: Normalized cutoff along the horizontal - frequency axis in ``(0, 1]``. - :param float ky_radius_fraction: Normalized cutoff along the vertical - frequency axis in ``(0, 1]``. + """Axis-aligned Cartesian resolution reduction with independent cutoffs. + + This distortion keeps only a centered rectangular region of Cartesian + k-space and zeros the remaining outer frequencies: + ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]``. + + In MRI terms, this models reduced in-plane acquisition resolution at fixed + field of view. The reconstructed image keeps the same matrix size, but its + effective spatial resolution decreases because less high-frequency k-space + support is retained. The resulting artifact is blur from limited k-space + extent, not aliasing from undersampling. + + The two parameters control the retained centered k-space support along the + Cartesian encoding axes. ``kx_radius_fraction`` corresponds to the retained + readout-direction frequency extent, while ``ky_radius_fraction`` + corresponds to the retained phase-encode-direction frequency extent. + Reducing either parameter broadens the point-spread function along the + corresponding image direction. A typical MRI-like setting keeps + ``kx_radius_fraction`` close to ``1.0`` and reduces + ``ky_radius_fraction``, reflecting that protocols often sacrifice more + phase-encode resolution than readout resolution. + + This is a hard rectangular cutoff. If a softer edge is desired, use one of + the taper-based resolution distortions instead. + + :param float kx_radius_fraction: Fraction of the original centered + k-space extent retained along the horizontal frequency axis in + ``(0, 1]``. This corresponds to the retained readout-direction + resolution support. ``1.0`` keeps the full sampled readout extent. + :param float ky_radius_fraction: Fraction of the original centered + k-space extent retained along the vertical frequency axis in + ``(0, 1]``. This corresponds to the retained phase-encode-direction + resolution support. Smaller values produce stronger blur along the + corresponding image direction. """ def __init__( @@ -126,8 +168,8 @@ def __init__( def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: normalized_kx, normalized_ky = _normalized_axis_frequencies(shape) - # A rectangular passband models direction-dependent resolution loss, - # such as stronger truncation along phase encode than along readout. + # Centered rectangular support models reduced acquired Cartesian + # resolution, often stronger along phase encode than along readout. mask = (normalized_kx <= self.kx_radius_fraction) & ( normalized_ky <= self.ky_radius_fraction ) @@ -135,18 +177,41 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class HannTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Hann transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a raised-cosine profile near the cutoff, and is exactly ``0`` - beyond ``radius_fraction``. + """Isotropic resolution reduction with a Hann-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a raised-cosine Hann + profile. Inside the passband the mask equals ``1``; inside the transition + band it decreases smoothly from ``1`` to ``0``; beyond the cutoff it is + exactly ``0``. + + In MRI terms, this is still a resolution-reduction operator: it suppresses + high spatial frequencies and therefore lowers effective in-plane spatial + resolution at fixed field of view. Relative to + :class:`IsotropicResolutionReduction`, the main difference is not the type + of resolution loss but the edge behavior of the k-space support. The smooth + transition reduces ringing that a hard truncation can introduce, at the + cost of making the support edge less sharp. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` sets how much of that outer region is devoted to + the smooth taper. Setting ``transition_fraction=0`` recovers the hard + circular cutoff. Larger transition fractions make the cutoff gentler and + behave more like k-space apodization. + + Compared with :class:`CartesianUndersampling`, this class does not skip + phase-encode lines within the original support. It reduces resolution by + attenuating and removing high frequencies, leading primarily to blur rather + than undersampling aliasing. See https://en.wikipedia.org/wiki/Hann_function for details on the Hann window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. + the smooth Hann transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. """ def __init__( @@ -173,20 +238,42 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class KaiserTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Kaiser transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a Kaiser-profile transition near the cutoff, and is exactly - ``0`` beyond ``radius_fraction``. + """Isotropic resolution reduction with a Kaiser-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a Kaiser-window profile. + Inside the passband the mask equals ``1``; inside the transition band it + decreases smoothly from ``1`` to ``0``; beyond the cutoff it is exactly + ``0``. + + In MRI terms, this lowers effective in-plane spatial resolution at fixed + field of view by reducing the retained high-frequency k-space extent. As + with :class:`HannTaperResolutionReduction`, the purpose of the taper is to + soften the hard support edge and reduce ringing, while still producing blur + from lost high-frequency information. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` controls the width of the taper band. ``beta`` + controls the shape of the Kaiser taper within that band: larger values + produce a steeper transition, while smaller positive values produce a more + gradual roll-off. Setting ``transition_fraction=0`` recovers the hard + circular cutoff regardless of ``beta``. + + Compared with :class:`CartesianUndersampling`, this class reduces + resolution by shrinking and tapering the effective k-space support, rather + than by skipping lines from the original grid. The dominant image effect is + blur and apodization, not aliasing from sub-Nyquist sampling. See https://en.wikipedia.org/wiki/Kaiser_window for details on the Kaiser window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. - :param float beta: Positive Kaiser shape parameter. Larger values create a - steeper transition inside the taper band. + the smooth Kaiser transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. + :param float beta: Positive Kaiser shape parameter controlling how steeply + the taper falls inside the transition band. """ def __init__( diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index afa6c39..e92316e 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -17,32 +17,48 @@ class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): - """Cartesian k-space undersampling along phase-encode direction. - - This distortion simulates true MRI acquisition undersampling by applying a - binary sampling mask along the phase-encode direction (default). The mask - keeps a contiguous center region (ACS - Auto-Calibration Signal) fully - sampled and randomly undersamples the periphery. - - The mask is deterministic given a shape and seed, ensuring reproducibility. - - :param float keep_fraction: Fraction of phase-encode lines to keep in - ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. - :param float center_fraction: Fraction of phase-encode lines reserved for - the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to - ``0.5 * keep_fraction``, leaving part of the acquisition budget for - randomized peripheral line sampling. Set to ``0`` to sample without a - guaranteed ACS block. + """Cartesian k-space undersampling along one encoding direction. + + This distortion simulates sub-Nyquist MRI acquisition by keeping only a + subset of Cartesian k-space lines along a chosen axis, phase encode by + default. A contiguous low-frequency center region (ACS) may be retained + fully, while the peripheral lines are sampled using a configurable random + or equispaced pattern. + + This is fundamentally different from resolution reduction. In + resolution-reduction distortions, the maximum retained k-space extent is + reduced, which primarily causes blur because high spatial frequencies are + absent. In Cartesian undersampling, the original k-space extent is still + targeted, but many lines inside that extent are skipped. The dominant + consequence is aliasing or incoherent undersampling artifact, not a simple + broader point-spread function. + + ``keep_fraction`` controls the total sampling budget along the chosen axis. + ``center_fraction`` reserves part of that budget for a fully sampled ACS + block near k-space center, which is important for many parallel imaging and + learned reconstruction pipelines. ``pattern`` controls how the remaining + peripheral lines are selected. ``axis`` determines which Cartesian encoding + direction is undersampled. ``seed`` makes the random patterns reproducible. + + The mask is deterministic for a given shape, device, and seed. + + :param float keep_fraction: Fraction of lines kept along the undersampled + axis in ``(0, 1]``. For example, ``0.25`` keeps 25% of the lines and + corresponds to approximately 4x acceleration when applied to a single + encoding direction. + :param float center_fraction: Fraction of lines along the undersampled axis + reserved for a contiguous, fully sampled ACS region in ``[0, 1]``. + Defaults to ``0.5 * keep_fraction`` so that part of the sampling budget + remains available for peripheral sampling. ``0`` disables the ACS block. :param str pattern: Peripheral sampling pattern. Supported values are ``"uniform_random"``, ``"variable_density_random"``, and - ``"equispaced"``. Defaults to ``"variable_density_random"``. - :param int axis: Axis along which to apply undersampling. Default is -2 - (phase-encode for 4D tensors), can also be -1 for readout/column - masking or -3 for the slice/depth axis in 5D tensors. + ``"equispaced"``. Variable-density sampling favors low-frequency lines + near the ACS boundary; equispaced sampling is deterministic. + :param int axis: Axis along which to undersample. The default ``-2`` + corresponds to phase encode for the repository's standard 2D k-space + convention. Other values allow undersampling along readout or depth. :param int | None seed: Random seed for reproducible mask generation. - If None, uses unseeded randomness (not recommended for reproducibility). - Has no effect when ``pattern="equispaced"`` because that pattern is - fully deterministic and uses no randomness. + Ignored by the deterministic ``"equispaced"`` pattern. """ def __init__( From 75d29ef6f54ca2fc1e5a9cf27d01be801381540c Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 09:01:43 +0200 Subject: [PATCH 14/17] added detailed explanations related to bias field --- mri_recon/distortions/biasfield.py | 36 +++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/mri_recon/distortions/biasfield.py b/mri_recon/distortions/biasfield.py index 26ca4f5..a9d7910 100644 --- a/mri_recon/distortions/biasfield.py +++ b/mri_recon/distortions/biasfield.py @@ -80,11 +80,37 @@ class OffCenterAnisotropicGaussianKspaceBiasField(BaseDistortion): The gain peaks at an offset location in k-space, decays with different widths along ``kx`` and ``ky``, and is normalized to unit maximum on the sampled grid. - :param float width_x_fraction: Gaussian width along the normalized ``kx`` direction. - :param float width_y_fraction: Gaussian width along the normalized ``ky`` direction. - :param float center_x_fraction: Center offset along normalized ``kx`` in ``[-1, 1]``. - :param float center_y_fraction: Center offset along normalized ``ky`` in ``[-1, 1]``. - :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie in ``(0, 1]``. + Note: This class can also approximate a readout-decay-like blur when used in a + centered anisotropic configuration. To mimic stronger attenuation along the + readout direction, keep the Gaussian centered at DC with + ``center_x_fraction=0.0`` and ``center_y_fraction=0.0``, choose a narrower + width along the readout axis than along the orthogonal axis, and reduce + ``edge_gain`` below ``1.0``. For example, a setting such as + ``width_x_fraction < width_y_fraction`` with a moderate ``edge_gain`` + produces a smooth directional loss of high-frequency content that can + resemble readout-decay blur. + + This remains a phenomenological approximation rather than an explicit + time-ordered readout-decay model. It applies a smooth multiplicative + k-space weighting, not a physically parameterized echo-train decay. + + :param float width_x_fraction: Gaussian width along the normalized ``kx`` + direction. Smaller values produce stronger attenuation away from the + center along ``kx``. When ``kx`` is the readout axis, choosing + ``width_x_fraction < width_y_fraction`` approximates readout-direction + blur. + :param float width_y_fraction: Gaussian width along the normalized ``ky`` + direction. Larger values preserve more support along ``ky`` relative + to ``kx``. + :param float center_x_fraction: Center offset along normalized ``kx`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float center_y_fraction: Center offset along normalized ``ky`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie + in ``(0, 1]``. Smaller values strengthen the peripheral attenuation + and therefore the resulting directional blur. """ def __init__( From 1a458d2066e7964d48a6fc763b9cfdf85432b7b4 Mon Sep 17 00:00:00 2001 From: Matthias Lenga Date: Thu, 14 May 2026 13:25:36 +0200 Subject: [PATCH 15/17] partial fourier --- examples/OASIS_inference_plot.py | 9 ++ examples/fastmri_inference_plot.py | 8 ++ mri_recon/distortions/__init__.py | 2 +- mri_recon/distortions/undersampling.py | 124 +++++++++++++++++++++++++ tests/test_distortions.py | 56 +++++++++++ 5 files changed, 198 insertions(+), 1 deletion(-) diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py index ac806ba..512dc33 100644 --- a/examples/OASIS_inference_plot.py +++ b/examples/OASIS_inference_plot.py @@ -40,6 +40,7 @@ IsotropicResolutionReduction, KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, PhaseEncodeGhostingDistortion, RadialHighPassEmphasisDistortion, RotationalMotionDistortion, @@ -86,6 +87,7 @@ "Cartesian undersampling (uniform random, zero ACS)", "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", + "Partial Fourier", "Phase-encode ghosting", "Segmented translation motion", "Translation motion", @@ -458,6 +460,13 @@ def choose_distortion( axis=-1, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-1, + side="high", + ) case "Phase-encode ghosting": return PhaseEncodeGhostingDistortion( line_period=2, diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index e8f23ef..fe52c07 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -37,6 +37,7 @@ "Cartesian undersampling (uniform random, zero ACS)", "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", + "Partial Fourier", "Phase-encode ghosting", "Segmented translation motion", "Segmented rotational motion", @@ -181,6 +182,13 @@ def choose_distortion(name: str) -> BaseDistortion: pattern="equispaced", seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) case "Anisotropic LP": return AnisotropicResolutionReduction( kx_radius_fraction=1.0, diff --git a/mri_recon/distortions/__init__.py b/mri_recon/distortions/__init__.py index 80bbad4..af22fd1 100644 --- a/mri_recon/distortions/__init__.py +++ b/mri_recon/distortions/__init__.py @@ -19,4 +19,4 @@ KaiserTaperResolutionReduction, RadialHighPassEmphasisDistortion, ) -from .undersampling import CartesianUndersampling +from .undersampling import CartesianUndersampling, PartialFourierDistortion diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index e92316e..6c0be94 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -8,6 +8,7 @@ PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} +PARTIAL_FOURIER_SIDES = {"low", "high"} # Lorentzian offset that controls how steeply the variable-density weights # fall off with normalized distance from k-space center. @@ -289,3 +290,126 @@ def _expand_mask_to_shape( mask = mask_1d.reshape(expand_shape) return mask + + +class PartialFourierDistortion(SelfAdjointMultiplicativeMaskDistortion): + """Asymmetric contiguous Cartesian mask for partial Fourier acquisition. + + This distortion simulates partial Fourier MRI acquisition by keeping a + contiguous asymmetric region of k-space along one encoding axis while + preserving a centered low-frequency block. Unlike symmetric resolution + reduction, the retained support extends farther on one side of k-space than + the other. Unlike Cartesian undersampling, the retained region is + contiguous rather than sparse throughout the original support. + + The distortion models the acquired k-space mask only. It does not attempt + to reconstruct or infer the missing region with homodyne, POCS, or any + other partial-Fourier-specific reconstruction method. + + :param float partial_fraction: Fraction of lines retained along the chosen + axis in ``[0.5, 1]``. ``1.0`` recovers the identity operator. + :param float center_fraction: Fraction of lines reserved for a centered, + fully retained low-frequency block in ``[0, 1]``. This block must not + exceed ``partial_fraction``. + :param int axis: Axis along which to apply the asymmetric truncation. The + default ``-2`` matches the repository's standard phase-encode axis. + :param str side: Side that retains more support outside the centered block. + Supported values are ``"low"`` and ``"high"``. + """ + + def __init__( + self, + partial_fraction: float = 0.7, + center_fraction: float = 0.1, + axis: int = -2, + side: str = "high", + ) -> None: + super().__init__() + + if not 0.5 <= partial_fraction <= 1.0: + raise ValueError(f"partial_fraction must be in [0.5, 1], got {partial_fraction}") + if not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + if center_fraction > partial_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"partial_fraction ({partial_fraction})" + ) + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + if side not in PARTIAL_FOURIER_SIDES: + raise ValueError(f"side must be one of {sorted(PARTIAL_FOURIER_SIDES)}, got {side!r}") + + self.partial_fraction = partial_fraction + self.center_fraction = center_fraction + self.axis = axis + self.side = side + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a deterministic partial Fourier mask. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to ``shape``. + :rtype: torch.Tensor + """ + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + axis_size = shape[self.axis] + mask_1d = self._generate_1d_mask(axis_size) + mask = self._expand_mask_to_shape(mask_1d, shape).to(device) + + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D contiguous asymmetric partial Fourier mask.""" + num_keep = max(1, int(round(axis_size * self.partial_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + num_center = min(num_center, num_keep) + + mask = torch.zeros(axis_size, dtype=torch.float32) + + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + remaining = num_keep - num_center + + low_available = center_start + high_available = axis_size - center_end + + if self.side == "high": + extra_high = min(remaining, high_available) + extra_low = remaining - extra_high + else: + extra_low = min(remaining, low_available) + extra_high = remaining - extra_low + + start = center_start - extra_low + end = center_end + extra_high + + mask[start:end] = 1.0 + return mask + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape.""" + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + return mask_1d.reshape(expand_shape) diff --git a/tests/test_distortions.py b/tests/test_distortions.py index 563f4e8..7dd4690 100644 --- a/tests/test_distortions.py +++ b/tests/test_distortions.py @@ -18,6 +18,7 @@ IsotropicResolutionReduction, KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, PhaseEncodeGhostingDistortion, RadialHighPassEmphasisDistortion, RotationalMotionDistortion, @@ -34,6 +35,7 @@ "Hann taper LP", "Kaiser taper LP", "Cartesian undersampling", + "Partial Fourier", "Radial high-pass emphasis", "Gaussian bias field", "Off-center anisotropic Gaussian bias field", @@ -49,6 +51,7 @@ "Isotropic LP", "Anisotropic LP", "Cartesian undersampling", + "Partial Fourier", "Phase-encode ghosting", "Translation motion", "Segmented translation motion", @@ -91,6 +94,13 @@ def choose_distortion(name): center_fraction=0.2, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) case "Radial high-pass emphasis": return RadialHighPassEmphasisDistortion(alpha=0.4) case "Gaussian bias field": @@ -229,6 +239,52 @@ def test_anisotropic_resolution_reduction_identity_at_full_cutoffs(device): assert torch.equal(y_distorted, y) +def test_partial_fourier_distortion_identity_at_full_fraction(device): + distortion = PartialFourierDistortion( + partial_fraction=1.0, + center_fraction=0.1, + axis=-2, + side="high", + ) + y = torch.randn((1, 2, 32, 32), device=device) + + assert torch.equal(distortion.A(y), y) + + +def test_partial_fourier_distortion_retains_contiguous_asymmetric_region(device): + distortion = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + + retained_indices = distortion._generate_1d_mask(16).nonzero().flatten().tolist() + + assert retained_indices == list(range(5, 16)) + + +def test_partial_fourier_distortion_side_changes_retained_half(device): + high = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + low = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="low", + ) + + high_indices = high._generate_1d_mask(16).nonzero().flatten().tolist() + low_indices = low._generate_1d_mask(16).nonzero().flatten().tolist() + + assert high_indices == list(range(5, 16)) + assert low_indices == list(range(0, 11)) + + def test_hann_taper_resolution_reduction_has_smooth_transition(device): distortion = HannTaperResolutionReduction( radius_fraction=0.8, From 9fb0527bb4bd35696a7f35a184b2eff699d04ac6 Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Thu, 14 May 2026 19:52:10 +0100 Subject: [PATCH 16/17] Add OASIS support to FastMRI reconstruction example --- README.md | 14 +- examples/OASIS_inference_plot.py | 731 ---------------------------- examples/fastmri_inference_plot.py | 147 ++++-- mri_recon/reconstruction/deep.py | 754 +++++++++++++++-------------- mri_recon/utils/__init__.py | 8 + mri_recon/utils/oasis_adapter.py | 212 ++++++++ 6 files changed, 735 insertions(+), 1131 deletions(-) delete mode 100644 examples/OASIS_inference_plot.py create mode 100644 mri_recon/utils/oasis_adapter.py diff --git a/README.md b/README.md index 1590e9a..aa7b16e 100644 --- a/README.md +++ b/README.md @@ -57,15 +57,21 @@ uv sync uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" ``` -## OASIS Inference Example +## Inference Examples -Run the OASIS plotting example with a local OASIS root folder. By default, it uses the packaged split and checkpoint manifest under `reconstruction_only/`. +Run the FastMRI plotting example with local FastMRI k-space files: ```bash -python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data --acceleration 4 +python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet ``` -Pass `--checkpoint /path/to/checkpoint.ckpt` to use a different trained OASIS U-Net checkpoint. +Run the same lightweight example on OASIS data. The packaged OASIS split CSV and U-Net checkpoint are downloaded automatically when missing: + +```bash +python examples/fastmri_inference_plot.py --source /path/to/oasis_cross_sectional_data --dataset oasis --algorithm unet +``` + +For OASIS, `--oasis_checkpoint_acceleration` only selects the packaged U-Net weights by their training acceleration. Distortion undersampling is still controlled by `--keep_fraction` and `--center_fraction`. ## Pre-commit diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py deleted file mode 100644 index 512dc33..0000000 --- a/examples/OASIS_inference_plot.py +++ /dev/null @@ -1,731 +0,0 @@ -"""Inference OASIS reconstructors for k-space distortion operators. - -Usage: - python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data -""" - -from __future__ import annotations - -import argparse -import os -import sys -from collections import OrderedDict -from pathlib import Path -from typing import Optional - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import matplotlib.pyplot as plt -import numpy as np -import torch -import deepinv as dinv -from torch.utils.data import DataLoader, Dataset - -try: - import nibabel as nib -except ImportError as exc: - raise ImportError( - "The OASIS example requires nibabel. Install the project dependencies " - "or add nibabel to your environment before running this script." - ) from exc - -from mri_recon.distortions import ( - AnisotropicResolutionReduction, - BaseDistortion, - CartesianUndersampling, - DistortedKspaceMultiCoilMRI, - GaussianKspaceBiasField, - GaussianNoiseDistortion, - HannTaperResolutionReduction, - IsotropicResolutionReduction, - KaiserTaperResolutionReduction, - OffCenterAnisotropicGaussianKspaceBiasField, - PartialFourierDistortion, - PhaseEncodeGhostingDistortion, - RadialHighPassEmphasisDistortion, - RotationalMotionDistortion, - SegmentedTranslationMotionDistortion, - TranslationMotionDistortion, -) -from mri_recon.reconstruction import ( - ConjugateGradientReconstructor, - DeepImagePriorReconstructor, - OASISSinglecoilUnetReconstructor, - RAMReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - ZeroFilledReconstructor, -) - -REPO_ROOT = Path(__file__).resolve().parents[1] -REPORT_DIR = Path("reports") / "oasis_inference_plot" -DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = ( - REPO_ROOT / "downloads" / "oasis_singlecoil_unet" / "checkpoints" / "manifest.json" -) - -REPORT_DIR.mkdir(parents=True, exist_ok=True) - -ALGORITHMS = [ - # "zero-filled", - # "conjugate-gradient", - # "ram", - # "dip", - "tv-pgd", - # "wavelet-fista", - # "tv-fista", - # "tv-pdhg", - "oasis-unet", - "unet", -] -DISTORTIONS = [ - "None", - "Cartesian undersampling (variable density)", - "Cartesian undersampling (uniform random)", - "Cartesian undersampling (uniform random, zero ACS)", - "Cartesian undersampling (equispaced)", - "Cartesian undersampling (equispaced, zero ACS)", - "Partial Fourier", - "Phase-encode ghosting", - "Segmented translation motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Radial high-pass emphasis", - "Gaussian noise", - "Isotropic LP", -] -METRICS = [ - "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", -] - - -class OasisSliceDataset(Dataset): - """Load 2D OASIS slices from Analyze/NIfTI volumes. - - Parameters - ---------- - split_csv : Path - CSV file listing OASIS subjects and slice counts. - data_path : Path - Root directory containing OASIS subject folders. - sample_rate : float, optional - Fraction of slices to include from each volume. - cache_size : int, optional - Number of loaded volumes to keep in memory. - """ - - def __init__( - self, - split_csv: Path, - data_path: Path, - sample_rate: float = 1.0, - cache_size: int = 2, - ) -> None: - self.split_csv = Path(split_csv) - self.data_path = Path(data_path) - if not 0 < sample_rate <= 1.0: - raise ValueError("sample_rate must be in the range (0, 1].") - self.sample_rate = sample_rate - self.cache_size = max(0, cache_size) - self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() - self.raw_samples = self._create_sample_list() - - def __len__(self) -> int: - """Return the number of available slices.""" - - return len(self.raw_samples) - - def __getitem__(self, index: int) -> dict[str, object]: - """Return one complex-valued OASIS slice in repo tensor convention.""" - - subject_id, slice_num = self.raw_samples[index] - target_np = self._read_raw_slice(subject_id, slice_num) - real = torch.from_numpy(target_np) - x = torch.stack([real, torch.zeros_like(real)], dim=0) - return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} - - def _create_sample_list(self) -> list[tuple[str, int]]: - samples: list[tuple[str, int]] = [] - with self.split_csv.open("r", encoding="utf-8") as handle: - for line in handle: - row = [item.strip() for item in line.split(",")] - if not row or not row[0]: - continue - try: - total_slices = int(row[-1]) - except ValueError: - continue - - subject_id = row[0] - if self.sample_rate >= 1.0: - start = 0 - stop = total_slices - else: - mid = round(total_slices / 2) - half_span = round(total_slices * self.sample_rate / 2) - start = max(0, mid - half_span) - stop = min(total_slices, mid + half_span) - - for slice_num in range(start, stop): - samples.append((subject_id, slice_num)) - return samples - - def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: - volume = self._get_volume(subject_id) - return np.ascontiguousarray(volume[slice_num], dtype=np.float32) - - def _get_volume(self, subject_id: str) -> np.ndarray: - if self.cache_size > 0 and subject_id in self._volume_cache: - self._volume_cache.move_to_end(subject_id) - return self._volume_cache[subject_id] - - image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" - matches = sorted(image_glob.glob("*t88_gfc.img")) - if not matches: - raise FileNotFoundError( - f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." - ) - - image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) - volume = np.ascontiguousarray( - np.transpose(np.squeeze(image_data), (1, 0, 2)), - dtype=np.float32, - ) - - if self.cache_size > 0: - self._volume_cache[subject_id] = volume - if len(self._volume_cache) > self.cache_size: - self._volume_cache.popitem(last=False) - - return volume - - -def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to log-magnitude image for visualization.""" - - kspace = kspace.detach().cpu() - if kspace.ndim == 5: - kspace = kspace[0] - if kspace.ndim == 4: - if kspace.shape[0] == 1 and kspace.shape[1] == 2: - kspace = kspace[0] - elif kspace.shape[0] != 2: - raise ValueError( - "Expected k-space with shape (2, H, W), (1, 2, H, W), " - f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" - ) - if kspace.ndim != 3 and kspace.ndim != 4: - raise ValueError( - "Expected k-space with shape (2, H, W), (1, 2, H, W), " - f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" - ) - if kspace.shape[0] != 2: - raise ValueError(f"Expected real/imaginary channel first, got {tuple(kspace.shape)}") - - kspace_complex = torch.view_as_complex(torch.movedim(kspace, 0, -1).contiguous()) - magnitude = torch.abs(kspace_complex) - if magnitude.ndim == 3: - magnitude = torch.sqrt(torch.sum(magnitude.square(), dim=0)) - magnitude = torch.log1p(magnitude) - - lower = torch.quantile(magnitude, 0.05) - upper = torch.quantile(magnitude, 0.995) - if float(upper) > float(lower): - magnitude = magnitude.clamp(lower, upper) - magnitude = (magnitude - lower) / (upper - lower) - else: - mag_max = float(magnitude.max()) - if mag_max > 0.0: - magnitude = magnitude / mag_max - - return torch.sqrt(magnitude) - - -def save_kspace_plot( - y_clean: torch.Tensor, - y_distorted: torch.Tensor, - save_fn: Path, - distortion_name: str, -) -> None: - """Save clean and distorted k-space magnitude plots.""" - - images = [ - ("Original k-space", _kspace_to_log_magnitude(y_clean)), - ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), - ] - - fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) - fig.suptitle(f"Distortion: {distortion_name}") - for ax, (title, image) in zip(axes, images, strict=True): - ax.imshow(image.numpy(), cmap="magma") - ax.set_title(title) - ax.axis("off") - fig.savefig(save_fn, dpi=200, bbox_inches="tight") - plt.close(fig) - - -def image_to_kspace(x: torch.Tensor) -> torch.Tensor: - """Convert channel-first complex images to centered k-space.""" - - x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) - y_complex = torch.fft.fftshift( - torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), - dim=(-2, -1), - ) - return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() - - -def kspace_to_image(y: torch.Tensor) -> torch.Tensor: - """Convert centered channel-first k-space to complex images.""" - - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - x_complex = torch.fft.ifft2( - torch.fft.ifftshift(y_complex, dim=(-2, -1)), - dim=(-2, -1), - norm="ortho", - ) - return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() - - -class OasisCenteredFFTPhysics: - """Physics adapter matching the OASIS U-Net FFT convention. - - Parameters - ---------- - distortion : BaseDistortion - K-space distortion applied after the centered FFT. - """ - - def __init__(self, distortion: BaseDistortion) -> None: - self.distortion = distortion - - def A(self, x: torch.Tensor) -> torch.Tensor: - """Apply centered FFT and k-space distortion. - - Parameters - ---------- - x : torch.Tensor - Complex image tensor with shape ``(B, 2, H, W)``. - - Returns - ------- - torch.Tensor - Distorted centered k-space tensor. - """ - - return self.distortion.A(image_to_kspace(x)) - - def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: - """Apply adjoint distortion and centered inverse FFT. - - Parameters - ---------- - y : torch.Tensor - Distorted centered k-space tensor. - - Returns - ------- - torch.Tensor - Complex image tensor with shape ``(B, 2, H, W)``. - """ - - return kspace_to_image(self.distortion.A_adjoint(y)) - - -def resolve_oasis_checkpoint( - checkpoint: Optional[Path], - acceleration: int, - manifest_path: Path, -) -> Path: - """Resolve an explicit or packaged OASIS checkpoint path. - - Parameters - ---------- - checkpoint : Path or None - User-provided checkpoint path. - acceleration : int - Acceleration key used when loading from the manifest. - manifest_path : Path - JSON manifest with packaged checkpoint metadata. - - Returns - ------- - Path - Resolved checkpoint path. - """ - - if checkpoint is not None: - return checkpoint.expanduser().resolve() - - return OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( - acceleration=acceleration, - manifest_path=manifest_path, - ) - - -def choose_algorithm( - name: str, - checkpoint_file: Path, - img_size: tuple = (640, 368), - device: torch.device = "cpu", - verbose: bool = False, -) -> dinv.models.Reconstructor: - """Construct a reconstructor by selector name.""" - - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=verbose) - case "tv-fista": - return TVFISTAReconstructor(n_iter=200, verbose=verbose) - case "tv-pdhg": - return TVPDHGReconstructor(n_iter=100, verbose=verbose) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) - case "oasis-unet" | "unet": - return OASISSinglecoilUnetReconstructor( - checkpoint_file=str(checkpoint_file), - device=device, - ) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - -def choose_distortion( - name: str, - acceleration: int, - center_fraction: float, -) -> BaseDistortion: - """Construct a k-space distortion by display name.""" - - match name: - case "None": - return BaseDistortion() - case "Cartesian undersampling (variable density)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="variable_density_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (uniform random)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="uniform_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (uniform random, zero ACS)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=0.0, - pattern="uniform_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (equispaced)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="equispaced", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (equispaced, zero ACS)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=0.0, - pattern="equispaced", - axis=-1, - seed=42, - ) - case "Partial Fourier": - return PartialFourierDistortion( - partial_fraction=0.7, - center_fraction=0.1, - axis=-1, - side="high", - ) - case "Phase-encode ghosting": - return PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - case "Anisotropic LP": - return AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.25, - ) - case "Hann taper LP": - return HannTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - ) - case "Kaiser taper LP": - return KaiserTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - beta=8.6, - ) - case "Radial high-pass emphasis": - return RadialHighPassEmphasisDistortion(alpha=0.4) - case "Isotropic LP": - return IsotropicResolutionReduction(radius_fraction=0.1) - case "Off-center anisotropic Gaussian bias field": - return OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.2, - width_y_fraction=0.35, - center_x_fraction=0.15, - center_y_fraction=-0.1, - edge_gain=0.3, - ) - case "Translation motion": - return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) - case "Rotational motion": - return RotationalMotionDistortion(angle_radians=torch.pi / 6) - case "Segmented translation motion": - return SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 20.0, 50.0, -50.0), - shift_y_pixels=(0.0, 10.0, -20.0, 20.0), - ) - case "Gaussian bias field": - return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - case "Gaussian noise": - return GaussianNoiseDistortion(sigma=0.00001) - case _: - raise ValueError(f"Unknown distortion {name!r}") - - -def choose_metric(name: str) -> dinv.metric.Metric: - """Construct a DeepInverse metric by selector name.""" - - match name: - case "PSNR": - return dinv.metric.PSNR(max_pixel=None, complex_abs=True) - case "NMSE": - return dinv.metric.NMSE(complex_abs=True) - case "SSIM": - return dinv.metric.SSIM(max_pixel=None, complex_abs=True) - case "HaarPSI": - return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) - case "BlurStrength": - return dinv.metric.BlurStrength(complex_abs=True) - case "SharpnessIndex": - return dinv.metric.SharpnessIndex(complex_abs=True) - case _: - raise ValueError(f"Unknown metric {name!r}") - - -def build_parser() -> argparse.ArgumentParser: - """Build the command-line parser.""" - - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--source", - type=Path, - required=True, - help="OASIS root directory containing subject folders.", - ) - parser.add_argument( - "--split_csv", - type=Path, - default=DEFAULT_SPLIT_CSV, - help="CSV listing OASIS subjects and slice counts.", - ) - parser.add_argument( - "--manifest", - type=Path, - default=DEFAULT_MANIFEST_PATH, - help="Checkpoint manifest JSON.", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", - ) - parser.add_argument( - "--acceleration", - type=int, - default=4, - help="Packaged OASIS checkpoint acceleration factor.", - ) - parser.add_argument( - "--center_fraction", - type=float, - default=0.08, - help="Center fraction used by the Cartesian undersampling distortion.", - ) - parser.add_argument( - "--distortion", - type=str, - default="Cartesian undersampling (uniform random)", - choices=DISTORTIONS, - ) - parser.add_argument( - "--algorithm", - type=str, - default="unet", - choices=ALGORITHMS, - help="Reconstruction algorithm applied to distorted OASIS k-space.", - ) - parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") - parser.add_argument( - "--sample_rate", - type=float, - default=0.6, - help="Fraction of slices per volume to include from the split CSV.", - ) - parser.add_argument("--volume_cache_size", type=int, default=2) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose output for reconstructors that support it.", - ) - return parser - - -def main() -> None: - """Run OASIS inference plots.""" - - args = build_parser().parse_args() - args.source = args.source.expanduser().resolve() - args.split_csv = args.split_csv.expanduser().resolve() - args.manifest = args.manifest.expanduser().resolve() - checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) - - device = dinv.utils.get_device() - dataset = OasisSliceDataset( - split_csv=args.split_csv, - data_path=args.source, - sample_rate=args.sample_rate, - cache_size=args.volume_cache_size, - ) - dataloader = DataLoader(dataset, batch_size=1, shuffle=True) - metrics = [choose_metric(m) for m in METRICS] - - for i, batch in enumerate(iter(dataloader)): - if i >= args.num_samples: - break - - x = batch["x"].to(device) - subject_id = batch["subject_id"][0] - slice_num = int(batch["slice_num"][0]) - - for distortion_name in [args.distortion]: - distortion = choose_distortion( - distortion_name, - acceleration=args.acceleration, - center_fraction=args.center_fraction, - ) - - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), - img_size=(1, 2, *x.shape[-2:]), - device=device, - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, - img_size=(1, 2, *x.shape[-2:]), - device=device, - ) - oasis_physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) - oasis_physics = OasisCenteredFFTPhysics(distortion) - - y = image_to_kspace(x) - y_distorted = distortion.A(y) - y_physics_distorted = physics(x) - x_distorted = kspace_to_image(y_distorted) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: - print( - f"Evaluating algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}..." - ) - - algo = choose_algorithm( - algo_name, - checkpoint_file=checkpoint_file, - img_size=x.shape[-2:], - device=device, - verbose=args.verbose, - ).to(device) - - if algo_name in {"oasis-unet", "unet"}: - y_eval = y_distorted - eval_physics_clean = oasis_physics_clean - eval_physics = oasis_physics - else: - y_eval = y_physics_distorted - eval_physics_clean = physics_clean - eval_physics = physics - - x_uncorrected = algo(y_eval, eval_physics_clean) - x_corrected = algo(y_eval, eval_physics) - uncorrected_scores = [ - f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" for m in metrics - ] - corrected_scores = [ - f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" for m in metrics - ] - print(f" uncorrected: {', '.join(uncorrected_scores)}") - print(f" corrected: {', '.join(corrected_scores)}") - - dinv.utils.plot( - { - "Ground truth OASIS slice": x, - "Distorted ksp, zero-filled": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join(uncorrected_scores), - "\n".join(corrected_scores), - ], - show=False, - close=True, - suptitle=( - f"Algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}" - ), - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, - ) - - print("done!") - - -if __name__ == "__main__": - main() diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index fe52c07..5a28a9d 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -17,9 +17,17 @@ from mri_recon.distortions import * from mri_recon.reconstruction import * - -REPORT_DIR = Path("reports") / "fastmri_inference_plot" -REPORT_DIR.mkdir(parents=True, exist_ok=True) +from mri_recon.utils import ( + OasisCenteredFFTPhysics, + OasisSliceDataset, + image_to_kspace, + kspace_to_image, +) + +FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" +OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" +FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) +OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ # "zero-filled", # "conjugate-gradient", @@ -29,7 +37,7 @@ # "wavelet-fista", # "tv-fista", # "tv-pdhg", - "unet", # will trigger download of pretrained weights if not already present + "unet", # dataset-specific U-Net; downloads weights if not already present ] DISTORTIONS = [ "Cartesian undersampling (variable density)", @@ -114,6 +122,8 @@ def choose_algorithm( img_size: tuple = (640, 368), device: torch.device = "cpu", verbose: bool = False, + dataset: str = "fastmri", + oasis_checkpoint_acceleration: int = 4, ) -> dinv.models.Reconstructor: match name: case "zero-filled": @@ -133,12 +143,22 @@ def choose_algorithm( case "wavelet-fista": return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) case "unet": + if dataset == "oasis": + return OASISSinglecoilUnetReconstructor( + acceleration=oasis_checkpoint_acceleration, + device=device, + ) return FastMRISinglecoilUnetReconstructor(device=device) case _: raise ValueError(f"Unknown algorithm {name!r}") -def choose_distortion(name: str) -> BaseDistortion: +def choose_distortion( + name: str, + keep_fraction: float = 0.25, + center_fraction: float = 0.125, + cartesian_axis: int = -2, +) -> BaseDistortion: match name: case "Phase-encode ghosting": return PhaseEncodeGhostingDistortion( @@ -149,44 +169,49 @@ def choose_distortion(name: str) -> BaseDistortion: ) case "Cartesian undersampling (variable density)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="variable_density_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (uniform random)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="uniform_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (uniform random, zero ACS)": return CartesianUndersampling( - keep_fraction=0.5, + keep_fraction=keep_fraction, center_fraction=0.0, pattern="uniform_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (equispaced)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="equispaced", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (equispaced, zero ACS)": return CartesianUndersampling( - keep_fraction=0.5, + keep_fraction=keep_fraction, center_fraction=0.0, pattern="equispaced", + axis=cartesian_axis, seed=42, ) case "Partial Fourier": return PartialFourierDistortion( partial_fraction=0.7, - center_fraction=0.1, - axis=-2, + center_fraction=center_fraction, + axis=cartesian_axis, side="high", ) case "Anisotropic LP": @@ -256,10 +281,30 @@ def choose_metric(name: str) -> dinv.metric.Metric: if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) + + # data related arguments parser.add_argument( - "--source", type=str, help="Local FastMRI directory with raw k-space .h5 files." + "--source", + type=Path, + help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) + parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri") + parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) + parser.add_argument( + "--keep_fraction", + type=float, + default=0.25, + help="Fraction of k-space lines to keep for undersampling distortions.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.125, + help="Fraction of low-frequency k-space lines to keep fully for undersampling distortions.", + ) + + # algo related arguments parser.add_argument( "--algorithm", type=str, @@ -267,6 +312,18 @@ def choose_metric(name: str) -> dinv.metric.Metric: choices=ALGORITHMS, help="Reconstruction algorithm applied to undistorted and distorted k-space.", ) + parser.add_argument( + "--oasis_checkpoint_acceleration", + type=int, + default=4, + choices=[4, 8, 10], + help=( + "Training acceleration of the packaged OASIS U-Net checkpoint. " + "This only selects pretrained weights; distortion undersampling is " + "controlled by --keep_fraction and --center_fraction." + ), + ) + # inference related arguments parser.add_argument("--num_samples", type=int, default=1, help="How many samples to process.") parser.add_argument( "--verbose", @@ -275,9 +332,20 @@ def choose_metric(name: str) -> dinv.metric.Metric: ) args = parser.parse_args() + # set up report dir + REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR + # set up device, dataset, metrics device = dinv.utils.get_device() - dataset = dinv.datasets.FastMRISliceDataset(args.source, slice_index="middle") + if args.dataset == "oasis": + split_csv = OASISSinglecoilUnetReconstructor.resolve_default_split_csv() + dataset = OasisSliceDataset( + data_path=args.source, + split_csv=split_csv, + sample_rate=0.6, + ) + else: + dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") metrics = [choose_metric(m) for m in METRICS] for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): @@ -285,30 +353,45 @@ def choose_metric(name: str) -> dinv.metric.Metric: if i >= args.num_samples: break - # batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan), - # y is kspace, and params is a dict containing mask (if test set) - y = batch[1] + if args.dataset == "oasis": + x = batch["x"].to(device) + y = image_to_kspace(x) + else: + # batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan), + # y is kspace, and params is a dict containing mask (if test set) + y = batch[1].to(device) for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: - distortion = choose_distortion(distortion_name) + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if args.dataset == "oasis" else -2, + ) # create physics objects for both clean and distorted k-space # the distortion is applied to the k-space measurements (not the image) # TODO: allow loading multicoil data - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device - ) - - y = y.to(device) + if args.dataset == "oasis": + physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) + physics = OasisCenteredFFTPhysics(distortion) + else: + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device + ) y_distorted = distortion.A(y) # generate reference reconstructions (CG) for both clean and distorted k-space # without correction for the distortion, i.e. using physics_clean in both cases - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + if args.dataset == "oasis": + x_clean = x + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) # plot and save the k-space magnitude for both clean and distorted k-space save_kspace_plot( @@ -327,6 +410,8 @@ def choose_metric(name: str) -> dinv.metric.Metric: img_size=y.shape[-2:], device=device, verbose=args.verbose, + dataset=args.dataset, + oasis_checkpoint_acceleration=args.oasis_checkpoint_acceleration, ).to(device) # actual reconstruction with the algo being evaluated diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index cb286c3..18ed4bd 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,365 +1,389 @@ -import json -from pathlib import Path -from typing import Optional - -import deepinv as dinv -import torch - -from ._fastmri_unet import Unet -from ..utils import ( - download_file_with_sha256, - download_google_drive_file_with_sha256, - matches_sha256, -) - - -class RAMReconstructor(dinv.models.Reconstructor): - """ - Wrapper for RAM from DeepInverse. - Normalises input by magnitude of adjoint. - - :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. - """ - - def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = dinv.models.RAM(device=device) - self.default_sigma = default_sigma - - def forward(self, y, physics): - _x_adj = physics.A_adjoint(y) - scale = torch.quantile(_x_adj, 0.99) - - physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() - physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() - - if physics_norm > 1.2 or physics_norm < 0.8: - raise ValueError( - f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" - ) - if physics_adjointness > 0.1 or physics_adjointness < -0.1: - raise ValueError( - f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" - ) - - sigma = ( - None - if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") - else self.default_sigma - ) - - with torch.no_grad(): - return self.model(y / scale, physics, sigma=sigma) * scale - - -class DeepImagePriorReconstructor(dinv.models.Reconstructor): - """ - Wrapper for Deep Image Prior from DeepInverse. - - :param tuple img_size: image size of the output. Defaults to (640, 368) - :param int n_iter: number of iterations to fit the DIP. Defaults to 100. - """ - - def __init__( - self, - img_size: tuple = (640, 368), - n_iter: int = 100, - verbose: bool = True, - ) -> None: - super().__init__() - - lr = 1e-2 # learning rate for the optimizer. - channels = 64 # number of channels per layer in the decoder. - in_size = [2, 2] # size of the input to the decoder. - - self.model = dinv.models.DeepImagePrior( - dinv.models.ConvDecoder( - img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels - ), - learning_rate=lr, - iterations=n_iter, - verbose=verbose, - input_size=[channels] + in_size, - ) - - def forward(self, y, physics): - return self.model(y, physics) - - -class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): - """ - Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. - - Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. - - Note: this model discards complex information and only returns the magnitude image. - - NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). - - The pretrained fastMRI model expects magnitude images that are normalized per slice, - so this wrapper matches that preprocessing and rescales the output back to the - original adjoint-image intensity range. - - See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. - """ - - MODEL_URL = ( - "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" - "knee_sc_leaderboard_state_dict.pt" - ) - MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" - MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" - MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "fastmri_singlecoil_unet" - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 256, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - - def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - self.model = Unet(**self.UNET_KWARGS) - - state_dict_path = ( - Path(state_dict_file).expanduser() - if state_dict_file is not None - else self.MODEL_DIR / self.MODEL_FILENAME - ) - - if state_dict_file is None: - if not matches_sha256(state_dict_path, self.MODEL_SHA256): - download_file_with_sha256( - self.MODEL_URL, - state_dict_path, - self.MODEL_SHA256, - label="FastMRI UNet checkpoint", - ) - elif not state_dict_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") - - self.model.load_state_dict( - torch.load(state_dict_path, map_location=device, weights_only=True) - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - x_in = physics.A_adjoint(y) - - x_in = dinv.utils.complex_abs(x_in, keepdim=True) - - # Match the fastMRI normalization used for training, then rescale the - # predicted magnitude image back to the original adjoint-image intensity range. - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu # (B, 1, H, W) - - return torch.cat([out, torch.zeros_like(out)], dim=1) - - -def _load_unet_checkpoint_state( - checkpoint_path: Path, - device: torch.device, -) -> dict[str, torch.Tensor]: - """Load U-Net weights from a plain or Lightning-style checkpoint.""" - - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - - # Accept either a checkpoint with "state_dict" or a plain state_dict - if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - - # If Lightning-style keys like 'unet.*' exist, strip the 'unet.' prefix - if any(key.startswith("unet.") for key in state_dict): - return { - key[len("unet.") :]: value - for key, value in state_dict.items() - if key.startswith("unet.") - } - - return state_dict - - -class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): - """Wrapper for a trained OASIS single-coil U-Net model. - - The model reuses the repository's fastMRI-derived :class:`Unet` module, but - can also download the packaged checkpoint manifest and checkpoint on demand - when no explicit checkpoint path is supplied. The forward pass converts - k-space to a zero-filled magnitude image, applies per-slice instance - normalization, runs the U-Net, then rescales the prediction back to the - adjoint-image intensity range. - - Parameters - ---------- - checkpoint_file : str, optional - Path to the trained OASIS U-Net checkpoint. If omitted, the reconstructor - downloads the packaged checkpoint for ``acceleration``. - acceleration : int, optional - Packaged checkpoint acceleration factor used when ``checkpoint_file`` is omitted. - manifest_path : str, optional - Override path for the downloaded or cached packaged checkpoint manifest. - device : torch.device, optional - Device on which to run inference. - """ - - UNET_KWARGS = { - "in_chans": 1, - "out_chans": 1, - "chans": 32, - "num_pool_layers": 4, - "drop_prob": 0.0, - } - MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "oasis_singlecoil_unet" - CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" - MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" - MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" - MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" - CHECKPOINT_FILE_IDS = { - "4": "11s6YeM6_YJeD4wcrn24jyMjyj_vX2ANU", - "8": "1w8PDiYpr2xBPXahzRllhZjQT1yoMGXg-", - "10": "1djJ2i0uYP4PT070CS0xx9nNJ41JmSFhh", - } - CHECKPOINT_SHA256 = { - "4": "4fcefa9860cb7895e581a0de8f90bd7f188ae1c0b5e428a4a07519dd2561ac29", - "8": "2cd4c44e3c7a3870adbe5090b2bfaae044f5e3f0b4bcaf2b1fc29969e5e6b9ca", - "10": "90e3d9b17aa0f9fd43aaf090c152edcbdead1b9be41a076594e41098db7befa8", - } - - @classmethod - def ensure_manifest(cls, manifest_path: Optional[Path] = None) -> Path: - """Ensure the packaged OASIS checkpoint manifest exists locally and is verified.""" - - resolved_manifest_path = ( - manifest_path.expanduser().resolve() if manifest_path is not None else cls.MANIFEST_PATH - ) - if not matches_sha256(resolved_manifest_path, cls.MANIFEST_SHA256): - download_google_drive_file_with_sha256( - cls.MANIFEST_FILE_ID, - resolved_manifest_path, - cls.MANIFEST_SHA256, - label="OASIS checkpoint manifest", - ) - return resolved_manifest_path - - @classmethod - def resolve_default_checkpoint( - cls, - acceleration: int, - manifest_path: Optional[Path] = None, - ) -> Path: - """Resolve and download the packaged OASIS checkpoint for a given acceleration.""" - - resolved_manifest_path = cls.ensure_manifest(manifest_path) - with resolved_manifest_path.open("r", encoding="utf-8") as handle: - manifest = json.load(handle) - - key = str(acceleration) - checkpoints = manifest.get("checkpoints", {}) - if key not in checkpoints: - available = ", ".join(sorted(checkpoints)) - raise ValueError( - f"No packaged checkpoint for acceleration {acceleration}. Available: {available}." - ) - - if key not in cls.CHECKPOINT_FILE_IDS or key not in cls.CHECKPOINT_SHA256: - raise ValueError( - f"No automated download metadata is configured for acceleration {acceleration}." - ) - - filename = Path(checkpoints[key]["filename"]) - checkpoint_path = ( - filename - if filename.is_absolute() - else (resolved_manifest_path.parent.parent / filename) - ).resolve() - - if not matches_sha256(checkpoint_path, cls.CHECKPOINT_SHA256[key]): - download_google_drive_file_with_sha256( - cls.CHECKPOINT_FILE_IDS[key], - checkpoint_path, - cls.CHECKPOINT_SHA256[key], - label=f"OASIS checkpoint x{acceleration}", - ) - - return checkpoint_path - - def __init__( - self, - checkpoint_file: str | None = None, - acceleration: int = 4, - manifest_path: str | None = None, - device: torch.device = None, - ) -> None: - super().__init__() - - if device is None: - device = torch.device("cpu") - self.device = device - - if checkpoint_file is None: - resolved_manifest_path = ( - Path(manifest_path).expanduser() if manifest_path is not None else None - ) - checkpoint_path = self.resolve_default_checkpoint( - acceleration=acceleration, - manifest_path=resolved_manifest_path, - ) - else: - checkpoint_path = Path(checkpoint_file).expanduser() - if not checkpoint_path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - - # Use the helper to obtain a normalized state_dict (handles plain or Lightning) - state_dict = _load_unet_checkpoint_state(checkpoint_path, device) - - self.model = Unet(**self.UNET_KWARGS) - self.model.load_state_dict( - state_dict, - strict=True, - ) - self.model.eval() - self.model.to(device) - - def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: - """Reconstruct a magnitude image from measured k-space. - - Parameters - ---------- - y : torch.Tensor - Measured k-space tensor. - physics : dinv.physics.Physics - Physics operator used to form the adjoint input image. - - Returns - ------- - torch.Tensor - Complex-valued reconstruction with zero imaginary channel. - """ - - x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) - mu = x_in.mean(dim=(-2, -1), keepdim=True) - std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 - x_in = (x_in - mu) / std - - with torch.no_grad(): - out = self.model(x_in) * std + mu - - return torch.cat([out, torch.zeros_like(out)], dim=1) +import json +from pathlib import Path +from typing import Optional + +import deepinv as dinv +import torch + +from ._fastmri_unet import Unet +from ..utils import ( + download_file_with_sha256, + download_google_drive_file_with_sha256, + matches_sha256, +) + + +class RAMReconstructor(dinv.models.Reconstructor): + """ + Wrapper for RAM from DeepInverse. + Normalises input by magnitude of adjoint. + + :param float default_sigma: default sigma for RAM input. Overriden if physics already has a sigma (e.g. in a Gaussian noise model) at inference time. + """ + + def __init__(self, default_sigma=0.05, device: torch.device = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = dinv.models.RAM(device=device) + self.default_sigma = default_sigma + + def forward(self, y, physics): + _x_adj = physics.A_adjoint(y) + scale = torch.quantile(_x_adj, 0.99) + + physics_norm = physics.compute_norm(torch.randn_like(_x_adj)).item() + physics_adjointness = physics.adjointness_test(torch.randn_like(_x_adj)).item() + + if physics_norm > 1.2 or physics_norm < 0.8: + raise ValueError( + f"RAM reconstructor requires physics norm = 1 but got {physics_norm:.4f}" + ) + if physics_adjointness > 0.1 or physics_adjointness < -0.1: + raise ValueError( + f"RAM reconstructor requires physics adjointness = 0 but got {physics_adjointness:.4f}" + ) + + sigma = ( + None + if hasattr(physics, "noise_model") and hasattr(physics.noise_model, "sigma") + else self.default_sigma + ) + + with torch.no_grad(): + return self.model(y / scale, physics, sigma=sigma) * scale + + +class DeepImagePriorReconstructor(dinv.models.Reconstructor): + """ + Wrapper for Deep Image Prior from DeepInverse. + + :param tuple img_size: image size of the output. Defaults to (640, 368) + :param int n_iter: number of iterations to fit the DIP. Defaults to 100. + """ + + def __init__( + self, + img_size: tuple = (640, 368), + n_iter: int = 100, + verbose: bool = True, + ) -> None: + super().__init__() + + lr = 1e-2 # learning rate for the optimizer. + channels = 64 # number of channels per layer in the decoder. + in_size = [2, 2] # size of the input to the decoder. + + self.model = dinv.models.DeepImagePrior( + dinv.models.ConvDecoder( + img_size=(2, *img_size[-2:]), in_size=in_size, channels=channels + ), + learning_rate=lr, + iterations=n_iter, + verbose=verbose, + input_size=[channels] + in_size, + ) + + def forward(self, y, physics): + return self.model(y, physics) + + +class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): + """ + Wrapper for pretrained UNet from FastMRI singlecoil knee challenge. + + Note: this model was trained for accelerated MRI reconstruction and may not have good performance on other degradations. + + Note: this model discards complex information and only returns the magnitude image. + + NOTE: this model was trained on both train+val splits of the challenge (i.e. trained on singlecoil_train, singlecoil_val). + + The pretrained fastMRI model expects magnitude images that are normalized per slice, + so this wrapper matches that preprocessing and rescales the output back to the + original adjoint-image intensity range. + + See https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples for more details. + """ + + MODEL_URL = ( + "https://dl.fbaipublicfiles.com/fastMRI/trained_models/unet/" + "knee_sc_leaderboard_state_dict.pt" + ) + MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" + MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "fastmri_singlecoil_unet" + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 256, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + + def __init__(self, device: torch.device = None, state_dict_file: str = None) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + self.model = Unet(**self.UNET_KWARGS) + + state_dict_path = ( + Path(state_dict_file).expanduser() + if state_dict_file is not None + else self.MODEL_DIR / self.MODEL_FILENAME + ) + + if state_dict_file is None: + if not matches_sha256(state_dict_path, self.MODEL_SHA256): + download_file_with_sha256( + self.MODEL_URL, + state_dict_path, + self.MODEL_SHA256, + label="FastMRI UNet checkpoint", + ) + elif not state_dict_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {state_dict_path}") + + self.model.load_state_dict( + torch.load(state_dict_path, map_location=device, weights_only=True) + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + x_in = physics.A_adjoint(y) + + x_in = dinv.utils.complex_abs(x_in, keepdim=True) + + # Match the fastMRI normalization used for training, then rescale the + # predicted magnitude image back to the original adjoint-image intensity range. + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-8 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu # (B, 1, H, W) + + return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + # Accept either a checkpoint with "state_dict" or a plain state_dict + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # If Lightning-style keys like 'unet.*' exist, strip the 'unet.' prefix + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """Wrapper for a trained OASIS single-coil U-Net model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + can also download the packaged checkpoint manifest and checkpoint on demand + when no explicit checkpoint path is supplied. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + Parameters + ---------- + checkpoint_file : str, optional + Path to the trained OASIS U-Net checkpoint. If omitted, the reconstructor + downloads the packaged checkpoint for ``acceleration``. + acceleration : int, optional + Training acceleration of the packaged checkpoint used when ``checkpoint_file`` + is omitted. This selects pretrained weights only; it does not configure the + measurement distortion used during inference. + manifest_path : str, optional + Override path for the downloaded or cached packaged checkpoint manifest. + device : torch.device, optional + Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "oasis_singlecoil_unet" + CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" + SPLITS_DIR = MODEL_DIR / "splits" + MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" + SPLIT_CSV_PATH = SPLITS_DIR / "oasis_balanced_test.csv" + MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" + MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" + SPLIT_CSV_FILE_ID = "16UoZ6sYzOwADv4KLRR9m0kjueXoxENrD" + SPLIT_CSV_SHA256 = "8627cf9781e6d5f94c2a0f08a7a75b386fb9b737cdce94fb1558f95fa2e62dd5" + CHECKPOINT_FILE_IDS = { + "4": "11s6YeM6_YJeD4wcrn24jyMjyj_vX2ANU", + "8": "1w8PDiYpr2xBPXahzRllhZjQT1yoMGXg-", + "10": "1djJ2i0uYP4PT070CS0xx9nNJ41JmSFhh", + } + CHECKPOINT_SHA256 = { + "4": "4fcefa9860cb7895e581a0de8f90bd7f188ae1c0b5e428a4a07519dd2561ac29", + "8": "2cd4c44e3c7a3870adbe5090b2bfaae044f5e3f0b4bcaf2b1fc29969e5e6b9ca", + "10": "90e3d9b17aa0f9fd43aaf090c152edcbdead1b9be41a076594e41098db7befa8", + } + + @classmethod + def ensure_manifest(cls, manifest_path: Optional[Path] = None) -> Path: + """Ensure the packaged OASIS checkpoint manifest exists locally and is verified.""" + + resolved_manifest_path = ( + manifest_path.expanduser().resolve() if manifest_path is not None else cls.MANIFEST_PATH + ) + if not matches_sha256(resolved_manifest_path, cls.MANIFEST_SHA256): + download_google_drive_file_with_sha256( + cls.MANIFEST_FILE_ID, + resolved_manifest_path, + cls.MANIFEST_SHA256, + label="OASIS checkpoint manifest", + ) + return resolved_manifest_path + + @classmethod + def resolve_default_split_csv(cls) -> Path: + """Resolve and download the packaged OASIS split CSV.""" + + resolved_split_csv_path = cls.SPLIT_CSV_PATH.resolve() + if matches_sha256(resolved_split_csv_path, cls.SPLIT_CSV_SHA256): + return resolved_split_csv_path + + download_google_drive_file_with_sha256( + cls.SPLIT_CSV_FILE_ID, + resolved_split_csv_path, + cls.SPLIT_CSV_SHA256, + label="OASIS split CSV", + ) + return resolved_split_csv_path + + @classmethod + def resolve_default_checkpoint( + cls, + acceleration: int, + manifest_path: Optional[Path] = None, + ) -> Path: + """Resolve and download the OASIS checkpoint for a training acceleration.""" + + resolved_manifest_path = cls.ensure_manifest(manifest_path) + with resolved_manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + key = str(acceleration) + checkpoints = manifest.get("checkpoints", {}) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for training acceleration {acceleration}. " + f"Available: {available}." + ) + + if key not in cls.CHECKPOINT_FILE_IDS or key not in cls.CHECKPOINT_SHA256: + raise ValueError( + "No automated download metadata is configured for OASIS checkpoint " + f"training acceleration {acceleration}." + ) + + filename = Path(checkpoints[key]["filename"]) + checkpoint_path = ( + filename + if filename.is_absolute() + else (resolved_manifest_path.parent.parent / filename) + ).resolve() + + if not matches_sha256(checkpoint_path, cls.CHECKPOINT_SHA256[key]): + download_google_drive_file_with_sha256( + cls.CHECKPOINT_FILE_IDS[key], + checkpoint_path, + cls.CHECKPOINT_SHA256[key], + label=f"OASIS checkpoint x{acceleration}", + ) + + return checkpoint_path + + def __init__( + self, + checkpoint_file: str | None = None, + acceleration: int = 4, + manifest_path: str | None = None, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + if checkpoint_file is None: + resolved_manifest_path = ( + Path(manifest_path).expanduser() if manifest_path is not None else None + ) + checkpoint_path = self.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=resolved_manifest_path, + ) + else: + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Use the helper to obtain a normalized state_dict (handles plain or Lightning) + state_dict = _load_unet_checkpoint_state(checkpoint_path, device) + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + state_dict, + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space. + + Parameters + ---------- + y : torch.Tensor + Measured k-space tensor. + physics : dinv.physics.Physics + Physics operator used to form the adjoint input image. + + Returns + ------- + torch.Tensor + Complex-valued reconstruction with zero imaginary channel. + """ + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 3d17127..1e27df9 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -2,10 +2,18 @@ from .io import download_google_drive_file_with_sha256 as download_google_drive_file_with_sha256 from .io import format_megabytes as format_megabytes from .io import matches_sha256 as matches_sha256 +from .oasis_adapter import OasisCenteredFFTPhysics as OasisCenteredFFTPhysics +from .oasis_adapter import OasisSliceDataset as OasisSliceDataset +from .oasis_adapter import image_to_kspace as image_to_kspace +from .oasis_adapter import kspace_to_image as kspace_to_image __all__ = [ "download_file_with_sha256", "download_google_drive_file_with_sha256", "format_megabytes", + "image_to_kspace", + "kspace_to_image", "matches_sha256", + "OasisCenteredFFTPhysics", + "OasisSliceDataset", ] diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py new file mode 100644 index 0000000..4b2750f --- /dev/null +++ b/mri_recon/utils/oasis_adapter.py @@ -0,0 +1,212 @@ +"""OASIS dataset and centered FFT adapters.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import Dataset + +from mri_recon.distortions import BaseDistortion + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes. + + Parameters + ---------- + data_path : Path + Root directory containing OASIS subject folders. + split_csv : Path + CSV file listing OASIS subjects and slice counts. + sample_rate : float, optional + Fraction of slices to keep from each volume. Values below ``1`` keep a + centered slice range. Defaults to ``1``. + """ + + def __init__( + self, + data_path: Path, + split_csv: Path, + sample_rate: float = 1.0, + ) -> None: + try: + import nibabel as nib + except ImportError as exc: + raise ImportError( + "OASIS loading requires nibabel. Install the project dependencies " + "or add nibabel to your environment before using OasisSliceDataset." + ) from exc + + self._nib = nib + self.data_path = Path(data_path) + self.split_csv = Path(split_csv) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.subject_paths = self._discover_subject_paths() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + volume = self._get_volume(subject_id) + target_np = np.ascontiguousarray(volume[slice_num], dtype=np.float32) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _discover_subject_paths(self) -> dict[str, Path]: + subject_paths = {} + for subject_dir in sorted(self.data_path.iterdir()): + if not subject_dir.is_dir(): + continue + image_glob = subject_dir / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if matches: + subject_paths[subject_dir.name] = matches[0] + + if not subject_paths: + raise FileNotFoundError( + "Could not find OASIS subject folders under " + f"{self.data_path} matching PROCESSED/MPRAGE/T88_111/*t88_gfc.img." + ) + return subject_paths + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples = [] + rows = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + rows.append((row[0], int(row[-1]))) + except ValueError: + continue + + for subject_id, num_slices in rows: + if subject_id not in self.subject_paths: + raise FileNotFoundError( + f"Could not find OASIS subject {subject_id!r} from split CSV under " + f"{self.data_path}." + ) + mid = round(num_slices / 2) + half_span = round(num_slices * self.sample_rate / 2) + start = 0 if self.sample_rate >= 1.0 else max(0, mid - half_span) + stop = num_slices if self.sample_rate >= 1.0 else min(num_slices, mid + half_span) + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _num_slices(self, image_path: Path) -> int: + shape = tuple(dim for dim in self._nib.load(str(image_path)).shape if dim != 1) + if len(shape) < 2: + raise ValueError(f"Expected at least 2D OASIS image, got shape {shape}.") + return shape[1] + + def _get_volume(self, subject_id: str) -> np.ndarray: + image_data = self._nib.load(str(self.subject_paths[subject_id])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + return volume + + +def image_to_kspace(x: torch.Tensor) -> torch.Tensor: + """Convert channel-first complex images to centered k-space. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Centered k-space tensor with shape ``(B, 2, H, W)``. + """ + + x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) + y_complex = torch.fft.fftshift( + torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), + dim=(-2, -1), + ) + return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() + + +def kspace_to_image(y: torch.Tensor) -> torch.Tensor: + """Convert centered channel-first k-space to complex images. + + Parameters + ---------- + y : torch.Tensor + Centered k-space tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + x_complex = torch.fft.ifft2( + torch.fft.ifftshift(y_complex, dim=(-2, -1)), + dim=(-2, -1), + norm="ortho", + ) + return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() + + +class OasisCenteredFFTPhysics: + """Physics adapter matching the OASIS U-Net FFT convention. + + Parameters + ---------- + distortion : BaseDistortion + K-space distortion applied after the centered FFT. + """ + + def __init__(self, distortion: BaseDistortion) -> None: + self.distortion = distortion + + def A(self, x: torch.Tensor) -> torch.Tensor: + """Apply centered FFT and k-space distortion. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Distorted centered k-space tensor. + """ + + return self.distortion.A(image_to_kspace(x)) + + def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + """Apply adjoint distortion and centered inverse FFT. + + Parameters + ---------- + y : torch.Tensor + Distorted centered k-space tensor. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + return kspace_to_image(self.distortion.A_adjoint(y)) From 945e60646d20e41980e369058af3cc3dacac87d0 Mon Sep 17 00:00:00 2001 From: Yuning Du Date: Thu, 14 May 2026 20:13:16 +0100 Subject: [PATCH 17/17] remove the OASIS_inference_plot.py --- examples/OASIS_inference_plot.py | 731 ------------------------------- 1 file changed, 731 deletions(-) delete mode 100644 examples/OASIS_inference_plot.py diff --git a/examples/OASIS_inference_plot.py b/examples/OASIS_inference_plot.py deleted file mode 100644 index 512dc33..0000000 --- a/examples/OASIS_inference_plot.py +++ /dev/null @@ -1,731 +0,0 @@ -"""Inference OASIS reconstructors for k-space distortion operators. - -Usage: - python examples/OASIS_inference_plot.py --source /path/to/oasis_cross_sectional_data -""" - -from __future__ import annotations - -import argparse -import os -import sys -from collections import OrderedDict -from pathlib import Path -from typing import Optional - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import matplotlib.pyplot as plt -import numpy as np -import torch -import deepinv as dinv -from torch.utils.data import DataLoader, Dataset - -try: - import nibabel as nib -except ImportError as exc: - raise ImportError( - "The OASIS example requires nibabel. Install the project dependencies " - "or add nibabel to your environment before running this script." - ) from exc - -from mri_recon.distortions import ( - AnisotropicResolutionReduction, - BaseDistortion, - CartesianUndersampling, - DistortedKspaceMultiCoilMRI, - GaussianKspaceBiasField, - GaussianNoiseDistortion, - HannTaperResolutionReduction, - IsotropicResolutionReduction, - KaiserTaperResolutionReduction, - OffCenterAnisotropicGaussianKspaceBiasField, - PartialFourierDistortion, - PhaseEncodeGhostingDistortion, - RadialHighPassEmphasisDistortion, - RotationalMotionDistortion, - SegmentedTranslationMotionDistortion, - TranslationMotionDistortion, -) -from mri_recon.reconstruction import ( - ConjugateGradientReconstructor, - DeepImagePriorReconstructor, - OASISSinglecoilUnetReconstructor, - RAMReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - ZeroFilledReconstructor, -) - -REPO_ROOT = Path(__file__).resolve().parents[1] -REPORT_DIR = Path("reports") / "oasis_inference_plot" -DEFAULT_SPLIT_CSV = REPO_ROOT / "reconstruction_only" / "splits" / "oasis_balanced_test.csv" -DEFAULT_MANIFEST_PATH = ( - REPO_ROOT / "downloads" / "oasis_singlecoil_unet" / "checkpoints" / "manifest.json" -) - -REPORT_DIR.mkdir(parents=True, exist_ok=True) - -ALGORITHMS = [ - # "zero-filled", - # "conjugate-gradient", - # "ram", - # "dip", - "tv-pgd", - # "wavelet-fista", - # "tv-fista", - # "tv-pdhg", - "oasis-unet", - "unet", -] -DISTORTIONS = [ - "None", - "Cartesian undersampling (variable density)", - "Cartesian undersampling (uniform random)", - "Cartesian undersampling (uniform random, zero ACS)", - "Cartesian undersampling (equispaced)", - "Cartesian undersampling (equispaced, zero ACS)", - "Partial Fourier", - "Phase-encode ghosting", - "Segmented translation motion", - "Translation motion", - "Rotational motion", - "Off-center anisotropic Gaussian bias field", - "Gaussian bias field", - "Anisotropic LP", - "Hann taper LP", - "Kaiser taper LP", - "Radial high-pass emphasis", - "Gaussian noise", - "Isotropic LP", -] -METRICS = [ - "PSNR", - "NMSE", - "SSIM", - "HaarPSI", - "SharpnessIndex", - "BlurStrength", -] - - -class OasisSliceDataset(Dataset): - """Load 2D OASIS slices from Analyze/NIfTI volumes. - - Parameters - ---------- - split_csv : Path - CSV file listing OASIS subjects and slice counts. - data_path : Path - Root directory containing OASIS subject folders. - sample_rate : float, optional - Fraction of slices to include from each volume. - cache_size : int, optional - Number of loaded volumes to keep in memory. - """ - - def __init__( - self, - split_csv: Path, - data_path: Path, - sample_rate: float = 1.0, - cache_size: int = 2, - ) -> None: - self.split_csv = Path(split_csv) - self.data_path = Path(data_path) - if not 0 < sample_rate <= 1.0: - raise ValueError("sample_rate must be in the range (0, 1].") - self.sample_rate = sample_rate - self.cache_size = max(0, cache_size) - self._volume_cache: OrderedDict[str, np.ndarray] = OrderedDict() - self.raw_samples = self._create_sample_list() - - def __len__(self) -> int: - """Return the number of available slices.""" - - return len(self.raw_samples) - - def __getitem__(self, index: int) -> dict[str, object]: - """Return one complex-valued OASIS slice in repo tensor convention.""" - - subject_id, slice_num = self.raw_samples[index] - target_np = self._read_raw_slice(subject_id, slice_num) - real = torch.from_numpy(target_np) - x = torch.stack([real, torch.zeros_like(real)], dim=0) - return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} - - def _create_sample_list(self) -> list[tuple[str, int]]: - samples: list[tuple[str, int]] = [] - with self.split_csv.open("r", encoding="utf-8") as handle: - for line in handle: - row = [item.strip() for item in line.split(",")] - if not row or not row[0]: - continue - try: - total_slices = int(row[-1]) - except ValueError: - continue - - subject_id = row[0] - if self.sample_rate >= 1.0: - start = 0 - stop = total_slices - else: - mid = round(total_slices / 2) - half_span = round(total_slices * self.sample_rate / 2) - start = max(0, mid - half_span) - stop = min(total_slices, mid + half_span) - - for slice_num in range(start, stop): - samples.append((subject_id, slice_num)) - return samples - - def _read_raw_slice(self, subject_id: str, slice_num: int) -> np.ndarray: - volume = self._get_volume(subject_id) - return np.ascontiguousarray(volume[slice_num], dtype=np.float32) - - def _get_volume(self, subject_id: str) -> np.ndarray: - if self.cache_size > 0 and subject_id in self._volume_cache: - self._volume_cache.move_to_end(subject_id) - return self._volume_cache[subject_id] - - image_glob = self.data_path / subject_id / "PROCESSED" / "MPRAGE" / "T88_111" - matches = sorted(image_glob.glob("*t88_gfc.img")) - if not matches: - raise FileNotFoundError( - f"Could not find OASIS image for subject {subject_id!r} under {image_glob}." - ) - - image_data = nib.load(str(matches[0])).get_fdata(dtype=np.float32) - volume = np.ascontiguousarray( - np.transpose(np.squeeze(image_data), (1, 0, 2)), - dtype=np.float32, - ) - - if self.cache_size > 0: - self._volume_cache[subject_id] = volume - if len(self._volume_cache) > self.cache_size: - self._volume_cache.popitem(last=False) - - return volume - - -def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to log-magnitude image for visualization.""" - - kspace = kspace.detach().cpu() - if kspace.ndim == 5: - kspace = kspace[0] - if kspace.ndim == 4: - if kspace.shape[0] == 1 and kspace.shape[1] == 2: - kspace = kspace[0] - elif kspace.shape[0] != 2: - raise ValueError( - "Expected k-space with shape (2, H, W), (1, 2, H, W), " - f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" - ) - if kspace.ndim != 3 and kspace.ndim != 4: - raise ValueError( - "Expected k-space with shape (2, H, W), (1, 2, H, W), " - f"or (1, 2, N, H, W), got {tuple(kspace.shape)}" - ) - if kspace.shape[0] != 2: - raise ValueError(f"Expected real/imaginary channel first, got {tuple(kspace.shape)}") - - kspace_complex = torch.view_as_complex(torch.movedim(kspace, 0, -1).contiguous()) - magnitude = torch.abs(kspace_complex) - if magnitude.ndim == 3: - magnitude = torch.sqrt(torch.sum(magnitude.square(), dim=0)) - magnitude = torch.log1p(magnitude) - - lower = torch.quantile(magnitude, 0.05) - upper = torch.quantile(magnitude, 0.995) - if float(upper) > float(lower): - magnitude = magnitude.clamp(lower, upper) - magnitude = (magnitude - lower) / (upper - lower) - else: - mag_max = float(magnitude.max()) - if mag_max > 0.0: - magnitude = magnitude / mag_max - - return torch.sqrt(magnitude) - - -def save_kspace_plot( - y_clean: torch.Tensor, - y_distorted: torch.Tensor, - save_fn: Path, - distortion_name: str, -) -> None: - """Save clean and distorted k-space magnitude plots.""" - - images = [ - ("Original k-space", _kspace_to_log_magnitude(y_clean)), - ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), - ] - - fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) - fig.suptitle(f"Distortion: {distortion_name}") - for ax, (title, image) in zip(axes, images, strict=True): - ax.imshow(image.numpy(), cmap="magma") - ax.set_title(title) - ax.axis("off") - fig.savefig(save_fn, dpi=200, bbox_inches="tight") - plt.close(fig) - - -def image_to_kspace(x: torch.Tensor) -> torch.Tensor: - """Convert channel-first complex images to centered k-space.""" - - x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) - y_complex = torch.fft.fftshift( - torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), - dim=(-2, -1), - ) - return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() - - -def kspace_to_image(y: torch.Tensor) -> torch.Tensor: - """Convert centered channel-first k-space to complex images.""" - - y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) - x_complex = torch.fft.ifft2( - torch.fft.ifftshift(y_complex, dim=(-2, -1)), - dim=(-2, -1), - norm="ortho", - ) - return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() - - -class OasisCenteredFFTPhysics: - """Physics adapter matching the OASIS U-Net FFT convention. - - Parameters - ---------- - distortion : BaseDistortion - K-space distortion applied after the centered FFT. - """ - - def __init__(self, distortion: BaseDistortion) -> None: - self.distortion = distortion - - def A(self, x: torch.Tensor) -> torch.Tensor: - """Apply centered FFT and k-space distortion. - - Parameters - ---------- - x : torch.Tensor - Complex image tensor with shape ``(B, 2, H, W)``. - - Returns - ------- - torch.Tensor - Distorted centered k-space tensor. - """ - - return self.distortion.A(image_to_kspace(x)) - - def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: - """Apply adjoint distortion and centered inverse FFT. - - Parameters - ---------- - y : torch.Tensor - Distorted centered k-space tensor. - - Returns - ------- - torch.Tensor - Complex image tensor with shape ``(B, 2, H, W)``. - """ - - return kspace_to_image(self.distortion.A_adjoint(y)) - - -def resolve_oasis_checkpoint( - checkpoint: Optional[Path], - acceleration: int, - manifest_path: Path, -) -> Path: - """Resolve an explicit or packaged OASIS checkpoint path. - - Parameters - ---------- - checkpoint : Path or None - User-provided checkpoint path. - acceleration : int - Acceleration key used when loading from the manifest. - manifest_path : Path - JSON manifest with packaged checkpoint metadata. - - Returns - ------- - Path - Resolved checkpoint path. - """ - - if checkpoint is not None: - return checkpoint.expanduser().resolve() - - return OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( - acceleration=acceleration, - manifest_path=manifest_path, - ) - - -def choose_algorithm( - name: str, - checkpoint_file: Path, - img_size: tuple = (640, 368), - device: torch.device = "cpu", - verbose: bool = False, -) -> dinv.models.Reconstructor: - """Construct a reconstructor by selector name.""" - - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=verbose) - case "tv-fista": - return TVFISTAReconstructor(n_iter=200, verbose=verbose) - case "tv-pdhg": - return TVPDHGReconstructor(n_iter=100, verbose=verbose) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) - case "oasis-unet" | "unet": - return OASISSinglecoilUnetReconstructor( - checkpoint_file=str(checkpoint_file), - device=device, - ) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - -def choose_distortion( - name: str, - acceleration: int, - center_fraction: float, -) -> BaseDistortion: - """Construct a k-space distortion by display name.""" - - match name: - case "None": - return BaseDistortion() - case "Cartesian undersampling (variable density)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="variable_density_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (uniform random)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="uniform_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (uniform random, zero ACS)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=0.0, - pattern="uniform_random", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (equispaced)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=center_fraction, - pattern="equispaced", - axis=-1, - seed=42, - ) - case "Cartesian undersampling (equispaced, zero ACS)": - return CartesianUndersampling( - keep_fraction=1.0 / acceleration, - center_fraction=0.0, - pattern="equispaced", - axis=-1, - seed=42, - ) - case "Partial Fourier": - return PartialFourierDistortion( - partial_fraction=0.7, - center_fraction=0.1, - axis=-1, - side="high", - ) - case "Phase-encode ghosting": - return PhaseEncodeGhostingDistortion( - line_period=2, - line_offset=1, - phase_error_radians=torch.pi / 2, - corrupted_line_scale=1.0, - ) - case "Anisotropic LP": - return AnisotropicResolutionReduction( - kx_radius_fraction=1.0, - ky_radius_fraction=0.25, - ) - case "Hann taper LP": - return HannTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - ) - case "Kaiser taper LP": - return KaiserTaperResolutionReduction( - radius_fraction=0.35, - transition_fraction=0.4, - beta=8.6, - ) - case "Radial high-pass emphasis": - return RadialHighPassEmphasisDistortion(alpha=0.4) - case "Isotropic LP": - return IsotropicResolutionReduction(radius_fraction=0.1) - case "Off-center anisotropic Gaussian bias field": - return OffCenterAnisotropicGaussianKspaceBiasField( - width_x_fraction=0.2, - width_y_fraction=0.35, - center_x_fraction=0.15, - center_y_fraction=-0.1, - edge_gain=0.3, - ) - case "Translation motion": - return TranslationMotionDistortion(shift_x_pixels=60, shift_y_pixels=10) - case "Rotational motion": - return RotationalMotionDistortion(angle_radians=torch.pi / 6) - case "Segmented translation motion": - return SegmentedTranslationMotionDistortion( - shift_x_pixels=(0.0, 20.0, 50.0, -50.0), - shift_y_pixels=(0.0, 10.0, -20.0, 20.0), - ) - case "Gaussian bias field": - return GaussianKspaceBiasField(width_fraction=0.35, edge_gain=0.4) - case "Gaussian noise": - return GaussianNoiseDistortion(sigma=0.00001) - case _: - raise ValueError(f"Unknown distortion {name!r}") - - -def choose_metric(name: str) -> dinv.metric.Metric: - """Construct a DeepInverse metric by selector name.""" - - match name: - case "PSNR": - return dinv.metric.PSNR(max_pixel=None, complex_abs=True) - case "NMSE": - return dinv.metric.NMSE(complex_abs=True) - case "SSIM": - return dinv.metric.SSIM(max_pixel=None, complex_abs=True) - case "HaarPSI": - return dinv.metric.HaarPSI(norm_inputs="min_max", complex_abs=True) - case "BlurStrength": - return dinv.metric.BlurStrength(complex_abs=True) - case "SharpnessIndex": - return dinv.metric.SharpnessIndex(complex_abs=True) - case _: - raise ValueError(f"Unknown metric {name!r}") - - -def build_parser() -> argparse.ArgumentParser: - """Build the command-line parser.""" - - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--source", - type=Path, - required=True, - help="OASIS root directory containing subject folders.", - ) - parser.add_argument( - "--split_csv", - type=Path, - default=DEFAULT_SPLIT_CSV, - help="CSV listing OASIS subjects and slice counts.", - ) - parser.add_argument( - "--manifest", - type=Path, - default=DEFAULT_MANIFEST_PATH, - help="Checkpoint manifest JSON.", - ) - parser.add_argument( - "--checkpoint", - type=Path, - default=None, - help="Explicit OASIS U-Net checkpoint. Overrides --acceleration.", - ) - parser.add_argument( - "--acceleration", - type=int, - default=4, - help="Packaged OASIS checkpoint acceleration factor.", - ) - parser.add_argument( - "--center_fraction", - type=float, - default=0.08, - help="Center fraction used by the Cartesian undersampling distortion.", - ) - parser.add_argument( - "--distortion", - type=str, - default="Cartesian undersampling (uniform random)", - choices=DISTORTIONS, - ) - parser.add_argument( - "--algorithm", - type=str, - default="unet", - choices=ALGORITHMS, - help="Reconstruction algorithm applied to distorted OASIS k-space.", - ) - parser.add_argument("--num_samples", type=int, default=1, help="How many slices to process.") - parser.add_argument( - "--sample_rate", - type=float, - default=0.6, - help="Fraction of slices per volume to include from the split CSV.", - ) - parser.add_argument("--volume_cache_size", type=int, default=2) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose output for reconstructors that support it.", - ) - return parser - - -def main() -> None: - """Run OASIS inference plots.""" - - args = build_parser().parse_args() - args.source = args.source.expanduser().resolve() - args.split_csv = args.split_csv.expanduser().resolve() - args.manifest = args.manifest.expanduser().resolve() - checkpoint_file = resolve_oasis_checkpoint(args.checkpoint, args.acceleration, args.manifest) - - device = dinv.utils.get_device() - dataset = OasisSliceDataset( - split_csv=args.split_csv, - data_path=args.source, - sample_rate=args.sample_rate, - cache_size=args.volume_cache_size, - ) - dataloader = DataLoader(dataset, batch_size=1, shuffle=True) - metrics = [choose_metric(m) for m in METRICS] - - for i, batch in enumerate(iter(dataloader)): - if i >= args.num_samples: - break - - x = batch["x"].to(device) - subject_id = batch["subject_id"][0] - slice_num = int(batch["slice_num"][0]) - - for distortion_name in [args.distortion]: - distortion = choose_distortion( - distortion_name, - acceleration=args.acceleration, - center_fraction=args.center_fraction, - ) - - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), - img_size=(1, 2, *x.shape[-2:]), - device=device, - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, - img_size=(1, 2, *x.shape[-2:]), - device=device, - ) - oasis_physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) - oasis_physics = OasisCenteredFFTPhysics(distortion) - - y = image_to_kspace(x) - y_distorted = distortion.A(y) - y_physics_distorted = physics(x) - x_distorted = kspace_to_image(y_distorted) - - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", - distortion_name, - ) - - for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: - print( - f"Evaluating algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}..." - ) - - algo = choose_algorithm( - algo_name, - checkpoint_file=checkpoint_file, - img_size=x.shape[-2:], - device=device, - verbose=args.verbose, - ).to(device) - - if algo_name in {"oasis-unet", "unet"}: - y_eval = y_distorted - eval_physics_clean = oasis_physics_clean - eval_physics = oasis_physics - else: - y_eval = y_physics_distorted - eval_physics_clean = physics_clean - eval_physics = physics - - x_uncorrected = algo(y_eval, eval_physics_clean) - x_corrected = algo(y_eval, eval_physics) - uncorrected_scores = [ - f"{m.__class__.__name__} {m(x_uncorrected, x).item():.2f}" for m in metrics - ] - corrected_scores = [ - f"{m.__class__.__name__} {m(x_corrected, x).item():.2f}" for m in metrics - ] - print(f" uncorrected: {', '.join(uncorrected_scores)}") - print(f" corrected: {', '.join(corrected_scores)}") - - dinv.utils.plot( - { - "Ground truth OASIS slice": x, - "Distorted ksp, zero-filled": x_distorted, - f"Distorted ksp, {algo_name} recon, uncorrected": x_uncorrected, - f"Distorted ksp, {algo_name} recon, corrected": x_corrected, - }, - subtitles=[ - "", - "", - "\n".join(uncorrected_scores), - "\n".join(corrected_scores), - ], - show=False, - close=True, - suptitle=( - f"Algo {algo_name}, distortion {distortion_name}, " - f"subject {subject_id}, slice {slice_num}" - ), - save_fn=REPORT_DIR / f"ALGO_{algo_name}_{distortion_name}_sample_{i}.png", - fontsize=3, - ) - - print("done!") - - -if __name__ == "__main__": - main()