Skip to content

Commit 89cf95f

Browse files
authored
bug fixes (#17)
1 parent eee7022 commit 89cf95f

9 files changed

Lines changed: 255 additions & 90 deletions

File tree

examples/add/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using namespace std;
55
int main() {
66
pyscheduler::PyManager manager;
77
pyscheduler::PyManager::InvokeHandler add =
8-
manager.getPythonModule("examples.add.python_modules.add", "invoke");
8+
manager.loadPythonModule("examples.add.python_modules.add", "invoke");
99
std::cout << add.invoke<int64_t>(3000, -1234) << std::endl;
1010
return 0;
1111
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
import torch.utils.dlpack
3+
4+
5+
def invoke(a, b):
6+
a_tensor = torch.utils.dlpack.from_dlpack(a)
7+
b_tensor = torch.utils.dlpack.from_dlpack(b)
8+
c_tensor = a_tensor * b_tensor
9+
return torch.utils.dlpack.to_dlpack(c_tensor)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifdef __INTELLISENSE__
2+
# include "pyscheduler/move_only.hpp"
3+
#endif
4+
5+
#include <memory>
6+
7+
namespace pyscheduler {
8+
9+
template <typename R, typename... Args>
10+
template <typename F>
11+
MoveOnlyFunction<R(Args...)>::MoveOnlyFunction(F&& f)
12+
: _obj(new F(std::move(f)), [](void* p) { delete static_cast<F*>(p); })
13+
, _invoke([](void* p, Args... args) -> R {
14+
return (*static_cast<F*>(p))(std::forward<Args>(args)...);
15+
}) { }
16+
17+
template <typename R, typename... Args>
18+
R MoveOnlyFunction<R(Args...)>::operator()(Args... args) {
19+
return _invoke(_obj.get(), std::forward<Args>(args)...);
20+
}
21+
22+
template <typename R, typename... Args>
23+
MoveOnlyFunction<R(Args...)>::operator bool() const {
24+
return (bool)_obj;
25+
}
26+
} // namespace pyscheduler

include/pyscheduler/details/pyscheduler_impl.hpp

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
#include "pyscheduler/pyscheduler.hpp"
1+
#ifdef __INTELLISENSE__
2+
# include "pyscheduler/pyscheduler.hpp"
3+
#endif
4+
5+
#include "pyscheduler/move_only.hpp"
6+
#include <cassert>
27
#include <chrono>
38

49
namespace pyscheduler {
@@ -10,27 +15,90 @@ PyManager::InvokeHandler::InvokeHandler(size_t id, std::unique_ptr<PyManager> ma
1015
: _id(id)
1116
, _manager(std::move(manager)) { }
1217

13-
const std::shared_ptr<std::pair<pybind11::module_, pybind11::object>>
18+
const std::shared_ptr<std::pair<pybind11::module_, pybind11::object>>&
1419
PyManager::InvokeHandler::getModuleAndFunc() {
1520
// need to lock py_mutex because we don't want a vector resize
1621
// to happen during lookup
1722

23+
// no need to acquire gil because not incrementing python reference count
24+
1825
// should allow multiple reads concurrently which don't mutate state
1926
PyManager::SharedState& state = _manager->shared();
2027
std::shared_lock lock(state.py_mutex);
2128
return _manager->shared().py_modules.at(_id);
2229
}
2330

31+
template <typename ReturnType, typename... Args>
32+
ReturnType PyManager::InvokeHandler::invoke(Args&&... args) {
33+
auto mod_and_func = getModuleAndFunc();
34+
pybind11::gil_scoped_acquire gil;
35+
pybind11::object result = mod_and_func->second(std::forward<Args>(args)...);
36+
return result.cast<ReturnType>();
37+
}
38+
39+
template <typename Callback, typename... Args>
40+
auto PyManager::InvokeHandler::invoke(Callback&& callback, Args&&... args)
41+
-> std::invoke_result_t<Callback, pybind11::object> {
42+
auto mod_and_func = getModuleAndFunc();
43+
pybind11::gil_scoped_acquire gil;
44+
pybind11::object result = mod_and_func->second(std::forward<Args>(args)...);
45+
return callback(result);
46+
}
47+
template <typename Callback, typename... Args>
48+
auto PyManager::InvokeHandler::queue_invoke(Callback&& callback, Args&&... args)
49+
-> std::future<std::invoke_result_t<Callback, pybind11::object>> {
50+
// Need to wrap a promise inside a shared_ptr because Promises are not
51+
// copy constructable (requirement enforced by appending to task queue)
52+
//
53+
// solution was to wrap a promise inside a shared pointer, which is
54+
// copy constructable
55+
using ReturnType = std::invoke_result_t<Callback, pybind11::object>;
56+
using PromisePtr = std::shared_ptr<std::promise<ReturnType>>;
57+
58+
auto args_tuple = std::make_tuple(std::forward<Args>(args)...);
59+
PromisePtr promise_ptr = std::make_shared<std::promise<ReturnType>>();
60+
std::future<ReturnType> future = promise_ptr->get_future();
61+
62+
// Dear reader, I'm sorry
63+
// this section creates a closure that executes a python method with
64+
// the provided arguments
65+
//
66+
// the return result from the python function is processed using the
67+
// callback function, and the value from that is stored into the
68+
// promise.
69+
70+
auto mod_and_func = getModuleAndFunc();
71+
auto method = [this,
72+
callback = std::move(callback),
73+
mod_and_func = std::move(mod_and_func),
74+
args_tuple = std::move(args_tuple),
75+
promise_ptr]() mutable {
76+
pybind11::gil_scoped_acquire gil;
77+
pybind11::object result = std::apply(
78+
[&mod_and_func](auto&&... unpackedArgs) {
79+
return mod_and_func->second(std::forward<decltype(unpackedArgs)>(unpackedArgs)...);
80+
},
81+
args_tuple);
82+
83+
promise_ptr->set_value(callback(result));
84+
};
85+
PyManager::shared().task_queue.enqueue(std::move(method));
86+
return future;
87+
}
88+
2489
///////////////////////////////////////////////////////////////////////////////
2590
// Impl PyManager
2691
///////////////////////////////////////////////////////////////////////////////
2792

93+
PyManager::SharedState PyManager::_instance;
94+
2895
PyManager::PyManager() {
96+
// this lock should be dropped at return so that postcondition (Python Interpreter Initialized)
97+
// is guaranteed
2998
std::unique_lock lock(shared().py_mutex);
3099
if(shared().arc.fetch_add(1) == 0) {
31100
shared().main_worker = std::thread(&PyManager::mainLoop, this);
32101
}
33-
34102
// small cost paid to block until interpreter is initalized
35103
while(!shared().interpreter_initialized)
36104
continue;
@@ -49,8 +117,8 @@ PyManager::~PyManager() {
49117
}
50118
}
51119

52-
PyManager::InvokeHandler PyManager::getPythonModule(const std::string& module_name,
53-
const std::string& entry_point) {
120+
PyManager::InvokeHandler PyManager::loadPythonModule(const std::string& module_name,
121+
const std::string& entry_point) {
54122

55123
SharedState& state = shared();
56124

@@ -67,17 +135,17 @@ PyManager::InvokeHandler PyManager::getPythonModule(const std::string& module_na
67135

68136
pybind11::gil_scoped_acquire gil;
69137

70-
pybind11::module_ mod = pybind11::module_::import(module_name.c_str());
71-
72-
if(!mod) {
73-
PyErr_Print();
138+
pybind11::module_ mod;
139+
try {
140+
mod = pybind11::module_::import(module_name.c_str());
141+
} catch(pybind11::error_already_set& e) {
74142
throw std::invalid_argument("Could not import module: " + module_name);
75143
}
76144

77-
pybind11::object func = mod.attr(entry_point.c_str());
78-
79-
if(!func) {
80-
PyErr_Print();
145+
pybind11::object func;
146+
try {
147+
func = mod.attr(entry_point.c_str());
148+
} catch(pybind11::error_already_set& e) {
81149
throw std::invalid_argument("Could not find the '" + entry_point +
82150
"' method in module " + module_name);
83151
}
@@ -89,7 +157,6 @@ PyManager::InvokeHandler PyManager::getPythonModule(const std::string& module_na
89157
std::make_pair(mod, func)));
90158
state.py_invoke_handler_map[module_name] = id;
91159
lock.unlock();
92-
93160
return PyManager::InvokeHandler(id, std::make_unique<PyManager>());
94161
}
95162

@@ -130,8 +197,9 @@ void PyManager::mainLoop() {
130197
std::vector<std::thread> sub_workers;
131198
for(size_t i = 0; i < NUM_WORKERS; i++) {
132199
sub_workers.emplace_back(std::thread([i]() {
133-
while(shared().threads_active) {
134-
std::function<void()> task;
200+
// worker should only end if stop signal is set and queue is empty
201+
while(shared().threads_active || shared().task_queue.size_approx() > 0) {
202+
MoveOnlyFunction<void()> task;
135203

136204
// have a small timeout so threads can wake up and check if they
137205
// need to exit.
@@ -145,7 +213,6 @@ void PyManager::mainLoop() {
145213
continue;
146214
}
147215

148-
// std::cout << shared().task_queue.size_approx() << std::endl;
149216
task();
150217
}
151218
}));
@@ -167,10 +234,9 @@ void PyManager::mainLoop() {
167234
// we clear and drop all items in queue to safely free
168235
// memory.
169236
while(shared().task_queue.size_approx() > 0) {
170-
std::function<void()> black_box;
237+
MoveOnlyFunction<void()> black_box;
171238
shared().task_queue.try_dequeue(black_box);
172239
}
173-
174240
} // end python interpreter
175241
}
176242
} // namespace pyscheduler

include/pyscheduler/move_only.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include <memory>
4+
5+
namespace pyscheduler {
6+
7+
template <typename Signature>
8+
class MoveOnlyFunction;
9+
10+
template <typename R, typename... Args>
11+
class MoveOnlyFunction<R(Args...)> {
12+
public:
13+
MoveOnlyFunction() = default;
14+
15+
template <typename F>
16+
MoveOnlyFunction(F&& f);
17+
18+
MoveOnlyFunction(MoveOnlyFunction&&) noexcept = default;
19+
MoveOnlyFunction& operator=(MoveOnlyFunction&&) noexcept = default;
20+
21+
MoveOnlyFunction(const MoveOnlyFunction&) = delete;
22+
MoveOnlyFunction& operator=(const MoveOnlyFunction&) = delete;
23+
24+
R operator()(Args... args);
25+
explicit operator bool() const;
26+
27+
private:
28+
using InvokeFn = R (*)(void*, Args&&...);
29+
using DestroyFn = void (*)(void*);
30+
31+
std::unique_ptr<void, DestroyFn> _obj{ nullptr, nullptr };
32+
InvokeFn _invoke{ nullptr };
33+
};
34+
35+
} // namespace pyscheduler
36+
37+
#include "pyscheduler/details/move_only_impl.hpp"

0 commit comments

Comments
 (0)