Skip to content

Commit 809b918

Browse files
committed
Add py:reload/1 to reload Python modules across all workers
Adds broadcast/1 function to py_pool for sending requests to all workers simultaneously. The reload/1 function uses importlib.reload() to refresh modules from disk, useful during development.
1 parent c0495e9 commit 809b918

3 files changed

Lines changed: 114 additions & 3 deletions

File tree

src/py.erl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@
8888
state_incr/1,
8989
state_incr/2,
9090
state_decr/1,
91-
state_decr/2
91+
state_decr/2,
92+
%% Module reload
93+
reload/1
9294
]).
9395

9496
-type py_result() :: {ok, term()} | {error, term()}.
@@ -574,3 +576,45 @@ state_decr(Key) ->
574576
-spec state_decr(term(), integer()) -> integer().
575577
state_decr(Key, Amount) ->
576578
py_state:decr(Key, Amount).
579+
580+
%%% ============================================================================
581+
%%% Module Reload
582+
%%% ============================================================================
583+
584+
%% @doc Reload a Python module across all workers.
585+
%% This uses importlib.reload() to refresh the module from disk.
586+
%% Useful during development when Python code changes.
587+
%%
588+
%% Note: This only affects already-imported modules. If the module
589+
%% hasn't been imported in a worker yet, the reload is a no-op for that worker.
590+
%%
591+
%% Example:
592+
%% ```
593+
%% %% After modifying mymodule.py on disk:
594+
%% ok = py:reload(mymodule).
595+
%% '''
596+
%%
597+
%% Returns ok if reload succeeded in all workers, or {error, Reasons}
598+
%% if any workers failed.
599+
-spec reload(py_module()) -> ok | {error, [{worker, term()}]}.
600+
reload(Module) ->
601+
ModuleBin = ensure_binary(Module),
602+
%% Build Python code that:
603+
%% 1. Checks if module is loaded in sys.modules
604+
%% 2. If yes, reloads it with importlib.reload()
605+
%% 3. Returns the module name or None if not loaded
606+
Code = <<"__import__('importlib').reload(__import__('sys').modules['",
607+
ModuleBin/binary,
608+
"']) if '", ModuleBin/binary, "' in __import__('sys').modules else None">>,
609+
%% Broadcast to all workers
610+
Request = {eval, undefined, undefined, Code, #{}},
611+
Results = py_pool:broadcast(Request),
612+
%% Check if any failed
613+
Errors = lists:filtermap(fun
614+
({ok, _}) -> false;
615+
({error, Reason}) -> {true, Reason}
616+
end, Results),
617+
case Errors of
618+
[] -> ok;
619+
_ -> {error, [{worker, E} || E <- Errors]}
620+
end.

src/py_pool.erl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
-export([
2525
start_link/1,
2626
request/1,
27+
broadcast/1,
2728
get_stats/0
2829
]).
2930

@@ -55,6 +56,12 @@ start_link(NumWorkers) ->
5556
request(Request) ->
5657
gen_server:cast(?MODULE, {request, Request}).
5758

59+
%% @doc Broadcast a request to all workers.
60+
%% Returns a list of results from each worker.
61+
-spec broadcast(term()) -> [{ok, term()} | {error, term()}].
62+
broadcast(Request) ->
63+
gen_server:call(?MODULE, {broadcast, Request}, infinity).
64+
5865
%% @doc Get pool statistics.
5966
-spec get_stats() -> map().
6067
get_stats() ->
@@ -94,6 +101,12 @@ handle_call(get_stats, _From, State) ->
94101
},
95102
{reply, Stats, State};
96103

104+
handle_call({broadcast, Request}, _From, State) ->
105+
%% Send request to all workers and collect results
106+
Workers = queue:to_list(State#state.workers),
107+
Results = broadcast_to_workers(Workers, Request),
108+
{reply, Results, State};
109+
97110
handle_call(_Request, _From, State) ->
98111
{reply, {error, unknown_request}, State}.
99112

@@ -150,3 +163,33 @@ extract_ref_caller({call, Ref, Caller, _, _, _, _}) -> {Ref, Caller, call};
150163
extract_ref_caller({eval, Ref, Caller, _, _}) -> {Ref, Caller, eval};
151164
extract_ref_caller({exec, Ref, Caller, _}) -> {Ref, Caller, exec};
152165
extract_ref_caller({stream, Ref, Caller, _, _, _, _}) -> {Ref, Caller, stream}.
166+
167+
%% @private
168+
%% Send a request to all workers and collect results
169+
broadcast_to_workers(Workers, RequestTemplate) ->
170+
Self = self(),
171+
%% Send requests to all workers in parallel
172+
Refs = lists:map(fun(Worker) ->
173+
Ref = make_ref(),
174+
Request = inject_ref_caller(RequestTemplate, Ref, Self),
175+
Worker ! {py_request, Request},
176+
Ref
177+
end, Workers),
178+
%% Collect all responses
179+
lists:map(fun(Ref) ->
180+
receive
181+
{py_response, Ref, Result} -> Result;
182+
{py_error, Ref, Error} -> {error, Error}
183+
after 30000 ->
184+
{error, timeout}
185+
end
186+
end, Refs).
187+
188+
%% @private
189+
%% Inject a reference and caller into a request template
190+
inject_ref_caller({exec, _Ref, _Caller, Code}, NewRef, NewCaller) ->
191+
{exec, NewRef, NewCaller, Code};
192+
inject_ref_caller({eval, _Ref, _Caller, Code, Locals}, NewRef, NewCaller) ->
193+
{eval, NewRef, NewCaller, Code, Locals};
194+
inject_ref_caller({eval, _Ref, _Caller, Code, Locals, Timeout}, NewRef, NewCaller) ->
195+
{eval, NewRef, NewCaller, Code, Locals, Timeout}.

test/py_SUITE.erl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
test_semaphore_timeout/1,
4747
test_semaphore_rate_limiting/1,
4848
test_overload_protection/1,
49-
test_shared_state/1
49+
test_shared_state/1,
50+
test_reload/1
5051
]).
5152

5253
all() ->
@@ -87,7 +88,8 @@ all() ->
8788
test_semaphore_timeout,
8889
test_semaphore_rate_limiting,
8990
test_overload_protection,
90-
test_shared_state
91+
test_shared_state,
92+
test_reload
9193
].
9294

9395
init_per_suite(Config) ->
@@ -891,3 +893,25 @@ assert val == 2, f'Expected 2, got {val}'
891893
ok = py:state_clear(),
892894

893895
ok.
896+
897+
%% Test module reload across all workers
898+
test_reload(_Config) ->
899+
%% First, ensure json module is imported in at least one worker
900+
{ok, _} = py:call(json, dumps, [[1, 2, 3]]),
901+
902+
%% Now reload it - should succeed across all workers
903+
ok = py:reload(json),
904+
905+
%% Verify the module still works after reload
906+
{ok, <<"[1, 2, 3]">>} = py:call(json, dumps, [[1, 2, 3]]),
907+
908+
%% Test reload of a module that might not be loaded (should not error)
909+
ok = py:reload(collections),
910+
911+
%% Test reload with binary module name
912+
ok = py:reload(<<"os">>),
913+
914+
%% Test reload with string module name
915+
ok = py:reload("sys"),
916+
917+
ok.

0 commit comments

Comments
 (0)