Skip to content

Commit cdb2307

Browse files
committed
Fix flag-based callback detection in replay path
The previous commit added flag-based detection in process_request but missed the replay path in nif_resume_callback_dirty. This caused SuspensionRequired exceptions to still leak when ASGI/WSGI middleware caught and re-raised exceptions during replay. Changes: - Update nif_resume_callback_dirty to check tl_pending_callback flag first, for both CALL and EVAL replay paths - Remove unused is_suspension_exception() and get_suspension_args() - Add test_callback_with_try_except test with py_test_middleware.py module that tests try/except, BaseException, nested try/except, try/finally, and multi-layer middleware patterns
1 parent 4384fe1 commit cdb2307

4 files changed

Lines changed: 182 additions & 81 deletions

File tree

c_src/py_callback.c

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -87,46 +87,6 @@
8787
* Suspended state management
8888
* ============================================================================ */
8989

90-
/**
91-
* Check if a SuspensionRequired exception is pending.
92-
* Returns true if the exception is set and matches SuspensionRequiredException.
93-
*/
94-
static bool is_suspension_exception(void) {
95-
if (!PyErr_Occurred()) {
96-
return false;
97-
}
98-
return PyErr_ExceptionMatches(SuspensionRequiredException);
99-
}
100-
101-
/**
102-
* Extract suspension info from the pending SuspensionRequired exception.
103-
* Returns the exception args tuple (callback_id, func_name, args) or NULL.
104-
* Clears the exception if successful.
105-
*/
106-
static PyObject *get_suspension_args(void) {
107-
PyObject *exc_type, *exc_value, *exc_tb;
108-
PyErr_Fetch(&exc_type, &exc_value, &exc_tb);
109-
110-
if (exc_value == NULL) {
111-
Py_XDECREF(exc_type);
112-
Py_XDECREF(exc_tb);
113-
return NULL;
114-
}
115-
116-
/* Get the args from the exception - it's a tuple (callback_id, func_name, args) */
117-
PyObject *args = PyObject_GetAttrString(exc_value, "args");
118-
Py_DECREF(exc_type);
119-
Py_DECREF(exc_value);
120-
Py_XDECREF(exc_tb);
121-
122-
if (args == NULL || !PyTuple_Check(args) || PyTuple_Size(args) != 3) {
123-
Py_XDECREF(args);
124-
return NULL;
125-
}
126-
127-
return args; /* Caller owns this reference */
128-
}
129-
13090
/**
13191
* Create a suspended state resource from exception args.
13292
* Args tuple format: (callback_id, func_name, args)
@@ -1193,23 +1153,46 @@ static ERL_NIF_TERM nif_resume_callback_dirty(ErlNifEnv *env, int argc, const ER
11931153
Py_XDECREF(kwargs);
11941154

11951155
if (py_result == NULL) {
1196-
if (is_suspension_exception()) {
1156+
if (tl_pending_callback) {
11971157
/*
1198-
* Another suspension during replay - Python made a second erlang.call().
1199-
* Create a new suspended state and return {suspended, ...} so Erlang
1200-
* can handle this callback and resume again.
1158+
* Flag-based callback detection during replay.
1159+
* Check flag FIRST, not exception type - this works even if
1160+
* Python code caught and re-raised the exception.
12011161
*/
1202-
PyObject *exc_args = get_suspension_args(); /* Clears exception */
1162+
PyErr_Clear(); /* Clear whatever exception is set */
1163+
1164+
/* Build exc_args tuple from thread-local storage */
1165+
PyObject *exc_args = PyTuple_New(3);
12031166
if (exc_args == NULL) {
1204-
result = make_error(env, "get_suspension_args_failed");
1167+
tl_pending_callback = false;
1168+
result = make_error(env, "alloc_exc_args_failed");
12051169
} else {
1206-
suspended_state_t *new_suspended = create_suspended_state_from_existing(env, exc_args, state);
1207-
if (new_suspended == NULL) {
1170+
PyObject *callback_id_obj = PyLong_FromUnsignedLongLong(tl_pending_callback_id);
1171+
PyObject *func_name_obj = PyUnicode_FromStringAndSize(
1172+
tl_pending_func_name, tl_pending_func_name_len);
1173+
1174+
if (callback_id_obj == NULL || func_name_obj == NULL) {
1175+
Py_XDECREF(callback_id_obj);
1176+
Py_XDECREF(func_name_obj);
12081177
Py_DECREF(exc_args);
1209-
result = make_error(env, "create_nested_suspended_state_failed");
1178+
tl_pending_callback = false;
1179+
result = make_error(env, "build_exc_args_failed");
12101180
} else {
1211-
result = make_suspended_term(env, new_suspended, exc_args);
1212-
Py_DECREF(exc_args);
1181+
PyTuple_SET_ITEM(exc_args, 0, callback_id_obj);
1182+
PyTuple_SET_ITEM(exc_args, 1, func_name_obj);
1183+
Py_INCREF(tl_pending_args);
1184+
PyTuple_SET_ITEM(exc_args, 2, tl_pending_args);
1185+
1186+
suspended_state_t *new_suspended = create_suspended_state_from_existing(env, exc_args, state);
1187+
if (new_suspended == NULL) {
1188+
Py_DECREF(exc_args);
1189+
tl_pending_callback = false;
1190+
result = make_error(env, "create_nested_suspended_state_failed");
1191+
} else {
1192+
result = make_suspended_term(env, new_suspended, exc_args);
1193+
Py_DECREF(exc_args);
1194+
tl_pending_callback = false;
1195+
}
12131196
}
12141197
}
12151198
} else {
@@ -1258,23 +1241,46 @@ static ERL_NIF_TERM nif_resume_callback_dirty(ErlNifEnv *env, int argc, const ER
12581241
Py_DECREF(compiled);
12591242

12601243
if (py_result == NULL) {
1261-
if (is_suspension_exception()) {
1244+
if (tl_pending_callback) {
12621245
/*
1263-
* Another suspension during replay - Python made a second erlang.call().
1264-
* Create a new suspended state and return {suspended, ...} so Erlang
1265-
* can handle this callback and resume again.
1246+
* Flag-based callback detection during eval replay.
1247+
* Check flag FIRST, not exception type - this works even if
1248+
* Python code caught and re-raised the exception.
12661249
*/
1267-
PyObject *exc_args = get_suspension_args(); /* Clears exception */
1250+
PyErr_Clear(); /* Clear whatever exception is set */
1251+
1252+
/* Build exc_args tuple from thread-local storage */
1253+
PyObject *exc_args = PyTuple_New(3);
12681254
if (exc_args == NULL) {
1269-
result = make_error(env, "get_suspension_args_failed");
1255+
tl_pending_callback = false;
1256+
result = make_error(env, "alloc_exc_args_failed");
12701257
} else {
1271-
suspended_state_t *new_suspended = create_suspended_state_from_existing(env, exc_args, state);
1272-
if (new_suspended == NULL) {
1258+
PyObject *callback_id_obj = PyLong_FromUnsignedLongLong(tl_pending_callback_id);
1259+
PyObject *func_name_obj = PyUnicode_FromStringAndSize(
1260+
tl_pending_func_name, tl_pending_func_name_len);
1261+
1262+
if (callback_id_obj == NULL || func_name_obj == NULL) {
1263+
Py_XDECREF(callback_id_obj);
1264+
Py_XDECREF(func_name_obj);
12731265
Py_DECREF(exc_args);
1274-
result = make_error(env, "create_nested_suspended_state_failed");
1266+
tl_pending_callback = false;
1267+
result = make_error(env, "build_exc_args_failed");
12751268
} else {
1276-
result = make_suspended_term(env, new_suspended, exc_args);
1277-
Py_DECREF(exc_args);
1269+
PyTuple_SET_ITEM(exc_args, 0, callback_id_obj);
1270+
PyTuple_SET_ITEM(exc_args, 1, func_name_obj);
1271+
Py_INCREF(tl_pending_args);
1272+
PyTuple_SET_ITEM(exc_args, 2, tl_pending_args);
1273+
1274+
suspended_state_t *new_suspended = create_suspended_state_from_existing(env, exc_args, state);
1275+
if (new_suspended == NULL) {
1276+
Py_DECREF(exc_args);
1277+
tl_pending_callback = false;
1278+
result = make_error(env, "create_nested_suspended_state_failed");
1279+
} else {
1280+
result = make_suspended_term(env, new_suspended, exc_args);
1281+
Py_DECREF(exc_args);
1282+
tl_pending_callback = false;
1283+
}
12781284
}
12791285
}
12801286
} else {

c_src/py_nif.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,23 +1107,6 @@ static PyObject *erlang_module_getattr(PyObject *module, PyObject *name);
11071107
*/
11081108
static void *async_event_loop_thread(void *arg);
11091109

1110-
/**
1111-
* @brief Check if SuspensionRequired exception is pending
1112-
*
1113-
* @return true if exception matches SuspensionRequiredException
1114-
*/
1115-
static bool is_suspension_exception(void);
1116-
1117-
/**
1118-
* @brief Extract args from pending SuspensionRequired exception
1119-
*
1120-
* Gets (callback_id, func_name, args) tuple from exception and
1121-
* clears the exception state.
1122-
*
1123-
* @return New reference to args tuple, or NULL on error
1124-
*/
1125-
static PyObject *get_suspension_args(void);
1126-
11271110
/**
11281111
* @brief Create suspended state for callback handling
11291112
*

test/py_reentrant_SUITE.erl

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
test_concurrent_reentrant/1,
2222
test_callback_with_complex_types/1,
2323
test_multiple_sequential_callbacks/1,
24-
test_call_from_non_worker_thread/1
24+
test_call_from_non_worker_thread/1,
25+
test_callback_with_try_except/1
2526
]).
2627

2728
all() ->
@@ -32,7 +33,8 @@ all() ->
3233
test_concurrent_reentrant,
3334
test_callback_with_complex_types,
3435
test_multiple_sequential_callbacks,
35-
test_call_from_non_worker_thread
36+
test_call_from_non_worker_thread,
37+
test_callback_with_try_except
3638
].
3739

3840
init_per_suite(Config) ->
@@ -259,3 +261,50 @@ test_call_from_non_worker_thread(_Config) ->
259261
%% Cleanup
260262
py:unregister_function(simple_add),
261263
ok.
264+
265+
%% @doc Test that erlang.call() works even when wrapped in try/except blocks.
266+
%% This simulates ASGI/WSGI middleware that catches all exceptions.
267+
%% The flag-based detection should work even when the SuspensionRequired
268+
%% exception is caught and re-raised by Python code.
269+
%%
270+
%% Uses py:call on a test module with try/except blocks to properly test
271+
%% the suspension mechanism through middleware-like exception handling.
272+
test_callback_with_try_except(_Config) ->
273+
%% Register a simple Erlang function
274+
py:register_function(get_value, fun([Key]) ->
275+
case Key of
276+
<<"a">> -> 1;
277+
<<"b">> -> 2;
278+
<<"c">> -> 3;
279+
_ -> 0
280+
end
281+
end),
282+
283+
%% Add test directory to Python path so we can import the test module
284+
TestDir = code:lib_dir(erlang_python, test),
285+
ok = py:exec(iolist_to_binary(io_lib:format(
286+
"import sys; sys.path.insert(0, '~s')", [TestDir]))),
287+
288+
%% Test 1: Try/except that catches Exception and re-raises
289+
{ok, Result1} = py:call(py_test_middleware, call_with_try_except, [<<"a">>]),
290+
1 = Result1,
291+
292+
%% Test 2: Try/except that catches BaseException (catches everything)
293+
{ok, Result2} = py:call(py_test_middleware, call_with_base_exception, [<<"b">>]),
294+
2 = Result2,
295+
296+
%% Test 3: Nested try/except blocks
297+
{ok, Result3} = py:call(py_test_middleware, call_with_nested_try, [<<"c">>]),
298+
3 = Result3,
299+
300+
%% Test 4: Try/finally pattern
301+
{ok, Result4} = py:call(py_test_middleware, call_with_finally, [<<"a">>]),
302+
1 = Result4,
303+
304+
%% Test 5: Multiple middleware layers
305+
{ok, Result5} = py:call(py_test_middleware, call_through_layers, [<<"b">>]),
306+
2 = Result5,
307+
308+
%% Cleanup
309+
py:unregister_function(get_value),
310+
ok.

test/py_test_middleware.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Test module for middleware-style exception handling.
2+
3+
This module tests that erlang.call() works correctly even when wrapped
4+
in try/except blocks that catch and re-raise exceptions, simulating
5+
ASGI/WSGI middleware behavior.
6+
"""
7+
8+
import erlang
9+
10+
11+
def call_with_try_except(key):
12+
"""Wraps erlang.call in try/except that catches and re-raises."""
13+
try:
14+
return erlang.call('get_value', key)
15+
except Exception as e:
16+
# This simulates middleware logging
17+
raise
18+
19+
20+
def call_with_base_exception(key):
21+
"""Wraps erlang.call in try/except that catches BaseException."""
22+
try:
23+
return erlang.call('get_value', key)
24+
except BaseException as e:
25+
# Catches everything including SuspensionRequired
26+
raise
27+
28+
29+
def call_with_nested_try(key):
30+
"""Multiple nested try/except blocks."""
31+
try:
32+
try:
33+
return erlang.call('get_value', key)
34+
except Exception:
35+
raise
36+
except Exception:
37+
raise
38+
39+
40+
def call_with_finally(key):
41+
"""Try/finally pattern common in cleanup code."""
42+
cleanup_ran = False
43+
try:
44+
return erlang.call('get_value', key)
45+
finally:
46+
cleanup_ran = True
47+
48+
49+
def call_through_layers(key):
50+
"""Simulates calling through multiple middleware layers."""
51+
def inner():
52+
try:
53+
return erlang.call('get_value', key)
54+
except Exception:
55+
raise
56+
57+
def outer():
58+
try:
59+
return inner()
60+
except Exception:
61+
raise
62+
63+
return outer()

0 commit comments

Comments
 (0)