Skip to content

Commit c4c7227

Browse files
authored
Various utilities for RigidTransforms (#404)
* If input is RigidTransform, return input * Remove hardcoded class name * Make slicing return a RigidTransform * Add a matmul overload
1 parent 5587c76 commit c4c7227

3 files changed

Lines changed: 30 additions & 6 deletions

File tree

diffdrr/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@
8787
'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'),
8888
'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'),
8989
'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'),
90+
'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'),
91+
'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'),
9092
'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'),
9193
'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'),
9294
'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'),

diffdrr/pose.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@ class RigidTransform(torch.nn.Module):
1818
inversion, and conversions to various representations of SE(3).
1919
"""
2020

21+
def __new__(cls, matrix, eps=1e-6):
22+
if isinstance(matrix, cls):
23+
return matrix
24+
return super().__new__(cls)
25+
2126
def __init__(self, matrix, eps=1e-6):
27+
if isinstance(matrix, type(self)):
28+
return
29+
2230
super().__init__()
2331
if matrix.dim() == 2:
2432
matrix = matrix.unsqueeze(0)
@@ -29,7 +37,10 @@ def __len__(self):
2937
return len(self.matrix)
3038

3139
def __getitem__(self, idx):
32-
return self.matrix[idx]
40+
return type(self)(self.matrix[idx])
41+
42+
def __matmul__(self, T):
43+
return T.compose(self)
3344

3445
def forward(self, x):
3546
"""Apply (a batch) of rigid transforms to a pointcloud."""
@@ -53,11 +64,11 @@ def inverse(self):
5364
matrix = make_matrix(Rinv, tinv)
5465
else:
5566
matrix = self.matrix.inverse()
56-
return RigidTransform(matrix)
67+
return type(self)(matrix)
5768

5869
def compose(self, T):
5970
matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix)
60-
return RigidTransform(matrix)
71+
return type(self)(matrix)
6172

6273
def convert(self, parameterization, convention=None, degrees=False):
6374
translation = -self.inverse().translation

notebooks/api/06_pose.ipynb

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,15 @@
102102
" inversion, and conversions to various representations of SE(3).\n",
103103
" \"\"\"\n",
104104
"\n",
105+
" def __new__(cls, matrix, eps=1e-6):\n",
106+
" if isinstance(matrix, cls):\n",
107+
" return matrix\n",
108+
" return super().__new__(cls)\n",
109+
"\n",
105110
" def __init__(self, matrix, eps=1e-6):\n",
111+
" if isinstance(matrix, type(self)):\n",
112+
" return \n",
113+
"\n",
106114
" super().__init__()\n",
107115
" if matrix.dim() == 2:\n",
108116
" matrix = matrix.unsqueeze(0)\n",
@@ -113,7 +121,10 @@
113121
" return len(self.matrix)\n",
114122
"\n",
115123
" def __getitem__(self, idx):\n",
116-
" return self.matrix[idx]\n",
124+
" return type(self)(self.matrix[idx])\n",
125+
"\n",
126+
" def __matmul__(self, T):\n",
127+
" return T.compose(self)\n",
117128
"\n",
118129
" def forward(self, x):\n",
119130
" \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n",
@@ -137,11 +148,11 @@
137148
" matrix = make_matrix(Rinv, tinv)\n",
138149
" else:\n",
139150
" matrix = self.matrix.inverse()\n",
140-
" return RigidTransform(matrix)\n",
151+
" return type(self)(matrix)\n",
141152
"\n",
142153
" def compose(self, T):\n",
143154
" matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n",
144-
" return RigidTransform(matrix)\n",
155+
" return type(self)(matrix)\n",
145156
"\n",
146157
" def convert(self, parameterization, convention=None, degrees=False):\n",
147158
" translation = -self.inverse().translation\n",

0 commit comments

Comments
 (0)