diff --git a/README.md b/README.md index 091d989..aa7b16e 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Reconstruction playground for the MRI Recon Metrics Reloaded workgroup. | `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. | | `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. | | `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. | +| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. | ## Implemented Distortions @@ -56,6 +57,22 @@ uv sync uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)" ``` +## Inference Examples + +Run the FastMRI plotting example with local FastMRI k-space files: + +```bash +python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet +``` + +Run the same lightweight example on OASIS data. The packaged OASIS split CSV and U-Net checkpoint are downloaded automatically when missing: + +```bash +python examples/fastmri_inference_plot.py --source /path/to/oasis_cross_sectional_data --dataset oasis --algorithm unet +``` + +For OASIS, `--oasis_checkpoint_acceleration` only selects the packaged U-Net weights by their training acceleration. Distortion undersampling is still controlled by `--keep_fraction` and `--center_fraction`. + ## Pre-commit Install the local tooling and register the git hook: diff --git a/examples/fastmri_inference_plot.py b/examples/fastmri_inference_plot.py index e8f23ef..5a28a9d 100644 --- a/examples/fastmri_inference_plot.py +++ b/examples/fastmri_inference_plot.py @@ -17,9 +17,17 @@ from mri_recon.distortions import * from mri_recon.reconstruction import * - -REPORT_DIR = Path("reports") / "fastmri_inference_plot" -REPORT_DIR.mkdir(parents=True, exist_ok=True) +from mri_recon.utils import ( + OasisCenteredFFTPhysics, + OasisSliceDataset, + image_to_kspace, + kspace_to_image, +) + +FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot" +OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot" +FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True) +OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True) ALGORITHMS = [ # "zero-filled", # "conjugate-gradient", @@ -29,7 +37,7 @@ # "wavelet-fista", # "tv-fista", # "tv-pdhg", - "unet", # will trigger download of pretrained weights if not already present + "unet", # dataset-specific U-Net; downloads weights if not already present ] DISTORTIONS = [ "Cartesian undersampling (variable density)", @@ -37,6 +45,7 @@ "Cartesian undersampling (uniform random, zero ACS)", "Cartesian undersampling (equispaced)", "Cartesian undersampling (equispaced, zero ACS)", + "Partial Fourier", "Phase-encode ghosting", "Segmented translation motion", "Segmented rotational motion", @@ -113,6 +122,8 @@ def choose_algorithm( img_size: tuple = (640, 368), device: torch.device = "cpu", verbose: bool = False, + dataset: str = "fastmri", + oasis_checkpoint_acceleration: int = 4, ) -> dinv.models.Reconstructor: match name: case "zero-filled": @@ -132,12 +143,22 @@ def choose_algorithm( case "wavelet-fista": return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose) case "unet": + if dataset == "oasis": + return OASISSinglecoilUnetReconstructor( + acceleration=oasis_checkpoint_acceleration, + device=device, + ) return FastMRISinglecoilUnetReconstructor(device=device) case _: raise ValueError(f"Unknown algorithm {name!r}") -def choose_distortion(name: str) -> BaseDistortion: +def choose_distortion( + name: str, + keep_fraction: float = 0.25, + center_fraction: float = 0.125, + cartesian_axis: int = -2, +) -> BaseDistortion: match name: case "Phase-encode ghosting": return PhaseEncodeGhostingDistortion( @@ -148,39 +169,51 @@ def choose_distortion(name: str) -> BaseDistortion: ) case "Cartesian undersampling (variable density)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="variable_density_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (uniform random)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="uniform_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (uniform random, zero ACS)": return CartesianUndersampling( - keep_fraction=0.5, + keep_fraction=keep_fraction, center_fraction=0.0, pattern="uniform_random", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (equispaced)": return CartesianUndersampling( - keep_fraction=0.25, - center_fraction=0.125, + keep_fraction=keep_fraction, + center_fraction=center_fraction, pattern="equispaced", + axis=cartesian_axis, seed=42, ) case "Cartesian undersampling (equispaced, zero ACS)": return CartesianUndersampling( - keep_fraction=0.5, + keep_fraction=keep_fraction, center_fraction=0.0, pattern="equispaced", + axis=cartesian_axis, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=center_fraction, + axis=cartesian_axis, + side="high", + ) case "Anisotropic LP": return AnisotropicResolutionReduction( kx_radius_fraction=1.0, @@ -248,10 +281,30 @@ def choose_metric(name: str) -> dinv.metric.Metric: if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) + + # data related arguments parser.add_argument( - "--source", type=str, help="Local FastMRI directory with raw k-space .h5 files." + "--source", + type=Path, + help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.", ) + parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri") + parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS) + parser.add_argument( + "--keep_fraction", + type=float, + default=0.25, + help="Fraction of k-space lines to keep for undersampling distortions.", + ) + parser.add_argument( + "--center_fraction", + type=float, + default=0.125, + help="Fraction of low-frequency k-space lines to keep fully for undersampling distortions.", + ) + + # algo related arguments parser.add_argument( "--algorithm", type=str, @@ -259,6 +312,18 @@ def choose_metric(name: str) -> dinv.metric.Metric: choices=ALGORITHMS, help="Reconstruction algorithm applied to undistorted and distorted k-space.", ) + parser.add_argument( + "--oasis_checkpoint_acceleration", + type=int, + default=4, + choices=[4, 8, 10], + help=( + "Training acceleration of the packaged OASIS U-Net checkpoint. " + "This only selects pretrained weights; distortion undersampling is " + "controlled by --keep_fraction and --center_fraction." + ), + ) + # inference related arguments parser.add_argument("--num_samples", type=int, default=1, help="How many samples to process.") parser.add_argument( "--verbose", @@ -267,9 +332,20 @@ def choose_metric(name: str) -> dinv.metric.Metric: ) args = parser.parse_args() + # set up report dir + REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR + # set up device, dataset, metrics device = dinv.utils.get_device() - dataset = dinv.datasets.FastMRISliceDataset(args.source, slice_index="middle") + if args.dataset == "oasis": + split_csv = OASISSinglecoilUnetReconstructor.resolve_default_split_csv() + dataset = OasisSliceDataset( + data_path=args.source, + split_csv=split_csv, + sample_rate=0.6, + ) + else: + dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle") metrics = [choose_metric(m) for m in METRICS] for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))): @@ -277,30 +353,45 @@ def choose_metric(name: str) -> dinv.metric.Metric: if i >= args.num_samples: break - # batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan), - # y is kspace, and params is a dict containing mask (if test set) - y = batch[1] + if args.dataset == "oasis": + x = batch["x"].to(device) + y = image_to_kspace(x) + else: + # batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan), + # y is kspace, and params is a dict containing mask (if test set) + y = batch[1].to(device) for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]: - distortion = choose_distortion(distortion_name) + distortion = choose_distortion( + distortion_name, + keep_fraction=args.keep_fraction, + center_fraction=args.center_fraction, + cartesian_axis=-1 if args.dataset == "oasis" else -2, + ) # create physics objects for both clean and distorted k-space # the distortion is applied to the k-space measurements (not the image) # TODO: allow loading multicoil data - physics_clean = DistortedKspaceMultiCoilMRI( - distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device - ) - physics = DistortedKspaceMultiCoilMRI( - distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device - ) - - y = y.to(device) + if args.dataset == "oasis": + physics_clean = OasisCenteredFFTPhysics(BaseDistortion()) + physics = OasisCenteredFFTPhysics(distortion) + else: + physics_clean = DistortedKspaceMultiCoilMRI( + distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device + ) + physics = DistortedKspaceMultiCoilMRI( + distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device + ) y_distorted = distortion.A(y) # generate reference reconstructions (CG) for both clean and distorted k-space # without correction for the distortion, i.e. using physics_clean in both cases - x_clean = ConjugateGradientReconstructor()(y, physics_clean) - x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) + if args.dataset == "oasis": + x_clean = x + x_distorted = kspace_to_image(y_distorted) + else: + x_clean = ConjugateGradientReconstructor()(y, physics_clean) + x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean) # plot and save the k-space magnitude for both clean and distorted k-space save_kspace_plot( @@ -319,6 +410,8 @@ def choose_metric(name: str) -> dinv.metric.Metric: img_size=y.shape[-2:], device=device, verbose=args.verbose, + dataset=args.dataset, + oasis_checkpoint_acceleration=args.oasis_checkpoint_acceleration, ).to(device) # actual reconstruction with the algo being evaluated diff --git a/mri_recon/distortions/__init__.py b/mri_recon/distortions/__init__.py index 80bbad4..af22fd1 100644 --- a/mri_recon/distortions/__init__.py +++ b/mri_recon/distortions/__init__.py @@ -19,4 +19,4 @@ KaiserTaperResolutionReduction, RadialHighPassEmphasisDistortion, ) -from .undersampling import CartesianUndersampling +from .undersampling import CartesianUndersampling, PartialFourierDistortion diff --git a/mri_recon/distortions/biasfield.py b/mri_recon/distortions/biasfield.py index 26ca4f5..a9d7910 100644 --- a/mri_recon/distortions/biasfield.py +++ b/mri_recon/distortions/biasfield.py @@ -80,11 +80,37 @@ class OffCenterAnisotropicGaussianKspaceBiasField(BaseDistortion): The gain peaks at an offset location in k-space, decays with different widths along ``kx`` and ``ky``, and is normalized to unit maximum on the sampled grid. - :param float width_x_fraction: Gaussian width along the normalized ``kx`` direction. - :param float width_y_fraction: Gaussian width along the normalized ``ky`` direction. - :param float center_x_fraction: Center offset along normalized ``kx`` in ``[-1, 1]``. - :param float center_y_fraction: Center offset along normalized ``ky`` in ``[-1, 1]``. - :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie in ``(0, 1]``. + Note: This class can also approximate a readout-decay-like blur when used in a + centered anisotropic configuration. To mimic stronger attenuation along the + readout direction, keep the Gaussian centered at DC with + ``center_x_fraction=0.0`` and ``center_y_fraction=0.0``, choose a narrower + width along the readout axis than along the orthogonal axis, and reduce + ``edge_gain`` below ``1.0``. For example, a setting such as + ``width_x_fraction < width_y_fraction`` with a moderate ``edge_gain`` + produces a smooth directional loss of high-frequency content that can + resemble readout-decay blur. + + This remains a phenomenological approximation rather than an explicit + time-ordered readout-decay model. It applies a smooth multiplicative + k-space weighting, not a physically parameterized echo-train decay. + + :param float width_x_fraction: Gaussian width along the normalized ``kx`` + direction. Smaller values produce stronger attenuation away from the + center along ``kx``. When ``kx`` is the readout axis, choosing + ``width_x_fraction < width_y_fraction`` approximates readout-direction + blur. + :param float width_y_fraction: Gaussian width along the normalized ``ky`` + direction. Larger values preserve more support along ``ky`` relative + to ``kx``. + :param float center_x_fraction: Center offset along normalized ``kx`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float center_y_fraction: Center offset along normalized ``ky`` in + ``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like + approximation. + :param float edge_gain: Baseline gain far from the Gaussian peak. Must lie + in ``(0, 1]``. Smaller values strengthen the peripheral attenuation + and therefore the resulting directional blur. """ def __init__( diff --git a/mri_recon/distortions/resolution.py b/mri_recon/distortions/resolution.py index 1011abc..399a037 100644 --- a/mri_recon/distortions/resolution.py +++ b/mri_recon/distortions/resolution.py @@ -73,14 +73,33 @@ def _smooth_radial_low_pass_mask( class IsotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Low-pass truncation with a circular mask. - - This applies - ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]`` - where ``K`` is ``radius_fraction`` on the normalized radial grid. - - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are set to zero. + """Isotropic in-plane resolution reduction with a circular hard cutoff. + + This distortion keeps only a centered circular region of k-space: + ``M_out(kx, ky) = M(kx, ky) * 1[r(kx, ky) <= K]``, where ``K`` is the + retained radial support on the normalized frequency grid. + + In MRI terms, this models reduced in-plane spatial resolution at fixed + field of view by removing high-frequency content equally in all in-plane + directions. The reconstructed image keeps the same matrix size, but fine + detail is lost because the maximum sampled k-space extent is smaller. The + resulting effect is isotropic blur from limited k-space support. + + The parameter ``radius_fraction`` controls how much of the centered + low-frequency region is retained. Smaller values preserve only the k-space + core and therefore produce stronger blur and a broader point-spread + function. A value of ``1.0`` keeps the full sampled support and recovers + the identity operator. + + Compared with :class:`AnisotropicResolutionReduction`, this class applies + the same reduction in all directions rather than separately along readout + and phase encode. Compared with :class:`CartesianUndersampling`, it models + resolution loss by shrinking the sampled support, not by skipping lines + within the original support. + + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Smaller values correspond to + stronger isotropic resolution reduction. """ def __init__(self, radius_fraction: float = 0.6) -> None: @@ -96,17 +115,40 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class AnisotropicResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Axis-aligned low-pass truncation with independent cutoffs. - - This applies a rectangular mask - ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]`` - where ``Kx`` and ``Ky`` are the normalized cutoffs along the readout and - phase-encode frequency axes respectively. - - :param float kx_radius_fraction: Normalized cutoff along the horizontal - frequency axis in ``(0, 1]``. - :param float ky_radius_fraction: Normalized cutoff along the vertical - frequency axis in ``(0, 1]``. + """Axis-aligned Cartesian resolution reduction with independent cutoffs. + + This distortion keeps only a centered rectangular region of Cartesian + k-space and zeros the remaining outer frequencies: + ``M_out(kx, ky) = M(kx, ky) * 1[|kx| <= Kx] * 1[|ky| <= Ky]``. + + In MRI terms, this models reduced in-plane acquisition resolution at fixed + field of view. The reconstructed image keeps the same matrix size, but its + effective spatial resolution decreases because less high-frequency k-space + support is retained. The resulting artifact is blur from limited k-space + extent, not aliasing from undersampling. + + The two parameters control the retained centered k-space support along the + Cartesian encoding axes. ``kx_radius_fraction`` corresponds to the retained + readout-direction frequency extent, while ``ky_radius_fraction`` + corresponds to the retained phase-encode-direction frequency extent. + Reducing either parameter broadens the point-spread function along the + corresponding image direction. A typical MRI-like setting keeps + ``kx_radius_fraction`` close to ``1.0`` and reduces + ``ky_radius_fraction``, reflecting that protocols often sacrifice more + phase-encode resolution than readout resolution. + + This is a hard rectangular cutoff. If a softer edge is desired, use one of + the taper-based resolution distortions instead. + + :param float kx_radius_fraction: Fraction of the original centered + k-space extent retained along the horizontal frequency axis in + ``(0, 1]``. This corresponds to the retained readout-direction + resolution support. ``1.0`` keeps the full sampled readout extent. + :param float ky_radius_fraction: Fraction of the original centered + k-space extent retained along the vertical frequency axis in + ``(0, 1]``. This corresponds to the retained phase-encode-direction + resolution support. Smaller values produce stronger blur along the + corresponding image direction. """ def __init__( @@ -126,8 +168,8 @@ def __init__( def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: normalized_kx, normalized_ky = _normalized_axis_frequencies(shape) - # A rectangular passband models direction-dependent resolution loss, - # such as stronger truncation along phase encode than along readout. + # Centered rectangular support models reduced acquired Cartesian + # resolution, often stronger along phase encode than along readout. mask = (normalized_kx <= self.kx_radius_fraction) & ( normalized_ky <= self.ky_radius_fraction ) @@ -135,18 +177,41 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class HannTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Hann transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a raised-cosine profile near the cutoff, and is exactly ``0`` - beyond ``radius_fraction``. + """Isotropic resolution reduction with a Hann-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a raised-cosine Hann + profile. Inside the passband the mask equals ``1``; inside the transition + band it decreases smoothly from ``1`` to ``0``; beyond the cutoff it is + exactly ``0``. + + In MRI terms, this is still a resolution-reduction operator: it suppresses + high spatial frequencies and therefore lowers effective in-plane spatial + resolution at fixed field of view. Relative to + :class:`IsotropicResolutionReduction`, the main difference is not the type + of resolution loss but the edge behavior of the k-space support. The smooth + transition reduces ringing that a hard truncation can introduce, at the + cost of making the support edge less sharp. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` sets how much of that outer region is devoted to + the smooth taper. Setting ``transition_fraction=0`` recovers the hard + circular cutoff. Larger transition fractions make the cutoff gentler and + behave more like k-space apodization. + + Compared with :class:`CartesianUndersampling`, this class does not skip + phase-encode lines within the original support. It reduces resolution by + attenuating and removing high frequencies, leading primarily to blur rather + than undersampling aliasing. See https://en.wikipedia.org/wiki/Hann_function for details on the Hann window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. + the smooth Hann transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. """ def __init__( @@ -173,20 +238,42 @@ def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: class KaiserTaperResolutionReduction(SelfAdjointMultiplicativeMaskDistortion): - """Radial low-pass reduction with a Kaiser transition band. - - The mask equals ``1`` in the low-frequency passband, tapers smoothly to - ``0`` with a Kaiser-profile transition near the cutoff, and is exactly - ``0`` beyond ``radius_fraction``. + """Isotropic resolution reduction with a Kaiser-tapered radial cutoff. + + This distortion keeps a centered circular low-frequency region and tapers + the mask smoothly to zero near the cutoff using a Kaiser-window profile. + Inside the passband the mask equals ``1``; inside the transition band it + decreases smoothly from ``1`` to ``0``; beyond the cutoff it is exactly + ``0``. + + In MRI terms, this lowers effective in-plane spatial resolution at fixed + field of view by reducing the retained high-frequency k-space extent. As + with :class:`HannTaperResolutionReduction`, the purpose of the taper is to + soften the hard support edge and reduce ringing, while still producing blur + from lost high-frequency information. + + ``radius_fraction`` sets the outer radial extent of the retained support. + ``transition_fraction`` controls the width of the taper band. ``beta`` + controls the shape of the Kaiser taper within that band: larger values + produce a steeper transition, while smaller positive values produce a more + gradual roll-off. Setting ``transition_fraction=0`` recovers the hard + circular cutoff regardless of ``beta``. + + Compared with :class:`CartesianUndersampling`, this class reduces + resolution by shrinking and tapering the effective k-space support, rather + than by skipping lines from the original grid. The dominant image effect is + blur and apodization, not aliasing from sub-Nyquist sampling. See https://en.wikipedia.org/wiki/Kaiser_window for details on the Kaiser window. - :param float radius_fraction: Normalized cutoff radius in ``(0, 1]``. - Frequencies outside this radius are fully suppressed. + :param float radius_fraction: Fraction of the original centered radial + k-space support retained in ``(0, 1]``. Frequencies outside this radius + are fully suppressed. :param float transition_fraction: Fraction of the cutoff radius occupied by - the smooth transition in ``[0, 1]``. ``0`` recovers the hard cutoff. - :param float beta: Positive Kaiser shape parameter. Larger values create a - steeper transition inside the taper band. + the smooth Kaiser transition in ``[0, 1]``. ``0`` recovers the hard + circular cutoff. + :param float beta: Positive Kaiser shape parameter controlling how steeply + the taper falls inside the transition band. """ def __init__( diff --git a/mri_recon/distortions/undersampling.py b/mri_recon/distortions/undersampling.py index 4a41674..6c0be94 100644 --- a/mri_recon/distortions/undersampling.py +++ b/mri_recon/distortions/undersampling.py @@ -1,274 +1,415 @@ -"""Cartesian k-space undersampling distortions.""" - -from __future__ import annotations - -import torch - -from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion - - -PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} - -# Lorentzian offset that controls how steeply the variable-density weights -# fall off with normalized distance from k-space center. -# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio -# between the ACS boundary and the k-space edge. -_VD_WEIGHT_OFFSET: float = 0.05 - - -class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): - """Cartesian k-space undersampling along phase-encode direction. - - This distortion simulates true MRI acquisition undersampling by applying a - binary sampling mask along the phase-encode direction (default). The mask - keeps a contiguous center region (ACS - Auto-Calibration Signal) fully - sampled and randomly undersamples the periphery. - - The mask is deterministic given a shape and seed, ensuring reproducibility. - - :param float keep_fraction: Fraction of phase-encode lines to keep in - ``(0, 1]``. For example, 0.25 keeps 25% of phase-encode lines. - :param float center_fraction: Fraction of phase-encode lines reserved for - the contiguous, fully-sampled ACS region in ``[0, 1]``. Defaults to - ``0.5 * keep_fraction``, leaving part of the acquisition budget for - randomized peripheral line sampling. Set to ``0`` to sample without a - guaranteed ACS block. - :param str pattern: Peripheral sampling pattern. Supported values are - ``"uniform_random"``, ``"variable_density_random"``, and - ``"equispaced"``. Defaults to ``"variable_density_random"``. - :param int axis: Axis along which to apply undersampling. Default is -2 - (phase-encode for 4D tensors), can also be -3 for 5D tensors. - :param int | None seed: Random seed for reproducible mask generation. - If None, uses unseeded randomness (not recommended for reproducibility). - Has no effect when ``pattern="equispaced"`` because that pattern is - fully deterministic and uses no randomness. - """ - - def __init__( - self, - keep_fraction: float = 0.25, - center_fraction: float | None = None, - pattern: str = "variable_density_random", - axis: int = -2, - seed: int | None = None, - ) -> None: - super().__init__() - - if not 0.0 < keep_fraction <= 1.0: - raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") - - if axis not in (-2, -3): - raise ValueError(f"axis must be -2 or -3, got {axis}") - - if pattern not in PATTERNS: - raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") - - if center_fraction is None: - # Reserve half of the kept lines for a contiguous ACS block and - # leave the remainder for peripheral sampling. - center_fraction = 0.5 * keep_fraction - elif not 0.0 <= center_fraction <= 1.0: - raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") - - if center_fraction > keep_fraction: - raise ValueError( - f"center_fraction ({center_fraction}) must not exceed " - f"keep_fraction ({keep_fraction})" - ) - - self.keep_fraction = keep_fraction - self.center_fraction = center_fraction - self.pattern = pattern - self.axis = axis - self.seed = seed - self._cached_mask = None - self._cached_shape = None - self._cached_device: torch.device | None = None - - def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: - """Generate a binary Cartesian undersampling mask. - - The mask is applied along the specified axis (phase-encode by default). - It keeps a contiguous center region fully sampled when requested and - randomly undersamples the periphery to achieve the desired keep_fraction. - - :param tuple[int, ...] shape: k-space tensor shape. - :param torch.device device: Device for the mask. - :returns: Binary mask broadcastable to shape. - :rtype: torch.Tensor - """ - # Cache the mask keyed on both shape and device so that repeated GPU - # forward passes do not perform a CPU→GPU copy on every call. - if ( - self._cached_mask is not None - and self._cached_shape == shape - and self._cached_device == device - ): - return self._cached_mask - - # Get the size along the undersampling axis - axis_size = shape[self.axis] - - # Set random seed for reproducibility - rng_state = None - if self.seed is not None: - rng_state = torch.get_rng_state() - torch.manual_seed(self.seed) - - try: - # Generate the 1D mask along the specified axis - mask_1d = self._generate_1d_mask(axis_size) - - # Expand the 1D mask to match the full shape - mask = self._expand_mask_to_shape(mask_1d, shape) - finally: - # Restore random state if we set a seed - if self.seed is not None and rng_state is not None: - torch.set_rng_state(rng_state) - - # Move to the target device and cache, including the device. - mask = mask.to(device) - self._cached_mask = mask - self._cached_shape = shape - self._cached_device = device - - return mask - - def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: - """Generate a 1D binary mask along the undersampling axis. - - :param int axis_size: Size of the axis along which to undersample. - :returns: 1D binary mask of shape (axis_size,). - :rtype: torch.Tensor - """ - # Calculate number of lines to keep and center region size. - num_keep = max(1, int(round(axis_size * self.keep_fraction))) - num_center = int(round(axis_size * self.center_fraction)) - - # Ensure center is not larger than total kept lines - num_center = min(num_center, num_keep) - - # Initialize mask (all zeros) - mask = torch.zeros(axis_size, dtype=torch.float32) - - # Keep the contiguous center region (ACS) when requested. - center_start = (axis_size - num_center) // 2 - center_end = center_start + num_center - if num_center > 0: - mask[center_start:center_end] = 1.0 - - peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) - - # Sample from the peripheral region with the requested pattern. - if num_keep > num_center: - num_peripheral = num_keep - num_center - selected_indices = self._select_peripheral_indices( - peripheral_indices=peripheral_indices, - num_peripheral=num_peripheral, - axis_size=axis_size, - ) - mask[selected_indices] = 1.0 - - return mask - - def _peripheral_indices( - self, - center_start: int, - center_end: int, - axis_size: int, - ) -> torch.Tensor: - """Return line indices outside the contiguous ACS region.""" - - return torch.cat( - [ - torch.arange(0, center_start, dtype=torch.long), - torch.arange(center_end, axis_size, dtype=torch.long), - ] - ) - - def _select_peripheral_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - axis_size: int, - ) -> torch.Tensor: - """Select peripheral lines according to the configured sampling pattern.""" - - if self.pattern == "uniform_random": - permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] - return peripheral_indices[permutation] - - if self.pattern == "variable_density_random": - center = 0.5 * (axis_size - 1) - distances = torch.abs(peripheral_indices.to(torch.float32) - center) - normalized_distances = distances / max(center, 1.0) - - # Favor low-frequency lines near the ACS boundary while preserving - # non-zero probability across the periphery. - weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) - selected_positions = torch.multinomial( - weights, - num_samples=num_peripheral, - replacement=False, - ) - return peripheral_indices[selected_positions] - - if self.pattern == "equispaced": - return self._select_equispaced_indices(peripheral_indices, num_peripheral) - - raise RuntimeError(f"Unsupported pattern {self.pattern!r}") - - def _select_equispaced_indices( - self, - peripheral_indices: torch.Tensor, - num_peripheral: int, - ) -> torch.Tensor: - """Select evenly spaced peripheral indices. - - Lines are spaced uniformly within the peripheral index array. - Because the peripheral array has a gap at the ACS region, the spacing - between selected k-space lines will be approximately doubled near the - ACS boundary compared to the spacing in the outer periphery. - This pattern is fully deterministic and unaffected by ``seed``. - """ - - if num_peripheral >= len(peripheral_indices): - return peripheral_indices - - # The early-return above guarantees step > 1.0, which in turn means - # floor((idx + 0.5) * step) is strictly increasing, so no collision - # resolution is needed. - step = len(peripheral_indices) / num_peripheral - positions = ( - (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() - ) - - return peripheral_indices[positions] - - def _expand_mask_to_shape( - self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] - ) -> torch.Tensor: - """Expand a 1D mask to the full k-space shape. - - The 1D mask is applied along the specified axis and broadcast to all - other dimensions. - - :param torch.Tensor mask_1d: 1D binary mask. - :param tuple[int, ...] target_shape: Target shape for expansion. - :returns: Mask broadcastable to target_shape. - :rtype: torch.Tensor - """ - # Convert negative axis to positive - ndim = len(target_shape) - axis = self.axis if self.axis >= 0 else ndim + self.axis - - # Create the full shape with 1s for broadcast dimensions - expand_shape = list(target_shape) - for i in range(len(expand_shape)): - if i != axis: - expand_shape[i] = 1 - - # Reshape the 1D mask to the expand shape - mask = mask_1d.reshape(expand_shape) - - return mask +"""Cartesian k-space undersampling distortions.""" + +from __future__ import annotations + +import torch + +from mri_recon.distortions.base import SelfAdjointMultiplicativeMaskDistortion + + +PATTERNS = {"uniform_random", "variable_density_random", "equispaced"} +PARTIAL_FOURIER_SIDES = {"low", "high"} + +# Lorentzian offset that controls how steeply the variable-density weights +# fall off with normalized distance from k-space center. +# Smaller values → steeper density gradient; 0.05 gives ~20x weight ratio +# between the ACS boundary and the k-space edge. +_VD_WEIGHT_OFFSET: float = 0.05 + + +class CartesianUndersampling(SelfAdjointMultiplicativeMaskDistortion): + """Cartesian k-space undersampling along one encoding direction. + + This distortion simulates sub-Nyquist MRI acquisition by keeping only a + subset of Cartesian k-space lines along a chosen axis, phase encode by + default. A contiguous low-frequency center region (ACS) may be retained + fully, while the peripheral lines are sampled using a configurable random + or equispaced pattern. + + This is fundamentally different from resolution reduction. In + resolution-reduction distortions, the maximum retained k-space extent is + reduced, which primarily causes blur because high spatial frequencies are + absent. In Cartesian undersampling, the original k-space extent is still + targeted, but many lines inside that extent are skipped. The dominant + consequence is aliasing or incoherent undersampling artifact, not a simple + broader point-spread function. + + ``keep_fraction`` controls the total sampling budget along the chosen axis. + ``center_fraction`` reserves part of that budget for a fully sampled ACS + block near k-space center, which is important for many parallel imaging and + learned reconstruction pipelines. ``pattern`` controls how the remaining + peripheral lines are selected. ``axis`` determines which Cartesian encoding + direction is undersampled. ``seed`` makes the random patterns reproducible. + + The mask is deterministic for a given shape, device, and seed. + + :param float keep_fraction: Fraction of lines kept along the undersampled + axis in ``(0, 1]``. For example, ``0.25`` keeps 25% of the lines and + corresponds to approximately 4x acceleration when applied to a single + encoding direction. + :param float center_fraction: Fraction of lines along the undersampled axis + reserved for a contiguous, fully sampled ACS region in ``[0, 1]``. + Defaults to ``0.5 * keep_fraction`` so that part of the sampling budget + remains available for peripheral sampling. ``0`` disables the ACS block. + :param str pattern: Peripheral sampling pattern. Supported values are + ``"uniform_random"``, ``"variable_density_random"``, and + ``"equispaced"``. Variable-density sampling favors low-frequency lines + near the ACS boundary; equispaced sampling is deterministic. + :param int axis: Axis along which to undersample. The default ``-2`` + corresponds to phase encode for the repository's standard 2D k-space + convention. Other values allow undersampling along readout or depth. + :param int | None seed: Random seed for reproducible mask generation. + Ignored by the deterministic ``"equispaced"`` pattern. + """ + + def __init__( + self, + keep_fraction: float = 0.25, + center_fraction: float | None = None, + pattern: str = "variable_density_random", + axis: int = -2, + seed: int | None = None, + ) -> None: + super().__init__() + + if not 0.0 < keep_fraction <= 1.0: + raise ValueError(f"keep_fraction must be in (0, 1], got {keep_fraction}") + + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + + if pattern not in PATTERNS: + raise ValueError(f"pattern must be one of {sorted(PATTERNS)}, got {pattern!r}") + + if center_fraction is None: + # Reserve half of the kept lines for a contiguous ACS block and + # leave the remainder for peripheral sampling. + center_fraction = 0.5 * keep_fraction + elif not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + + if center_fraction > keep_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"keep_fraction ({keep_fraction})" + ) + + self.keep_fraction = keep_fraction + self.center_fraction = center_fraction + self.pattern = pattern + self.axis = axis + self.seed = seed + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a binary Cartesian undersampling mask. + + The mask is applied along the specified axis (phase-encode by default). + It keeps a contiguous center region fully sampled when requested and + randomly undersamples the periphery to achieve the desired keep_fraction. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to shape. + :rtype: torch.Tensor + """ + # Cache the mask keyed on both shape and device so that repeated GPU + # forward passes do not perform a CPU→GPU copy on every call. + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + # Get the size along the undersampling axis + axis_size = shape[self.axis] + + # Set random seed for reproducibility + rng_state = None + if self.seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(self.seed) + + try: + # Generate the 1D mask along the specified axis + mask_1d = self._generate_1d_mask(axis_size) + + # Expand the 1D mask to match the full shape + mask = self._expand_mask_to_shape(mask_1d, shape) + finally: + # Restore random state if we set a seed + if self.seed is not None and rng_state is not None: + torch.set_rng_state(rng_state) + + # Move to the target device and cache, including the device. + mask = mask.to(device) + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D binary mask along the undersampling axis. + + :param int axis_size: Size of the axis along which to undersample. + :returns: 1D binary mask of shape (axis_size,). + :rtype: torch.Tensor + """ + # Calculate number of lines to keep and center region size. + num_keep = max(1, int(round(axis_size * self.keep_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + + # Ensure center is not larger than total kept lines + num_center = min(num_center, num_keep) + + # Initialize mask (all zeros) + mask = torch.zeros(axis_size, dtype=torch.float32) + + # Keep the contiguous center region (ACS) when requested. + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + if num_center > 0: + mask[center_start:center_end] = 1.0 + + peripheral_indices = self._peripheral_indices(center_start, center_end, axis_size) + + # Sample from the peripheral region with the requested pattern. + if num_keep > num_center: + num_peripheral = num_keep - num_center + selected_indices = self._select_peripheral_indices( + peripheral_indices=peripheral_indices, + num_peripheral=num_peripheral, + axis_size=axis_size, + ) + mask[selected_indices] = 1.0 + + return mask + + def _peripheral_indices( + self, + center_start: int, + center_end: int, + axis_size: int, + ) -> torch.Tensor: + """Return line indices outside the contiguous ACS region.""" + + return torch.cat( + [ + torch.arange(0, center_start, dtype=torch.long), + torch.arange(center_end, axis_size, dtype=torch.long), + ] + ) + + def _select_peripheral_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + axis_size: int, + ) -> torch.Tensor: + """Select peripheral lines according to the configured sampling pattern.""" + + if self.pattern == "uniform_random": + permutation = torch.randperm(len(peripheral_indices))[:num_peripheral] + return peripheral_indices[permutation] + + if self.pattern == "variable_density_random": + center = 0.5 * (axis_size - 1) + distances = torch.abs(peripheral_indices.to(torch.float32) - center) + normalized_distances = distances / max(center, 1.0) + + # Favor low-frequency lines near the ACS boundary while preserving + # non-zero probability across the periphery. + weights = 1.0 / (_VD_WEIGHT_OFFSET + normalized_distances.square()) + selected_positions = torch.multinomial( + weights, + num_samples=num_peripheral, + replacement=False, + ) + return peripheral_indices[selected_positions] + + if self.pattern == "equispaced": + return self._select_equispaced_indices(peripheral_indices, num_peripheral) + + raise RuntimeError(f"Unsupported pattern {self.pattern!r}") + + def _select_equispaced_indices( + self, + peripheral_indices: torch.Tensor, + num_peripheral: int, + ) -> torch.Tensor: + """Select evenly spaced peripheral indices. + + Lines are spaced uniformly within the peripheral index array. + Because the peripheral array has a gap at the ACS region, the spacing + between selected k-space lines will be approximately doubled near the + ACS boundary compared to the spacing in the outer periphery. + This pattern is fully deterministic and unaffected by ``seed``. + """ + + if num_peripheral >= len(peripheral_indices): + return peripheral_indices + + # The early-return above guarantees step > 1.0, which in turn means + # floor((idx + 0.5) * step) is strictly increasing, so no collision + # resolution is needed. + step = len(peripheral_indices) / num_peripheral + positions = ( + (torch.arange(num_peripheral, dtype=torch.float32) * step + 0.5 * step).floor().long() + ) + + return peripheral_indices[positions] + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape. + + The 1D mask is applied along the specified axis and broadcast to all + other dimensions. + + :param torch.Tensor mask_1d: 1D binary mask. + :param tuple[int, ...] target_shape: Target shape for expansion. + :returns: Mask broadcastable to target_shape. + :rtype: torch.Tensor + """ + # Convert negative axis to positive + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + # Create the full shape with 1s for broadcast dimensions + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + # Reshape the 1D mask to the expand shape + mask = mask_1d.reshape(expand_shape) + + return mask + + +class PartialFourierDistortion(SelfAdjointMultiplicativeMaskDistortion): + """Asymmetric contiguous Cartesian mask for partial Fourier acquisition. + + This distortion simulates partial Fourier MRI acquisition by keeping a + contiguous asymmetric region of k-space along one encoding axis while + preserving a centered low-frequency block. Unlike symmetric resolution + reduction, the retained support extends farther on one side of k-space than + the other. Unlike Cartesian undersampling, the retained region is + contiguous rather than sparse throughout the original support. + + The distortion models the acquired k-space mask only. It does not attempt + to reconstruct or infer the missing region with homodyne, POCS, or any + other partial-Fourier-specific reconstruction method. + + :param float partial_fraction: Fraction of lines retained along the chosen + axis in ``[0.5, 1]``. ``1.0`` recovers the identity operator. + :param float center_fraction: Fraction of lines reserved for a centered, + fully retained low-frequency block in ``[0, 1]``. This block must not + exceed ``partial_fraction``. + :param int axis: Axis along which to apply the asymmetric truncation. The + default ``-2`` matches the repository's standard phase-encode axis. + :param str side: Side that retains more support outside the centered block. + Supported values are ``"low"`` and ``"high"``. + """ + + def __init__( + self, + partial_fraction: float = 0.7, + center_fraction: float = 0.1, + axis: int = -2, + side: str = "high", + ) -> None: + super().__init__() + + if not 0.5 <= partial_fraction <= 1.0: + raise ValueError(f"partial_fraction must be in [0.5, 1], got {partial_fraction}") + if not 0.0 <= center_fraction <= 1.0: + raise ValueError(f"center_fraction must be in [0, 1], got {center_fraction}") + if center_fraction > partial_fraction: + raise ValueError( + f"center_fraction ({center_fraction}) must not exceed " + f"partial_fraction ({partial_fraction})" + ) + if axis not in (-1, -2, -3): + raise ValueError(f"axis must be -1, -2, or -3, got {axis}") + if side not in PARTIAL_FOURIER_SIDES: + raise ValueError(f"side must be one of {sorted(PARTIAL_FOURIER_SIDES)}, got {side!r}") + + self.partial_fraction = partial_fraction + self.center_fraction = center_fraction + self.axis = axis + self.side = side + self._cached_mask = None + self._cached_shape = None + self._cached_device: torch.device | None = None + + def _mask(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor: + """Generate a deterministic partial Fourier mask. + + :param tuple[int, ...] shape: k-space tensor shape. + :param torch.device device: Device for the mask. + :returns: Binary mask broadcastable to ``shape``. + :rtype: torch.Tensor + """ + if ( + self._cached_mask is not None + and self._cached_shape == shape + and self._cached_device == device + ): + return self._cached_mask + + axis_size = shape[self.axis] + mask_1d = self._generate_1d_mask(axis_size) + mask = self._expand_mask_to_shape(mask_1d, shape).to(device) + + self._cached_mask = mask + self._cached_shape = shape + self._cached_device = device + return mask + + def _generate_1d_mask(self, axis_size: int) -> torch.Tensor: + """Generate a 1D contiguous asymmetric partial Fourier mask.""" + num_keep = max(1, int(round(axis_size * self.partial_fraction))) + num_center = int(round(axis_size * self.center_fraction)) + num_center = min(num_center, num_keep) + + mask = torch.zeros(axis_size, dtype=torch.float32) + + center_start = (axis_size - num_center) // 2 + center_end = center_start + num_center + remaining = num_keep - num_center + + low_available = center_start + high_available = axis_size - center_end + + if self.side == "high": + extra_high = min(remaining, high_available) + extra_low = remaining - extra_high + else: + extra_low = min(remaining, low_available) + extra_high = remaining - extra_low + + start = center_start - extra_low + end = center_end + extra_high + + mask[start:end] = 1.0 + return mask + + def _expand_mask_to_shape( + self, mask_1d: torch.Tensor, target_shape: tuple[int, ...] + ) -> torch.Tensor: + """Expand a 1D mask to the full k-space shape.""" + ndim = len(target_shape) + axis = self.axis if self.axis >= 0 else ndim + self.axis + + expand_shape = list(target_shape) + for i in range(len(expand_shape)): + if i != axis: + expand_shape[i] = 1 + + return mask_1d.reshape(expand_shape) diff --git a/mri_recon/reconstruction/__init__.py b/mri_recon/reconstruction/__init__.py index 07dec7d..eafddf3 100644 --- a/mri_recon/reconstruction/__init__.py +++ b/mri_recon/reconstruction/__init__.py @@ -1,9 +1,14 @@ -from .deep import RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor -from .classic import ( - ZeroFilledReconstructor, - ConjugateGradientReconstructor, - TVPGDReconstructor, - WaveletFISTAReconstructor, - TVFISTAReconstructor, - TVPDHGReconstructor, -) +from .deep import ( + RAMReconstructor, + DeepImagePriorReconstructor, + FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, +) +from .classic import ( + ZeroFilledReconstructor, + ConjugateGradientReconstructor, + TVPGDReconstructor, + WaveletFISTAReconstructor, + TVFISTAReconstructor, + TVPDHGReconstructor, +) diff --git a/mri_recon/reconstruction/deep.py b/mri_recon/reconstruction/deep.py index 5d507c4..18ed4bd 100644 --- a/mri_recon/reconstruction/deep.py +++ b/mri_recon/reconstruction/deep.py @@ -1,10 +1,16 @@ +import json from pathlib import Path +from typing import Optional import deepinv as dinv import torch from ._fastmri_unet import Unet -from ..utils import download_file_with_sha256, matches_sha256 +from ..utils import ( + download_file_with_sha256, + download_google_drive_file_with_sha256, + matches_sha256, +) class RAMReconstructor(dinv.models.Reconstructor): @@ -108,6 +114,7 @@ class FastMRISinglecoilUnetReconstructor(dinv.models.Reconstructor): ) MODEL_SHA256 = "8f41f67d8eab2cca31ffff632a733a8712b1171c11f13e95b6f90fdf63399f9e" MODEL_FILENAME = "knee_sc_leaderboard_state_dict.pt" + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "fastmri_singlecoil_unet" UNET_KWARGS = { "in_chans": 1, "out_chans": 1, @@ -128,7 +135,7 @@ def __init__(self, device: torch.device = None, state_dict_file: str = None) -> state_dict_path = ( Path(state_dict_file).expanduser() if state_dict_file is not None - else Path(__file__).resolve().parents[2] / self.MODEL_FILENAME + else self.MODEL_DIR / self.MODEL_FILENAME ) if state_dict_file is None: @@ -163,3 +170,220 @@ def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tenso out = self.model(x_in) * std + mu # (B, 1, H, W) return torch.cat([out, torch.zeros_like(out)], dim=1) + + +def _load_unet_checkpoint_state( + checkpoint_path: Path, + device: torch.device, +) -> dict[str, torch.Tensor]: + """Load U-Net weights from a plain or Lightning-style checkpoint.""" + + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + # Accept either a checkpoint with "state_dict" or a plain state_dict + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # If Lightning-style keys like 'unet.*' exist, strip the 'unet.' prefix + if any(key.startswith("unet.") for key in state_dict): + return { + key[len("unet.") :]: value + for key, value in state_dict.items() + if key.startswith("unet.") + } + + return state_dict + + +class OASISSinglecoilUnetReconstructor(dinv.models.Reconstructor): + """Wrapper for a trained OASIS single-coil U-Net model. + + The model reuses the repository's fastMRI-derived :class:`Unet` module, but + can also download the packaged checkpoint manifest and checkpoint on demand + when no explicit checkpoint path is supplied. The forward pass converts + k-space to a zero-filled magnitude image, applies per-slice instance + normalization, runs the U-Net, then rescales the prediction back to the + adjoint-image intensity range. + + Parameters + ---------- + checkpoint_file : str, optional + Path to the trained OASIS U-Net checkpoint. If omitted, the reconstructor + downloads the packaged checkpoint for ``acceleration``. + acceleration : int, optional + Training acceleration of the packaged checkpoint used when ``checkpoint_file`` + is omitted. This selects pretrained weights only; it does not configure the + measurement distortion used during inference. + manifest_path : str, optional + Override path for the downloaded or cached packaged checkpoint manifest. + device : torch.device, optional + Device on which to run inference. + """ + + UNET_KWARGS = { + "in_chans": 1, + "out_chans": 1, + "chans": 32, + "num_pool_layers": 4, + "drop_prob": 0.0, + } + MODEL_DIR = Path(__file__).resolve().parents[2] / "downloads" / "oasis_singlecoil_unet" + CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" + SPLITS_DIR = MODEL_DIR / "splits" + MANIFEST_PATH = CHECKPOINTS_DIR / "manifest.json" + SPLIT_CSV_PATH = SPLITS_DIR / "oasis_balanced_test.csv" + MANIFEST_FILE_ID = "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I" + MANIFEST_SHA256 = "d5180c49fcaafe7ba439319dcf4afe4d7489473bea437418d836070ecd506952" + SPLIT_CSV_FILE_ID = "16UoZ6sYzOwADv4KLRR9m0kjueXoxENrD" + SPLIT_CSV_SHA256 = "8627cf9781e6d5f94c2a0f08a7a75b386fb9b737cdce94fb1558f95fa2e62dd5" + CHECKPOINT_FILE_IDS = { + "4": "11s6YeM6_YJeD4wcrn24jyMjyj_vX2ANU", + "8": "1w8PDiYpr2xBPXahzRllhZjQT1yoMGXg-", + "10": "1djJ2i0uYP4PT070CS0xx9nNJ41JmSFhh", + } + CHECKPOINT_SHA256 = { + "4": "4fcefa9860cb7895e581a0de8f90bd7f188ae1c0b5e428a4a07519dd2561ac29", + "8": "2cd4c44e3c7a3870adbe5090b2bfaae044f5e3f0b4bcaf2b1fc29969e5e6b9ca", + "10": "90e3d9b17aa0f9fd43aaf090c152edcbdead1b9be41a076594e41098db7befa8", + } + + @classmethod + def ensure_manifest(cls, manifest_path: Optional[Path] = None) -> Path: + """Ensure the packaged OASIS checkpoint manifest exists locally and is verified.""" + + resolved_manifest_path = ( + manifest_path.expanduser().resolve() if manifest_path is not None else cls.MANIFEST_PATH + ) + if not matches_sha256(resolved_manifest_path, cls.MANIFEST_SHA256): + download_google_drive_file_with_sha256( + cls.MANIFEST_FILE_ID, + resolved_manifest_path, + cls.MANIFEST_SHA256, + label="OASIS checkpoint manifest", + ) + return resolved_manifest_path + + @classmethod + def resolve_default_split_csv(cls) -> Path: + """Resolve and download the packaged OASIS split CSV.""" + + resolved_split_csv_path = cls.SPLIT_CSV_PATH.resolve() + if matches_sha256(resolved_split_csv_path, cls.SPLIT_CSV_SHA256): + return resolved_split_csv_path + + download_google_drive_file_with_sha256( + cls.SPLIT_CSV_FILE_ID, + resolved_split_csv_path, + cls.SPLIT_CSV_SHA256, + label="OASIS split CSV", + ) + return resolved_split_csv_path + + @classmethod + def resolve_default_checkpoint( + cls, + acceleration: int, + manifest_path: Optional[Path] = None, + ) -> Path: + """Resolve and download the OASIS checkpoint for a training acceleration.""" + + resolved_manifest_path = cls.ensure_manifest(manifest_path) + with resolved_manifest_path.open("r", encoding="utf-8") as handle: + manifest = json.load(handle) + + key = str(acceleration) + checkpoints = manifest.get("checkpoints", {}) + if key not in checkpoints: + available = ", ".join(sorted(checkpoints)) + raise ValueError( + f"No packaged checkpoint for training acceleration {acceleration}. " + f"Available: {available}." + ) + + if key not in cls.CHECKPOINT_FILE_IDS or key not in cls.CHECKPOINT_SHA256: + raise ValueError( + "No automated download metadata is configured for OASIS checkpoint " + f"training acceleration {acceleration}." + ) + + filename = Path(checkpoints[key]["filename"]) + checkpoint_path = ( + filename + if filename.is_absolute() + else (resolved_manifest_path.parent.parent / filename) + ).resolve() + + if not matches_sha256(checkpoint_path, cls.CHECKPOINT_SHA256[key]): + download_google_drive_file_with_sha256( + cls.CHECKPOINT_FILE_IDS[key], + checkpoint_path, + cls.CHECKPOINT_SHA256[key], + label=f"OASIS checkpoint x{acceleration}", + ) + + return checkpoint_path + + def __init__( + self, + checkpoint_file: str | None = None, + acceleration: int = 4, + manifest_path: str | None = None, + device: torch.device = None, + ) -> None: + super().__init__() + + if device is None: + device = torch.device("cpu") + self.device = device + + if checkpoint_file is None: + resolved_manifest_path = ( + Path(manifest_path).expanduser() if manifest_path is not None else None + ) + checkpoint_path = self.resolve_default_checkpoint( + acceleration=acceleration, + manifest_path=resolved_manifest_path, + ) + else: + checkpoint_path = Path(checkpoint_file).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + # Use the helper to obtain a normalized state_dict (handles plain or Lightning) + state_dict = _load_unet_checkpoint_state(checkpoint_path, device) + + self.model = Unet(**self.UNET_KWARGS) + self.model.load_state_dict( + state_dict, + strict=True, + ) + self.model.eval() + self.model.to(device) + + def forward(self, y: torch.Tensor, physics: dinv.physics.Physics) -> torch.Tensor: + """Reconstruct a magnitude image from measured k-space. + + Parameters + ---------- + y : torch.Tensor + Measured k-space tensor. + physics : dinv.physics.Physics + Physics operator used to form the adjoint input image. + + Returns + ------- + torch.Tensor + Complex-valued reconstruction with zero imaginary channel. + """ + + x_in = dinv.utils.complex_abs(physics.A_adjoint(y), keepdim=True) + mu = x_in.mean(dim=(-2, -1), keepdim=True) + std = x_in.std(dim=(-2, -1), keepdim=True) + 1e-11 + x_in = (x_in - mu) / std + + with torch.no_grad(): + out = self.model(x_in) * std + mu + + return torch.cat([out, torch.zeros_like(out)], dim=1) diff --git a/mri_recon/utils/__init__.py b/mri_recon/utils/__init__.py index 61ae0e6..1e27df9 100644 --- a/mri_recon/utils/__init__.py +++ b/mri_recon/utils/__init__.py @@ -1,5 +1,19 @@ from .io import download_file_with_sha256 as download_file_with_sha256 +from .io import download_google_drive_file_with_sha256 as download_google_drive_file_with_sha256 from .io import format_megabytes as format_megabytes from .io import matches_sha256 as matches_sha256 +from .oasis_adapter import OasisCenteredFFTPhysics as OasisCenteredFFTPhysics +from .oasis_adapter import OasisSliceDataset as OasisSliceDataset +from .oasis_adapter import image_to_kspace as image_to_kspace +from .oasis_adapter import kspace_to_image as kspace_to_image -__all__ = ["download_file_with_sha256", "format_megabytes", "matches_sha256"] +__all__ = [ + "download_file_with_sha256", + "download_google_drive_file_with_sha256", + "format_megabytes", + "image_to_kspace", + "kspace_to_image", + "matches_sha256", + "OasisCenteredFFTPhysics", + "OasisSliceDataset", +] diff --git a/mri_recon/utils/io.py b/mri_recon/utils/io.py index 5aaa962..bde5d0e 100644 --- a/mri_recon/utils/io.py +++ b/mri_recon/utils/io.py @@ -1,6 +1,9 @@ import hashlib +import re import tempfile +from io import BytesIO from pathlib import Path +from urllib.parse import urlencode from urllib.request import urlopen from tqdm.auto import tqdm @@ -21,8 +24,24 @@ def format_megabytes(num_bytes: int) -> str: return f"{num_bytes / (1024 * 1024):.1f} MB" -def download_file_with_sha256( - url: str, +class _BytesResponse: + def __init__(self, payload: bytes): + self._buffer = BytesIO(payload) + self.headers = {"Content-Length": str(len(payload))} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self, size: int = -1) -> bytes: + return self._buffer.read(size) + + +def _download_response_with_sha256( + response_factory, + source: str, destination: Path, expected_sha256: str, *, @@ -30,7 +49,7 @@ def download_file_with_sha256( report_interval_mb: int = 25, ) -> None: destination.parent.mkdir(parents=True, exist_ok=True) - print(f"Downloading {label} from {url} to {destination}. This may take a moment.") + print(f"Downloading {label} from {source} to {destination}. This may take a moment.") chunk_size = 1024 * 1024 report_interval = report_interval_mb * 1024 * 1024 @@ -42,7 +61,7 @@ def download_file_with_sha256( ) as handle: tmp_path = Path(handle.name) - with urlopen(url, timeout=30) as response: + with response_factory() as response: total_size = response.headers.get("Content-Length") total_size = int(total_size) if total_size is not None else None with tqdm( @@ -67,3 +86,61 @@ def download_file_with_sha256( if tmp_path is not None: tmp_path.unlink(missing_ok=True) raise + + +def download_file_with_sha256( + url: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + _download_response_with_sha256( + lambda: urlopen(url, timeout=30), + url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + + +def download_google_drive_file_with_sha256( + file_id: str, + destination: Path, + expected_sha256: str, + *, + label: str = "file", + report_interval_mb: int = 25, +) -> None: + base_url = "https://drive.usercontent.google.com/download" + request_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download'})}" + + with urlopen(request_url, timeout=30) as response: + payload = response.read() + + if b"Google Drive - Virus scan warning" not in payload: + _download_response_with_sha256( + lambda: _BytesResponse(payload), + request_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) + return + + html = payload.decode("utf-8", errors="replace") + uuid_match = re.search(r'name="uuid" value="([^"]+)"', html) + if uuid_match is None: + raise ValueError(f"Could not parse Google Drive confirmation token for file {file_id}.") + + confirmed_url = f"{base_url}?{urlencode({'id': file_id, 'export': 'download', 'confirm': 't', 'uuid': uuid_match.group(1)})}" + download_file_with_sha256( + confirmed_url, + destination, + expected_sha256, + label=label, + report_interval_mb=report_interval_mb, + ) diff --git a/mri_recon/utils/oasis_adapter.py b/mri_recon/utils/oasis_adapter.py new file mode 100644 index 0000000..4b2750f --- /dev/null +++ b/mri_recon/utils/oasis_adapter.py @@ -0,0 +1,212 @@ +"""OASIS dataset and centered FFT adapters.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import Dataset + +from mri_recon.distortions import BaseDistortion + + +class OasisSliceDataset(Dataset): + """Load 2D OASIS slices from Analyze/NIfTI volumes. + + Parameters + ---------- + data_path : Path + Root directory containing OASIS subject folders. + split_csv : Path + CSV file listing OASIS subjects and slice counts. + sample_rate : float, optional + Fraction of slices to keep from each volume. Values below ``1`` keep a + centered slice range. Defaults to ``1``. + """ + + def __init__( + self, + data_path: Path, + split_csv: Path, + sample_rate: float = 1.0, + ) -> None: + try: + import nibabel as nib + except ImportError as exc: + raise ImportError( + "OASIS loading requires nibabel. Install the project dependencies " + "or add nibabel to your environment before using OasisSliceDataset." + ) from exc + + self._nib = nib + self.data_path = Path(data_path) + self.split_csv = Path(split_csv) + if not 0 < sample_rate <= 1.0: + raise ValueError("sample_rate must be in the range (0, 1].") + self.sample_rate = sample_rate + self.subject_paths = self._discover_subject_paths() + self.raw_samples = self._create_sample_list() + + def __len__(self) -> int: + """Return the number of available slices.""" + + return len(self.raw_samples) + + def __getitem__(self, index: int) -> dict[str, object]: + """Return one complex-valued OASIS slice in repo tensor convention.""" + + subject_id, slice_num = self.raw_samples[index] + volume = self._get_volume(subject_id) + target_np = np.ascontiguousarray(volume[slice_num], dtype=np.float32) + real = torch.from_numpy(target_np) + x = torch.stack([real, torch.zeros_like(real)], dim=0) + return {"x": x.float(), "subject_id": subject_id, "slice_num": slice_num} + + def _discover_subject_paths(self) -> dict[str, Path]: + subject_paths = {} + for subject_dir in sorted(self.data_path.iterdir()): + if not subject_dir.is_dir(): + continue + image_glob = subject_dir / "PROCESSED" / "MPRAGE" / "T88_111" + matches = sorted(image_glob.glob("*t88_gfc.img")) + if matches: + subject_paths[subject_dir.name] = matches[0] + + if not subject_paths: + raise FileNotFoundError( + "Could not find OASIS subject folders under " + f"{self.data_path} matching PROCESSED/MPRAGE/T88_111/*t88_gfc.img." + ) + return subject_paths + + def _create_sample_list(self) -> list[tuple[str, int]]: + samples = [] + rows = [] + with self.split_csv.open("r", encoding="utf-8") as handle: + for line in handle: + row = [item.strip() for item in line.split(",")] + if not row or not row[0]: + continue + try: + rows.append((row[0], int(row[-1]))) + except ValueError: + continue + + for subject_id, num_slices in rows: + if subject_id not in self.subject_paths: + raise FileNotFoundError( + f"Could not find OASIS subject {subject_id!r} from split CSV under " + f"{self.data_path}." + ) + mid = round(num_slices / 2) + half_span = round(num_slices * self.sample_rate / 2) + start = 0 if self.sample_rate >= 1.0 else max(0, mid - half_span) + stop = num_slices if self.sample_rate >= 1.0 else min(num_slices, mid + half_span) + for slice_num in range(start, stop): + samples.append((subject_id, slice_num)) + return samples + + def _num_slices(self, image_path: Path) -> int: + shape = tuple(dim for dim in self._nib.load(str(image_path)).shape if dim != 1) + if len(shape) < 2: + raise ValueError(f"Expected at least 2D OASIS image, got shape {shape}.") + return shape[1] + + def _get_volume(self, subject_id: str) -> np.ndarray: + image_data = self._nib.load(str(self.subject_paths[subject_id])).get_fdata(dtype=np.float32) + volume = np.ascontiguousarray( + np.transpose(np.squeeze(image_data), (1, 0, 2)), + dtype=np.float32, + ) + return volume + + +def image_to_kspace(x: torch.Tensor) -> torch.Tensor: + """Convert channel-first complex images to centered k-space. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Centered k-space tensor with shape ``(B, 2, H, W)``. + """ + + x_complex = torch.view_as_complex(x.movedim(1, -1).contiguous()) + y_complex = torch.fft.fftshift( + torch.fft.fft2(x_complex, dim=(-2, -1), norm="ortho"), + dim=(-2, -1), + ) + return torch.view_as_real(y_complex).movedim(-1, 1).contiguous() + + +def kspace_to_image(y: torch.Tensor) -> torch.Tensor: + """Convert centered channel-first k-space to complex images. + + Parameters + ---------- + y : torch.Tensor + Centered k-space tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + y_complex = torch.view_as_complex(y.movedim(1, -1).contiguous()) + x_complex = torch.fft.ifft2( + torch.fft.ifftshift(y_complex, dim=(-2, -1)), + dim=(-2, -1), + norm="ortho", + ) + return torch.view_as_real(x_complex).movedim(-1, 1).contiguous() + + +class OasisCenteredFFTPhysics: + """Physics adapter matching the OASIS U-Net FFT convention. + + Parameters + ---------- + distortion : BaseDistortion + K-space distortion applied after the centered FFT. + """ + + def __init__(self, distortion: BaseDistortion) -> None: + self.distortion = distortion + + def A(self, x: torch.Tensor) -> torch.Tensor: + """Apply centered FFT and k-space distortion. + + Parameters + ---------- + x : torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + + Returns + ------- + torch.Tensor + Distorted centered k-space tensor. + """ + + return self.distortion.A(image_to_kspace(x)) + + def A_adjoint(self, y: torch.Tensor) -> torch.Tensor: + """Apply adjoint distortion and centered inverse FFT. + + Parameters + ---------- + y : torch.Tensor + Distorted centered k-space tensor. + + Returns + ------- + torch.Tensor + Complex image tensor with shape ``(B, 2, H, W)``. + """ + + return kspace_to_image(self.distortion.A_adjoint(y)) diff --git a/pyproject.toml b/pyproject.toml index b12f61c..611d63c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "fastmri>=0.3.0", "h5py>=3.16.0", "matplotlib>=3.9.0", + "nibabel>=5.3.2", "numpy>=2.4.3", "ptwt>=1.0.1", "pydicom>=3.0.1", diff --git a/tests/test_distortions.py b/tests/test_distortions.py index 563f4e8..7e884f6 100644 --- a/tests/test_distortions.py +++ b/tests/test_distortions.py @@ -18,6 +18,7 @@ IsotropicResolutionReduction, KaiserTaperResolutionReduction, OffCenterAnisotropicGaussianKspaceBiasField, + PartialFourierDistortion, PhaseEncodeGhostingDistortion, RadialHighPassEmphasisDistortion, RotationalMotionDistortion, @@ -34,6 +35,7 @@ "Hann taper LP", "Kaiser taper LP", "Cartesian undersampling", + "Partial Fourier", "Radial high-pass emphasis", "Gaussian bias field", "Off-center anisotropic Gaussian bias field", @@ -49,6 +51,7 @@ "Isotropic LP", "Anisotropic LP", "Cartesian undersampling", + "Partial Fourier", "Phase-encode ghosting", "Translation motion", "Segmented translation motion", @@ -91,6 +94,13 @@ def choose_distortion(name): center_fraction=0.2, seed=42, ) + case "Partial Fourier": + return PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) case "Radial high-pass emphasis": return RadialHighPassEmphasisDistortion(alpha=0.4) case "Gaussian bias field": @@ -229,6 +239,52 @@ def test_anisotropic_resolution_reduction_identity_at_full_cutoffs(device): assert torch.equal(y_distorted, y) +def test_partial_fourier_distortion_identity_at_full_fraction(device): + distortion = PartialFourierDistortion( + partial_fraction=1.0, + center_fraction=0.1, + axis=-2, + side="high", + ) + y = torch.randn((1, 2, 32, 32), device=device) + + assert torch.equal(distortion.A(y), y) + + +def test_partial_fourier_distortion_retains_contiguous_asymmetric_region(device): + distortion = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + + retained_indices = distortion._generate_1d_mask(16).nonzero().flatten().tolist() + + assert retained_indices == list(range(5, 16)) + + +def test_partial_fourier_distortion_side_changes_retained_half(device): + high = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="high", + ) + low = PartialFourierDistortion( + partial_fraction=0.7, + center_fraction=0.1, + axis=-2, + side="low", + ) + + high_indices = high._generate_1d_mask(16).nonzero().flatten().tolist() + low_indices = low._generate_1d_mask(16).nonzero().flatten().tolist() + + assert high_indices == list(range(5, 16)) + assert low_indices == list(range(0, 11)) + + def test_hann_taper_resolution_reduction_has_smooth_transition(device): distortion = HannTaperResolutionReduction( radius_fraction=0.8, @@ -766,13 +822,13 @@ def test_cartesian_undersampling_rejects_center_fraction_exceeding_keep_fraction def test_cartesian_undersampling_rejects_invalid_axis(device): - """Verify that axis must be -2 or -3.""" - with pytest.raises(ValueError, match="axis must be -2 or -3"): - CartesianUndersampling(axis=-1) - - with pytest.raises(ValueError, match="axis must be -2 or -3"): + """Verify that axis must be one of the supported trailing k-space axes.""" + with pytest.raises(ValueError, match="axis must be -1, -2, or -3"): CartesianUndersampling(axis=0) + with pytest.raises(ValueError, match="axis must be -1, -2, or -3"): + CartesianUndersampling(axis=-4) + def test_cartesian_undersampling_rejects_invalid_pattern(device): """Verify that pattern must be one of the supported sampling strategies.""" @@ -832,6 +888,22 @@ def test_cartesian_undersampling_preserves_center_acs_region(device): assert torch.all(mask_1d[center_start:center_end] == 1.0) +def test_cartesian_undersampling_can_mask_readout_axis(device): + """Verify that axis=-1 masks columns rather than phase-encode rows.""" + distortion = CartesianUndersampling( + keep_fraction=0.25, + center_fraction=0.125, + pattern="equispaced", + axis=-1, + ) + shape = (1, 2, 32, 64) + mask = distortion._mask(shape, torch.device(device)) + + assert mask.shape == (1, 1, 1, 64) + assert torch.sum(mask[0, 0, 0]).item() == 16 + assert torch.all(mask.expand(shape)[:, :, 0, :] == mask.expand(shape)[:, :, -1, :]) + + def test_cartesian_undersampling_zero_center_fraction_has_no_forced_acs(device): """Verify that center_fraction=0 does not force a contiguous ACS block.""" distortion = CartesianUndersampling( diff --git a/tests/test_reconstructions.py b/tests/test_reconstructions.py index b331a6b..36117ff 100644 --- a/tests/test_reconstructions.py +++ b/tests/test_reconstructions.py @@ -12,6 +12,7 @@ RAMReconstructor, DeepImagePriorReconstructor, FastMRISinglecoilUnetReconstructor, + OASISSinglecoilUnetReconstructor, ) from mri_recon.reconstruction.classic import ( ZeroFilledReconstructor, @@ -118,3 +119,136 @@ def test_fastmri_singlecoil_unet_reconstructor(device, tmp_path, monkeypatch): assert x_hat.shape == x.shape assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_singlecoil_unet_reconstructor_loads_lightning_checkpoint(device, tmp_path): + """ + Test the OASIS UNet reconstructor with a Lightning-style checkpoint. + """ + weights_path = tmp_path / "oasis_unet_test_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + x = dinv.utils.load_example( + "butterfly.png", img_size=(32, 32), grayscale=True, resize_mode="resize", device=device + ) + x = torch.cat([x, torch.zeros_like(x)], dim=1) + + model = OASISSinglecoilUnetReconstructor( + checkpoint_file=str(weights_path), + device=device, + ) + physics = DistortedKspaceMultiCoilMRI( + img_size=(1, 2, *x.shape[-2:]), coil_maps=1, device=device + ) + + y = physics(x) + x_hat = model(y, physics) + + assert x_hat.shape == x.shape + assert torch.allclose(x_hat[:, 1], torch.zeros_like(x_hat[:, 1])) + + +def test_oasis_resolve_default_checkpoint_downloads_manifest_and_checkpoint(tmp_path, monkeypatch): + model_dir = tmp_path / "downloads" / "oasis_singlecoil_unet" + manifest_path = model_dir / "checkpoints" / "manifest.json" + checkpoint_path = model_dir / "checkpoints" / "oasis_balanced_seed24_accel4.ckpt" + manifest_bytes = ( + b'{"checkpoints": {"4": {"filename": "checkpoints/oasis_balanced_seed24_accel4.ckpt"}}}' + ) + checkpoint_bytes = b"fake-oasis-checkpoint" + downloads = [] + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "MANIFEST_SHA256", + __import__("hashlib").sha256(manifest_bytes).hexdigest(), + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_SHA256", + {"4": __import__("hashlib").sha256(checkpoint_bytes).hexdigest()}, + ) + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "CHECKPOINT_FILE_IDS", + {"4": "checkpoint-file-id"}, + ) + + def fake_download(file_id, destination, expected_sha256, **kwargs): + downloads.append((file_id, destination)) + destination.parent.mkdir(parents=True, exist_ok=True) + if file_id == "1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I": + destination.write_bytes(manifest_bytes) + elif file_id == "checkpoint-file-id": + destination.write_bytes(checkpoint_bytes) + else: + raise AssertionError(f"Unexpected file_id: {file_id}") + + monkeypatch.setattr( + "mri_recon.reconstruction.deep.download_google_drive_file_with_sha256", + fake_download, + ) + + resolved = OASISSinglecoilUnetReconstructor.resolve_default_checkpoint( + 4, manifest_path=manifest_path + ) + + assert resolved == checkpoint_path.resolve() + assert manifest_path.read_bytes() == manifest_bytes + assert checkpoint_path.read_bytes() == checkpoint_bytes + assert downloads == [ + ("1zefZh7Vh5k2ssXKpLxV3Xnwf3S6dqu6I", manifest_path.resolve()), + ("checkpoint-file-id", checkpoint_path.resolve()), + ] + + +def test_oasis_singlecoil_unet_reconstructor_uses_packaged_checkpoint_defaults( + device, tmp_path, monkeypatch +): + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "UNET_KWARGS", + { + "in_chans": 1, + "out_chans": 1, + "chans": 8, + "num_pool_layers": 2, + "drop_prob": 0.0, + }, + ) + + weights_path = tmp_path / "oasis_unet_packaged_weights.ckpt" + state_dict = Unet(**OASISSinglecoilUnetReconstructor.UNET_KWARGS).state_dict() + torch.save( + {"state_dict": {f"unet.{key}": value for key, value in state_dict.items()}}, + weights_path, + ) + + captured = {} + + def fake_resolve(cls, acceleration, manifest_path=None): + captured["acceleration"] = acceleration + captured["manifest_path"] = manifest_path + return weights_path + + monkeypatch.setattr( + OASISSinglecoilUnetReconstructor, + "resolve_default_checkpoint", + classmethod(fake_resolve), + ) + + model = OASISSinglecoilUnetReconstructor( + acceleration=8, + manifest_path=str(tmp_path / "manifest.json"), + device=device, + ) + + assert captured == { + "acceleration": 8, + "manifest_path": tmp_path / "manifest.json", + } + assert isinstance(model.model, Unet) diff --git a/tests/test_utils_io.py b/tests/test_utils_io.py index b984872..3d2c434 100644 --- a/tests/test_utils_io.py +++ b/tests/test_utils_io.py @@ -1,7 +1,7 @@ import hashlib from io import BytesIO -from mri_recon.utils.io import download_file_with_sha256 +from mri_recon.utils.io import download_file_with_sha256, download_google_drive_file_with_sha256 class FakeResponse: @@ -39,3 +39,30 @@ def test_download_file_with_sha256_moves_temp_file_after_close(tmp_path, monkeyp assert destination.read_bytes() == payload assert list(tmp_path.glob("*.tmp")) == [] + + +def test_download_google_drive_file_with_sha256_confirms_large_download(tmp_path, monkeypatch): + warning_html = b"""
Google Drive - Virus scan warning""" + payload = b"oasis-checkpoint" + expected_sha256 = hashlib.sha256(payload).hexdigest() + destination = tmp_path / "oasis.ckpt" + requested_urls = [] + + def fake_urlopen(url, timeout=30): + requested_urls.append(url) + if "confirm=t" in url: + return FakeResponse(payload) + return FakeResponse(warning_html) + + monkeypatch.setattr("mri_recon.utils.io.urlopen", fake_urlopen) + + download_google_drive_file_with_sha256( + "file-123", + destination, + expected_sha256, + label="OASIS checkpoint", + report_interval_mb=1, + ) + + assert destination.read_bytes() == payload + assert any("confirm=t" in url and "uuid=uuid-456" in url for url in requested_urls) diff --git a/uv.lock b/uv.lock index 7034d30..d25083a 100644 --- a/uv.lock +++ b/uv.lock @@ -947,6 +947,7 @@ dependencies = [ { name = "fastmri" }, { name = "h5py" }, { name = "matplotlib" }, + { name = "nibabel" }, { name = "numpy" }, { name = "ptwt" }, { name = "pydicom" }, @@ -976,6 +977,7 @@ requires-dist = [ { name = "fastmri", specifier = ">=0.3.0" }, { name = "h5py", specifier = ">=3.16.0" }, { name = "matplotlib", specifier = ">=3.9.0" }, + { name = "nibabel", specifier = ">=5.3.2" }, { name = "numpy", specifier = ">=2.4.3" }, { name = "ptwt", specifier = ">=1.0.1" }, { name = "pydicom", specifier = ">=3.0.1" }, @@ -1117,6 +1119,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] +[[package]] +name = "nibabel" +version = "5.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/01/3d2cc510c616bc8e27be17a063070d9126f69407961594a9ae734ea51121/nibabel-5.4.2.tar.gz", hash = "sha256:d5f4b9076a13178ae7f7acf18c8dbd503ee1c4d5c0c23b85df7be87efcbb49da", size = 4663132, upload-time = "2026-03-11T13:31:52.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/d7/601b6396b33536811668935faa790112266c70661be94555999be431f86f/nibabel-5.4.2-py3-none-any.whl", hash = "sha256:553482c5f1e1034fc312edf6fb7f32236c0056439845d1c29293b7e8c98d4854", size = 3300985, upload-time = "2026-03-11T13:31:50.028Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0"