Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions src/shambindings/include/shambindings/pybindaliases.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,141 @@
*/

#include "shambase/call_lambda.hpp"
#include "shambase/exception.hpp"
#include "shambase/unique_name_macro.hpp"
#include <pybind11/pybind11.h>
#include <map>
#include <optional>
Comment on lines 27 to +29

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The header is missing several standard library includes required by the registry_t template and the build_all_modules function. Specifically, it needs , <string_view>, , , , , and to be self-contained and avoid compilation errors in different translation units.

Suggested change
#include <pybind11/pybind11.h>
#include <map>
#include <optional>
#include <pybind11/pybind11.h>
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <vector>


/// alias to pybind11 namespace
namespace py = pybind11;

// ------------------------------------------------------------------
// Submodule registry
// ------------------------------------------------------------------

namespace shambindings::submodules {

// ------------------------------------------------------------------
// Generic registry
// ------------------------------------------------------------------

template<typename T>
struct registry_t {

// note here that std::less<> is for transparent lookup with string_view
using map_t = std::map<std::string, T, std::less<>>;
std::unique_ptr<map_t> storage = {};

using getter_fn = std::function<T &()>;
std::optional<std::map<std::string, getter_fn, std::less<>>> overrides;

map_t &data() {
if (!storage) {
storage = std::make_unique<map_t>();
}
return *storage;
}

void reset() {
storage.reset();
overrides = std::nullopt;
}

// enable override system lazily
auto &override_map() {
if (!overrides) {
overrides.emplace();
}
return *overrides;
}

void set_override(std::string_view key, getter_fn fn) {
override_map().insert_or_assign(std::string(key), std::move(fn));
}

void clear_overrides() { overrides = std::nullopt; }

// does not include overrides
std::vector<std::string> keys() const {
if (!storage) {
return {};
}

std::vector<std::string> out;
out.reserve(storage->size());

for (auto const &kv : *storage) {
out.push_back(kv.first);
}

return out;
}

T &get(std::string_view key) {

// global override first
if (overrides) {
if (auto it = overrides->find(key); it != overrides->end()) {
return it->second();
}
}

auto &map = data();

if (auto it = map.find(key); it != map.end()) {
return it->second;
}

throw shambase::make_except_with_loc<std::out_of_range>(
"registry entry not found: " + std::string(key));
}

void insert(std::string_view key, T &&value) {
auto [it, inserted] = data().emplace(std::string(key), std::move(value));

if (!inserted) {
throw shambase::make_except_with_loc<std::runtime_error>(
"registry entry already exists: " + std::string(key));
}
}
};

// ------------------------------------------------------------------
// Registry types
// ------------------------------------------------------------------

using module_factory_t = std::function<py::module()>;

// ------------------------------------------------------------------
// Global registries
// ------------------------------------------------------------------

registry_t<py::module> &modules();

registry_t<module_factory_t> &builders();

// ------------------------------------------------------------------
// Global build call
// ------------------------------------------------------------------

inline void build_all_modules() {
auto &module_map = modules().data();

// Get snapshot of builder keys
std::vector<std::string> keys = builders().keys();

// Lexicographic order
std::sort(keys.begin(), keys.end());

for (auto const &key : keys) {
auto &builder = builders().get(key);
module_map.emplace(key, builder());
}
}

} // namespace shambindings::submodules

/// function signature used to register python modules
using fct_sig = std::function<void(py::module &)>;

Expand Down Expand Up @@ -69,3 +198,15 @@ void register_pybind_init_func(fct_sig);
#define ON_PYTHON_INIT \
_internal_register_pybind_init( \
__shamrock_unique_name(pybind_), __shamrock_unique_name(pybind_class_obj_), root_module)


#define Register_pymodsubmodule_int(path, funcname, lambda_name) \
py::module funcname(); \
shambase::call_lambda lambda_name([]() { \
shambindings::submodules::builders().insert(path, funcname); \
}); \
py::module funcname()

#define Register_pymodsubmodule(path) \
Register_pymodsubmodule_int( \
path, __shamrock_unique_name(pybind_class_obj_), __shamrock_unique_name(pybind_))
23 changes: 23 additions & 0 deletions src/shambindings/src/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ void register_py_to_sham_print(py::module &m) {
});
}

namespace shambindings::submodules {

registry_t<py::module> &modules() {
static auto _reg = registry_t<py::module>{};
return _reg;
}

registry_t<module_factory_t> &builders() {
static auto _reg = registry_t<module_factory_t>{};
return _reg;
}

} // namespace shambindings::submodules

namespace shambindings {

enum { None = 0, Lib = 1, Embed = 2 } init_state = None;
Expand All @@ -130,12 +144,21 @@ namespace shambindings {
&py_func_printer_normal, &py_func_printer_ln, &py_func_flush_func);
}

submodules::modules().set_override("shamrock", [&]() -> py::module & {
return m;
});

shambindings::submodules::build_all_modules();

if (static_init_shamrock_pybind) {
for (auto fct : *static_init_shamrock_pybind) {
fct(m);
}
}

submodules::builders().reset();
submodules::modules().reset();

if (is_lib_mode) {
init_state = Lib;
} else {
Expand Down
6 changes: 6 additions & 0 deletions src/shampylib/src/pyShamphys.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
#include <complex>
#include <utility>

Register_pymodsubmodule("shamrock.phys") {
return shambindings::submodules::modules()
.get("shamrock")
.def_submodule("phys", "Physics Library");
}

ON_PYTHON_INIT {

py::module shamphys_module = root_module.def_submodule("phys", "Physics Library");
Expand Down
13 changes: 5 additions & 8 deletions src/tests/shambase/unique_name_macro_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@

#include "shambase/unique_name_macro.hpp"

namespace {
int __shamrock_unique_name(test_var) = 0;
int __shamrock_unique_name(test_var) = 0;
int __shamrock_unique_name(test_var) = 0;
int __shamrock_unique_name(test_var) = 0;
int __shamrock_unique_name(test_var) = 0;
int __shamrock_unique_name(test_var) = 0;
} // namespace
static int __shamrock_unique_name(test_var) = 0;
static int __shamrock_unique_name(test_var) = 0;

static void __shamrock_unique_name(test_func)(){};
static void __shamrock_unique_name(test_func)(){};
19 changes: 19 additions & 0 deletions src/tests/shambase/unique_name_macro_test2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#include "shambase/unique_name_macro.hpp"

static int __shamrock_unique_name(test_var) = 0;
static int __shamrock_unique_name(test_var) = 0;

static void __shamrock_unique_name(test_func)(){};
static void __shamrock_unique_name(test_func)(){};

// This file duplicates the content of unique_name_macro_test.cpp to test that using the same macro
// as the same spot does not provide linker errors
Loading