diff --git a/README.md b/README.md index b0c475b..3ec9ae7 100644 --- a/README.md +++ b/README.md @@ -63,16 +63,22 @@ uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available Run the FastMRI plotting example with local FastMRI k-space files: ```bash -python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet +python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet-fastmri +``` + +Run the same FastMRI data through the packaged OASIS U-Net by choosing an explicit OASIS checkpoint variant. The example adapts the FastMRI measurements to the centered OASIS FFT convention automatically: + +```bash +python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet-oasis-acceleration8 ``` Run the same lightweight example on OASIS data. The packaged OASIS split CSV and U-Net checkpoint are downloaded automatically when missing: ```bash -python examples/fastmri_inference_plot.py --source /path/to/oasis_cross_sectional_data --dataset oasis --algorithm unet +python examples/fastmri_inference_plot.py --source /path/to/oasis_cross_sectional_data --dataset oasis --algorithm unet-oasis-acceleration4 ``` -For OASIS, `--oasis_checkpoint_acceleration` only selects the packaged U-Net weights by their training acceleration. Distortion undersampling is still controlled by `--keep_fraction` and `--center_fraction`. +Supported explicit U-Net algorithms are `unet-fastmri`, `unet-oasis-acceleration4`, `unet-oasis-acceleration8`, and `unet-oasis-acceleration10`. `unet-fastmri` on the OASIS dataset is intentionally rejected. ## Pre-commit diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index 5a28a9d..1379db0 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -11,17 +11,44 @@ import argparse from pathlib import Path -import matplotlib.pyplot as plt -import torch import deepinv as dinv +import torch -from mri_recon.distortions import * -from mri_recon.reconstruction import * +from mri_recon.distortions import ( + AnisotropicResolutionReduction, + BaseDistortion, + CartesianUndersampling, + DistortedKspaceMultiCoilMRI, + GaussianKspaceBiasField, + GaussianNoiseDistortion, + HannTaperResolutionReduction, + IsotropicResolutionReduction, + KaiserTaperResolutionReduction, + OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, + PhaseEncodeGhostingDistortion, + RadialHighPassEmphasisDistortion, + RotationalMotionDistortion, + SegmentedRotationalMotionDistortion, + SegmentedTranslationMotionDistortion, + TranslationMotionDistortion, +) +from mri_recon.reconstruction import ( + ConjugateGradientReconstructor, + EXPLICIT_UNET_ALGORITHMS, + OASISSinglecoilUnetReconstructor, + choose_reconstructor, + uses_oasis_centered_path, + validate_algorithm_dataset_compatibility, +) from mri_recon.utils import ( OasisCenteredFFTPhysics, OasisSliceDataset, + fastmri_measurement_to_image, + fastmri_measurement_to_oasis_kspace, image_to_kspace, kspace_to_image, + save_kspace_plot, ) FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" @@ -29,16 +56,17 @@ FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ - # "zero-filled", + "zero-filled", # "conjugate-gradient", # "ram", # "dip", - # "tv-pgd", + "tv-pgd", # "wavelet-fista", - # "tv-fista", + "tv-fista", # "tv-pdhg", - "unet", # dataset-specific U-Net; downloads weights if not already present + *list(EXPLICIT_UNET_ALGORITHMS), ] + DISTORTIONS = [ "Cartesian undersampling (variable density)", "Cartesian undersampling (uniform random)", @@ -70,95 +98,19 @@ ] -def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: - """Convert k-space tensor to log-magnitude image for visualization.""" - if kspace.ndim == 4: - kspace = kspace[0] - if kspace.ndim != 3 or kspace.shape[0] != 2: - raise ValueError( - f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" - ) - - kspace = kspace.detach().cpu() - kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) - magnitude = torch.log1p(torch.abs(kspace_complex)) - - lower = torch.quantile(magnitude, 0.05) - upper = torch.quantile(magnitude, 0.995) - if float(upper) > float(lower): - magnitude = magnitude.clamp(lower, upper) - magnitude = (magnitude - lower) / (upper - lower) - else: - mag_max = float(magnitude.max()) - if mag_max > 0.0: - magnitude = magnitude / mag_max - - return torch.sqrt(magnitude) - - -def save_kspace_plot( - y_clean: torch.Tensor, - y_distorted: torch.Tensor, - save_fn: Path, - distortion_name: str, -) -> None: - images = [ - ("Original k-space", _kspace_to_log_magnitude(y_clean)), - ("Distorted k-space", _kspace_to_log_magnitude(y_distorted)), - ] - - fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) - fig.suptitle(f"Distortion: {distortion_name}") - for ax, (title, image) in zip(axes, images, strict=True): - ax.imshow(image.numpy(), cmap="magma") - ax.set_title(title) - ax.axis("off") - fig.savefig(save_fn, dpi=200, bbox_inches="tight") - plt.close(fig) - - -def choose_algorithm( - name: str, - img_size: tuple = (640, 368), - device: torch.device = "cpu", - verbose: bool = False, - dataset: str = "fastmri", - oasis_checkpoint_acceleration: int = 4, -) -> dinv.models.Reconstructor: - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=verbose) - case "tv-fista": - return TVFISTAReconstructor(n_iter=200, verbose=verbose) - case "tv-pdhg": - return TVPDHGReconstructor(n_iter=100, verbose=verbose) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) - case "unet": - if dataset == "oasis": - return OASISSinglecoilUnetReconstructor( - acceleration=oasis_checkpoint_acceleration, - device=device, - ) - return FastMRISinglecoilUnetReconstructor(device=device) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - def choose_distortion( name: str, keep_fraction: float = 0.25, center_fraction: float = 0.125, cartesian_axis: int = -2, ) -> BaseDistortion: + """Build one distortion operator for the inference comparison script. + + The ``cartesian_axis`` is supplied by the active measurement convention: + FastMRI-native runs use the repository's existing axis, while OASIS-native + and FastMRI-to-OASIS runs use the centered OASIS axis. + """ + match name: case "Phase-encode ghosting": return PhaseEncodeGhostingDistortion( @@ -264,6 +216,8 @@ def choose_distortion( def choose_metric(name: str) -> dinv.metric.Metric: + """Build one evaluation metric used in the saved comparison plots.""" + match name: case "PSNR": return dinv.metric.PSNR(max_pixel=None, complex_abs=True) @@ -279,6 +233,59 @@ def choose_metric(name: str) -> dinv.metric.Metric: return dinv.metric.SharpnessIndex(complex_abs=True) +def prepare_measurement_sample( + sample_batch: object, + dataset_name: str, + use_oasis_fft_path: bool, + run_device: torch.device | str, +) -> tuple[torch.Tensor | None, torch.Tensor]: + """Prepare one input measurement and its clean image reference. + + FastMRI samples are loaded as native measurements. When the OASIS U-Net is + selected on FastMRI data, the helper converts those measurements into the + centered OASIS k-space convention while preserving the native adjoint image + as the clean reference. + """ + + if dataset_name == "oasis": + reference_image = sample_batch["x"].to(run_device) + return reference_image, image_to_kspace(reference_image) + + # FastMRI batches are tuples such as (x, y) or (x, y, params). + y_fastmri = sample_batch[1].to(run_device) + if use_oasis_fft_path: + reference_image = fastmri_measurement_to_image(y_fastmri, device=run_device) + return reference_image, fastmri_measurement_to_oasis_kspace(y_fastmri, device=run_device) + + return None, y_fastmri + + +def build_physics_pair( + image_shape: tuple[int, int], + distortion_operator: BaseDistortion, + run_device: torch.device | str, + use_oasis_fft_path: bool, +) -> tuple[object, object]: + """Build clean and distorted physics operators for the active path.""" + + if use_oasis_fft_path: + return OasisCenteredFFTPhysics(BaseDistortion()), OasisCenteredFFTPhysics( + distortion_operator + ) + + clean_physics = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *image_shape), + device=run_device, + ) + distorted_physics = DistortedKspaceMultiCoilMRI( + distortion=distortion_operator, + img_size=(1, 2, *image_shape), + device=run_device, + ) + return clean_physics, distorted_physics + + if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) @@ -312,17 +319,6 @@ def choose_metric(name: str) -> dinv.metric.Metric: choices=ALGORITHMS, help="Reconstruction algorithm applied to undistorted and distorted k-space.", ) - parser.add_argument( - "--oasis_checkpoint_acceleration", - type=int, - default=4, - choices=[4, 8, 10], - help=( - "Training acceleration of the packaged OASIS U-Net checkpoint. " - "This only selects pretrained weights; distortion undersampling is " - "controlled by --keep_fraction and --center_fraction." - ), - ) # inference related arguments parser.add_argument("--num_samples", type=int, default=1, help="How many samples to process.") parser.add_argument( @@ -332,6 +328,11 @@ def choose_metric(name: str) -> dinv.metric.Metric: ) args = parser.parse_args() + selected_algorithms = ALGORITHMS if args.algorithm == "" else [args.algorithm] + selected_distortions = DISTORTIONS if args.distortion == "" else [args.distortion] + for algo_name in selected_algorithms: + validate_algorithm_dataset_compatibility(args.dataset, algo_name) + # set up report dir REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR @@ -353,67 +354,56 @@ def choose_metric(name: str) -> dinv.metric.Metric: if i >= args.num_samples: break - if args.dataset == "oasis": - x = batch["x"].to(device) - y = image_to_kspace(x) - else: - # batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan), - # y is kspace, and params is a dict containing mask (if test set) - y = batch[1].to(device) - - for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: - distortion = choose_distortion( - distortion_name, - keep_fraction=args.keep_fraction, - center_fraction=args.center_fraction, - cartesian_axis=-1 if args.dataset == "oasis" else -2, + for algo_name in selected_algorithms: + use_oasis_path = uses_oasis_centered_path(args.dataset, algo_name) + x_reference, y = prepare_measurement_sample( + sample_batch=batch, + dataset_name=args.dataset, + use_oasis_fft_path=use_oasis_path, + run_device=device, ) + algo = choose_reconstructor( + algo_name, + img_size=y.shape[-2:], + device=device, + verbose=args.verbose, + dataset=args.dataset, + ).to(device) + + for distortion_name in selected_distortions: + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if use_oasis_path else -2, + ) - # create physics objects for both clean and distorted k-space - # the distortion is applied to the k-space measurements (not the image) - # TODO: allow loading multicoil data - if args.dataset == "oasis": - physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) - physics = OasisCenteredFFTPhysics(distortion) - else: - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device + physics_clean, physics = build_physics_pair( + image_shape=y.shape[-2:], + distortion_operator=distortion, + run_device=device, + use_oasis_fft_path=use_oasis_path, ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device + y_distorted = distortion.A(y) + + # generate reference reconstructions (CG) for both clean and distorted k-space + # without correction for the distortion, i.e. using physics_clean in both cases + if use_oasis_path: + x_clean = x_reference + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + + save_kspace_plot( + y, + y_distorted, + REPORT_DIR / f"DISTORTION_{algo_name}_{distortion_name}_sample_{i}.png", + distortion_name, ) - y_distorted = distortion.A(y) - - # generate reference reconstructions (CG) for both clean and distorted k-space - # without correction for the distortion, i.e. using physics_clean in both cases - if args.dataset == "oasis": - x_clean = x - x_distorted = kspace_to_image(y_distorted) - else: - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) - - # plot and save the k-space magnitude for both clean and distorted k-space - save_kspace_plot( - y, - y_distorted, - REPORT_DIR / f"DISTORTION_{distortion_name}_sample_{i}.png", - distortion_name, - ) - # loop through algorithms to evaluate on the distorted k-space, with and without correction for the distortion - for algo_name in ALGORITHMS if args.algorithm == "" else [args.algorithm]: print(f"Evaluating algo {algo_name}, distortion {distortion_name}, sample {i}...") - algo = choose_algorithm( - algo_name, - img_size=y.shape[-2:], - device=device, - verbose=args.verbose, - dataset=args.dataset, - oasis_checkpoint_acceleration=args.oasis_checkpoint_acceleration, - ).to(device) - # actual reconstruction with the algo being evaluated x_uncorrected = algo(y_distorted, physics_clean) x_corrected = algo(y_distorted, physics) diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index eafddf3..5eb75ea 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -4,6 +4,14 @@ FastMRISinglecoilUnetReconstructor, OASISSinglecoilUnetReconstructor, ) +from .inference import ( + EXPLICIT_UNET_ALGORITHMS, + FASTMRI_UNET_ALGORITHM, + OASIS_UNET_ALGORITHMS, + choose_reconstructor, + uses_oasis_centered_path, + validate_algorithm_dataset_compatibility, +) from .classic import ( ZeroFilledReconstructor, ConjugateGradientReconstructor, diff --git a/mri_recon/reconstruction/inference.py b/mri_recon/reconstruction/inference.py new file mode 100644 index 0000000..b0fe1d3 --- /dev/null +++ b/mri_recon/reconstruction/inference.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import deepinv as dinv +import torch + +from .classic import ( + ConjugateGradientReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + ZeroFilledReconstructor, +) +from .deep import ( + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, + RAMReconstructor, +) + +FASTMRI_UNET_ALGORITHM = "unet-fastmri" +OASIS_UNET_ALGORITHMS = { + "unet-oasis-acceleration4": 4, + "unet-oasis-acceleration8": 8, + "unet-oasis-acceleration10": 10, +} +EXPLICIT_UNET_ALGORITHMS = (FASTMRI_UNET_ALGORITHM, *tuple(OASIS_UNET_ALGORITHMS)) + + +def uses_oasis_centered_path( + dataset: str, + algorithm: str, +) -> bool: + """Return whether inference should use the centered OASIS k-space path. + + OASIS samples always use the centered FFT convention. FastMRI only switches + to that path when the selected algorithm is one of the explicit OASIS U-Net + variants. + """ + + if dataset == "oasis": + return True + return algorithm in OASIS_UNET_ALGORITHMS + + +def validate_algorithm_dataset_compatibility(dataset: str, algorithm: str) -> None: + """Raise a clear error when an explicit algorithm is incompatible with a dataset.""" + + if dataset == "oasis" and algorithm == FASTMRI_UNET_ALGORITHM: + raise ValueError( + "The algorithm 'unet-fastmri' is not supported on the OASIS dataset. " + "Use one of the explicit OASIS U-Net algorithms instead." + ) + + +def choose_reconstructor( + name: str, + img_size: tuple = (640, 368), + device: torch.device | str = "cpu", + verbose: bool = False, + dataset: str = "fastmri", +) -> dinv.models.Reconstructor: + """Create a reconstructor while enforcing the supported dataset/model matrix. + + Parameters + ---------- + name : str + High-level algorithm identifier used by the example entry point. + img_size : tuple, optional + Spatial image size used by reconstructors that need an explicit image + shape, such as DIP. + device : torch.device | str, optional + Device on which to instantiate the reconstructor. + verbose : bool, optional + Forwarded to reconstructors that expose a verbose mode. + dataset : str, optional + Dataset being evaluated. This only affects compatibility checks for + explicit algorithm names that are dataset-specific. + """ + + validate_algorithm_dataset_compatibility(dataset, name) + + match name: + case "zero-filled": + return ZeroFilledReconstructor() + case "conjugate-gradient": + return ConjugateGradientReconstructor(max_iter=20) + case "ram": + return RAMReconstructor(default_sigma=0.05, device=device) + case "dip": + return DeepImagePriorReconstructor(img_size=img_size, n_iter=100, verbose=verbose) + case "tv-pgd": + return TVPGDReconstructor(n_iter=100, verbose=verbose) + case "tv-fista": + return TVFISTAReconstructor(n_iter=200, verbose=verbose) + case "tv-pdhg": + return TVPDHGReconstructor(n_iter=100, verbose=verbose) + case "wavelet-fista": + return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) + case _ if name == FASTMRI_UNET_ALGORITHM: + return FastMRISinglecoilUnetReconstructor(device=device) + case _ if name in OASIS_UNET_ALGORITHMS: + return OASISSinglecoilUnetReconstructor( + acceleration=OASIS_UNET_ALGORITHMS[name], + device=device, + ) + case _: + raise ValueError(f"Unknown algorithm {name!r}") diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 1e27df9..744c860 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -4,16 +4,24 @@ from .io import matches_sha256 as matches_sha256 from .oasis_adapter import OasisCenteredFFTPhysics as OasisCenteredFFTPhysics from .oasis_adapter import OasisSliceDataset as OasisSliceDataset +from .oasis_adapter import fastmri_measurement_to_image as fastmri_measurement_to_image +from .oasis_adapter import ( + fastmri_measurement_to_oasis_kspace as fastmri_measurement_to_oasis_kspace, +) from .oasis_adapter import image_to_kspace as image_to_kspace from .oasis_adapter import kspace_to_image as kspace_to_image +from .plot import save_kspace_plot as save_kspace_plot __all__ = [ "download_file_with_sha256", "download_google_drive_file_with_sha256", "format_megabytes", + "fastmri_measurement_to_image", + "fastmri_measurement_to_oasis_kspace", "image_to_kspace", "kspace_to_image", "matches_sha256", "OasisCenteredFFTPhysics", "OasisSliceDataset", + "save_kspace_plot", ] diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py index 4b2750f..0f833aa 100644 --- a/mri_recon/utils/oasis_adapter.py +++ b/mri_recon/utils/oasis_adapter.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import Dataset -from mri_recon.distortions import BaseDistortion +from mri_recon.distortions import BaseDistortion, DistortedKspaceMultiCoilMRI class OasisSliceDataset(Dataset): @@ -167,6 +167,57 @@ def kspace_to_image(y: torch.Tensor) -> torch.Tensor: return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() +def fastmri_measurement_to_image( + y: torch.Tensor, + device: torch.device | str | None = None, +) -> torch.Tensor: + """Convert FastMRI measurements to image space using the repo's native physics. + + Parameters + ---------- + y : torch.Tensor + FastMRI measurement tensor with shape ``(B, 2, H, W)``. + device : torch.device | str, optional + Device on which to instantiate the temporary native physics operator. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + if device is None: + device = y.device + physics = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *y.shape[-2:]), + device=device, + ) + return physics.A_adjoint(y) + + +def fastmri_measurement_to_oasis_kspace( + y: torch.Tensor, + device: torch.device | str | None = None, +) -> torch.Tensor: + """Adapt FastMRI measurements to the centered OASIS k-space convention. + + Parameters + ---------- + y : torch.Tensor + FastMRI measurement tensor with shape ``(B, 2, H, W)``. + device : torch.device | str, optional + Device on which to instantiate the temporary native physics operator. + + Returns + ------- + torch.Tensor + Centered OASIS-convention k-space tensor with shape ``(B, 2, H, W)``. + """ + + return image_to_kspace(fastmri_measurement_to_image(y, device=device)) + + class OasisCenteredFFTPhysics: """Physics adapter matching the OASIS U-Net FFT convention. diff --git a/mri_recon/utils/plot.py b/mri_recon/utils/plot.py new file mode 100644 index 0000000..068c222 --- /dev/null +++ b/mri_recon/utils/plot.py @@ -0,0 +1,58 @@ +"""Plotting helpers shared across examples and utilities.""" + +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt +import torch + + +def _kspace_to_log_magnitude(kspace: torch.Tensor) -> torch.Tensor: + """Convert k-space tensor to a log-magnitude image for visualization.""" + + if kspace.ndim == 4: + kspace = kspace[0] + if kspace.ndim != 3 or kspace.shape[0] != 2: + raise ValueError( + f"Expected k-space with shape (2, H, W) or (1, 2, H, W), got {tuple(kspace.shape)}" + ) + + kspace = kspace.detach().cpu() + kspace_complex = torch.view_as_complex(kspace.permute(1, 2, 0).contiguous()) + magnitude = torch.log1p(torch.abs(kspace_complex)) + + lower = torch.quantile(magnitude, 0.05) + upper = torch.quantile(magnitude, 0.995) + if float(upper) > float(lower): + magnitude = magnitude.clamp(lower, upper) + magnitude = (magnitude - lower) / (upper - lower) + else: + mag_max = float(magnitude.max()) + if mag_max > 0.0: + magnitude = magnitude / mag_max + + return torch.sqrt(magnitude) + + +def save_kspace_plot( + clean_kspace: torch.Tensor, + distorted_kspace: torch.Tensor, + save_fn: Path, + distortion_label: str, +) -> None: + """Save side-by-side log-magnitude visualizations of clean and distorted k-space.""" + + images = [ + ("Original k-space", _kspace_to_log_magnitude(clean_kspace)), + ("Distorted k-space", _kspace_to_log_magnitude(distorted_kspace)), + ] + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + fig.suptitle(f"Distortion: {distortion_label}") + for ax, (title, image) in zip(axes, images, strict=True): + ax.imshow(image.numpy(), cmap="magma") + ax.set_title(title) + ax.axis("off") + fig.savefig(save_fn, dpi=200, bbox_inches="tight") + plt.close(fig) diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index 36117ff..3c95455 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -9,17 +9,15 @@ from mri_recon.reconstruction._fastmri_unet import Unet from mri_recon.reconstruction.deep import ( - RAMReconstructor, - DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor, OASISSinglecoilUnetReconstructor, ) -from mri_recon.reconstruction.classic import ( - ZeroFilledReconstructor, - ConjugateGradientReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - TVFISTAReconstructor, +from mri_recon.reconstruction.inference import ( + FASTMRI_UNET_ALGORITHM, + OASIS_UNET_ALGORITHMS, + choose_reconstructor, + uses_oasis_centered_path, + validate_algorithm_dataset_compatibility, ) from mri_recon.distortions import DistortedKspaceMultiCoilMRI @@ -34,45 +32,29 @@ ] -def choose_algorithm(name, img_size, device): - match name: - case "zero-filled": - return ZeroFilledReconstructor() - case "conjugate-gradient": - return ConjugateGradientReconstructor(max_iter=20) - case "ram": - return RAMReconstructor(default_sigma=0.05, device=device) - case "dip": - return DeepImagePriorReconstructor(img_size=img_size[-2:], n_iter=100) - case "tv-pgd": - return TVPGDReconstructor(n_iter=100, verbose=False) - case "tv-fista": - return TVFISTAReconstructor(n_iter=100, verbose=False) - case "wavelet-fista": - return WaveletFISTAReconstructor(n_iter=100, verbose=False, device=device) - case _: - raise ValueError(f"Unknown algorithm {name!r}") - - -@pytest.fixture -def device(): +@pytest.fixture(name="runtime_device") +def fixture_runtime_device(): return "cpu" @pytest.mark.parametrize("name", ALGORITHMS) -def test_reconstructors(name, device): +def test_reconstructors(name, runtime_device): """ Test that reconstruction algorithms work end to end on a dummy example. """ x = dinv.utils.load_example( - "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + "butterfly.png", + img_size=(32, 32), + grayscale=True, + resize_mode="resize", + device=runtime_device, ) x = torch.cat([x, torch.zeros_like(x)], dim=1) # dummy complex data - model = choose_algorithm(name, img_size=x.shape[1:], device=device) + model = choose_reconstructor(name, img_size=x.shape[1:], device=runtime_device) physics = DistortedKspaceMultiCoilMRI( - img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=runtime_device ) y = physics(x) @@ -82,7 +64,7 @@ def test_reconstructors(name, device): assert x_hat.shape == x.shape -def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): +def test_fastmri_singlecoil_unet_reconstructor(runtime_device, tmp_path, monkeypatch): """ Test the FastMRI UNet reconstructor end to end using local fixture weights. """ @@ -102,16 +84,20 @@ def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): torch.save(Unet(**FastMRISinglecoilUnetReconstructor.UNET_KWARGS).state_dict(), weights_path) x = dinv.utils.load_example( - "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + "butterfly.png", + img_size=(32, 32), + grayscale=True, + resize_mode="resize", + device=runtime_device, ) x = torch.cat([x, torch.zeros_like(x)], dim=1) model = FastMRISinglecoilUnetReconstructor( - device=device, + device=runtime_device, state_dict_file=str(weights_path), ) physics = DistortedKspaceMultiCoilMRI( - img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=runtime_device ) y = physics(x) @@ -121,7 +107,7 @@ def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) -def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, tmp_path): +def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(runtime_device, tmp_path): """ Test the OASIS UNet reconstructor with a Lightning-style checkpoint. """ @@ -133,16 +119,20 @@ def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, ) x = dinv.utils.load_example( - "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + "butterfly.png", + img_size=(32, 32), + grayscale=True, + resize_mode="resize", + device=runtime_device, ) x = torch.cat([x, torch.zeros_like(x)], dim=1) model = OASISSinglecoilUnetReconstructor( checkpoint_file=str(weights_path), - device=device, + device=runtime_device, ) physics = DistortedKspaceMultiCoilMRI( - img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=runtime_device ) y = physics(x) @@ -178,7 +168,7 @@ def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_ {"4": "checkpoint-file-id"}, ) - def fake_download(file_id, destination, expected_sha256, **kwargs): + def fake_download(file_id, destination, _expected_sha256, **_kwargs): downloads.append((file_id, destination)) destination.parent.mkdir(parents=True, exist_ok=True) if file_id == "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I": @@ -207,7 +197,7 @@ def fake_download(file_id, destination, expected_sha256, **kwargs): def test_oasis_singlecoil_unet_reconstructor_uses_packaged_checkpoint_defaults( - device, tmp_path, monkeypatch + runtime_device, tmp_path, monkeypatch ): monkeypatch.setattr( OASISSinglecoilUnetReconstructor, @@ -230,7 +220,7 @@ def test_oasis_singlecoil_unet_reconstructor_uses_packaged_checkpoint_defaults( captured = {} - def fake_resolve(cls, acceleration, manifest_path=None): + def fake_resolve(_cls, acceleration, manifest_path=None): captured["acceleration"] = acceleration captured["manifest_path"] = manifest_path return weights_path @@ -244,7 +234,7 @@ def fake_resolve(cls, acceleration, manifest_path=None): model = OASISSinglecoilUnetReconstructor( acceleration=8, manifest_path=str(tmp_path / "manifest.json"), - device=device, + device=runtime_device, ) assert captured == { @@ -252,3 +242,94 @@ def fake_resolve(cls, acceleration, manifest_path=None): "manifest_path": tmp_path / "manifest.json", } assert isinstance(model.model, Unet) + + +def test_validate_algorithm_dataset_compatibility_accepts_supported_explicit_unets(): + validate_algorithm_dataset_compatibility("fastmri", FASTMRI_UNET_ALGORITHM) + validate_algorithm_dataset_compatibility("fastmri", "unet-oasis-acceleration8") + validate_algorithm_dataset_compatibility("oasis", "unet-oasis-acceleration4") + + +def test_validate_algorithm_dataset_compatibility_rejects_unsupported_oasis_fastmri_combo(): + with pytest.raises(ValueError, match="unet-fastmri"): + validate_algorithm_dataset_compatibility("oasis", FASTMRI_UNET_ALGORITHM) + + +def test_uses_oasis_centered_path_tracks_dataset_and_explicit_algorithm(): + assert uses_oasis_centered_path("oasis", FASTMRI_UNET_ALGORITHM) is True + assert uses_oasis_centered_path("fastmri", "unet-oasis-acceleration8") is True + assert uses_oasis_centered_path("fastmri", FASTMRI_UNET_ALGORITHM) is False + assert uses_oasis_centered_path("fastmri", "tv-pgd") is False + + +def test_choose_reconstructor_selects_oasis_unet_for_fastmri_when_requested(monkeypatch): + captured = {} + + class Marker: + pass + + def fake_oasis(*, acceleration, device): + captured["acceleration"] = acceleration + captured["device"] = device + return Marker() + + monkeypatch.setattr( + "mri_recon.reconstruction.inference.OASISSinglecoilUnetReconstructor", + fake_oasis, + ) + + reconstructor = choose_reconstructor( + "unet-oasis-acceleration8", + dataset="fastmri", + device="cpu", + ) + + assert isinstance(reconstructor, Marker) + assert captured == {"acceleration": 8, "device": "cpu"} + + +def test_choose_reconstructor_uses_fastmri_unet_by_default(monkeypatch): + marker = object() + + def fake_fastmri(*, device): + assert device == "cpu" + return marker + + monkeypatch.setattr( + "mri_recon.reconstruction.inference.FastMRISinglecoilUnetReconstructor", + fake_fastmri, + ) + + reconstructor = choose_reconstructor( + FASTMRI_UNET_ALGORITHM, + dataset="fastmri", + device="cpu", + ) + + assert reconstructor is marker + + +def test_choose_reconstructor_supports_all_explicit_oasis_algorithms(monkeypatch): + captured = [] + + class Marker: + pass + + def fake_oasis(*, acceleration, device): + captured.append((acceleration, device)) + return Marker() + + monkeypatch.setattr( + "mri_recon.reconstruction.inference.OASISSinglecoilUnetReconstructor", + fake_oasis, + ) + + for algorithm_name in OASIS_UNET_ALGORITHMS: + reconstructor = choose_reconstructor( + algorithm_name, + dataset="fastmri", + device="cpu", + ) + assert isinstance(reconstructor, Marker) + + assert captured == [(4, "cpu"), (8, "cpu"), (10, "cpu")] diff --git a/tests/test_utils_io.py b/tests/test_utils_io.py index 3d2c434..304446d 100644 --- a/tests/test_utils_io.py +++ b/tests/test_utils_io.py @@ -1,6 +1,14 @@ import hashlib from io import BytesIO +import torch + +from mri_recon.distortions import BaseDistortion, DistortedKspaceMultiCoilMRI +from mri_recon.utils.oasis_adapter import ( + fastmri_measurement_to_image, + fastmri_measurement_to_oasis_kspace, + kspace_to_image, +) from mri_recon.utils.io import download_file_with_sha256, download_google_drive_file_with_sha256 @@ -66,3 +74,28 @@ def fake_urlopen(url, timeout=30): assert destination.read_bytes() == payload assert any("confirm=t" in url and "uuid=uuid-456" in url for url in requested_urls) + + +def test_fastmri_measurement_helpers_match_centered_oasis_path(): + x = torch.randn(1, 2, 16, 12) + physics = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), + img_size=(1, 2, *x.shape[-2:]), + device="cpu", + ) + y_fastmri = physics.A(x) + x_native = physics.A_adjoint(y_fastmri) + y_oasis = fastmri_measurement_to_oasis_kspace(y_fastmri, device="cpu") + + assert torch.allclose( + fastmri_measurement_to_image(y_fastmri, device="cpu"), + x_native, + atol=1e-6, + rtol=1e-6, + ) + assert torch.allclose( + kspace_to_image(y_oasis), + x_native, + atol=1e-6, + rtol=1e-6, + )