55
66from __future__ import annotations
77
8+ from fractions import Fraction
89from functools import partial , reduce
9- from types import FunctionType
10+ from types import FunctionType , MethodType
1011from typing import TYPE_CHECKING , Generic , Protocol , TypeAlias , TypeVar , Union , cast , overload
1112
1213from typing_extensions import TypeVarTuple , Unpack
1314
15+ from . import bindings
1416from .conversion import convert , converter , get_type_args
15- from .egraph import BaseExpr , BuiltinExpr , Unit , function , get_current_ruleset , method
17+ from .declarations import *
18+ from .egraph import BaseExpr , BuiltinExpr , EGraph , expr_fact , function , get_current_ruleset , method
19+ from .egraph_state import GLOBAL_PY_OBJECT_SORT
1620from .functionalize import functionalize
1721from .runtime import RuntimeClass , RuntimeExpr , RuntimeFunction
1822from .thunk import Thunk
1923
2024if TYPE_CHECKING :
21- from collections .abc import Callable
25+ from collections .abc import Callable , Iterator
2226
2327
2428__all__ = [
3236 "SetLike" ,
3337 "String" ,
3438 "StringLike" ,
39+ "Unit" ,
3540 "UnstableFn" ,
3641 "Vec" ,
3742 "VecLike" ,
4651]
4752
4853
54+ class Unit (BuiltinExpr , egg_sort = "Unit" ):
55+ """
56+ The unit type. This is used to reprsent if a value exists in the e-graph or not.
57+ """
58+
59+ def __init__ (self ) -> None : ...
60+
61+ @method (preserve = True )
62+ def __bool__ (self ) -> bool :
63+ return bool (expr_fact (self ))
64+
65+
4966class String (BuiltinExpr ):
67+ @method (preserve = True )
68+ def eval (self ) -> str :
69+ value = _extract_lit (self )
70+ assert isinstance (value , bindings .String )
71+ return value .value
72+
5073 def __init__ (self , value : str ) -> None : ...
5174
5275 @method (egg_fn = "replace" )
@@ -62,10 +85,20 @@ def join(*strings: StringLike) -> String: ...
6285
6386converter (str , String , String )
6487
65- BoolLike = Union ["Bool" , bool ]
88+ BoolLike : TypeAlias = Union ["Bool" , bool ]
6689
6790
6891class Bool (BuiltinExpr , egg_sort = "bool" ):
92+ @method (preserve = True )
93+ def eval (self ) -> bool :
94+ value = _extract_lit (self )
95+ assert isinstance (value , bindings .Bool )
96+ return value .value
97+
98+ @method (preserve = True )
99+ def __bool__ (self ) -> bool :
100+ return self .eval ()
101+
69102 def __init__ (self , value : bool ) -> None : ...
70103
71104 @method (egg_fn = "not" )
@@ -91,6 +124,20 @@ def implies(self, other: BoolLike) -> Bool: ...
91124
92125
93126class i64 (BuiltinExpr ): # noqa: N801
127+ @method (preserve = True )
128+ def eval (self ) -> int :
129+ value = _extract_lit (self )
130+ assert isinstance (value , bindings .Int )
131+ return value .value
132+
133+ @method (preserve = True )
134+ def __index__ (self ) -> int :
135+ return self .eval ()
136+
137+ @method (preserve = True )
138+ def __int__ (self ) -> int :
139+ return self .eval ()
140+
94141 def __init__ (self , value : int ) -> None : ...
95142
96143 @method (egg_fn = "+" )
@@ -193,6 +240,20 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ...
193240
194241
195242class f64 (BuiltinExpr ): # noqa: N801
243+ @method (preserve = True )
244+ def eval (self ) -> float :
245+ value = _extract_lit (self )
246+ assert isinstance (value , bindings .Float )
247+ return value .value
248+
249+ @method (preserve = True )
250+ def __float__ (self ) -> float :
251+ return self .eval ()
252+
253+ @method (preserve = True )
254+ def __int__ (self ) -> int :
255+ return int (self .eval ())
256+
196257 def __init__ (self , value : float ) -> None : ...
197258
198259 @method (egg_fn = "neg" )
@@ -265,6 +326,33 @@ def to_string(self) -> String: ...
265326
266327
267328class Map (BuiltinExpr , Generic [T , V ]):
329+ @method (preserve = True )
330+ def eval (self ) -> dict [T , V ]:
331+ call = _extract_call (self )
332+ expr = cast (RuntimeExpr , self )
333+ d = {}
334+ while call .callable != ClassMethodRef ("Map" , "empty" ):
335+ assert call .callable == MethodRef ("Map" , "insert" )
336+ call_typed , k_typed , v_typed = call .args
337+ assert isinstance (call_typed .expr , CallDecl )
338+ k = cast (T , expr .__with_expr__ (k_typed ))
339+ v = cast (V , expr .__with_expr__ (v_typed ))
340+ d [k ] = v
341+ call = call_typed .expr
342+ return d
343+
344+ @method (preserve = True )
345+ def __iter__ (self ) -> Iterator [T ]:
346+ return iter (self .eval ())
347+
348+ @method (preserve = True )
349+ def __len__ (self ) -> int :
350+ return len (self .eval ())
351+
352+ @method (preserve = True )
353+ def __contains__ (self , key : T ) -> bool :
354+ return key in self .eval ()
355+
268356 @method (egg_fn = "map-empty" )
269357 @classmethod
270358 def empty (cls ) -> Map [T , V ]: ...
@@ -305,6 +393,24 @@ def rebuild(self) -> Map[T, V]: ...
305393
306394
307395class Set (BuiltinExpr , Generic [T ]):
396+ @method (preserve = True )
397+ def eval (self ) -> set [T ]:
398+ call = _extract_call (self )
399+ assert call .callable == InitRef ("Set" )
400+ return {cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args }
401+
402+ @method (preserve = True )
403+ def __iter__ (self ) -> Iterator [T ]:
404+ return iter (self .eval ())
405+
406+ @method (preserve = True )
407+ def __len__ (self ) -> int :
408+ return len (self .eval ())
409+
410+ @method (preserve = True )
411+ def __contains__ (self , key : T ) -> bool :
412+ return key in self .eval ()
413+
308414 @method (egg_fn = "set-of" )
309415 def __init__ (self , * args : T ) -> None : ...
310416
@@ -349,6 +455,28 @@ def rebuild(self) -> Set[T]: ...
349455
350456
351457class Rational (BuiltinExpr ):
458+ @method (preserve = True )
459+ def eval (self ) -> Fraction :
460+ call = _extract_call (self )
461+ assert call .callable == InitRef ("Rational" )
462+
463+ def _to_int (e : TypedExprDecl ) -> int :
464+ expr = e .expr
465+ assert isinstance (expr , LitDecl )
466+ assert isinstance (expr .value , int )
467+ return expr .value
468+
469+ num , den = call .args
470+ return Fraction (_to_int (num ), _to_int (den ))
471+
472+ @method (preserve = True )
473+ def __float__ (self ) -> float :
474+ return float (self .eval ())
475+
476+ @method (preserve = True )
477+ def __int__ (self ) -> int :
478+ return int (self .eval ())
479+
352480 @method (egg_fn = "rational" )
353481 def __init__ (self , num : i64Like , den : i64Like ) -> None : ...
354482
@@ -410,6 +538,26 @@ def denom(self) -> i64: ...
410538
411539
412540class Vec (BuiltinExpr , Generic [T ]):
541+ @method (preserve = True )
542+ def eval (self ) -> tuple [T , ...]:
543+ call = _extract_call (self )
544+ if call .callable == ClassMethodRef ("Vec" , "empty" ):
545+ return ()
546+ assert call .callable == InitRef ("Vec" )
547+ return tuple (cast (T , cast (RuntimeExpr , self ).__with_expr__ (x )) for x in call .args )
548+
549+ @method (preserve = True )
550+ def __iter__ (self ) -> Iterator [T ]:
551+ return iter (self .eval ())
552+
553+ @method (preserve = True )
554+ def __len__ (self ) -> int :
555+ return len (self .eval ())
556+
557+ @method (preserve = True )
558+ def __contains__ (self , key : T ) -> bool :
559+ return key in self .eval ()
560+
413561 @method (egg_fn = "vec-of" )
414562 def __init__ (self , * args : T ) -> None : ...
415563
@@ -461,6 +609,13 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
461609
462610
463611class PyObject (BuiltinExpr ):
612+ @method (preserve = True )
613+ def eval (self ) -> object :
614+ report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , self ), 0 )
615+ assert isinstance (report , bindings .Best )
616+ expr = report .termdag .term_to_expr (report .term , bindings .PanicSpan ())
617+ return GLOBAL_PY_OBJECT_SORT .load (expr )
618+
464619 def __init__ (self , value : object ) -> None : ...
465620
466621 @method (egg_fn = "py-from-string" )
@@ -554,6 +709,8 @@ def __init__(self, f, *partial) -> None: ...
554709 def __call__ (self , * args : Unpack [TS ]) -> T : ...
555710
556711
712+ # Method Type is for builtins like __getitem__
713+ converter (MethodType , UnstableFn , lambda m : UnstableFn (m .__func__ , m .__self__ ))
557714converter (RuntimeFunction , UnstableFn , UnstableFn )
558715converter (partial , UnstableFn , lambda p : UnstableFn (p .func , * p .args ))
559716
@@ -590,3 +747,24 @@ def value_to_annotation(a: object) -> type | None:
590747
591748
592749converter (FunctionType , UnstableFn , _convert_function )
750+
751+
752+ def _extract_lit (e : BaseExpr ) -> bindings ._Literal :
753+ """
754+ Special case extracting literals to make this faster by using termdag directly.
755+ """
756+ report = (EGraph .current or EGraph ())._run_extract (cast (RuntimeExpr , e ), 0 )
757+ assert isinstance (report , bindings .Best )
758+ term = report .term
759+ assert isinstance (term , bindings .TermLit )
760+ return term .value
761+
762+
763+ def _extract_call (e : BaseExpr ) -> CallDecl :
764+ """
765+ Extracts the call form of an expression
766+ """
767+ extracted = cast (RuntimeExpr , (EGraph .current or EGraph ()).extract (e ))
768+ expr = extracted .__egg_typed_expr__ .expr
769+ assert isinstance (expr , CallDecl )
770+ return expr
0 commit comments