Skip to content

Commit c1f79df

Browse files
committed
Make sure deprecation warnings don't reach users that can not act on them
1 parent 256a66e commit c1f79df

5 files changed

Lines changed: 42 additions & 29 deletions

File tree

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
4343
std::string description_
4444
);
4545

46+
/// Overload to use the `sample_kind` constructor with a `const char*`,
47+
/// otherwise this would default to calling the `per_atom` constructor.
48+
ModelOutputHolder(
49+
std::string quantity,
50+
std::string unit,
51+
const char* sample_kind,
52+
std::vector<std::string> explicit_gradients_,
53+
std::string description_
54+
): ModelOutputHolder(
55+
std::move(quantity),
56+
std::move(unit),
57+
std::string(sample_kind),
58+
std::move(explicit_gradients_),
59+
std::move(description_)
60+
) {}
61+
4662
/// For backward compatibility in the C++ API (per_atom argument)
4763
ModelOutputHolder(
4864
std::string quantity,

metatomic-torch/src/model.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ std::string ModelOutputHolder::to_json() const {
169169
return model_output_to_json(*this).dump(/*indent*/4, /*indent_char*/' ', /*ensure_ascii*/ true);
170170
}
171171

172-
static ModelOutput model_output_from_json(const nlohmann::json& data) {
172+
ModelOutput ModelOutputHolder::from_json(std::string_view json) {
173+
auto data = nlohmann::json::parse(json);
173174
if (!data.is_object()) {
174175
throw std::runtime_error("invalid JSON data for ModelOutput, expected an object");
175176
}
@@ -207,7 +208,7 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
207208
if (!data["per_atom"].is_boolean()) {
208209
throw std::runtime_error("'per_atom' in JSON for ModelOutput must be a boolean");
209210
}
210-
result->set_per_atom(data["per_atom"]);
211+
result->set_per_atom_no_deprecation(data["per_atom"]);
211212
} else {
212213
result->set_sample_kind("system");
213214
}
@@ -233,11 +234,6 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
233234
return result;
234235
}
235236

236-
ModelOutput ModelOutputHolder::from_json(std::string_view json) {
237-
auto data = nlohmann::json::parse(json);
238-
return model_output_from_json(data);
239-
}
240-
241237
static std::set<std::string> SUPPORTED_SAMPLE_KINDS = {
242238
"system",
243239
"atom",
@@ -426,7 +422,7 @@ ModelCapabilities ModelCapabilitiesHolder::from_json(std::string_view json) {
426422
}
427423

428424
for (const auto& output: data["outputs"].items()) {
429-
outputs.insert(output.key(), model_output_from_json(output.value()));
425+
outputs.insert(output.key(), ModelOutputHolder::from_json(output.value().dump()));
430426
}
431427

432428
result->set_outputs(outputs);
@@ -620,7 +616,7 @@ ModelEvaluationOptions ModelEvaluationOptionsHolder::from_json(std::string_view
620616
}
621617

622618
for (const auto& output: data["outputs"].items()) {
623-
result->outputs.insert(output.key(), model_output_from_json(output.value()));
619+
result->outputs.insert(output.key(), ModelOutputHolder::from_json(output.value().dump()));
624620
}
625621
}
626622

metatomic-torch/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
3030
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded")
3131
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-exit-time-destructors")
3232
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-global-constructors")
33+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-poison-system-directories")
3334
endif()
3435

3536
file(GLOB ALL_TESTS *.cpp)

metatomic-torch/tests/models.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ TEST_CASE("Models metadata") {
9090
output = ModelOutputHolder::from_json(json);
9191
CHECK(output->quantity() == "length");
9292
CHECK(output->unit().empty());
93-
CHECK(output->get_per_atom() == false);
9493
CHECK(output->sample_kind() == "system");
9594
CHECK(output->explicit_gradients.empty());
9695

@@ -135,7 +134,7 @@ TEST_CASE("Models metadata") {
135134
options->outputs.insert("output_1", torch::make_intrusive<ModelOutputHolder>());
136135

137136
auto output = torch::make_intrusive<ModelOutputHolder>();
138-
output->set_per_atom(true);
137+
output->set_sample_kind("atom");
139138
output->set_quantity("energy");
140139
output->set_unit("eV");
141140
options->outputs.insert("output_2", output);
@@ -193,7 +192,6 @@ TEST_CASE("Models metadata") {
193192
output = options->outputs.at("foo");
194193
CHECK(output->quantity().empty());
195194
CHECK(output->unit().empty());
196-
CHECK(output->get_per_atom() == false);
197195
CHECK(output->sample_kind() == "system");
198196
CHECK(output->explicit_gradients == std::vector<std::string>{"test"});
199197

@@ -286,7 +284,6 @@ TEST_CASE("Models metadata") {
286284
CHECK(output->quantity().empty());
287285
CHECK(output->unit().empty());
288286
// check that we can load JSON with `per_atom` and without `sample_kind`
289-
CHECK(output->get_per_atom() == true);
290287
CHECK(output->sample_kind() == "atom");
291288
CHECK(output->explicit_gradients == std::vector<std::string>{"µ-test"});
292289

@@ -305,7 +302,7 @@ TEST_CASE("Models metadata") {
305302

306303
auto capabilities_variants = torch::make_intrusive<ModelCapabilitiesHolder>();
307304
auto output_variant = torch::make_intrusive<ModelOutputHolder>();
308-
output_variant->set_per_atom(true);
305+
output_variant->set_sample_kind("atom");
309306
output_variant->description = "variant output";
310307

311308
auto outputs_variant = torch::Dict<std::string, ModelOutput>();

metatomic-torch/tests/units.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ TEST_CASE("ModelOutput rejects mismatched quantity and unit") {
284284
// energy quantity with a force unit
285285
CHECK_THROWS_WITH(
286286
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
287-
"energy", "eV/A", false, std::vector<std::string>{}, ""
287+
"energy", "eV/A", "system", std::vector<std::string>{}, ""
288288
),
289289
Contains(
290290
"unit 'eV/A' has dimension L T^-2 M which is incompatible "
@@ -295,43 +295,46 @@ TEST_CASE("ModelOutput rejects mismatched quantity and unit") {
295295
// force quantity with an energy unit
296296
CHECK_THROWS_WITH(
297297
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
298-
"force", "eV", false, std::vector<std::string>{}, ""
298+
"force", "eV", "system", std::vector<std::string>{}, ""
299299
),
300300
Contains(
301301
"unit 'eV' has dimension L^2 T^-2 M which is incompatible "
302302
"with quantity 'force' (expected L T^-2 M)"
303303
)
304304
);
305305

306-
// // length quantity with a pressure unit
307-
// CHECK_THROWS_WITH(
308-
// torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
309-
// "length", "eV/A^3", false, std::vector<std::string>{}, ""
310-
// ),
311-
// Contains("incompatible with qsfqk,uantity")
312-
// );
306+
// length quantity with a pressure unit
307+
CHECK_THROWS_WITH(
308+
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
309+
"length", "eV/A^3", "system", std::vector<std::string>{}, ""
310+
),
311+
Contains(
312+
"unit 'eV/A^3' has dimension L^-1 T^-2 M which is incompatible with "
313+
"quantity 'length' (expected L)"
314+
)
315+
);
313316
}
314317

315318

316319
TEST_CASE("ModelOutput accepts matching quantity and unit") {
317320
// These should not throw
318321
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
319-
"energy", "eV", false, std::vector<std::string>{}, ""
322+
"energy", "eV", "system", std::vector<std::string>{}, ""
320323
);
321324
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
322-
"force", "eV/A", false, std::vector<std::string>{}, ""
325+
"force", "eV/A", "system", std::vector<std::string>{}, ""
323326
);
324327
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
325-
"pressure", "eV/A^3", false, std::vector<std::string>{}, ""
328+
"pressure", "eV/A^3", "system", std::vector<std::string>{}, ""
326329
);
327330
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
328-
"length", "Angstrom", false, std::vector<std::string>{}, ""
331+
"length", "Angstrom", "system", std::vector<std::string>{}, ""
329332
);
330333
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
331-
"momentum", "u*A/fs", false, std::vector<std::string>{}, ""
334+
"momentum", "u*A/fs", "system", std::vector<std::string>{}, ""
332335
);
333336
torch::make_intrusive<metatomic_torch::ModelOutputHolder>(
334-
"velocity", "A/fs", false, std::vector<std::string>{}, ""
337+
"velocity", "A/fs", "system", std::vector<std::string>{}, ""
335338
);
336339
}
337340

0 commit comments

Comments
 (0)