Skip to content

Commit 55543a3

Browse files
committed
fix: avoid tensor copy warning
1 parent d6e423c commit 55543a3

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

aviary/cgcnn/data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,16 @@ def __init__(
249249

250250
self.var = var
251251

252-
def expand(self, distances: np.ndarray) -> np.ndarray:
252+
def expand(self, distances: Tensor) -> Tensor:
253253
"""Apply Gaussian distance filter to a numpy distance array.
254254
255255
Args:
256256
distances (ArrayLike): A distance matrix of any shape.
257257
258258
Returns:
259-
np.ndarray: Expanded distance matrix with the last dimension of length
259+
Tensor: Expanded distance matrix with the last dimension of length
260260
len(self.filter)
261261
"""
262-
distances = torch.tensor(distances)
263-
264262
return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)
265263

266264

0 commit comments

Comments
 (0)