|
134 | 134 | " volume,\n", |
135 | 135 | " source,\n", |
136 | 136 | " target,\n", |
| 137 | + " img,\n", |
137 | 138 | " align_corners=False,\n", |
138 | 139 | " mask=None,\n", |
139 | 140 | " ):\n", |
|
158 | 159 | " # Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel\n", |
159 | 160 | " if self.stop_gradients_through_grid_sample:\n", |
160 | 161 | " with torch.no_grad():\n", |
161 | | - " img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n", |
| 162 | + " img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n", |
162 | 163 | " else:\n", |
163 | | - " img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n", |
| 164 | + " img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n", |
164 | 165 | "\n", |
165 | 166 | " # Weight each intersected voxel by the length of the ray's intersection with the voxel\n", |
166 | 167 | " intersection_length = torch.diff(alphas, dim=-1)\n", |
|
176 | 177 | " B, D, _ = img.shape\n", |
177 | 178 | " C = int(mask.max().item() + 1)\n", |
178 | 179 | " channels = _get_voxel(\n", |
179 | | - " mask, xyzs, self.mode, align_corners=align_corners\n", |
| 180 | + " mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n", |
180 | 181 | " ).long()\n", |
181 | 182 | " img = (\n", |
182 | 183 | " torch.zeros(B, C, D)\n", |
|
253 | 254 | " return xyzs\n", |
254 | 255 | "\n", |
255 | 256 | "\n", |
256 | | - "def _get_voxel(volume, xyzs, mode, align_corners):\n", |
| 257 | + "def _get_voxel(volume, xyzs, img, mode, align_corners):\n", |
257 | 258 | " \"\"\"Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates.\"\"\"\n", |
258 | 259 | " batch_size = len(xyzs)\n", |
259 | 260 | " voxels = grid_sample(\n", |
|
262 | 263 | " mode=mode,\n", |
263 | 264 | " align_corners=align_corners,\n", |
264 | 265 | " )[:, 0, 0]\n", |
265 | | - " return voxels" |
| 266 | + " if img is not None:\n", |
| 267 | + " img = torch.einsum(\"bcn, bnj -> bnj\", img, voxels)\n", |
| 268 | + " else:\n", |
| 269 | + " img = voxels\n", |
| 270 | + " return img" |
266 | 271 | ] |
267 | 272 | }, |
268 | 273 | { |
|
307 | 312 | " volume,\n", |
308 | 313 | " source,\n", |
309 | 314 | " target,\n", |
| 315 | + " img,\n", |
310 | 316 | " n_points=500,\n", |
311 | 317 | " align_corners=False,\n", |
312 | 318 | " mask=None,\n", |
|
328 | 334 | " xyzs = _get_xyzs(alphas, source, target, dims, self.eps)\n", |
329 | 335 | "\n", |
330 | 336 | " # Sample the volume with trilinear interpolation\n", |
331 | | - " img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n", |
| 337 | + " img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n", |
332 | 338 | " \n", |
333 | 339 | " # Multiply by the step size to compute the rectangular rule for integration\n", |
334 | 340 | " step_size = (alphamax - alphamin) / (n_points - 1)\n", |
|
341 | 347 | " B, D, _ = img.shape\n", |
342 | 348 | " C = int(mask.max().item() + 1)\n", |
343 | 349 | " channels = _get_voxel(\n", |
344 | | - " mask, xyzs, self.mode, align_corners=align_corners\n", |
| 350 | + " mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n", |
345 | 351 | " ).long()\n", |
346 | 352 | " img = (\n", |
347 | 353 | " torch.zeros(B, C, D)\n", |
|
0 commit comments