Skip to content

Commit fca4e05

Browse files
pfebrerLuthaf
andcommitted
Change per_atom into sample_kind while keeping backward compatibility
Co-Authored-By: Guillaume Fraux <guillaume.fraux@epfl.ch>
1 parent e726b80 commit fca4e05

22 files changed

Lines changed: 507 additions & 122 deletions

File tree

docs/src/engines/plumed-model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def forward(
2525
if "features" not in outputs:
2626
return {}
2727

28-
if outputs["features"].per_atom:
28+
if outputs["features"].sample_kind == "atom":
2929
raise ValueError("per-atoms features are not supported in this model")
3030

3131
# PLUMED will first call the model with 0 atoms to get the size of the
@@ -94,7 +94,7 @@ def forward(
9494
# metatdata about what the model can do
9595
capabilities = mta.ModelCapabilities(
9696
length_unit="Angstrom",
97-
outputs={"features": mta.ModelOutput(per_atom=False)},
97+
outputs={"features": mta.ModelOutput(sample_kind="system")},
9898
atomic_types=[0],
9999
interaction_range=torch.inf,
100100
supported_devices=["cpu", "cuda"],

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,36 @@ using ModelMetadata = torch::intrusive_ptr<ModelMetadataHolder>;
3232
/// Information about one of the quantity a model can compute
3333
class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder {
3434
public:
35-
ModelOutputHolder() = default;
35+
ModelOutputHolder();
3636

3737
/// Initialize `ModelOutput` with the given data
38+
ModelOutputHolder(
39+
std::string quantity,
40+
std::string unit,
41+
std::string sample_kind,
42+
std::vector<std::string> explicit_gradients_,
43+
std::string description_
44+
);
45+
46+
/// For backward compatibility in the C++ API (per_atom argument)
3847
ModelOutputHolder(
3948
std::string quantity,
4049
std::string unit,
4150
bool per_atom_,
4251
std::vector<std::string> explicit_gradients_,
4352
std::string description_
44-
):
45-
description(std::move(description_)),
46-
per_atom(per_atom_),
47-
explicit_gradients(std::move(explicit_gradients_))
48-
{
49-
this->set_quantity(std::move(quantity));
50-
this->set_unit(std::move(unit));
51-
}
53+
);
54+
55+
/// For backward compatibility in the Python API
56+
ModelOutputHolder(
57+
std::string quantity,
58+
std::string unit,
59+
torch::IValue per_atom_or_sample_kind,
60+
std::vector<std::string> explicit_gradients_,
61+
std::string description_,
62+
torch::optional<bool> per_atom = torch::nullopt,
63+
torch::optional<std::string> sample_kind = torch::nullopt
64+
);
5265

5366
~ModelOutputHolder() override = default;
5467

@@ -72,9 +85,22 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
7285
/// set the unit of the output
7386
void set_unit(std::string unit);
7487

75-
/// is the output defined per-atom or for the overall structure
88+
/// The setter and getter for `per_atom` that are used in TorchBind, which
89+
/// allow us to raise an error if `sample_kind` can't be mapped to a boolean
90+
/// value for `per_atom`.
91+
void set_per_atom(bool per_atom);
92+
bool get_per_atom() const;
93+
94+
/// This is deprecated in favor of `sample_kind`, and kept for backward compatibility reasons only.
95+
[[deprecated("use sample_kind instead")]]
7696
bool per_atom = false;
7797

98+
/// Get the sample kind of the output. TODO: explain
99+
std::string sample_kind() const;
100+
101+
/// Set the `sample_kind` of the output.
102+
void set_sample_kind(std::string sample_kind);
103+
78104
/// Which gradients should be computed eagerly and stored inside the output
79105
/// `TensorMap`
80106
std::vector<std::string> explicit_gradients;
@@ -85,8 +111,12 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
85111
static ModelOutput from_json(std::string_view json);
86112

87113
private:
114+
void set_per_atom_no_deprecation(bool per_atom);
115+
bool get_per_atom_no_deprecation() const;
116+
88117
std::string quantity_;
89118
std::string unit_;
119+
torch::optional<std::string> sample_kind_;
90120
};
91121

92122

metatomic-torch/src/model.cpp

Lines changed: 177 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,92 @@ static void read_vector_int_json(
5353

5454
/******************************************************************************/
5555

56+
#if defined(__GNUC__) || defined(__clang__)
57+
#pragma GCC diagnostic push
58+
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
59+
#endif
60+
61+
ModelOutputHolder::ModelOutputHolder() = default;
62+
63+
ModelOutputHolder::ModelOutputHolder(
64+
std::string quantity,
65+
std::string unit,
66+
std::string sample_kind,
67+
std::vector<std::string> explicit_gradients_,
68+
std::string description_
69+
):
70+
description(std::move(description_)),
71+
explicit_gradients(std::move(explicit_gradients_))
72+
{
73+
this->set_quantity(std::move(quantity));
74+
this->set_unit(std::move(unit));
75+
this->set_sample_kind(std::move(sample_kind));
76+
}
77+
78+
ModelOutputHolder::ModelOutputHolder(
79+
std::string quantity,
80+
std::string unit,
81+
bool per_atom_,
82+
std::vector<std::string> explicit_gradients_,
83+
std::string description_
84+
):
85+
description(std::move(description_)),
86+
explicit_gradients(std::move(explicit_gradients_))
87+
{
88+
this->set_quantity(std::move(quantity));
89+
this->set_unit(std::move(unit));
90+
this->set_per_atom(per_atom_);
91+
}
92+
93+
ModelOutputHolder::ModelOutputHolder(
94+
std::string quantity,
95+
std::string unit,
96+
torch::IValue per_atom_or_sample_kind,
97+
std::vector<std::string> explicit_gradients_,
98+
std::string description_,
99+
torch::optional<bool> per_atom,
100+
torch::optional<std::string> sample_kind
101+
):
102+
description(std::move(description_)),
103+
explicit_gradients(std::move(explicit_gradients_))
104+
{
105+
this->set_quantity(std::move(quantity));
106+
this->set_unit(std::move(unit));
107+
108+
if (per_atom_or_sample_kind.isNone()) {
109+
// check the kwargs for backward compatibility
110+
if (sample_kind.has_value() && per_atom.has_value()) {
111+
C10_THROW_ERROR(ValueError, "cannot specify both `per_atom` and `sample_kind`");
112+
} else if (sample_kind.has_value()) {
113+
this->set_sample_kind(sample_kind.value());
114+
} else if (per_atom.has_value()) {
115+
this->set_per_atom(per_atom.value());
116+
}
117+
} else if (per_atom_or_sample_kind.isBool()) {
118+
if (per_atom.has_value()) {
119+
C10_THROW_ERROR(ValueError,
120+
"cannot specify `per_atom` both as a positional and keyword argument"
121+
);
122+
}
123+
this->set_per_atom(per_atom_or_sample_kind.toBool());
124+
} else if (per_atom_or_sample_kind.isString()) {
125+
if (sample_kind.has_value()) {
126+
C10_THROW_ERROR(ValueError,
127+
"cannot specify `sample_kind` both as a positional and keyword argument"
128+
);
129+
}
130+
this->set_sample_kind(per_atom_or_sample_kind.toStringRef());
131+
} else {
132+
C10_THROW_ERROR(ValueError,
133+
"positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
134+
);
135+
}
136+
}
137+
138+
#if defined(__GNUC__) || defined(__clang__)
139+
#pragma GCC diagnostic pop
140+
#endif
141+
56142
void ModelOutputHolder::set_quantity(std::string quantity) {
57143
if (valid_quantity(quantity)) {
58144
validate_unit(quantity, unit_);
@@ -72,7 +158,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72158
result["class"] = "ModelOutput";
73159
result["quantity"] = self.quantity();
74160
result["unit"] = self.unit();
75-
result["per_atom"] = self.per_atom;
161+
result["sample_kind"] = self.sample_kind();
76162
result["explicit_gradients"] = self.explicit_gradients;
77163
result["description"] = self.description;
78164

@@ -112,11 +198,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112198
result->set_unit(data["unit"]);
113199
}
114200

115-
if (data.contains("per_atom")) {
201+
if (data.contains("sample_kind")) {
202+
if (!data["sample_kind"].is_string()) {
203+
throw std::runtime_error("'sample_kind' in JSON for ModelOutput must be a string");
204+
}
205+
result->set_sample_kind(data["sample_kind"]);
206+
} else if (data.contains("per_atom")) {
116207
if (!data["per_atom"].is_boolean()) {
117208
throw std::runtime_error("'per_atom' in JSON for ModelOutput must be a boolean");
118209
}
119-
result->per_atom = data["per_atom"];
210+
result->set_per_atom(data["per_atom"]);
211+
} else {
212+
result->set_sample_kind("system");
120213
}
121214

122215
if (data.contains("explicit_gradients")) {
@@ -145,6 +238,87 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145238
return model_output_from_json(data);
146239
}
147240

241+
static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
242+
"system",
243+
"atom",
244+
"atom_pair",
245+
};
246+
247+
void ModelOutputHolder::set_sample_kind(std::string sample_kind) {
248+
if (sample_kind == "atom") {
249+
this->set_per_atom_no_deprecation(true);
250+
} else if (sample_kind == "system") {
251+
this->set_per_atom_no_deprecation(false);
252+
} else {
253+
if (SUPPORTED_SAMPLE_KINDS.find(sample_kind) == SUPPORTED_SAMPLE_KINDS.end()) {
254+
C10_THROW_ERROR(ValueError,
255+
"invalid sample_kind '" + sample_kind + "': supported values are [" +
256+
torch::str(SUPPORTED_SAMPLE_KINDS) + "]"
257+
);
258+
}
259+
260+
this->sample_kind_ = std::move(sample_kind);
261+
}
262+
}
263+
264+
std::string ModelOutputHolder::sample_kind() const {
265+
if (sample_kind_.has_value()) {
266+
return sample_kind_.value();
267+
} else if (this->get_per_atom_no_deprecation()) {
268+
return "atom";
269+
} else {
270+
return "system";
271+
}
272+
}
273+
274+
void ModelOutputHolder::set_per_atom(bool per_atom_) {
275+
TORCH_WARN_DEPRECATION(
276+
"`per_atom` is deprecated, please use `sample_kind` instead"
277+
);
278+
279+
this->set_per_atom_no_deprecation(per_atom_);
280+
}
281+
282+
bool ModelOutputHolder::get_per_atom() const {
283+
TORCH_WARN_DEPRECATION(
284+
"`per_atom` is deprecated, please use `sample_kind` instead"
285+
);
286+
287+
return this->get_per_atom_no_deprecation();
288+
}
289+
290+
#if defined(__GNUC__) || defined(__clang__)
291+
#pragma GCC diagnostic push
292+
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
293+
#endif
294+
295+
void ModelOutputHolder::set_per_atom_no_deprecation(bool per_atom) {
296+
this->per_atom = per_atom;
297+
298+
this->sample_kind_ = torch::nullopt;
299+
}
300+
301+
bool ModelOutputHolder::get_per_atom_no_deprecation() const {
302+
if (sample_kind_.has_value()) {
303+
if (sample_kind_.value() == "atom") {
304+
return true;
305+
} else if (sample_kind_.value() == "system") {
306+
return false;
307+
} else {
308+
C10_THROW_ERROR(
309+
ValueError,
310+
"Can't infer `per_atom` from `sample_kind` '" + this->sample_kind() + "'. "
311+
"`per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
312+
);
313+
}
314+
}
315+
return per_atom;
316+
}
317+
318+
#if defined(__GNUC__) || defined(__clang__)
319+
#pragma GCC diagnostic pop
320+
#endif
321+
148322
/******************************************************************************/
149323

150324

0 commit comments

Comments
 (0)