Skip to content

Commit acaa68c

Browse files
authored
Add renderer compilation and gradient checkpointing (#384)
Thanks to @etienne87 for the idea!
1 parent e1cacdb commit acaa68c

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

diffdrr/drr.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def __init__(
3939
patch_size: int | None = None, # Render patches of the DRR in series
4040
renderer: str = "siddon", # Rendering backend, either "siddon" or "trilinear"
4141
persistent: bool = True, # Set persistent value in `torch.nn.Module.register_buffer`
42+
compile_renderer: bool = False, # Compile the renderer for performance boost
43+
checkpoint_gradients: bool = False, # Checkpoint gradients to improve memory usage
4244
**renderer_kwargs, # Kwargs for the renderer
4345
):
4446
super().__init__()
@@ -96,8 +98,11 @@ def __init__(
9698
raise ValueError(
9799
f"renderer must be 'siddon' or 'trilinear', not {renderer}"
98100
)
101+
if compile_renderer:
102+
self.renderer = torch.compile(self.renderer, mode="default")
99103
self.reshape = reshape
100104
self.patch_size = patch_size
105+
self.checkpoint_gradients = checkpoint_gradients
101106

102107
def reshape_transform(self, img, batch_size):
103108
if self.reshape:
@@ -141,6 +146,8 @@ def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: in
141146
return drr
142147

143148
# %% ../notebooks/api/00_drr.ipynb 10
149+
from torch.utils.checkpoint import checkpoint
150+
144151
from .pose import RigidTransform, convert
145152

146153

@@ -163,7 +170,19 @@ def forward(
163170

164171
# Create the source / target points and render the image
165172
source, target = self.detector(pose, calibration)
166-
img = self.render(self.density, source, target, mask_to_channels, **kwargs)
173+
174+
if self.checkpoint_gradients:
175+
img = checkpoint(
176+
self.render,
177+
self.density,
178+
source,
179+
target,
180+
mask_to_channels,
181+
**kwargs,
182+
use_reentrant=False,
183+
)
184+
else:
185+
img = self.render(self.density, source, target, mask_to_channels, **kwargs)
167186
return self.reshape_transform(img, batch_size=len(pose))
168187

169188

notebooks/api/00_drr.ipynb

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@
134134
" patch_size: int | None = None, # Render patches of the DRR in series\n",
135135
" renderer: str = \"siddon\", # Rendering backend, either \"siddon\" or \"trilinear\"\n",
136136
" persistent: bool = True, # Set persistent value in `torch.nn.Module.register_buffer`\n",
137+
" compile_renderer: bool = False, # Compile the renderer for performance boost\n",
138+
" checkpoint_gradients: bool = False, # Checkpoint gradients to improve memory usage\n",
137139
" **renderer_kwargs, # Kwargs for the renderer\n",
138140
" ):\n",
139141
" super().__init__()\n",
@@ -191,8 +193,11 @@
191193
" raise ValueError(\n",
192194
" f\"renderer must be 'siddon' or 'trilinear', not {renderer}\"\n",
193195
" )\n",
196+
" if compile_renderer:\n",
197+
" self.renderer = torch.compile(self.renderer, mode=\"default\")\n",
194198
" self.reshape = reshape\n",
195199
" self.patch_size = patch_size\n",
200+
" self.checkpoint_gradients = checkpoint_gradients\n",
196201
"\n",
197202
" def reshape_transform(self, img, batch_size):\n",
198203
" if self.reshape:\n",
@@ -260,6 +265,8 @@
260265
"outputs": [],
261266
"source": [
262267
"#| export\n",
268+
"from torch.utils.checkpoint import checkpoint\n",
269+
"\n",
263270
"from diffdrr.pose import RigidTransform, convert\n",
264271
"\n",
265272
"\n",
@@ -282,7 +289,19 @@
282289
"\n",
283290
" # Create the source / target points and render the image\n",
284291
" source, target = self.detector(pose, calibration)\n",
285-
" img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n",
292+
"\n",
293+
" if self.checkpoint_gradients:\n",
294+
" img = checkpoint(\n",
295+
" self.render,\n",
296+
" self.density,\n",
297+
" source,\n",
298+
" target,\n",
299+
" mask_to_channels,\n",
300+
" **kwargs,\n",
301+
" use_reentrant=False,\n",
302+
" )\n",
303+
" else:\n",
304+
" img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n",
286305
" return self.reshape_transform(img, batch_size=len(pose))\n",
287306
"\n",
288307
"\n",
@@ -408,7 +427,7 @@
408427
" x[..., 1] = self.detector.height - x[..., 1]\n",
409428
" if self.detector.reverse_x_axis:\n",
410429
" x[..., 0] = self.detector.width - x[..., 0]\n",
411-
" \n",
430+
"\n",
412431
" return x[..., :2]"
413432
]
414433
},

0 commit comments

Comments
 (0)