Skip to content

Commit b1f3e52

Browse files
committed
Now working with the conda lammps
1 parent 641998c commit b1f3e52

7 files changed

Lines changed: 110 additions & 64 deletions

File tree

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,23 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
4949
std::string quantity,
5050
std::string unit,
5151
torch::optional<bool> per_atom_,
52-
torch::optional<std::string> sample_kind_,
5352
std::vector<std::string> explicit_gradients_,
54-
std::string description_
53+
std::string description_,
54+
torch::optional<std::string> sample_kind = torch::nullopt
5555
):
5656
description(std::move(description_)),
5757
explicit_gradients(std::move(explicit_gradients_))
5858
{
5959
this->set_quantity(std::move(quantity));
6060
this->set_unit(std::move(unit));
6161

62-
if (per_atom_.has_value() && sample_kind_.has_value()) {
62+
if (per_atom_.has_value() && sample_kind.has_value()) {
6363
C10_THROW_ERROR(ValueError, "Cannot set both `per_atom` and `sample_kind` for a ModelOutput");
64-
} else if (per_atom_.has_value()) {
65-
this->set_per_atom(std::move(per_atom_.value()));
64+
} else if (sample_kind.has_value()) {
65+
this->set_sample_kind(sample_kind.value());
6666
} else {
67-
this->set_sample_kind(std::move(sample_kind_.value_or("system")));
68-
}
67+
this->set_per_atom(per_atom_.value_or(false));
68+
}
6969

7070
}
7171

@@ -91,16 +91,43 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
9191
/// set the unit of the output
9292
void set_unit(std::string unit);
9393

94-
/// Which kind of sample this output is. E.g. "system", "atom", "atom_pair"...
95-
std::string sample_kind;
96-
const std::string& get_sample_kind() const {
97-
return sample_kind;
94+
/// Although we are moving to using `sample_kind` instead of `per_atom`,
95+
/// for now we need to keep the `per_atom` data member for backward
96+
/// compatibility in the C++ API and ABI.
97+
bool per_atom;
98+
99+
/// The setter and getter for `per_atom` that are used in python, which
100+
/// allow us to raise an error if `sample_kind` can't be mapped to
101+
/// a boolean value for `per_atom`.
102+
void set_per_atom(bool per_atom_) {
103+
this->per_atom = per_atom_;
104+
/// Unset sample_kind
105+
this->sample_kind_ = torch::nullopt;
106+
}
107+
bool get_per_atom() const {
108+
if (sample_kind_.has_value()) {
109+
C10_THROW_ERROR(
110+
ValueError,
111+
"Can't infer `per_atom` from `sample_kind` '" + this->sample_kind() + "'. "
112+
"`per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
113+
);
114+
}
115+
return per_atom;
116+
}
117+
118+
/// Sample kind of the output. For now it relies on the `per_atom` boolean for
119+
/// backward compatibility, but in the future we will simply have a variable
120+
/// storing the sample kind.
121+
const std::string sample_kind() const {
122+
if (sample_kind_.has_value()) {
123+
return sample_kind_.value();
124+
} else if (per_atom) {
125+
return "atom";
126+
} else {
127+
return "system";
128+
}
98129
}
99130
void set_sample_kind(std::string sample_kind);
100-
101-
// For backward compatibility.
102-
void set_per_atom(bool per_atom);
103-
bool get_per_atom() const;
104131

105132
/// Which gradients should be computed eagerly and stored inside the output
106133
/// `TensorMap`
@@ -114,6 +141,12 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
114141
private:
115142
std::string quantity_;
116143
std::string unit_;
144+
torch::optional<std::string> sample_kind_;
145+
std::vector<std::string> supported_sample_kinds_ = {
146+
"system",
147+
"atom",
148+
"atom_pair",
149+
};
117150
};
118151

119152

metatomic-torch/src/model.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -66,49 +66,48 @@ void ModelOutputHolder::set_unit(std::string unit) {
6666
this->unit_ = std::move(unit);
6767
}
6868

69-
void ModelOutputHolder::set_sample_kind(std::string sample_kind) {
70-
/// Sample kind has to be one of "system", "atom", "atom_pair"
71-
if (sample_kind != "system" && sample_kind != "atom" && sample_kind != "atom_pair") {
72-
C10_THROW_ERROR(ValueError,
73-
"invalid sample kind '" + sample_kind + "'. Only 'system', 'atom' and 'atom_pair' are supported"
74-
);
69+
std::string join_strings(const std::vector<std::string>& strings, const std::string& separator) {
70+
std::ostringstream oss;
71+
for (size_t i = 0; i < strings.size(); ++i) {
72+
if (i > 0) {
73+
oss << separator;
74+
}
75+
oss << strings[i];
7576
}
76-
77-
this->sample_kind = std::move(sample_kind);
77+
return oss.str();
7878
}
7979

80-
/// For backward compatibility, we keep the `per_atom` property, which
81-
/// is not really stored, it is simply derived from `sample_kind`.
82-
bool ModelOutputHolder::get_per_atom() const {
80+
void ModelOutputHolder::set_sample_kind(std::string sample_kind) {
8381
if (sample_kind == "atom") {
84-
return true;
82+
this->set_per_atom(true);
8583
} else if (sample_kind == "system") {
86-
return false;
84+
this->set_per_atom(false);
8785
} else {
88-
C10_THROW_ERROR( ValueError,
89-
"Cannot determine `per_atom` from sample_kind '" + sample_kind + "'. "
90-
);
91-
}
92-
}
86+
/// If sample_kind is not a value that can be mapped to per_atom,
87+
/// we just store the value in the sample_kind_ private field.
9388

94-
/// If the user sets `per_atom`, we update the `sample_kind` accordingly.
95-
void ModelOutputHolder::set_per_atom(bool per_atom) {
96-
if (per_atom) {
97-
this->sample_kind = "atom";
98-
} else {
99-
this->sample_kind = "system";
89+
// Warn if the sample_kind is not one of the supported ones.
90+
if (std::find(supported_sample_kinds_.begin(), supported_sample_kinds_.end(), sample_kind) == supported_sample_kinds_.end()) {
91+
TORCH_WARN(
92+
"Sample_kind '", sample_kind, "' is not officially supported. ",
93+
"This means that metatomic doesn't natively understand how to deal ",
94+
"with such outputs. If this is a mistake, pass one of the supported ",
95+
"sample kinds instead: [", join_strings(supported_sample_kinds_, ", "), "]. "
96+
);
97+
}
98+
99+
this->sample_kind_ = std::move(sample_kind);
100100
}
101101
}
102102

103-
104-
105103
static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
106104
nlohmann::json result;
107105

108106
result["class"] = "ModelOutput";
109107
result["quantity"] = self.quantity();
110108
result["unit"] = self.unit();
111-
result["sample_kind"] = self.sample_kind;
109+
result["per_atom"] = self.per_atom;
110+
result["sample_kind"] = self.sample_kind();
112111
result["explicit_gradients"] = self.explicit_gradients;
113112
result["description"] = self.description;
114113

metatomic-torch/src/outputs.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,18 @@ static void validate_atomic_samples(
8686

8787
// Check if the samples names are as expected based on the sample_kind
8888
std::vector<std::string> expected_samples_names;
89-
if (request->sample_kind == "atom") {
89+
if (request->sample_kind() == "atom") {
9090
expected_samples_names = {"system", "atom"};
91-
} else if (request->sample_kind == "atom_pair") {
92-
expected_samples_names = {"system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"};
93-
} else {
91+
} else if (request->sample_kind() == "system") {
9492
expected_samples_names = {"system"};
93+
} else if (request->sample_kind() == "atom_pair") {
94+
expected_samples_names = {"system", "first_atom", "second_atom"};
95+
} else {
96+
C10_THROW_ERROR(ValueError,
97+
"Metatomic does not support validating samples for sample_kind"
98+
"other than 'system', 'atom' or 'atom_pair' at the moment."
99+
" Received sample_kind '" + request->sample_kind()
100+
);
95101
}
96102

97103
if (block->samples()->names() != expected_samples_names) {
@@ -104,7 +110,7 @@ static void validate_atomic_samples(
104110

105111
// Check if the samples match the systems and selected_atoms
106112
Labels expected_samples;
107-
if (request->sample_kind == "atom") {
113+
if (request->sample_kind() == "atom") {
108114
std::vector<int64_t> expected_values_flat;
109115
for (size_t s; s < systems.size(); s++) {
110116
for (size_t a; a < systems[s]->size(); a++) {
@@ -123,7 +129,7 @@ static void validate_atomic_samples(
123129
if (selected_atoms) {
124130
expected_samples = expected_samples->set_intersection(selected_atoms.value());
125131
}
126-
} else if (request->sample_kind == "system") {
132+
} else if (request->sample_kind() == "system") {
127133
expected_samples = torch::make_intrusive<LabelsHolder>(
128134
"system",
129135
torch::arange(static_cast<int64_t>(systems.size()), tensor_options).reshape({-1, 1}),
@@ -139,6 +145,9 @@ static void validate_atomic_samples(
139145
);
140146
expected_samples = expected_samples->set_intersection(selected_systems);
141147
}
148+
} else {
149+
/// We don't validate values for other cases for now
150+
return;
142151
}
143152

144153
if (expected_samples->set_union(block->samples())->size() != expected_samples->size()) {

metatomic-torch/src/register.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,24 +154,24 @@ TORCH_LIBRARY(metatomic, m) {
154154
std::string,
155155
std::string,
156156
torch::optional<bool>,
157-
torch::optional<std::string>,
158157
std::vector<std::string>,
159-
std::string
158+
std::string,
159+
torch::optional<std::string>
160160
>(),
161161
DOCSTRING, {
162162
torch::arg("quantity") = "",
163163
torch::arg("unit") = "",
164164
torch::arg("per_atom") = std::nullopt,
165-
torch::arg("sample_kind") = std::nullopt,
166165
torch::arg("explicit_gradients") = std::vector<std::string>(),
167166
torch::arg("description") = "",
167+
torch::arg("sample_kind") = std::nullopt,
168168
}
169169
)
170170
.def_readwrite("description", &ModelOutputHolder::description)
171171
.def_property("quantity", &ModelOutputHolder::quantity, &ModelOutputHolder::set_quantity)
172172
.def_property("unit", &ModelOutputHolder::unit, &ModelOutputHolder::set_unit)
173173
.def_property("per_atom", &ModelOutputHolder::get_per_atom, &ModelOutputHolder::set_per_atom)
174-
.def_property("sample_kind", &ModelOutputHolder::get_sample_kind, &ModelOutputHolder::set_sample_kind)
174+
.def_property("sample_kind", &ModelOutputHolder::sample_kind, &ModelOutputHolder::set_sample_kind)
175175
.def_readwrite("explicit_gradients", &ModelOutputHolder::explicit_gradients)
176176
.def_pickle(
177177
[](const ModelOutput& self) -> std::string {

metatomic-torch/tests/models.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ TEST_CASE("Models metadata") {
6565
output->description = "my awesome energy";
6666
output->set_quantity("energy");
6767
output->set_unit("kJ / mol");
68-
output->set_per_atom(false);
68+
output->per_atom = false;
6969
output->explicit_gradients = {"baz", "not.this-one_"};
7070

7171
const auto* expected = R"({
@@ -75,6 +75,7 @@ TEST_CASE("Models metadata") {
7575
"baz",
7676
"not.this-one_"
7777
],
78+
"per_atom": false,
7879
"quantity": "energy",
7980
"sample_kind": "system",
8081
"unit": "kJ / mol"
@@ -90,7 +91,7 @@ TEST_CASE("Models metadata") {
9091
output = ModelOutputHolder::from_json(json);
9192
CHECK(output->quantity() == "length");
9293
CHECK(output->unit().empty());
93-
CHECK(output->get_per_atom() == false);
94+
CHECK(output->per_atom == false);
9495
CHECK(output->explicit_gradients.empty());
9596

9697
CHECK_THROWS_WITH(
@@ -132,11 +133,11 @@ TEST_CASE("Models metadata") {
132133
options->set_length_unit("nanometer");
133134

134135
auto output1 = torch::make_intrusive<ModelOutputHolder>();
135-
output1->set_per_atom(false);
136+
output1->per_atom = false;
136137
options->outputs.insert("output_1", output1);
137138

138139
auto output2 = torch::make_intrusive<ModelOutputHolder>();
139-
output2->set_per_atom(true);
140+
output2->per_atom = true;
140141
output2->set_quantity("something");
141142
output2->set_unit("something");
142143
options->outputs.insert("output_2", output2);
@@ -149,6 +150,7 @@ TEST_CASE("Models metadata") {
149150
"class": "ModelOutput",
150151
"description": "",
151152
"explicit_gradients": [],
153+
"per_atom": false,
152154
"quantity": "",
153155
"sample_kind": "system",
154156
"unit": ""
@@ -157,6 +159,7 @@ TEST_CASE("Models metadata") {
157159
"class": "ModelOutput",
158160
"description": "",
159161
"explicit_gradients": [],
162+
"per_atom": true,
160163
"quantity": "something",
161164
"sample_kind": "atom",
162165
"unit": "something"
@@ -194,7 +197,7 @@ TEST_CASE("Models metadata") {
194197
auto output = options->outputs.at("foo");
195198
CHECK(output->quantity().empty());
196199
CHECK(output->unit().empty());
197-
CHECK(output->get_per_atom() == false);
200+
CHECK(output->per_atom == false);
198201
CHECK(output->explicit_gradients == std::vector<std::string>{"test"});
199202

200203
CHECK_THROWS_WITH(
@@ -221,7 +224,7 @@ TEST_CASE("Models metadata") {
221224
capabilities->supported_devices = {"cuda", "xla", "cpu"};
222225

223226
auto output = torch::make_intrusive<ModelOutputHolder>();
224-
output->set_per_atom(true);
227+
output->per_atom = true;
225228
output->set_quantity("length");
226229
output->explicit_gradients.emplace_back("µ-λ");
227230

@@ -246,6 +249,7 @@ TEST_CASE("Models metadata") {
246249
"explicit_gradients": [
247250
"\u00b5-\u03bb"
248251
],
252+
"per_atom": true,
249253
"quantity": "length",
250254
"sample_kind": "atom",
251255
"unit": ""
@@ -284,7 +288,7 @@ TEST_CASE("Models metadata") {
284288
output = capabilities->outputs().at("tests::foo");
285289
CHECK(output->quantity().empty());
286290
CHECK(output->unit().empty());
287-
CHECK(output->get_per_atom() == false);
291+
CHECK(output->per_atom == false);
288292
CHECK(output->explicit_gradients == std::vector<std::string>{"µ-test"});
289293

290294
CHECK_THROWS_WITH(
@@ -302,7 +306,7 @@ TEST_CASE("Models metadata") {
302306

303307
auto capabilities_variants = torch::make_intrusive<ModelCapabilitiesHolder>();
304308
auto output_variant = torch::make_intrusive<ModelOutputHolder>();
305-
output_variant->set_per_atom(true);
309+
output_variant->per_atom = true;
306310
output_variant->description = "variant output";
307311

308312
auto outputs_variant = torch::Dict<std::string, ModelOutput>();

python/metatomic_torch/metatomic/torch/documentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ def __init__(
254254
quantity: str = "",
255255
unit: str = "",
256256
per_atom: Optional[bool] = None,
257-
sample_kind: Optional[Literal["system", "atom", "atom_pair"]] = None,
258257
explicit_gradients: List[str] = [], # noqa B006
259258
description: str = "",
259+
sample_kind: Optional[Literal["system", "atom", "atom_pair"]] = None,
260260
):
261261
pass
262262

python/metatomic_torch/tests/outputs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def test_sample_kind():
5353
with pytest.raises(ValueError):
5454
ModelOutput(per_atom=True, sample_kind="system")
5555

56-
# Check that setting sample_kind to an invalid value raises an error
57-
with pytest.raises(ValueError):
58-
ModelOutput(sample_kind="invalid_value")
56+
# Arbitrary sample_kind values are allowed, although they will not be
57+
# supported by metatomic interfaces to engines. Setting sample_kind
58+
# to an arbitrary value will issue a warning.
59+
ModelOutput(sample_kind="arbitrary_value")
5960

6061
# Initialize model output with sample_kind="atom_pair"
6162
# and check that per_atom can not be retrieved

0 commit comments

Comments
 (0)