File tree Expand file tree Collapse file tree
metatomic-torch/include/metatomic/torch
python/metatomic_torch/tests Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -91,8 +91,6 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
9191 // / set the unit of the output
9292 void set_unit (std::string unit);
9393
94- // / bool get_per_atom() const;
95-
9694 // / Which kind of sample this output is. E.g. "system", "atom", "atom_pair"...
9795 std::string sample_kind;
9896 const std::string& get_sample_kind () const {
@@ -102,7 +100,7 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
102100
103101 // For backward compatibility.
104102 void set_per_atom (bool per_atom);
105- bool per_atom () const ;
103+ bool get_per_atom () const ;
106104
107105 // / Which gradients should be computed eagerly and stored inside the output
108106 // / `TensorMap`
Original file line number Diff line number Diff line change 1313 System ,
1414)
1515
16+
1617def test_sample_kind ():
1718 """Checks that ``sample_kind`` and ``per_atom`` are always
1819 consistent with each other.
@@ -56,13 +57,14 @@ def test_sample_kind():
5657 with pytest .raises (ValueError ):
5758 ModelOutput (sample_kind = "invalid_value" )
5859
59- # Initialize model output with sample_kind="atom_pair"
60+ # Initialize model output with sample_kind="atom_pair"
6061 # and check that per_atom can not be retrieved
6162 output = ModelOutput (sample_kind = "atom_pair" )
6263 assert output .sample_kind == "atom_pair"
6364 with pytest .raises (ValueError ):
6465 _ = output .per_atom
6566
67+
6668@pytest .fixture
6769def system ():
6870 return System (
You can’t perform that action at this time.
0 commit comments