Skip to content

Commit 3a9b8d9

Browse files
Upgrade support to new egglog version
1 parent 83e73c3 commit 3a9b8d9

14 files changed

Lines changed: 157 additions & 104 deletions

Cargo.lock

Lines changed: 21 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/egglog/bindings.pyi

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ __all__ = [
3131
"Fact",
3232
"Fail",
3333
"Float",
34+
"FrozenEGraph",
35+
"FrozenFunction",
36+
"FrozenRow",
3437
"Function",
3538
"FunctionCommand",
3639
"FusedIntersect",
@@ -55,6 +58,9 @@ __all__ = [
5558
"PrintFunctionSize",
5659
"PrintOverallStatistics",
5760
"PrintSize",
61+
"Prove",
62+
"ProveExists",
63+
"ProveExistsOutput",
5864
"Push",
5965
"Relation",
6066
"Repeat",
@@ -269,6 +275,7 @@ class TermApp:
269275
def __new__(cls, name: str, args: list[int]) -> TermApp: ...
270276

271277
_Term: TypeAlias = TermLit | TermVar | TermApp
278+
_TermId: TypeAlias = int
272279

273280
##
274281
# Facts
@@ -531,15 +538,20 @@ class PrintAllFunctionsSize:
531538
@final
532539
class ExtractVariants:
533540
termdag: TermDag
534-
terms: list[_Term]
535-
def __new__(cls, termdag: TermDag, terms: list[_Term]) -> ExtractVariants: ...
541+
terms: list[_TermId]
542+
def __new__(cls, termdag: TermDag, terms: list[_TermId]) -> ExtractVariants: ...
536543

537544
@final
538545
class ExtractBest:
539546
termdag: TermDag
540547
cost: int
541-
term: _Term
542-
def __new__(cls, termdag: TermDag, cost: int, term: _Term) -> ExtractBest: ...
548+
term: _TermId
549+
def __new__(cls, termdag: TermDag, cost: int, term: _TermId) -> ExtractBest: ...
550+
551+
@final
552+
class ProveExistsOutput:
553+
proof: str
554+
def __new__(cls, proof: str) -> ProveExistsOutput: ...
543555

544556
@final
545557
class OverallStatistics:
@@ -555,10 +567,10 @@ class RunScheduleOutput:
555567
class PrintFunctionOutput:
556568
function: Function
557569
termdag: TermDag
558-
terms: list[tuple[_Term, _Term]]
570+
terms: list[tuple[_TermId, _TermId]]
559571
mode: _PrintFunctionMode
560572
def __new__(
561-
cls, function: Function, termdag: TermDag, terms: list[tuple[_Term, _Term]], mode: _PrintFunctionMode
573+
cls, function: Function, termdag: TermDag, terms: list[tuple[_TermId, _TermId]], mode: _PrintFunctionMode
562574
) -> PrintFunctionOutput: ...
563575

564576
@final
@@ -571,6 +583,7 @@ _CommandOutput: TypeAlias = (
571583
| PrintAllFunctionsSize
572584
| ExtractVariants
573585
| ExtractBest
586+
| ProveExistsOutput
574587
| OverallStatistics
575588
| RunScheduleOutput
576589
| PrintFunctionOutput
@@ -718,6 +731,18 @@ class Check:
718731
facts: list[_Fact]
719732
def __new__(cls, span: _Span, facts: list[_Fact]) -> Check: ...
720733

734+
@final
735+
class Prove:
736+
span: _Span
737+
facts: list[_Fact]
738+
def __new__(cls, span: _Span, facts: list[_Fact]) -> Prove: ...
739+
740+
@final
741+
class ProveExists:
742+
span: _Span
743+
expr: str
744+
def __new__(cls, span: _Span, expr: str) -> ProveExists: ...
745+
721746
@final
722747
class PrintFunction:
723748
span: _Span
@@ -823,6 +848,8 @@ _Command: TypeAlias = (
823848
| RunSchedule
824849
| Extract
825850
| Check
851+
| Prove
852+
| ProveExists
826853
| PrintFunction
827854
| PrintSize
828855
| Output
@@ -845,14 +872,14 @@ _Command: TypeAlias = (
845872
@final
846873
class TermDag:
847874
def size(self) -> int: ...
848-
def lookup(self, node: _Term) -> int: ...
849-
def get(self, id: int) -> _Term: ...
850-
def app(self, sym: str, children: list[int]) -> _Term: ...
851-
def lit(self, lit: _Literal) -> _Term: ...
852-
def var(self, sym: str) -> _Term: ...
853-
def expr_to_term(self, expr: _Expr) -> _Term: ...
854-
def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
855-
def to_string(self, term: _Term) -> str: ...
875+
def lookup(self, node: _Term) -> _TermId: ...
876+
def get(self, id: _TermId) -> _Term: ...
877+
def app(self, sym: str, children: list[_TermId]) -> _TermId: ...
878+
def lit(self, lit: _Literal) -> _TermId: ...
879+
def var(self, sym: str) -> _TermId: ...
880+
def expr_to_term(self, expr: _Expr) -> _TermId: ...
881+
def term_to_expr(self, term: _TermId, span: _Span) -> _Expr: ...
882+
def to_string(self, term: _TermId) -> str: ...
856883

857884
##
858885
# Extraction
@@ -882,10 +909,10 @@ class Extractor(Generic[_COST]):
882909
def __new__(
883910
cls, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST, Any]
884911
) -> Extractor[_COST]: ...
885-
def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
912+
def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _TermId]: ...
886913
def extract_variants(
887914
self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
888-
) -> list[tuple[_COST, _Term]]: ...
915+
) -> list[tuple[_COST, _TermId]]: ...
889916

890917
##
891918
# Frozen

python/egglog/builtins.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def multiset_not_contains_swapped(x: T, xs: MultiSet[T]) -> Unit: ...
656656
def multiset_contains_swapped(x: T, xs: MultiSet[T]) -> Unit: ...
657657

658658

659-
@function(egg_fn="unstable-multiset-fold", builtin=True)
659+
@function(egg_fn="unstable-multiset-reduce", builtin=True)
660660
def multiset_fold(f: Callable[[T, T], T], initial: T, xs: MultiSet[T]) -> T: ...
661661

662662

@@ -875,7 +875,7 @@ def bool_le(self, other: BigIntLike) -> Bool: ...
875875
def bool_ge(self, other: BigIntLike) -> Bool: ...
876876

877877

878-
converter(i64, BigInt, lambda i: BigInt(i))
878+
converter(i64, BigInt, BigInt)
879879

880880
BigIntLike: TypeAlias = BigInt | i64Like
881881

@@ -1113,6 +1113,7 @@ def __call__(self, *args: *TS) -> T: ...
11131113

11141114
# Method Type is for builtins like __getitem__
11151115
converter(MethodType, UnstableFn, lambda m: UnstableFn[*get_type_args()](m.__func__, m.__self__)) # type: ignore[operator, misc]
1116+
# Ignore PLW0108.
11161117
converter(RuntimeFunction, UnstableFn, lambda rf: UnstableFn[*get_type_args()](rf)) # type: ignore[operator, misc]
11171118
# converter(RuntimeClass, UnstableFn, lambda rc: UnstableFn[*get_type_args()](rc)) # type: ignore[operator, misc]
11181119
converter(partial, UnstableFn, lambda p: UnstableFn[*get_type_args()](p.func, *p.args)) # type: ignore[operator, misc]

python/egglog/egraph.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@ class EGraph:
868868
# For storing the global "current" egraph
869869
_token_stack: list[EGraph] = field(default_factory=list, repr=False)
870870

871+
# TODO: Add EGraph(...commands) constructor
871872
def __post_init__(self, seminaive: bool, save_egglog_string: bool) -> None:
872873
egraph = bindings.EGraph(seminaive=seminaive, record=save_egglog_string)
873874
self._state = EGraphState(egraph)
@@ -1037,7 +1038,7 @@ def extract(
10371038
res = self._from_termdag(termdag, term, tp)
10381039
return (res, cost) if include_cost else res
10391040

1040-
def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
1041+
def _from_termdag(self, termdag: bindings.TermDag, term: int, tp: JustTypeRef) -> Any:
10411042
(new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
10421043
return RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)
10431044

@@ -1179,13 +1180,14 @@ def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> N
11791180
serialized = self._serialize(**kwargs)
11801181
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()
11811182

1182-
def saturate(
1183+
def saturate( # noqa: C901
11831184
self,
11841185
schedule: Schedule | None = None,
11851186
*,
11861187
expr: Expr | None = None,
11871188
max: int = 1000,
11881189
visualize: bool = True,
1190+
print_frozen: bool = False,
11891191
**kwargs: Unpack[GraphvizKwargs],
11901192
) -> None:
11911193
"""
@@ -1210,9 +1212,14 @@ def to_json() -> str:
12101212
i += 1
12111213
if visualize:
12121214
egraphs.append(to_json())
1215+
if print_frozen:
1216+
print(f"After iteration {i}:")
1217+
print(self.freeze())
12131218
except:
12141219
if visualize:
12151220
egraphs.append(to_json())
1221+
if print_frozen:
1222+
print(self.freeze())
12161223
raise
12171224
finally:
12181225
if visualize:

python/egglog/egraph_state.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,14 +631,12 @@ def translate_call(self, expr: CallDecl | GetCostDecl) -> tuple[str, list[TypedE
631631
case _:
632632
assert_never(expr)
633633

634-
def exprs_from_egg(
635-
self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
636-
) -> Iterable[TypedExprDecl]:
634+
def exprs_from_egg(self, termdag: bindings.TermDag, terms: list[int], tp: JustTypeRef) -> Iterable[TypedExprDecl]:
637635
"""
638636
Create a function that can convert from an egg term to a typed expr.
639637
"""
640638
state = FromEggState(self, termdag)
641-
return [state.from_expr(tp, term) for term in terms]
639+
return [state.resolve_term(term_id, tp) for term_id in terms]
642640

643641
def _get_possible_types(self, cls_ident: Ident) -> frozenset[JustTypeRef]:
644642
"""

python/egglog/exp/array_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __bool__(self) -> bool:
6363
not _CURRENT_EGRAPH
6464
and (
6565
args := get_callable_args(self, Int.__eq__)
66-
or get_callable_args(self, Boolean.__eq__)
67-
or get_callable_args(self, Value.__eq__)
66+
or get_callable_args(self, Boolean.__eq__) # type: ignore[arg-type]
67+
or get_callable_args(self, Value.__eq__) # type: ignore[arg-type]
6868
)
6969
is not None
7070
):

python/egglog/type_constraint_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def bind_class(self, ref: JustTypeRef, decls: Declarations) -> None:
4747
try:
4848
cls_typevars = decls.get_class_decl(ref.ident).type_vars
4949
except KeyError:
50-
cls_typevars = []
50+
cls_typevars = ()
5151
for typevar, arg in zip(cls_typevars, ref.args, strict=True):
5252
self.infer_typevars(typevar, arg)
5353

0 commit comments

Comments
 (0)