Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
55f9045
Add OASIS U-Net inference example
ydu0117 May 6, 2026
5275ffd
Merge remote-tracking branch 'origin/main' into codex-oasis-inference
ydu0117 May 6, 2026
cc62b05
Add OASIS inference example
May 7, 2026
06baac4
Automated SHA256 certified model download for OASIS
MatthiasLen May 8, 2026
6848e2a
Merge remote-tracking branch 'origin/main' into oasis_unet_reconstruc…
MatthiasLen May 8, 2026
06de9fa
added common download storage path for all models
MatthiasLen May 8, 2026
4ea5aca
added detailed explanations related to resolution reduction
MatthiasLen May 14, 2026
997cccc
added detailed explanations related to bias field
MatthiasLen May 14, 2026
5e5b69c
partial fourier
MatthiasLen May 14, 2026
d9f6ebf
Merge branch 'main' into oasis_unet_reconstruction
MatthiasLen May 14, 2026
a753ce3
updaed readme
MatthiasLen May 14, 2026
c56665b
Add OASIS U-Net inference example
ydu0117 May 6, 2026
913a162
Add OASIS inference example
May 7, 2026
52fa4ed
Automated SHA256 certified model download for OASIS
MatthiasLen May 8, 2026
220c26e
added common download storage path for all models
MatthiasLen May 8, 2026
b7d0063
added detailed explanations related to resolution reduction
MatthiasLen May 14, 2026
75d29ef
added detailed explanations related to bias field
MatthiasLen May 14, 2026
1a458d2
partial fourier
MatthiasLen May 14, 2026
9fb0527
Add OASIS support to FastMRI reconstruction example
ydu0117 May 14, 2026
5c47163
Merge remote-tracking branch 'origin/oasis_unet_reconstruction' into …
ydu0117 May 14, 2026
945e606
remove the OASIS_inference_plot.py
ydu0117 May 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Reconstruction playground for the MRI Recon Metrics Reloaded workgroup.
| `RAMReconstructor` | `ram` | Deep learning | Wrapper around the DeepInverse RAM model, with input normalization based on the adjoint reconstruction. |
| `DeepImagePriorReconstructor` | `dip` | Deep learning | Deep Image Prior reconstruction using an untrained convolutional decoder optimized at inference time. |
| `FastMRISinglecoilUnetReconstructor` | `unet` | Deep learning | Wrapper around the pretrained fastMRI single-coil U-Net, returning a magnitude-based reconstruction with a zero imaginary channel. |
| `OASISSinglecoilUnetReconstructor` | `oasis-unet` | Deep learning | Wrapper around a trained OASIS single-coil U-Net checkpoint, reusing the shared fastMRI-derived U-Net module. |

## Implemented Distortions

Expand Down Expand Up @@ -56,6 +57,22 @@ uv sync
uv run python -c "import torch; print(torch.__version__, torch.cuda.is_available(), torch.version.cuda)"
```

## Inference Examples

Run the FastMRI plotting example with local FastMRI k-space files:

```bash
python examples/fastmri_inference_plot.py --source /path/to/fastmri/singlecoil_val --dataset fastmri --algorithm unet
```

Run the same lightweight example on OASIS data. The packaged OASIS split CSV and U-Net checkpoint are downloaded automatically when missing:

```bash
python examples/fastmri_inference_plot.py --source /path/to/oasis_cross_sectional_data --dataset oasis --algorithm unet
```

For OASIS, `--oasis_checkpoint_acceleration` only selects the packaged U-Net weights by their training acceleration. Distortion undersampling is still controlled by `--keep_fraction` and `--center_fraction`.

## Pre-commit

Install the local tooling and register the git hook:
Expand Down
151 changes: 122 additions & 29 deletions examples/fastmri_inference_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@

from mri_recon.distortions import *
from mri_recon.reconstruction import *

REPORT_DIR = Path("reports") / "fastmri_inference_plot"
REPORT_DIR.mkdir(parents=True, exist_ok=True)
from mri_recon.utils import (
OasisCenteredFFTPhysics,
OasisSliceDataset,
image_to_kspace,
kspace_to_image,
)

FASTMRI_REPORT_DIR = Path("reports") / "fastmri_inference_plot"
OASIS_REPORT_DIR = Path("reports") / "oasis_inference_plot"
FASTMRI_REPORT_DIR.mkdir(parents=True, exist_ok=True)
OASIS_REPORT_DIR.mkdir(parents=True, exist_ok=True)
ALGORITHMS = [
# "zero-filled",
# "conjugate-gradient",
Expand All @@ -29,14 +37,15 @@
# "wavelet-fista",
# "tv-fista",
# "tv-pdhg",
"unet", # will trigger download of pretrained weights if not already present
"unet", # dataset-specific U-Net; downloads weights if not already present
]
DISTORTIONS = [
"Cartesian undersampling (variable density)",
"Cartesian undersampling (uniform random)",
"Cartesian undersampling (uniform random, zero ACS)",
"Cartesian undersampling (equispaced)",
"Cartesian undersampling (equispaced, zero ACS)",
"Partial Fourier",
"Phase-encode ghosting",
"Segmented translation motion",
"Segmented rotational motion",
Expand Down Expand Up @@ -113,6 +122,8 @@ def choose_algorithm(
img_size: tuple = (640, 368),
device: torch.device = "cpu",
verbose: bool = False,
dataset: str = "fastmri",
oasis_checkpoint_acceleration: int = 4,
) -> dinv.models.Reconstructor:
match name:
case "zero-filled":
Expand All @@ -132,12 +143,22 @@ def choose_algorithm(
case "wavelet-fista":
return WaveletFISTAReconstructor(n_iter=100, device=device, verbose=verbose)
case "unet":
if dataset == "oasis":
return OASISSinglecoilUnetReconstructor(
acceleration=oasis_checkpoint_acceleration,
device=device,
)
return FastMRISinglecoilUnetReconstructor(device=device)
case _:
raise ValueError(f"Unknown algorithm {name!r}")


def choose_distortion(name: str) -> BaseDistortion:
def choose_distortion(
name: str,
keep_fraction: float = 0.25,
center_fraction: float = 0.125,
cartesian_axis: int = -2,
) -> BaseDistortion:
match name:
case "Phase-encode ghosting":
return PhaseEncodeGhostingDistortion(
Expand All @@ -148,39 +169,51 @@ def choose_distortion(name: str) -> BaseDistortion:
)
case "Cartesian undersampling (variable density)":
return CartesianUndersampling(
keep_fraction=0.25,
center_fraction=0.125,
keep_fraction=keep_fraction,
center_fraction=center_fraction,
pattern="variable_density_random",
axis=cartesian_axis,
seed=42,
)
case "Cartesian undersampling (uniform random)":
return CartesianUndersampling(
keep_fraction=0.25,
center_fraction=0.125,
keep_fraction=keep_fraction,
center_fraction=center_fraction,
pattern="uniform_random",
axis=cartesian_axis,
seed=42,
)
case "Cartesian undersampling (uniform random, zero ACS)":
return CartesianUndersampling(
keep_fraction=0.5,
keep_fraction=keep_fraction,
center_fraction=0.0,
pattern="uniform_random",
axis=cartesian_axis,
seed=42,
)
case "Cartesian undersampling (equispaced)":
return CartesianUndersampling(
keep_fraction=0.25,
center_fraction=0.125,
keep_fraction=keep_fraction,
center_fraction=center_fraction,
pattern="equispaced",
axis=cartesian_axis,
seed=42,
)
case "Cartesian undersampling (equispaced, zero ACS)":
return CartesianUndersampling(
keep_fraction=0.5,
keep_fraction=keep_fraction,
center_fraction=0.0,
pattern="equispaced",
axis=cartesian_axis,
seed=42,
)
case "Partial Fourier":
return PartialFourierDistortion(
partial_fraction=0.7,
center_fraction=center_fraction,
axis=cartesian_axis,
side="high",
)
case "Anisotropic LP":
return AnisotropicResolutionReduction(
kx_radius_fraction=1.0,
Expand Down Expand Up @@ -248,17 +281,49 @@ def choose_metric(name: str) -> dinv.metric.Metric:

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)

# data related arguments
parser.add_argument(
"--source", type=str, help="Local FastMRI directory with raw k-space .h5 files."
"--source",
type=Path,
help="Local FastMRI directory with raw k-space .h5 files or OASIS root directory.",
)
parser.add_argument("--dataset", choices=("fastmri", "oasis"), default="fastmri")

parser.add_argument("--distortion", type=str, default="", choices=DISTORTIONS)
parser.add_argument(
"--keep_fraction",
type=float,
default=0.25,
help="Fraction of k-space lines to keep for undersampling distortions.",
)
parser.add_argument(
"--center_fraction",
type=float,
default=0.125,
help="Fraction of low-frequency k-space lines to keep fully for undersampling distortions.",
)

# algo related arguments
parser.add_argument(
"--algorithm",
type=str,
default="",
choices=ALGORITHMS,
help="Reconstruction algorithm applied to undistorted and distorted k-space.",
)
parser.add_argument(
"--oasis_checkpoint_acceleration",
type=int,
default=4,
choices=[4, 8, 10],
help=(
"Training acceleration of the packaged OASIS U-Net checkpoint. "
"This only selects pretrained weights; distortion undersampling is "
"controlled by --keep_fraction and --center_fraction."
),
)
# inference related arguments
parser.add_argument("--num_samples", type=int, default=1, help="How many samples to process.")
parser.add_argument(
"--verbose",
Expand All @@ -267,40 +332,66 @@ def choose_metric(name: str) -> dinv.metric.Metric:
)
args = parser.parse_args()

# set up report dir
REPORT_DIR = OASIS_REPORT_DIR if args.dataset == "oasis" else FASTMRI_REPORT_DIR

# set up device, dataset, metrics
device = dinv.utils.get_device()
dataset = dinv.datasets.FastMRISliceDataset(args.source, slice_index="middle")
if args.dataset == "oasis":
split_csv = OASISSinglecoilUnetReconstructor.resolve_default_split_csv()
dataset = OasisSliceDataset(
data_path=args.source,
split_csv=split_csv,
sample_rate=0.6,
)
else:
dataset = dinv.datasets.FastMRISliceDataset(str(args.source), slice_index="middle")
metrics = [choose_metric(m) for m in METRICS]

for i, batch in enumerate(iter(torch.utils.data.DataLoader(dataset))):
# exit loop if we have processed the specified number of samples
if i >= args.num_samples:
break

# batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan),
# y is kspace, and params is a dict containing mask (if test set)
y = batch[1]
if args.dataset == "oasis":
x = batch["x"].to(device)
y = image_to_kspace(x)
else:
# batch is a tuple of (x, y) or (x, y, params) where x is GT (could be torch.nan),
# y is kspace, and params is a dict containing mask (if test set)
y = batch[1].to(device)

for distortion_name in DISTORTIONS if args.distortion == "" else [args.distortion]:
distortion = choose_distortion(distortion_name)
distortion = choose_distortion(
distortion_name,
keep_fraction=args.keep_fraction,
center_fraction=args.center_fraction,
cartesian_axis=-1 if args.dataset == "oasis" else -2,
)

# create physics objects for both clean and distorted k-space
# the distortion is applied to the k-space measurements (not the image)
# TODO: allow loading multicoil data
physics_clean = DistortedKspaceMultiCoilMRI(
distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device
)
physics = DistortedKspaceMultiCoilMRI(
distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device
)

y = y.to(device)
if args.dataset == "oasis":
physics_clean = OasisCenteredFFTPhysics(BaseDistortion())
physics = OasisCenteredFFTPhysics(distortion)
else:
physics_clean = DistortedKspaceMultiCoilMRI(
distortion=BaseDistortion(), img_size=(1, 2, *y.shape[-2:]), device=device
)
physics = DistortedKspaceMultiCoilMRI(
distortion=distortion, img_size=(1, 2, *y.shape[-2:]), device=device
)
y_distorted = distortion.A(y)

# generate reference reconstructions (CG) for both clean and distorted k-space
# without correction for the distortion, i.e. using physics_clean in both cases
x_clean = ConjugateGradientReconstructor()(y, physics_clean)
x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean)
if args.dataset == "oasis":
x_clean = x
x_distorted = kspace_to_image(y_distorted)
else:
x_clean = ConjugateGradientReconstructor()(y, physics_clean)
x_distorted = ConjugateGradientReconstructor()(y_distorted, physics_clean)

# plot and save the k-space magnitude for both clean and distorted k-space
save_kspace_plot(
Expand All @@ -319,6 +410,8 @@ def choose_metric(name: str) -> dinv.metric.Metric:
img_size=y.shape[-2:],
device=device,
verbose=args.verbose,
dataset=args.dataset,
oasis_checkpoint_acceleration=args.oasis_checkpoint_acceleration,
).to(device)

# actual reconstruction with the algo being evaluated
Expand Down
2 changes: 1 addition & 1 deletion mri_recon/distortions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
KaiserTaperResolutionReduction,
RadialHighPassEmphasisDistortion,
)
from .undersampling import CartesianUndersampling
from .undersampling import CartesianUndersampling, PartialFourierDistortion
36 changes: 31 additions & 5 deletions mri_recon/distortions/biasfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,37 @@ class OffCenterAnisotropicGaussianKspaceBiasField(BaseDistortion):
The gain peaks at an offset location in k-space, decays with different widths
along ``kx`` and ``ky``, and is normalized to unit maximum on the sampled grid.

:param float width_x_fraction: Gaussian width along the normalized ``kx`` direction.
:param float width_y_fraction: Gaussian width along the normalized ``ky`` direction.
:param float center_x_fraction: Center offset along normalized ``kx`` in ``[-1, 1]``.
:param float center_y_fraction: Center offset along normalized ``ky`` in ``[-1, 1]``.
:param float edge_gain: Baseline gain far from the Gaussian peak. Must lie in ``(0, 1]``.
Note: This class can also approximate a readout-decay-like blur when used in a
centered anisotropic configuration. To mimic stronger attenuation along the
readout direction, keep the Gaussian centered at DC with
``center_x_fraction=0.0`` and ``center_y_fraction=0.0``, choose a narrower
width along the readout axis than along the orthogonal axis, and reduce
``edge_gain`` below ``1.0``. For example, a setting such as
``width_x_fraction < width_y_fraction`` with a moderate ``edge_gain``
produces a smooth directional loss of high-frequency content that can
resemble readout-decay blur.

This remains a phenomenological approximation rather than an explicit
time-ordered readout-decay model. It applies a smooth multiplicative
k-space weighting, not a physically parameterized echo-train decay.

:param float width_x_fraction: Gaussian width along the normalized ``kx``
direction. Smaller values produce stronger attenuation away from the
center along ``kx``. When ``kx`` is the readout axis, choosing
``width_x_fraction < width_y_fraction`` approximates readout-direction
blur.
:param float width_y_fraction: Gaussian width along the normalized ``ky``
direction. Larger values preserve more support along ``ky`` relative
to ``kx``.
:param float center_x_fraction: Center offset along normalized ``kx`` in
``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like
approximation.
:param float center_y_fraction: Center offset along normalized ``ky`` in
``[-1, 1]``. Use ``0.0`` for a centered readout-decay-like
approximation.
:param float edge_gain: Baseline gain far from the Gaussian peak. Must lie
in ``(0, 1]``. Smaller values strengthen the peripheral attenuation
and therefore the resulting directional blur.
"""

def __init__(
Expand Down
Loading
Loading