|
102 | 102 | " inversion, and conversions to various representations of SE(3).\n", |
103 | 103 | " \"\"\"\n", |
104 | 104 | "\n", |
| 105 | + " def __new__(cls, matrix, eps=1e-6):\n", |
| 106 | + " if isinstance(matrix, cls):\n", |
| 107 | + " return matrix\n", |
| 108 | + " return super().__new__(cls)\n", |
| 109 | + "\n", |
105 | 110 | " def __init__(self, matrix, eps=1e-6):\n", |
| 111 | + " if isinstance(matrix, type(self)):\n", |
| 112 | + " return \n", |
| 113 | + "\n", |
106 | 114 | " super().__init__()\n", |
107 | 115 | " if matrix.dim() == 2:\n", |
108 | 116 | " matrix = matrix.unsqueeze(0)\n", |
|
113 | 121 | " return len(self.matrix)\n", |
114 | 122 | "\n", |
115 | 123 | " def __getitem__(self, idx):\n", |
116 | | - " return self.matrix[idx]\n", |
| 124 | + " return type(self)(self.matrix[idx])\n", |
| 125 | + "\n", |
| 126 | + " def __matmul__(self, T):\n", |
| 127 | + " return T.compose(self)\n", |
117 | 128 | "\n", |
118 | 129 | " def forward(self, x):\n", |
119 | 130 | " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", |
|
137 | 148 | " matrix = make_matrix(Rinv, tinv)\n", |
138 | 149 | " else:\n", |
139 | 150 | " matrix = self.matrix.inverse()\n", |
140 | | - " return RigidTransform(matrix)\n", |
| 151 | + " return type(self)(matrix)\n", |
141 | 152 | "\n", |
142 | 153 | " def compose(self, T):\n", |
143 | 154 | " matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n", |
144 | | - " return RigidTransform(matrix)\n", |
| 155 | + " return type(self)(matrix)\n", |
145 | 156 | "\n", |
146 | 157 | " def convert(self, parameterization, convention=None, degrees=False):\n", |
147 | 158 | " translation = -self.inverse().translation\n", |
|
0 commit comments