We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d6e423c commit 55543a3Copy full SHA for 55543a3
1 file changed
aviary/cgcnn/data.py
@@ -249,18 +249,16 @@ def __init__(
249
250
self.var = var
251
252
- def expand(self, distances: np.ndarray) -> np.ndarray:
+ def expand(self, distances: Tensor) -> Tensor:
253
"""Apply Gaussian distance filter to a numpy distance array.
254
255
Args:
256
distances (ArrayLike): A distance matrix of any shape.
257
258
Returns:
259
- np.ndarray: Expanded distance matrix with the last dimension of length
+ Tensor: Expanded distance matrix with the last dimension of length
260
len(self.filter)
261
"""
262
- distances = torch.tensor(distances)
263
-
264
return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)
265
266
0 commit comments