@@ -53,6 +53,92 @@ static void read_vector_int_json(
5353
5454/* *****************************************************************************/
5555
56+ #if defined(__GNUC__) || defined(__clang__)
57+ #pragma GCC diagnostic push
58+ #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
59+ #endif
60+
61+ ModelOutputHolder::ModelOutputHolder () = default;
62+
63+ ModelOutputHolder::ModelOutputHolder (
64+ std::string quantity,
65+ std::string unit,
66+ std::string sample_kind,
67+ std::vector<std::string> explicit_gradients_,
68+ std::string description_
69+ ):
70+ description(std::move(description_)),
71+ explicit_gradients(std::move(explicit_gradients_))
72+ {
73+ this ->set_quantity (std::move (quantity));
74+ this ->set_unit (std::move (unit));
75+ this ->set_sample_kind (std::move (sample_kind));
76+ }
77+
78+ ModelOutputHolder::ModelOutputHolder (
79+ std::string quantity,
80+ std::string unit,
81+ bool per_atom_,
82+ std::vector<std::string> explicit_gradients_,
83+ std::string description_
84+ ):
85+ description(std::move(description_)),
86+ explicit_gradients(std::move(explicit_gradients_))
87+ {
88+ this ->set_quantity (std::move (quantity));
89+ this ->set_unit (std::move (unit));
90+ this ->set_per_atom (per_atom_);
91+ }
92+
93+ ModelOutputHolder::ModelOutputHolder (
94+ std::string quantity,
95+ std::string unit,
96+ torch::IValue per_atom_or_sample_kind,
97+ std::vector<std::string> explicit_gradients_,
98+ std::string description_,
99+ torch::optional<bool > per_atom,
100+ torch::optional<std::string> sample_kind
101+ ):
102+ description(std::move(description_)),
103+ explicit_gradients(std::move(explicit_gradients_))
104+ {
105+ this ->set_quantity (std::move (quantity));
106+ this ->set_unit (std::move (unit));
107+
108+ if (per_atom_or_sample_kind.isNone ()) {
109+ // check the kwargs for backward compatibility
110+ if (sample_kind.has_value () && per_atom.has_value ()) {
111+ C10_THROW_ERROR (ValueError, " cannot specify both `per_atom` and `sample_kind`" );
112+ } else if (sample_kind.has_value ()) {
113+ this ->set_sample_kind (sample_kind.value ());
114+ } else if (per_atom.has_value ()) {
115+ this ->set_per_atom (per_atom.value ());
116+ }
117+ } else if (per_atom_or_sample_kind.isBool ()) {
118+ if (per_atom.has_value ()) {
119+ C10_THROW_ERROR (ValueError,
120+ " cannot specify `per_atom` both as a positional and keyword argument"
121+ );
122+ }
123+ this ->set_per_atom (per_atom_or_sample_kind.toBool ());
124+ } else if (per_atom_or_sample_kind.isString ()) {
125+ if (sample_kind.has_value ()) {
126+ C10_THROW_ERROR (ValueError,
127+ " cannot specify `sample_kind` both as a positional and keyword argument"
128+ );
129+ }
130+ this ->set_sample_kind (per_atom_or_sample_kind.toStringRef ());
131+ } else {
132+ C10_THROW_ERROR (ValueError,
133+ " positional argument for `per_atom`/`sample_kind` must be either a boolean or a string"
134+ );
135+ }
136+ }
137+
138+ #if defined(__GNUC__) || defined(__clang__)
139+ #pragma GCC diagnostic pop
140+ #endif
141+
56142void ModelOutputHolder::set_quantity (std::string quantity) {
57143 if (valid_quantity (quantity)) {
58144 validate_unit (quantity, unit_);
@@ -72,7 +158,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
72158 result[" class" ] = " ModelOutput" ;
73159 result[" quantity" ] = self.quantity ();
74160 result[" unit" ] = self.unit ();
75- result[" per_atom " ] = self.per_atom ;
161+ result[" sample_kind " ] = self.sample_kind () ;
76162 result[" explicit_gradients" ] = self.explicit_gradients ;
77163 result[" description" ] = self.description ;
78164
@@ -112,11 +198,18 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
112198 result->set_unit (data[" unit" ]);
113199 }
114200
115- if (data.contains (" per_atom" )) {
201+ if (data.contains (" sample_kind" )) {
202+ if (!data[" sample_kind" ].is_string ()) {
203+ throw std::runtime_error (" 'sample_kind' in JSON for ModelOutput must be a string" );
204+ }
205+ result->set_sample_kind (data[" sample_kind" ]);
206+ } else if (data.contains (" per_atom" )) {
116207 if (!data[" per_atom" ].is_boolean ()) {
117208 throw std::runtime_error (" 'per_atom' in JSON for ModelOutput must be a boolean" );
118209 }
119- result->per_atom = data[" per_atom" ];
210+ result->set_per_atom (data[" per_atom" ]);
211+ } else {
212+ result->set_sample_kind (" system" );
120213 }
121214
122215 if (data.contains (" explicit_gradients" )) {
@@ -145,6 +238,87 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
145238 return model_output_from_json (data);
146239}
147240
241+ static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
242+ " system" ,
243+ " atom" ,
244+ " atom_pair" ,
245+ };
246+
247+ void ModelOutputHolder::set_sample_kind (std::string sample_kind) {
248+ if (sample_kind == " atom" ) {
249+ this ->set_per_atom_no_deprecation (true );
250+ } else if (sample_kind == " system" ) {
251+ this ->set_per_atom_no_deprecation (false );
252+ } else {
253+ if (SUPPORTED_SAMPLE_KINDS.find (sample_kind) == SUPPORTED_SAMPLE_KINDS.end ()) {
254+ C10_THROW_ERROR (ValueError,
255+ " invalid sample_kind '" + sample_kind + " ': supported values are [" +
256+ torch::str (SUPPORTED_SAMPLE_KINDS) + " ]"
257+ );
258+ }
259+
260+ this ->sample_kind_ = std::move (sample_kind);
261+ }
262+ }
263+
264+ std::string ModelOutputHolder::sample_kind () const {
265+ if (sample_kind_.has_value ()) {
266+ return sample_kind_.value ();
267+ } else if (this ->get_per_atom_no_deprecation ()) {
268+ return " atom" ;
269+ } else {
270+ return " system" ;
271+ }
272+ }
273+
274+ void ModelOutputHolder::set_per_atom (bool per_atom_) {
275+ TORCH_WARN_DEPRECATION (
276+ " `per_atom` is deprecated, please use `sample_kind` instead"
277+ );
278+
279+ this ->set_per_atom_no_deprecation (per_atom_);
280+ }
281+
282+ bool ModelOutputHolder::get_per_atom () const {
283+ TORCH_WARN_DEPRECATION (
284+ " `per_atom` is deprecated, please use `sample_kind` instead"
285+ );
286+
287+ return this ->get_per_atom_no_deprecation ();
288+ }
289+
290+ #if defined(__GNUC__) || defined(__clang__)
291+ #pragma GCC diagnostic push
292+ #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
293+ #endif
294+
295+ void ModelOutputHolder::set_per_atom_no_deprecation (bool per_atom) {
296+ this ->per_atom = per_atom;
297+
298+ this ->sample_kind_ = torch::nullopt ;
299+ }
300+
301+ bool ModelOutputHolder::get_per_atom_no_deprecation () const {
302+ if (sample_kind_.has_value ()) {
303+ if (sample_kind_.value () == " atom" ) {
304+ return true ;
305+ } else if (sample_kind_.value () == " system" ) {
306+ return false ;
307+ } else {
308+ C10_THROW_ERROR (
309+ ValueError,
310+ " Can't infer `per_atom` from `sample_kind` '" + this ->sample_kind () + " '. "
311+ " `per_atom` only makes sense for `sample_kind` 'atom' and 'system'."
312+ );
313+ }
314+ }
315+ return per_atom;
316+ }
317+
318+ #if defined(__GNUC__) || defined(__clang__)
319+ #pragma GCC diagnostic pop
320+ #endif
321+
148322/* *****************************************************************************/
149323
150324
0 commit comments