Skip to content

Commit 1fef16a

Browse files
Change PyObject sort to be pickled form instead of reference
1 parent 75ed2d7 commit 1fef16a

22 files changed

Lines changed: 296 additions & 377 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,4 @@ inlined
8585
visualizer.tgz
8686
package
8787
.mypy_cache/
88+
*.json

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ lalrpop-util = { version = "0.22", features = ["lexer"] }
2929
ordered-float = "5"
3030
uuid = { version = "1.18", features = ["v4"] }
3131
rayon = "1.11"
32+
base64 = "0.22.1"
3233

3334
# Use patched version of egglog in experimental
3435
[patch.'https://github.com/egraphs-good/egglog']

docs/reference/python-integration.md

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,43 +96,55 @@ We define a custom "primitive sort" (i.e. a builtin type) for `PyObject`s. This
9696

9797
### Saving Python Objects
9898

99-
To create an expression of type `PyObject`, we call the call the constructor with any Python object. It will
100-
save a reference to the object:
99+
To create an expression of type `PyObject`, call the constructor with any Python object. The value is immediately
100+
serialized with `cloudpickle.dumps`, and the serialized bytes (base64 encoded when printed) are what get stored
101+
inside the e-graph. This means the e-graph keeps a snapshot of the object rather than a live reference.
101102

102103
```{code-cell} python
103-
PyObject(1)
104-
```
105-
106-
We see that this as saved internally as a pointer to the Python object. For hashable objects like `int` we store two integers, a hash of the type and a has of the value.
104+
from dataclasses import dataclass
107105
108-
We can also store unhashable objects in the e-graph like lists.
106+
@dataclass
107+
class MyObject:
108+
a: int = 10
109109
110-
```{code-cell} python
111-
lst = PyObject([1, 2, 3])
112-
lst
110+
PyObject(MyObject())
113111
```
114112

115-
We see that this is stored with one number, simply the `id` of the object.
113+
The new serialization approach works for both hashable and unhashable Python values, and no longer depends on
114+
their `id()`. Subsequent inserts of equal values round-trip through `cloudpickle` so the e-graph can identify and
115+
merge them by value.
116116

117-
```{admonition} Mutable Objects
118-
:class: warning
117+
```{admonition} Serialization requirements
118+
:class: note
119119
120-
While it is possible to store unhashable objects in the e-graph, you have to be careful defining any rules which create new unhashable objects. If each time a rule is run, it creates a new object, then the e-graph will never saturate.
121-
122-
Creating hashable objects is safer, since while the rule might create new Python objects each time it executes, they should have the same hash, i.e. be equal, so that the e-graph can saturate.
120+
`PyObject` relies on `cloudpickle`. Any object you store must be serializable by `cloudpickle.dumps`; objects such
121+
as open file handles, generators, or extension types that `cloudpickle` cannot handle will raise an error when you
122+
try to construct a `PyObject`.
123123
```
124124

125125
### Retrieving Python Objects
126126

127-
Like other primitives, we can retrieve the Python object from the e-graph by using the `.value` property:
127+
Like other primitives, we can retrieve a Python object by using the `.value` property. Deserialization happens on
128+
every access, so you receive a fresh copy each time rather than the original object.
128129

129130
```{code-cell} python
130-
assert lst.value == [1, 2, 3]
131+
original = {"count": 1}
132+
expr = PyObject(original)
133+
134+
restored = expr.value
135+
assert restored == original
136+
assert restored is not original
137+
138+
# Mutating the copy does not affect the stored value.
139+
restored["count"] = 2
140+
assert expr.value == {"count": 1}
131141
```
132142

133143
### Builtin methods
134144

135-
Currently, we only support a few methods on `PyObject`s, but we plan to add more in the future.
145+
Currently, we only support a few methods on `PyObject`s, but we plan to add more in the future. Each builtin
146+
deserializes its inputs, performs the operation in Python, and then serializes the result back into a new
147+
`PyObject`, so previously stored values remain unchanged.
136148

137149
Conversion to/from a string:
138150

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ array = [
3737
"numba>=0.59.1",
3838
"llvmlite>=0.42.0",
3939
"numpy>2",
40+
"cloudpickle>=3",
4041
]
4142
dev = [
4243
"ruff",

python/egglog/bindings.pyi

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ __all__ = [
5454
"PrintOverallStatistics",
5555
"PrintSize",
5656
"Push",
57-
"PyObjectSort",
5857
"Relation",
5958
"Repeat",
6059
"Rewrite",
@@ -105,17 +104,10 @@ class SerializedEGraph:
105104
def map_ops(self, map: dict[str, str]) -> None: ...
106105
def split_classes(self, egraph: EGraph, ops: set[str]) -> None: ...
107106

108-
@final
109-
class PyObjectSort:
110-
def __init__(self) -> None: ...
111-
def store(self, __o: object, /) -> _Expr: ...
112-
def load(self, __e: _Expr, /) -> object: ...
113-
114107
@final
115108
class EGraph:
116109
def __init__(
117110
self,
118-
py_object_sort: PyObjectSort | None = None,
119111
/,
120112
*,
121113
fact_directory: str | Path | None = None,
@@ -142,7 +134,7 @@ class EGraph:
142134
def value_to_rational(self, v: Value) -> Fraction: ...
143135
def value_to_bigint(self, v: Value) -> int: ...
144136
def value_to_bigrat(self, v: Value) -> Fraction: ...
145-
def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
137+
def value_to_pyobject(self, v: Value) -> object: ...
146138
def value_to_map(self, v: Value) -> dict[Value, Value]: ...
147139
def value_to_multiset(self, v: Value) -> list[Value]: ...
148140
def value_to_vec(self, v: Value) -> list[Value]: ...

python/egglog/builtins.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from types import FunctionType, MethodType
1414
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, cast, overload
1515

16+
import cloudpickle
1617
from typing_extensions import TypeVarTuple, Unpack, deprecated
1718

1819
from .conversion import convert, converter, get_type_args, resolve_literal
@@ -984,12 +985,15 @@ def value(self) -> object:
984985
expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
985986
if not isinstance(expr, PyObjectDecl):
986987
raise ExprValueError(self, "PyObject(x)")
987-
return expr.value
988+
return cloudpickle.loads(expr.pickled)
988989

989990
__match_args__ = ("value",)
990991

991992
def __init__(self, value: object) -> None: ...
992993

994+
@method(egg_fn="py-call")
995+
def __call__(self, *args: object) -> PyObject: ...
996+
993997
@method(egg_fn="py-from-string")
994998
@classmethod
995999
def from_string(cls, s: StringLike) -> PyObject: ...
@@ -1023,24 +1027,15 @@ class PyObjectFunction(Protocol):
10231027
def __call__(self, *__args: PyObject) -> PyObject: ...
10241028

10251029

1030+
@deprecated("use PyObject(fn) directly")
10261031
def py_eval_fn(fn: Callable) -> PyObjectFunction:
10271032
"""
10281033
Takes a python callable and maps it to a callable which takes and returns PyObjects.
10291034
10301035
It translates it to a call which uses `py_eval` to call the function, passing in the
10311036
args as locals, and using the globals from function.
10321037
"""
1033-
1034-
def inner(*__args: PyObject, __fn: Callable = fn) -> PyObject:
1035-
new_kvs: list[object] = []
1036-
eval_str = "__fn("
1037-
for i, arg in enumerate(__args):
1038-
new_kvs.extend((f"__arg_{i}", arg))
1039-
eval_str += f"__arg_{i}, "
1040-
eval_str += ")"
1041-
return py_eval(eval_str, PyObject({"__fn": __fn}).dict_update(*new_kvs), getattr(__fn, "__globals__", {}))
1042-
1043-
return inner
1038+
return PyObject(fn)
10441039

10451040

10461041
@function(builtin=True, egg_fn="py-exec")

python/egglog/declarations.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -617,23 +617,7 @@ class LetRefDecl:
617617

618618
@dataclass(frozen=True)
619619
class PyObjectDecl:
620-
value: object
621-
622-
def __hash__(self) -> int:
623-
"""Tries using the hash of the value, if unhashable use the ID."""
624-
try:
625-
return hash((type(self.value), self.value))
626-
except TypeError:
627-
return id(self.value)
628-
629-
def __eq__(self, other: object) -> bool:
630-
if not isinstance(other, PyObjectDecl):
631-
return False
632-
return self.parts == other.parts
633-
634-
@property
635-
def parts(self) -> tuple[type, object]:
636-
return (type(self.value), self.value)
620+
pickled: bytes
637621

638622

639623
LitType: TypeAlias = int | str | float | bool | None

python/egglog/deconstruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from functools import partial
99
from typing import TYPE_CHECKING, TypeVar, overload
1010

11+
import cloudpickle
1112
from typing_extensions import TypeVarTuple, Unpack
1213

1314
from .declarations import *
@@ -64,7 +65,7 @@ def get_literal_value(x: object) -> object:
6465
case LitDecl(v):
6566
return v
6667
case PyObjectDecl(obj):
67-
return obj
68+
return cloudpickle.loads(obj)
6869
case PartialCallDecl(call):
6970
fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
7071
if not args:

python/egglog/egraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ class EGraph:
839839
_token_stack: list[EGraph] = field(default_factory=list, repr=False)
840840

841841
def __post_init__(self, seminaive: bool, save_egglog_string: bool) -> None:
842-
egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
842+
egraph = bindings.EGraph(seminaive=seminaive, record=save_egglog_string)
843843
self._state = EGraphState(egraph)
844844

845845
def _add_decls(self, *decls: DeclerationsLike) -> None:

0 commit comments

Comments
 (0)