Skip to content

Commit 8d15c5c

Browse files
committed
wip
1 parent 726c9c3 commit 8d15c5c

5 files changed

Lines changed: 114 additions & 77 deletions

File tree

FastSurferCNN/data_loader/conform.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from numpy import typing as npt
2929

3030
from FastSurferCNN.utils import AffineMatrix4x4, ScalarType, Shape1d, logging, nibabelHeader, nibabelImage
31-
from FastSurferCNN.utils.affines import aff2axcodes, OrntArrayType
31+
from FastSurferCNN.utils.affines import OrntArrayType, aff2axcodes, io_orientation
3232
from FastSurferCNN.utils.arg_types import ImageSizeOption, OrientationType, StrictOrientationType, VoxSizeOption
3333
from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one_mm
3434
from FastSurferCNN.utils.arg_types import img_size as __img_size
@@ -55,7 +55,7 @@
5555
Date: May-12-2025
5656
"""
5757
FIX_MGH_AFFINE_CALCULATION = False
58-
FIX_CENTER_NOT_CENTER = False
58+
FIX_CENTER_NOT_CENTER = True
5959

6060
LOGGER = logging.getLogger(__name__)
6161

@@ -290,7 +290,7 @@ def do_nothing(data: _TB) -> _TB:
290290
return image_data, do_nothing
291291

292292
# vox2vox should always be a "soft" transform
293-
vox2vox = affine_for_target_orientation(source_affine, "soft " + _target_orientation, image_data.shape).round()
293+
vox2vox = vox2vox_for_target_orientation(source_affine, "soft " + _target_orientation, image_data.shape).round()
294294
if np.allclose(vox2vox, np.eye(4)): # is already target_orientation
295295
return image_data, do_nothing
296296
else: # is not target_affine yet
@@ -300,7 +300,7 @@ def do_nothing(data: _TB) -> _TB:
300300
return apply_vox2vox(image_data, vox2vox, out_shape), inverse_apply_vox2vox
301301

302302

303-
def affine_for_target_orientation(
303+
def vox2vox_for_target_orientation(
304304
source_affine: AffineMatrix4x4,
305305
target_orientation: OrientationType,
306306
shape: npt.ArrayLike,
@@ -956,11 +956,12 @@ def _validate(param_name, param_type, param_value):
956956
source_vox_size = img.header.get_zooms()
957957

958958
source_affine = get_affine_from_any(img)
959-
_vox2vox = affine_for_target_orientation(source_affine, orientation, (0,) * 3)[:3, :3]
959+
_vox2vox = vox2vox_for_target_orientation(source_affine, orientation, (0,) * 3)[:3, :3]
960960
_target_affine = source_affine[:3, :3] @ _vox2vox
961961
h1["Mdc"] = np.linalg.inv(_target_affine / np.linalg.norm(_target_affine, axis=0, keepdims=True))
962-
re_order_axes = np.argmax(np.abs(_vox2vox), axis=1).tolist()
963-
# re_order_axes = nib.io_orientation(_vox2vox)[:, 0].tolist()
962+
_vox2vox_hom = np.eye(4, dtype=_vox2vox.dtype)
963+
_vox2vox_hom[:3, :3] = _vox2vox
964+
re_order_axes = io_orientation(_vox2vox_hom)[:, 0].astype(int).tolist()
964965

965966
shape: list[int] = [(source_img_shape if target_img_size is None else target_img_size)[i] for i in re_order_axes]
966967
h1.set_data_shape(shape + [1])

FastSurferCNN/run_prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import FastSurferCNN.reduce_to_aseg as rta
4141
from FastSurferCNN.data_loader import data_utils as du
42-
from FastSurferCNN.data_loader.conform import affine_for_target_orientation, conform, is_conform, to_target_orientation
42+
from FastSurferCNN.data_loader.conform import vox2vox_for_target_orientation, conform, is_conform, to_target_orientation
4343
from FastSurferCNN.inference import Inference
4444
from FastSurferCNN.quick_qc import check_volume
4545
from FastSurferCNN.utils import PLANES, Plane, logging, nibabelImage, parser_defaults
@@ -388,7 +388,7 @@ def get_prediction(
388388

389389
orig_in_lia, back_to_native = to_target_orientation(orig_data, affine, target_orientation="LIA")
390390
shape = orig_in_lia.shape + (self.get_num_classes(),)
391-
_ornt_transform, _ = affine_for_target_orientation(affine, target_orientation="LIA")
391+
_ornt_transform, _ = vox2vox_for_target_orientation(affine, target_orientation="LIA")
392392
_zoom = _zoom[_ornt_transform[:, 0]]
393393

394394
pred_prob = torch.zeros(shape, **kwargs)

test/image/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def vox_size(request) -> float:
2727
return request.param
2828

2929

30-
@pytest.fixture(scope="session", params=[4, 6]) # [128, 256])
30+
@pytest.fixture(scope="session", params=[8, 15]) # [128, 256])
3131
def img_size(request) -> float:
3232
return request.param
3333

test/image/test_conform_reorient.py

Lines changed: 100 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,57 @@
77
from pytest import approx
88

99
from FastSurferCNN.data_loader.conform import OrientationType, conform, prepare_mgh_header
10-
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, nibabelHeader
10+
from FastSurferCNN.utils import AffineMatrix4x4, Image3d, nibabelHeader, nibabelImage
11+
from FastSurferCNN.utils.affines import aff2axcodes
1112
from FastSurferCNN.utils.arg_types import StrictOrientationType
1213

1314
logger = getLogger(__name__)
1415

1516
class MultiCoordImages(TypedDict):
16-
X: nib.Nifti1Image
17-
Y: nib.Nifti1Image
18-
Z: nib.Nifti1Image
17+
X: nibabelImage
18+
Y: nibabelImage
19+
Z: nibabelImage
1920

2021
conform_reorient = {"rescale": None, "dtype": np.float32}
2122

2223

23-
def circle_data(img_size: int) -> Image3d[np.float32]:
24+
def circle_data(img_size: int, radius: float, center: float) -> Image3d[np.float32]:
2425
"""Generates a 3D image with a centered sphere of radius img_size/2."""
25-
data = np.mgrid[0:img_size, 0:img_size, 0:img_size].astype(np.float32) - (img_size - 1) / 2.0
26-
return (np.sum(data * data, axis=0) < img_size * img_size / 4.0).astype(np.float32)
26+
data = np.mgrid[0:img_size, 0:img_size, 0:img_size].astype(np.float32) - center
27+
return (np.sum(data * data, axis=0) < radius * radius).astype(np.float32)
2728

2829

2930
@pytest.fixture(scope="session")
30-
def circle_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image:
31-
return nib.Nifti1Image(circle_data(img_size), random_affine)
32-
31+
def radius(img_size: int) -> float:
32+
return img_size / 2.0 - 2.0
3333

3434
@pytest.fixture(scope="session")
35-
def random_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image:
36-
return nib.Nifti1Image(np.random.randn(img_size, img_size, img_size), random_affine)
35+
def center(img_size: int) -> float:
36+
return (img_size - 1) / 2.0
3737

3838

3939
@pytest.fixture(scope="session")
40-
def empty_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image:
41-
return nib.Nifti1Image(np.empty((img_size,) * 3), random_affine)
40+
def circle_image(random_affine: AffineMatrix4x4, img_size: int, radius: float, center: float) -> nib.Nifti1Image:
41+
return nib.Nifti1Image(circle_data(img_size, radius, center), random_affine)
42+
43+
44+
@pytest.fixture(scope="session", params=[1.0, 0.23])
45+
def resample_factor(request) -> float:
46+
return request.param
4247

4348

4449
@pytest.fixture(scope="session")
45-
def worldcoord_images(random_affine: AffineMatrix4x4, img_size: int) -> MultiCoordImages:
46-
data = worldcoords_data(random_affine, img_size)
47-
return MultiCoordImages(**{c: nib.Nifti1Image(data[..., i], random_affine) for i, c in enumerate("XYZ")})
50+
def random_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image:
51+
return nib.Nifti1Image(np.random.randn(img_size, img_size, img_size), random_affine)
4852

4953

50-
def worldcoords_data(affine: AffineMatrix4x4, img_size: int) -> np.ndarray:
51-
xi = np.moveaxis(np.mgrid[0:img_size, 0:img_size, 0:img_size], 0, -1)
52-
return nib.affines.apply_affine(affine, xi.reshape((-1, 3)).astype(float)).reshape(xi.shape).astype(np.float32)
54+
@pytest.fixture(scope="session")
55+
def empty_image(random_affine: AffineMatrix4x4, img_size: int) -> nib.Nifti1Image:
56+
return nib.Nifti1Image(np.empty((img_size,) * 3), random_affine)
5357

5458

5559
def affine2orientation(affine: AffineMatrix4x4) -> OrientationType:
5660
"""Generates the orientation type string from an affine matrix."""
57-
from FastSurferCNN.utils.affines import aff2axcodes
58-
5961
orientation: StrictOrientationType = "".join(aff2axcodes(affine, ("LR", "PA", "IS")))
6062
# make sure the affine is normalized for vox_sizes not 1
6163
norm_affine = affine[:3, :3] / np.linalg.norm(affine[:3, :3], keepdims=True, axis=0)
@@ -121,17 +123,20 @@ def image(
121123
soft_orientation: OrientationType,
122124
img_size: int,
123125
vox_size: float,
126+
resample_factor: float,
124127
) -> nib.MGHImage:
125128
"""
126129
Conform `circle_image` to `soft_orientation` and back to the original orientation.
127130
"""
128-
from FastSurferCNN.utils.affines import aff2axcodes
129-
130-
there = conform(circle_image, orientation=soft_orientation, **conform_reorient)
131+
there = conform(
132+
circle_image,
133+
orientation=soft_orientation, vox_size=vox_size * resample_factor, img_size=64, order=1,
134+
**conform_reorient,
135+
)
131136
in_orientation: OrientationType = "soft " + "".join(aff2axcodes(circle_image.affine, ("LR", "PA", "IS")))
132-
return conform(there, orientation=in_orientation, **conform_reorient, img_size=img_size, vox_size=vox_size)
137+
return conform(there, orientation=in_orientation, img_size=img_size, vox_size=vox_size, **conform_reorient)
133138

134-
def test_affine(self, image: nib.MGHImage, circle_image: nib.Nifti1Image) -> None:
139+
def test_affine(self, image: nib.MGHImage, circle_image: nibabelImage) -> None:
135140
"""
136141
Tests whether the affines of the original and the re-oriented images are the same.
137142
"""
@@ -140,57 +145,88 @@ def test_affine(self, image: nib.MGHImage, circle_image: nib.Nifti1Image) -> Non
140145
# currently, the translation parts of the affines differ
141146
assert actual == approx(expected, abs=1e-5), "The affines of original and re-reoriented images differ!"
142147

143-
def test_image(self, image: nib.MGHImage, circle_image: nib.Nifti1Image):
148+
def test_image(
149+
self,
150+
image: nib.MGHImage,
151+
circle_image: nibabelImage,
152+
radius: float,
153+
center: float,
154+
soft_orientation: AffineMatrix4x4,
155+
) -> None:
144156
"""
145157
Tests whether the content of the original and the re-oriented images are the same in the "center circle".
146158
"""
147159

148160
# this has to be filtered by the region of the image that is the same
149-
inner_circle = circle_data(circle_image.shape[0] - 2)
150-
outer_circle = circle_data(circle_image.shape[0] + 2)[1:-1, 1:-1, 1:-1]
151-
mask = np.pad(inner_circle, 1, constant_values=0) + (1 - outer_circle)
152-
expected = np.where(mask, circle_image.dataobj, np.nan)
161+
inner_circle = circle_data(circle_image.shape[0], radius - 1, center)
162+
outer_circle = circle_data(circle_image.shape[0], radius + 1, center)
163+
# the expected image is 0 outside the outer circle, 1 inside the inner circle and ignored (NaN) in between,
164+
# where interpolation can cause differences
165+
expected = np.where(outer_circle - inner_circle, np.nan, circle_image.dataobj)
153166
actual = np.asarray(image.dataobj)
154-
assert actual == approx(expected, abs=0.3, nan_ok=True), "The data differs from the re-oriented image!"
155-
167+
logger.info(f"Difference is {np.max(np.abs(actual - circle_image.get_fdata()))} - {np.nanmax(np.abs(actual - expected))}")
156168

157-
def test_reorient_worldcoords(
158-
worldcoord_images: MultiCoordImages,
159-
soft_orientation: OrientationType,
160-
vox_size: float,
161-
img_size: int,
162-
):
163-
"""
164-
This test checks, whether the world coordinates are consistent.
165-
166-
In effect this means generating world coordinates for an original image (xyz) and
169+
assert actual == approx(expected, abs=0.3, nan_ok=True), "The data differs from the re-oriented image!"
167170

168-
Parameters
169-
----------
170-
worldcoord_images : MultiCoordImages
171-
A dictionary of three world coordinate images (X, Y, Z) generated based on random_affine.
172-
soft_orientation : OrientationType
173-
The soft orientation to which the world coordinate images will be conformed.
171+
class TestReorientWorldCoords:
174172

173+
@staticmethod
174+
def worldcoords_data(affine: AffineMatrix4x4, img_size: int) -> np.ndarray:
175+
xi = np.moveaxis(np.mgrid[0:img_size, 0:img_size, 0:img_size], 0, -1)
176+
return nib.affines.apply_affine(affine, xi.reshape((-1, 3)).astype(float)).reshape(xi.shape).astype(np.float32)
175177

176-
Returns
177-
-------
178+
@pytest.fixture(scope="class")
179+
def worldcoord_images(self, random_affine: AffineMatrix4x4, img_size: int) -> MultiCoordImages:
180+
data = self.worldcoords_data(random_affine, img_size)
181+
return MultiCoordImages(**{c: nib.Nifti1Image(data[..., i], random_affine) for i, c in enumerate("XYZ")})
178182

179-
"""
180-
from pytest import approx, xfail
183+
@pytest.fixture(scope="class")
184+
def conf_img_size(self) -> int:
185+
return 16
181186

182-
from FastSurferCNN.data_loader.conform import conform
187+
@pytest.fixture(scope="class")
188+
def conf_images(
189+
self,
190+
conf_img_size: int,
191+
worldcoord_images: MultiCoordImages,
192+
orientation: OrientationType,
193+
) -> MultiCoordImages:
194+
return MultiCoordImages(
195+
**{k: conform(img, **conform_reorient, img_size=conf_img_size) for k, img in worldcoord_images.items()},
196+
)
197+
198+
@pytest.mark.parametrize(argnames=["dim_name"], argvalues=[["X"], ["Y"], ["Z"]])
199+
def test_reorient_worldcoords_affine(
200+
self,
201+
conf_images: MultiCoordImages,
202+
orientation: OrientationType,
203+
dim_name: str,
204+
):
205+
"""
206+
This test checks, whether the affines of the world coordinate images are consistent with the original affine.
207+
"""
208+
actual = "".join(aff2axcodes(conf_images[dim_name], ("LR", "PA", "IS")))
209+
expected = orientation
210+
assert actual == expected
183211

184-
xfail("This test is currently broken!")
185212

186-
conf_img_size = 8
213+
def test_reorient_worldcoords_image(
214+
self,
215+
conf_images: MultiCoordImages,
216+
soft_orientation: OrientationType,
217+
vox_size: float,
218+
img_size: int,
219+
):
220+
"""
221+
This test checks, whether the world coordinates are consistent.
222+
"""
223+
from pytest import approx, xfail
187224

188-
ximg, yimg, zimg = [conform(img, **conform_reorient, img_size=conf_img_size) for img in worldcoord_images.values()]
189-
xyz = np.stack([ximg.get_fdata(), yimg.get_fdata(), zimg.get_fdata()], axis=-1)
190-
logger.info("Checking affines of world images:")
225+
xyz = np.stack([img.get_fdata() for img in conf_images], axis=-1)
226+
logger.info("Checking affines of world images:")
191227

192-
worldcoords = worldcoords_data(ximg.affine, ximg.shape[0])
193-
center_mask = np.pad(circle_data(ximg.shape[0] - 2), 1, constant_values=0)
194-
expected = np.where(center_mask[..., None].astype(np.bool_), worldcoords, np.NaN)
228+
worldcoords = self.worldcoords_data(conf_images[0].affine, conf_images[0].shape[0])
229+
# center_mask = np.pad(circle_data(conf_images[0] - 2), 1, constant_values=0)
230+
expected = np.where(center_mask[..., None].astype(np.bool_), worldcoords, np.NaN)
195231

196-
assert xyz == approx(expected, abs=0.1), "The world images "
232+
assert xyz == approx(expected, abs=0.1), "The world images "

test/image/test_orientation_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytest import approx
88

99
from FastSurferCNN.data_loader.conform import (
10-
affine_for_target_orientation,
10+
vox2vox_for_target_orientation,
1111
does_vox2vox_rot_require_interpolation,
1212
ornt2affine,
1313
)
@@ -87,7 +87,7 @@ def test_ornt2affine_data(strict_orientation: OrientationType, img_size: int):
8787

8888
def test_affine_for_target_orientation(random_affine: AffineMatrix4x4, img_size: int, orientation: OrientationType):
8989
"""Test whether affine_for_target_orientation works as expected."""
90-
vox2vox = affine_for_target_orientation(random_affine, orientation, (img_size,) * 3)
90+
vox2vox = vox2vox_for_target_orientation(random_affine, orientation, (img_size,) * 3)
9191
combined = random_affine @ vox2vox
9292
actual = "".join(aff2axcodes(combined, ("lr", "pa", "is")))
9393
expected = orientation.lower().removeprefix("soft").lstrip("-_ ")
@@ -116,7 +116,7 @@ def test_affine_for_target_orientation2(
116116
strict_orientation: StrictOrientationType,
117117
):
118118
"""Test whether affine_for_target_orientation works as expected."""
119-
vox2vox = affine_for_target_orientation(random_affine, "soft " + strict_orientation, (img_size,) * 3)
119+
vox2vox = vox2vox_for_target_orientation(random_affine, "soft " + strict_orientation, (img_size,) * 3)
120120
is_reorder_flip = not does_vox2vox_rot_require_interpolation(vox2vox)
121121
assert is_reorder_flip, "affine_for_target_orientation did not yield a \"soft\" vox2vox transformation!"
122122

0 commit comments

Comments
 (0)