Skip to content

Commit c282b59

Browse files
Try fixing type resolution again
1 parent a4f58a2 commit c282b59

9 files changed

Lines changed: 101 additions & 101 deletions

File tree

python/egglog/conversion.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
from .pretty import *
1212
from .runtime import *
1313
from .thunk import *
14-
from .type_constraint_solver import TypeConstraintError
1514

1615
if TYPE_CHECKING:
1716
from collections.abc import Generator
1817

1918
from .egraph import BaseExpr
20-
from .type_constraint_solver import TypeConstraintSolver
2119

2220
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
2321
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
@@ -220,41 +218,17 @@ def resolve_literal(
220218
tp: TypeOrVarRef,
221219
arg: object,
222220
decls: Callable[[], Declarations] = retrieve_conversion_decls,
223-
tcs: TypeConstraintSolver | None = None,
224221
) -> RuntimeExpr:
225222
"""
226223
Try to convert an object to a type, raising a ConvertError if it is not possible.
227224
228-
If the type has vars in it, they will be tried to be resolved into concrete vars based on the type constraint solver.
229-
230225
If it cannot be resolved, we assume that the value passed in will resolve it.
231226
"""
232-
arg_type = resolve_type(arg)
233-
234-
# If we have any type variables, don't bother trying to resolve the literal, just return the arg
235-
try:
236-
tp_just = tp.to_just()
237-
except TypeVarError:
238-
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
239-
# args first based on the existing type constraint solver
240-
if tcs:
241-
try:
242-
tp_just = tcs.substitute_typevars_try_function(tp, arg, decls)
243-
# If we can't resolve the type var yet, then just assume it is the right value
244-
except TypeConstraintError as e:
245-
if not isinstance(arg, RuntimeExpr):
246-
raise ConvertError(f"Cannot convert {arg} of type {arg_type} to {tp}") from e
247-
tp_just = arg.__egg_typed_expr__.tp
248-
else:
249-
# If this is a var, it has to be a runtime expression
250-
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
251-
return arg
252-
if tcs:
253-
tcs.infer_typevars(tp, tp_just)
254-
if arg_type == tp_just:
255-
# If the type is an egg type, it has to be a runtime expr
256-
assert isinstance(arg, RuntimeExpr)
227+
# If this is a runtime expression that could match the type already, just return it
228+
if isinstance(arg, RuntimeExpr) and tp.matches_just({}, arg.__egg_typed_expr__.tp):
257229
return arg
230+
tp_just = tp.to_just()
231+
arg_type = resolve_type(arg)
258232
if arg is DUMMY_VALUE:
259233
return RuntimeExpr.__from_values__(decls(), TypedExprDecl(tp_just, DummyDecl()))
260234
if (conversion := _lookup_conversion(arg_type, tp_just)) is not None:

python/egglog/declarations.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from dataclasses import dataclass, field
1010
from functools import cached_property
11+
from itertools import chain, repeat
1112
from typing import (
1213
TYPE_CHECKING,
1314
ClassVar,
@@ -413,6 +414,13 @@ def matches_just(self, vars: dict[TypeVarRef, JustTypeRef], other: JustTypeRef)
413414
vars[self] = other
414415
return True
415416

417+
@property
418+
def vars(self) -> set[TypeVarRef]:
419+
"""
420+
Returns all type variables in this type reference.
421+
"""
422+
return {self}
423+
416424

417425
@dataclass(frozen=True)
418426
class TypeRefWithVars:
@@ -437,6 +445,16 @@ def matches_just(self, vars: dict[TypeVarRef, JustTypeRef], other: JustTypeRef)
437445
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
438446
)
439447

448+
@property
449+
def vars(self) -> set[TypeVarRef]:
450+
"""
451+
Returns all type variables in this type reference.
452+
"""
453+
vars = set[TypeVarRef]()
454+
for arg in self.args:
455+
vars.update(arg.vars)
456+
return vars
457+
440458

441459
TypeOrVarRef: TypeAlias = TypeVarRef | TypeRefWithVars
442460

@@ -589,6 +607,25 @@ def semantic_return_type(self) -> TypeOrVarRef:
589607
def mutates(self) -> bool:
590608
return self.return_type is None
591609

610+
@property
611+
def arg_vars(self) -> set[TypeVarRef]:
612+
"""
613+
Returns all type variables in the argument types.
614+
"""
615+
vars = set[TypeVarRef]()
616+
for arg in self.arg_types:
617+
vars.update(arg.vars)
618+
if self.var_arg_type:
619+
vars.update(self.var_arg_type.vars)
620+
return vars
621+
622+
@property
623+
def all_args(self) -> Iterable[TypeOrVarRef]:
624+
"""
625+
Returns all argument types, including var args.
626+
"""
627+
return chain(self.arg_types, (repeat(self.var_arg_type) if self.var_arg_type else []))
628+
592629

593630
@dataclass(frozen=True)
594631
class FunctionDecl:
@@ -675,7 +712,7 @@ def __new__(cls, *args: object, **kwargs: object) -> Self:
675712
"""
676713
Pool CallDecls so that they can be compared by identity more quickly.
677714
678-
Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
715+
Necessary bc we search for common parents when serializing CallDecl trees to egglog to
679716
only serialize each sub-tree once.
680717
"""
681718
# normalize the args/kwargs to a tuple so that they can be compared

python/egglog/deconstruct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def _deconstruct_call_decl(
188188
TypeRefWithVars(call.callable.ident, tuple(tp.to_var() for tp in (call.bound_tp_params or []))),
189189
), arg_exprs
190190
egg_bound = (
191-
JustTypeRef(call.callable.ident, call.bound_tp_params or ())
192-
if isinstance(call.callable, (ClassMethodRef, MethodRef))
191+
JustTypeRef(call.callable.ident, call.bound_tp_params)
192+
if isinstance(call.callable, (ClassMethodRef, MethodRef, InitRef)) and call.bound_tp_params
193193
else None
194194
)
195195

python/egglog/egraph_state.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,9 @@ def type_ref_to_egg(self, ref: JustTypeRef) -> str:
461461
decl = self.__egg_decls__._classes[ref.ident]
462462
self.type_ref_to_egg_sort[ref] = egg_name = (not ref.args and decl.egg_name) or _generate_type_egg_name(ref)
463463
self.egg_sort_to_type_ref[egg_name] = ref
464-
if not decl.builtin or ref.args:
464+
465+
if decl.builtin:
466+
# If this has args, create a new parameterized version of the builtin class
465467
if ref.args:
466468
if ref.ident == Ident.builtin("UnstableFn"):
467469
# UnstableFn is a special case, where the rest of args are collected into a call
@@ -478,18 +480,18 @@ def type_ref_to_egg(self, ref: JustTypeRef) -> str:
478480
]
479481
else:
480482
type_args = [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args]
481-
args = (self.type_ref_to_egg(JustTypeRef(ref.ident)), type_args)
482-
else:
483-
args = None
484-
self.egraph.run_program(bindings.Sort(span(), egg_name, args))
485-
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
486-
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
487-
# even if you never use that function.
488-
if decl.builtin:
483+
assert decl.egg_name
484+
self.egraph.run_program(bindings.Sort(span(), egg_name, (decl.egg_name, type_args)))
485+
486+
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods.
487+
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
488+
# even if you never use that function.
489489
for method_name in decl.class_methods:
490490
self.callable_ref_to_egg(ClassMethodRef(ref.ident, method_name))
491491
if decl.init:
492492
self.callable_ref_to_egg(InitRef(ref.ident))
493+
else:
494+
self.egraph.run_program(bindings.Sort(span(), egg_name, None))
493495

494496
return egg_name
495497

@@ -760,23 +762,23 @@ def value_to_expr(self, tp: JustTypeRef, value: bindings.Value) -> ExprDecl: #
760762
return CallDecl(
761763
InitRef(Ident.builtin("Set")),
762764
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs_),
763-
(v_tp,),
765+
(v_tp,) if not xs_ else (),
764766
)
765767
case "Vec":
766768
xs = self.egraph.value_to_vec(value)
767769
(v_tp,) = tp.args
768770
return CallDecl(
769771
InitRef(Ident.builtin("Vec")),
770772
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs),
771-
(v_tp,),
773+
(v_tp,) if not xs else (),
772774
)
773775
case "MultiSet":
774776
xs = self.egraph.value_to_multiset(value)
775777
(v_tp,) = tp.args
776778
return CallDecl(
777779
InitRef(Ident.builtin("MultiSet")),
778780
tuple(TypedExprDecl(v_tp, self.value_to_expr(v_tp, x)) for x in xs),
779-
(v_tp,),
781+
(v_tp,) if not xs else (),
780782
)
781783
case "UnstableFn":
782784
_names, _args = self.egraph.value_to_function(value)
@@ -796,22 +798,13 @@ def _unstable_fn_value_to_expr(
796798
continue
797799
if signature.semantic_return_type.ident != return_tp.ident:
798800
continue
799-
tcs = TypeConstraintSolver()
800-
801-
arg_types = tcs.infer_arg_types(
801+
arg_types = TypeConstraintSolver().infer_arg_types(
802802
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp
803803
)
804-
805804
args = tuple(
806805
TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
807806
)
808-
if isinstance(callable_ref, ClassMethodRef | InitRef):
809-
bound_tp_params = tuple(
810-
map(tcs.substitute_typevars, self.__egg_decls__.get_class_decl(callable_ref.ident).type_vars)
811-
)
812-
else:
813-
bound_tp_params = ()
814-
call_decl = CallDecl(callable_ref, args, bound_tp_params)
807+
call_decl = CallDecl(callable_ref, args)
815808
return PartialCallDecl(call_decl)
816809
raise ValueError(f"Function '{name}' not found")
817810

@@ -936,6 +929,9 @@ def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
936929
tcs = TypeConstraintSolver()
937930
if possible_type and possible_type.args:
938931
tcs.bind_class(possible_type, self.decls)
932+
bound_args = possible_type.args
933+
else:
934+
bound_args = ()
939935
try:
940936
arg_types = tcs.infer_arg_types(
941937
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp
@@ -945,12 +941,9 @@ def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
945941
except TypeConstraintError:
946942
continue
947943
args = tuple(self.resolve_term(a, tp) for a, tp in a_tp)
948-
if not args and isinstance(callable_ref, ClassMethodRef | InitRef):
949-
bound_tp_params = tuple(
950-
map(tcs.substitute_typevars, self.decls.get_class_decl(callable_ref.ident).type_vars)
951-
)
952-
else:
953-
bound_tp_params = ()
944+
# Only save bound tp params if needed for inferring return type
945+
# this is true if the set of set of type vars in the return are not a subset of those in the args
946+
bound_tp_params = () if signature.semantic_return_type.vars.issubset(signature.arg_vars) else bound_args
954947
return CallDecl(callable_ref, args, bound_tp_params)
955948
raise ValueError(
956949
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"

python/egglog/exp/array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2500,7 +2500,7 @@ def factor_ruleset(
25002500
yield rule(
25012501
eq(n).to(polynomial(mss)),
25022502
# Find factor that shows up in most monomials, at least two of them
2503-
counts == MultiSet.sum_multisets(mss.map(MultiSet[Value].reset_counts)),
2503+
counts == MultiSet.sum_multisets(mss.map(MultiSet.reset_counts)),
25042504
eq(factor).to(counts.pick_max()),
25052505
# Only factor out if it term appears in more than one monomial
25062506
counts.count(factor) > 1,

python/egglog/exp/polynomials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def factor_ruleset(
253253
yield rule(
254254
n == polynomial(mss),
255255
# Find factor that shows up in most monomials, at least two of them
256-
counts == MultiSet.sum_multisets(mss.map(MultiSet[Number].reset_counts)),
256+
counts == MultiSet.sum_multisets(mss.map(MultiSet.reset_counts)),
257257
factor == counts.pick_max(),
258258
# Only factor out if it term appears in more than one monomial
259259
counts.count(factor) > 1,

python/egglog/runtime.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import Callable
1818
from dataclasses import InitVar, dataclass, replace
1919
from inspect import Parameter, Signature
20-
from itertools import zip_longest
2120
from typing import TYPE_CHECKING, Any, TypeVar, Union, assert_never, cast, get_args, get_origin
2221

2322
import cloudpickle
@@ -398,23 +397,18 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
398397
self.__egg_decls_thunk__,
399398
Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.ident, name)))),
400399
)
400+
bound = self.__egg_tp__.to_just() if self.__egg_tp__.args else None
401401
if name in cls_decl.class_methods:
402402
return RuntimeFunction(
403-
self.__egg_decls_thunk__,
404-
Thunk.value(ClassMethodRef(self.__egg_tp__.ident, name)),
405-
self.__egg_tp__.to_just(),
403+
self.__egg_decls_thunk__, Thunk.value(ClassMethodRef(self.__egg_tp__.ident, name)), bound
406404
)
407405
# allow referencing properties and methods as class variables as well
408406
if name in cls_decl.properties:
409407
return RuntimeFunction(
410-
self.__egg_decls_thunk__,
411-
Thunk.value(PropertyRef(self.__egg_tp__.ident, name)),
412-
self.__egg_tp__.to_just(),
408+
self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.ident, name)), bound
413409
)
414410
if name in cls_decl.methods:
415-
return RuntimeFunction(
416-
self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.ident, name)), self.__egg_tp__.to_just()
417-
)
411+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.ident, name)), bound)
418412

419413
msg = f"Class {self.__egg_tp__.ident} has no method {name}"
420414
raise AttributeError(msg) from None
@@ -504,7 +498,7 @@ def __hash__(self) -> int:
504498
def __egg_ref__(self) -> CallableRef:
505499
return self.__egg_ref_thunk__()
506500

507-
def __call__( # noqa: C901
501+
def __call__( # noqa: C901,PLR0912
508502
self, *args: object, _egg_function_types: tuple[TypeOrVarRef, ...] | None = None, **kwargs: object
509503
) -> RuntimeExpr | None:
510504
from .conversion import resolve_literal # noqa: PLC0415
@@ -551,20 +545,14 @@ def __call__( # noqa: C901
551545
args = bound.args
552546

553547
tcs = TypeConstraintSolver()
554-
bound_tp = (
555-
None
556-
if self.__egg_bound__ is None
557-
else self.__egg_bound__.__egg_typed_expr__.tp
558-
if isinstance(self.__egg_bound__, RuntimeExpr)
559-
else self.__egg_bound__
560-
)
561-
if (
562-
bound_tp
563-
and bound_tp.args
564-
# Don't bind class if we have a first class function arg, b/c we don't support that yet
565-
and not function_value
566-
):
567-
tcs.bind_class(bound_tp, decls)
548+
if isinstance(self.__egg_bound__, JustTypeRef) and self.__egg_bound__.args:
549+
if function_value:
550+
msg = "Cannot have both bound type params and function value"
551+
raise ValueError(msg)
552+
tcs.bind_class(self.__egg_bound__, decls)
553+
bound_tp_params = self.__egg_bound__.args
554+
else:
555+
bound_tp_params = ()
568556
assert (operator.ge if signature.var_arg_type else operator.eq)(len(args), len(signature.arg_types))
569557
# Hack to allow being explicit on function types when casting. # noqa: FIX004
570558
for _fn_tp in _egg_function_types or ():
@@ -575,20 +563,29 @@ def __call__( # noqa: C901
575563
tcs.bind_class(_fn_tp_just, decls)
576564
if _fn_tp_just.args:
577565
pass
566+
# Try using any runtime expressions passed in to help infer typevars
567+
for arg, tp in zip(args, signature.all_args, strict=False):
568+
if not isinstance(arg, RuntimeExpr):
569+
continue
570+
try:
571+
tcs.infer_typevars(tp, arg.__egg_typed_expr__.tp)
572+
# If this leads to an incompatibility, just skip it, since it could need to be upcasted
573+
except TypeConstraintError:
574+
continue
575+
# Now at this point we should be able to resolve all the typevars
578576
upcasted_args = [
579-
resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs)
580-
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
577+
resolve_literal(
578+
tcs.substitute_typevars_try_function(tp, arg, Thunk.value(decls)).to_var(), arg, Thunk.value(decls)
579+
)
580+
for arg, tp in zip(args, signature.all_args, strict=False)
581581
]
582582
decls.update(*upcasted_args)
583583
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
584584
return_tp = tcs.substitute_typevars(signature.semantic_return_type)
585-
bound_params = (
586-
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
587-
)
588585
# If we were using unstable-app to call a function, add that function back as the first arg.
589586
if function_value:
590587
arg_exprs = (function_value, *arg_exprs)
591-
expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
588+
expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_tp_params)
592589
typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
593590
# If there is not return type, we are mutating the first arg
594591
if not signature.return_type:
@@ -901,8 +898,7 @@ def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | Run
901898
case InitRef(name):
902899
return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
903900
case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
904-
bound = JustTypeRef(ref.ident) if isinstance(ref, ClassMethodRef) else None
905-
return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
901+
return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), None)
906902
case ConstantRef(name):
907903
tp = decls._constants[name].type_ref
908904
case ClassVariableRef(cls_name, var_name):

0 commit comments

Comments
 (0)