Skip to content

Commit de6b9aa

Browse files
authored
Refactor rendering computation (#329)
1 parent 7fdf286 commit de6b9aa

5 files changed

Lines changed: 96 additions & 41 deletions

File tree

diffdrr/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
3131
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
3232
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
33+
'diffdrr.drr.DRR.render': ('api/drr.html#drr.render', 'diffdrr/drr.py'),
3334
'diffdrr.drr.DRR.rescale_detector_': ('api/drr.html#drr.rescale_detector_', 'diffdrr/drr.py'),
3435
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
3536
'diffdrr.drr.DRR.set_intrinsics_': ('api/drr.html#drr.set_intrinsics_', 'diffdrr/drr.py'),

diffdrr/drr.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def reshape_transform(self, img, batch_size):
105105
if self.reshape:
106106
if self.detector.n_subsample is None:
107107
img = img.view(
108-
batch_size, -1, self.detector.height, self.detector.width
108+
batch_size,
109+
-1,
110+
self.detector.height,
111+
self.detector.width,
109112
)
110113
else:
111114
img = reshape_subsampled_drr(img, self.detector, batch_size)
@@ -147,37 +150,54 @@ def forward(
147150
pose = args[0]
148151
else:
149152
pose = convert(*args, parameterization=parameterization, convention=convention)
153+
154+
# Create the source / target points and render the image
150155
source, target = self.detector(pose, calibration)
156+
img = self.render(self.density, source, target, mask_to_channels, **kwargs)
157+
return self.reshape_transform(img, batch_size=len(pose))
158+
159+
160+
@patch
161+
def render(
162+
self: DRR,
163+
density: torch.tensor,
164+
source: torch.tensor,
165+
target: torch.tensor,
166+
mask_to_channels: bool,
167+
**kwargs,
168+
):
169+
# Initialize the image with the length of each cast ray
170+
img = (target - source).norm(dim=-1).unsqueeze(1)
171+
172+
# Convert rays to voxelspace
151173
source = self.affine_inverse(source)
152174
target = self.affine_inverse(target)
153175

154-
# Render the DRR
176+
# Render the image
155177
kwargs["mask"] = self.mask if mask_to_channels else None
156178
if self.patch_size is None:
157179
img = self.renderer(
158-
self.density,
180+
density,
159181
source,
160182
target,
183+
img,
161184
**kwargs,
162185
)
163186
else:
164187
n_points = target.shape[1] // self.n_patches
165-
img = []
188+
partials = []
166189
for idx in range(self.n_patches):
167-
t = target[:, idx * n_points : (idx + 1) * n_points]
168190
partial = self.renderer(
169-
self.density,
191+
density,
170192
source,
171-
t,
193+
target[:, idx * n_points : (idx + 1) * n_points],
194+
img[:, idx * n_points : (idx + 1) * n_points],
172195
**kwargs,
173196
)
174-
img.append(partial)
175-
img = torch.cat(img, dim=-1)
197+
partials.append(partial)
198+
img = torch.cat(partials, dim=-1)
176199

177-
# Multiply by the raylength (in world coordinate units)
178-
img *= self.affine(target - source).norm(dim=-1).unsqueeze(1)
179-
180-
return self.reshape_transform(img, batch_size=len(pose))
200+
return img
181201

182202
# %% ../notebooks/api/00_drr.ipynb 11
183203
@patch

diffdrr/renderers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def forward(
3232
volume,
3333
source,
3434
target,
35+
img,
3536
align_corners=False,
3637
mask=None,
3738
):
@@ -56,9 +57,11 @@ def forward(
5657
# Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel
5758
if self.stop_gradients_through_grid_sample:
5859
with torch.no_grad():
59-
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
60+
img = _get_voxel(
61+
volume, xyzs, img, self.mode, align_corners=align_corners
62+
)
6063
else:
61-
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
64+
img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)
6265

6366
# Weight each intersected voxel by the length of the ray's intersection with the voxel
6467
intersection_length = torch.diff(alphas, dim=-1)
@@ -74,7 +77,7 @@ def forward(
7477
B, D, _ = img.shape
7578
C = int(mask.max().item() + 1)
7679
channels = _get_voxel(
77-
mask, xyzs, self.mode, align_corners=align_corners
80+
mask, xyzs, img=None, mode=self.mode, align_corners=align_corners
7881
).long()
7982
img = (
8083
torch.zeros(B, C, D)
@@ -144,7 +147,7 @@ def _get_xyzs(alpha, source, target, dims, eps):
144147
return xyzs
145148

146149

147-
def _get_voxel(volume, xyzs, mode, align_corners):
150+
def _get_voxel(volume, xyzs, img, mode, align_corners):
148151
"""Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates."""
149152
batch_size = len(xyzs)
150153
voxels = grid_sample(
@@ -153,7 +156,11 @@ def _get_voxel(volume, xyzs, mode, align_corners):
153156
mode=mode,
154157
align_corners=align_corners,
155158
)[:, 0, 0]
156-
return voxels
159+
if img is not None:
160+
img = torch.einsum("bcn, bnj -> bnj", img, voxels)
161+
else:
162+
img = voxels
163+
return img
157164

158165
# %% ../notebooks/api/01_renderers.ipynb 10
159166
class Trilinear(torch.nn.Module):
@@ -176,6 +183,7 @@ def forward(
176183
volume,
177184
source,
178185
target,
186+
img,
179187
n_points=500,
180188
align_corners=False,
181189
mask=None,
@@ -197,7 +205,7 @@ def forward(
197205
xyzs = _get_xyzs(alphas, source, target, dims, self.eps)
198206

199207
# Sample the volume with trilinear interpolation
200-
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
208+
img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)
201209

202210
# Multiply by the step size to compute the rectangular rule for integration
203211
step_size = (alphamax - alphamin) / (n_points - 1)
@@ -210,7 +218,7 @@ def forward(
210218
B, D, _ = img.shape
211219
C = int(mask.max().item() + 1)
212220
channels = _get_voxel(
213-
mask, xyzs, self.mode, align_corners=align_corners
221+
mask, xyzs, img=None, mode=self.mode, align_corners=align_corners
214222
).long()
215223
img = (
216224
torch.zeros(B, C, D)

notebooks/api/00_drr.ipynb

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,15 @@
195195
" self.patch_size = patch_size\n",
196196
" if self.patch_size is not None:\n",
197197
" self.n_patches = (height * width) // (self.patch_size**2)\n",
198-
" \n",
198+
"\n",
199199
" def reshape_transform(self, img, batch_size):\n",
200200
" if self.reshape:\n",
201201
" if self.detector.n_subsample is None:\n",
202202
" img = img.view(\n",
203-
" batch_size, -1, self.detector.height, self.detector.width\n",
203+
" batch_size,\n",
204+
" -1,\n",
205+
" self.detector.height,\n",
206+
" self.detector.width,\n",
204207
" )\n",
205208
" else:\n",
206209
" img = reshape_subsampled_drr(img, self.detector, batch_size)\n",
@@ -266,37 +269,54 @@
266269
" pose = args[0]\n",
267270
" else:\n",
268271
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
272+
"\n",
273+
" # Create the source / target points and render the image\n",
269274
" source, target = self.detector(pose, calibration)\n",
275+
" img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n",
276+
" return self.reshape_transform(img, batch_size=len(pose))\n",
277+
"\n",
278+
"\n",
279+
"@patch\n",
280+
"def render(\n",
281+
" self: DRR,\n",
282+
" density: torch.tensor,\n",
283+
" source: torch.tensor,\n",
284+
" target: torch.tensor,\n",
285+
" mask_to_channels: bool,\n",
286+
" **kwargs,\n",
287+
"):\n",
288+
" # Initialize the image with the length of each cast ray\n",
289+
" img = (target - source).norm(dim=-1).unsqueeze(1)\n",
290+
"\n",
291+
" # Convert rays to voxelspace\n",
270292
" source = self.affine_inverse(source)\n",
271293
" target = self.affine_inverse(target)\n",
272294
"\n",
273-
" # Render the DRR\n",
295+
" # Render the image\n",
274296
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
275297
" if self.patch_size is None:\n",
276298
" img = self.renderer(\n",
277-
" self.density,\n",
299+
" density,\n",
278300
" source,\n",
279301
" target,\n",
302+
" img,\n",
280303
" **kwargs,\n",
281304
" )\n",
282305
" else:\n",
283306
" n_points = target.shape[1] // self.n_patches\n",
284-
" img = []\n",
307+
" partials = []\n",
285308
" for idx in range(self.n_patches):\n",
286-
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
287309
" partial = self.renderer(\n",
288-
" self.density,\n",
310+
" density,\n",
289311
" source,\n",
290-
" t,\n",
312+
" target[:, idx * n_points : (idx + 1) * n_points],\n",
313+
" img[:, idx * n_points : (idx + 1) * n_points],\n",
291314
" **kwargs,\n",
292315
" )\n",
293-
" img.append(partial)\n",
294-
" img = torch.cat(img, dim=-1)\n",
295-
" \n",
296-
" # Multiply by the raylength (in world coordinate units)\n",
297-
" img *= self.affine(target - source).norm(dim=-1).unsqueeze(1)\n",
316+
" partials.append(partial)\n",
317+
" img = torch.cat(partials, dim=-1)\n",
298318
"\n",
299-
" return self.reshape_transform(img, batch_size=len(pose))"
319+
" return img"
300320
]
301321
},
302322
{

notebooks/api/01_renderers.ipynb

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
" volume,\n",
135135
" source,\n",
136136
" target,\n",
137+
" img,\n",
137138
" align_corners=False,\n",
138139
" mask=None,\n",
139140
" ):\n",
@@ -158,9 +159,9 @@
158159
" # Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel\n",
159160
" if self.stop_gradients_through_grid_sample:\n",
160161
" with torch.no_grad():\n",
161-
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
162+
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
162163
" else:\n",
163-
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
164+
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
164165
"\n",
165166
" # Weight each intersected voxel by the length of the ray's intersection with the voxel\n",
166167
" intersection_length = torch.diff(alphas, dim=-1)\n",
@@ -176,7 +177,7 @@
176177
" B, D, _ = img.shape\n",
177178
" C = int(mask.max().item() + 1)\n",
178179
" channels = _get_voxel(\n",
179-
" mask, xyzs, self.mode, align_corners=align_corners\n",
180+
" mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n",
180181
" ).long()\n",
181182
" img = (\n",
182183
" torch.zeros(B, C, D)\n",
@@ -253,7 +254,7 @@
253254
" return xyzs\n",
254255
"\n",
255256
"\n",
256-
"def _get_voxel(volume, xyzs, mode, align_corners):\n",
257+
"def _get_voxel(volume, xyzs, img, mode, align_corners):\n",
257258
" \"\"\"Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates.\"\"\"\n",
258259
" batch_size = len(xyzs)\n",
259260
" voxels = grid_sample(\n",
@@ -262,7 +263,11 @@
262263
" mode=mode,\n",
263264
" align_corners=align_corners,\n",
264265
" )[:, 0, 0]\n",
265-
" return voxels"
266+
" if img is not None:\n",
267+
" img = torch.einsum(\"bcn, bnj -> bnj\", img, voxels)\n",
268+
" else:\n",
269+
" img = voxels\n",
270+
" return img"
266271
]
267272
},
268273
{
@@ -307,6 +312,7 @@
307312
" volume,\n",
308313
" source,\n",
309314
" target,\n",
315+
" img,\n",
310316
" n_points=500,\n",
311317
" align_corners=False,\n",
312318
" mask=None,\n",
@@ -328,7 +334,7 @@
328334
" xyzs = _get_xyzs(alphas, source, target, dims, self.eps)\n",
329335
"\n",
330336
" # Sample the volume with trilinear interpolation\n",
331-
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
337+
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
332338
" \n",
333339
" # Multiply by the step size to compute the rectangular rule for integration\n",
334340
" step_size = (alphamax - alphamin) / (n_points - 1)\n",
@@ -341,7 +347,7 @@
341347
" B, D, _ = img.shape\n",
342348
" C = int(mask.max().item() + 1)\n",
343349
" channels = _get_voxel(\n",
344-
" mask, xyzs, self.mode, align_corners=align_corners\n",
350+
" mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n",
345351
" ).long()\n",
346352
" img = (\n",
347353
" torch.zeros(B, C, D)\n",

0 commit comments

Comments
 (0)