Skip to content

Commit 75ed2d7

Browse files
Bubble up Python exceptions
1 parent 12a8975 commit 75ed2d7

3 files changed

Lines changed: 83 additions & 42 deletions

File tree

python/tests/test_high_level.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,3 +1494,17 @@ def __str__(self) -> str:
14941494
assert isinstance(A.__add__, RuntimeFunction)
14951495
assert A.__str__(A()) == "hi"
14961496
assert A.__str__.__doc__ == "Hi"
1497+
1498+
1499+
def test_py_object_raise_exception():
1500+
"""
1501+
Verify that PyObject can raise exceptions properly
1502+
"""
1503+
msg = "bad"
1504+
1505+
def raises(_val):
1506+
raise ValueError(msg)
1507+
1508+
egraph = EGraph()
1509+
with pytest.raises(ValueError, match=msg):
1510+
egraph.extract(py_eval_fn(raises)(PyObject(None)))

src/egraph.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::error::{EggResult, WrappedError};
55
use crate::py_object_sort::{PyObjectIdent, PyObjectSort};
66
use crate::serialize::SerializedEGraph;
77

8-
use egglog::prelude::{add_base_sort, RustSpan, Span};
8+
use egglog::prelude::{RustSpan, Span, add_base_sort};
99
use egglog::{SerializeConfig, span};
1010
use log::info;
1111
use num_bigint::BigInt;
@@ -84,6 +84,9 @@ impl EGraph {
8484
{
8585
cmds.push_str(&cmds_str);
8686
}
87+
if let Some(err) = PyErr::take(py) {
88+
return Err(WrappedError::Py(err));
89+
}
8790
res.map(|xs| xs.iter().map(|o| o.into()).collect())
8891
}
8992

@@ -134,12 +137,18 @@ impl EGraph {
134137
.map(Value)
135138
}
136139

137-
fn eval_expr(&mut self, expr: Expr) -> EggResult<(String, Value)> {
140+
fn eval_expr(&mut self, py: Python<'_>, expr: Expr) -> EggResult<(String, Value)> {
138141
let expr: egglog::ast::Expr = expr.into();
139-
self.egraph
140-
.eval_expr(&expr)
141-
.map(|(s, v)| (s.name().to_string(), Value(v)))
142-
.map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}")))
142+
let res = py.detach(|| {
143+
self.egraph
144+
.eval_expr(&expr)
145+
.map(|(s, v)| (s.name().to_string(), Value(v)))
146+
.map_err(|e| WrappedError::Egglog(e, format!("\nWhen evaluating expr: {expr}")))
147+
});
148+
if let Some(err) = PyErr::take(py) {
149+
return Err(WrappedError::Py(err));
150+
}
151+
res
143152
}
144153

145154
fn value_to_i64(&self, v: Value) -> i64 {

src/py_object_sort.rs

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,18 @@ impl BaseSort for PyObjectSort {
193193
add_primitive!(
194194
eg,
195195
"py-eval" = {self.clone(): PyObjectSort}
196-
|code: S, globals: PyObjectIdent, locals: PyObjectIdent| -> PyObjectIdent {
196+
|code: S, globals: PyObjectIdent, locals: PyObjectIdent| -?> PyObjectIdent {
197197
{
198-
Python::attach(|py| {
198+
attach(|py| {
199199
let globals = self.ctx.load(py, globals);
200200
let locals = self.ctx.load(py, locals);
201201
let res = py
202202
.eval(
203203
CString::new(code.to_string()).unwrap().as_c_str(),
204-
Some(globals.downcast::<PyDict>().unwrap()),
205-
Some(locals.downcast::<PyDict>().unwrap()),
206-
)
207-
.unwrap();
208-
self.ctx.store(py, res.unbind()).unwrap()
204+
Some(globals.downcast::<PyDict>()?),
205+
Some(locals.downcast::<PyDict>()?),
206+
)?;
207+
self.ctx.store(py, res.unbind())
209208
})
210209
}
211210
}
@@ -215,13 +214,13 @@ impl BaseSort for PyObjectSort {
215214
add_primitive!(
216215
eg,
217216
"py-exec" = {self.clone(): PyObjectSort}
218-
|code: S, globals: PyObjectIdent, locals: PyObjectIdent| -> PyObjectIdent {
219-
Python::attach(|py| {
217+
|code: S, globals: PyObjectIdent, locals: PyObjectIdent| -?> PyObjectIdent {
218+
attach(|py| {
220219
let globals = self.ctx.load(py, globals);
221220
let locals = self.ctx.load(py, locals);
222221

223222
// Copy the locals so we can mutate them and return them
224-
let locals = locals.downcast::<PyDict>().unwrap().copy().unwrap();
223+
let locals = locals.downcast::<PyDict>()?.copy()?;
225224
// Copy code into temporary file
226225
// Keep it around so that if errors occur we can debug them after the program exits
227226
let mut path = temp_dir();
@@ -233,75 +232,74 @@ impl BaseSort for PyObjectSort {
233232
run_path(
234233
py,
235234
CString::new(code.into_inner()).unwrap().as_c_str(),
236-
Some(globals.downcast::<PyDict>().unwrap()),
235+
Some(globals.downcast::<PyDict>()?),
237236
Some(&locals),
238237
CString::new(path).unwrap().as_c_str(),
239-
)
240-
.unwrap();
241-
self.ctx.store(py, locals.unbind().into()).unwrap()
238+
)?;
239+
self.ctx.store(py, locals.unbind().into())
242240
})
243241
}
244242
);
245243

246244
// (py-dict [<key-object> <value-object>]*)
247-
add_primitive!(eg, "py-dict" = {self.clone(): PyObjectSort} [xs: PyObjectIdent] -> PyObjectIdent {
248-
Python::attach(|py| {
245+
add_primitive!(eg, "py-dict" = {self.clone(): PyObjectSort} [xs: PyObjectIdent] -?> PyObjectIdent {
246+
attach(|py| {
249247
let dict = PyDict::new(py);
250248
for i in xs.map(|x| self.ctx.load(py, x)).collect::<Vec<_>>().chunks_exact(2) {
251-
dict.set_item(i[0].clone(), i[1].clone()).unwrap();
249+
dict.set_item(i[0].clone(), i[1].clone())?;
252250
}
253-
self.ctx.store(py, dict.unbind().into()).unwrap()
251+
self.ctx.store(py, dict.unbind().into())
254252
})
255253
});
256254
// Supports calling (py-dict-update <dict-obj> [<key-object> <value-obj>]*)
257-
add_primitive!(eg, "py-dict-update" = {self.clone(): PyObjectSort} [xs: PyObjectIdent] -> PyObjectIdent {{
258-
Python::attach(|py| {
255+
add_primitive!(eg, "py-dict-update" = {self.clone(): PyObjectSort} [xs: PyObjectIdent] -?> PyObjectIdent {{
256+
attach(|py| {
259257
let xs = xs.map(|x| self.ctx.load(py, x)).collect::<Vec<_>>();
260258
// Copy the dict so we can mutate it and return it
261-
let dict = xs[0].downcast::<PyDict>().unwrap().copy().unwrap();
259+
let dict = xs[0].downcast::<PyDict>()?.copy()?;
262260
// Update the dict with the key-value pairs
263261
for i in xs[1..].chunks_exact(2) {
264-
dict.set_item(i[0].clone(), i[1].clone()).unwrap();
262+
dict.set_item(i[0].clone(), i[1].clone())?;
265263
}
266-
self.ctx.store(py, dict.unbind().into()).unwrap()
264+
self.ctx.store(py, dict.unbind().into())
267265
})
268266
}});
269267
// (py-to-string <obj>)
270268
add_primitive!(
271269
eg,
272-
"py-to-string" = {self.clone(): PyObjectSort} |x: PyObjectIdent| -> S {
270+
"py-to-string" = {self.clone(): PyObjectSort} |x: PyObjectIdent| -?> S {
273271
{
274-
let s: String = Python::attach(move |py| self.ctx.load(py, x).extract().unwrap());
275-
s.into()
272+
let s: String = attach(move |py| self.ctx.load(py, x).extract())?;
273+
Some(s.into())
276274
}
277275
}
278276
);
279277
// (py-to-bool <obj>)
280278
add_primitive!(
281279
eg,
282-
"py-to-bool" = {self.clone(): PyObjectSort} |x: PyObjectIdent| -> bool {
280+
"py-to-bool" = {self.clone(): PyObjectSort} |x: PyObjectIdent| -?> bool {
283281
{
284-
Python::attach(move |py| self.ctx.load(py, x).extract().unwrap())
282+
attach(move |py| self.ctx.load(py, x).extract())
285283
}
286284
}
287285
);
288286
// (py-from-string <str>)
289287
add_primitive!(
290288
eg,
291-
"py-from-string" = {self.clone(): PyObjectSort} |x: S| -> PyObjectIdent {
292-
Python::attach(|py| {
293-
let obj = x.to_string().into_pyobject(py).unwrap();
294-
self.ctx.store(py, obj.unbind().into()).unwrap()
289+
"py-from-string" = {self.clone(): PyObjectSort} |x: S| -?> PyObjectIdent {
290+
attach(|py| {
291+
let obj = x.to_string().into_pyobject(py)?;
292+
self.ctx.store(py, obj.unbind().into())
295293
})
296294
}
297295
);
298296
// (py-from-int <int>)
299297
add_primitive!(
300298
eg,
301-
"py-from-int" = {self.clone(): PyObjectSort} |x: i64| -> PyObjectIdent {
302-
Python::attach(|py| {
303-
let obj = x.into_pyobject(py).unwrap();
304-
self.ctx.store(py, obj.unbind().into()).unwrap()
299+
"py-from-int" = {self.clone(): PyObjectSort} |x: i64| -?> PyObjectIdent {
300+
attach(|py| {
301+
let obj = x.into_pyobject(py)?;
302+
self.ctx.store(py, obj.unbind().into())
305303
})
306304
}
307305
);
@@ -319,6 +317,26 @@ impl BaseSort for PyObjectSort {
319317
}
320318
}
321319

320+
/// Attaches to the Python interpreter and runs the given closure.
321+
///
322+
/// Also handles errors, by saving them on the interpreter and returning None.
323+
fn attach<F, R>(f: F) -> Option<R>
324+
where
325+
F: for<'py> FnOnce(Python<'py>) -> PyResult<R>,
326+
{
327+
Python::attach(|py| {
328+
if PyErr::occurred(py) {
329+
return None
330+
};
331+
match f(py) {
332+
Ok(val) => Some(val),
333+
Err(err) => {
334+
err.restore(py);
335+
None
336+
}
337+
}
338+
})
339+
}
322340
/// Runs the code in the given context with a certain path.
323341
/// Copied from `run`, but allows specifying the path.
324342
/// https://github.com/PyO3/pyo3/blob/55d379cff8e4157024ffe22215715bd04a5fb1a1/src/marker.rs#L667-L682

0 commit comments

Comments
 (0)