Skip to content

Commit c73b0ad

Browse files
committed
forgot to undo some changes
1 parent 80191c2 commit c73b0ad

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
9191
/// set the unit of the output
9292
void set_unit(std::string unit);
9393

94-
/// bool get_per_atom() const;
95-
9694
/// Which kind of sample this output is. E.g. "system", "atom", "atom_pair"...
9795
std::string sample_kind;
9896
const std::string& get_sample_kind() const {
@@ -102,7 +100,7 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
102100

103101
// For backward compatibility.
104102
void set_per_atom(bool per_atom);
105-
bool per_atom() const;
103+
bool get_per_atom() const;
106104

107105
/// Which gradients should be computed eagerly and stored inside the output
108106
/// `TensorMap`

python/metatomic_torch/tests/outputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
System,
1414
)
1515

16+
1617
def test_sample_kind():
1718
"""Checks that ``sample_kind`` and ``per_atom`` are always
1819
consistent with each other.
@@ -56,13 +57,14 @@ def test_sample_kind():
5657
with pytest.raises(ValueError):
5758
ModelOutput(sample_kind="invalid_value")
5859

59-
# Initialize model output with sample_kind="atom_pair"
60+
# Initialize model output with sample_kind="atom_pair"
6061
# and check that per_atom can not be retrieved
6162
output = ModelOutput(sample_kind="atom_pair")
6263
assert output.sample_kind == "atom_pair"
6364
with pytest.raises(ValueError):
6465
_ = output.per_atom
6566

67+
6668
@pytest.fixture
6769
def system():
6870
return System(

0 commit comments

Comments
 (0)