Skip to content

Commit a4f58a2

Browse files
Fix type analysis
1 parent e391e1e commit a4f58a2

2 files changed

Lines changed: 38 additions & 4 deletions

File tree

python/egglog/conversion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,14 @@ def resolve_literal(
239239
# args first based on the existing type constraint solver
240240
if tcs:
241241
try:
242-
tp_just = tcs.substitute_typevars(tp)
242+
tp_just = tcs.substitute_typevars_try_function(tp, arg, decls)
243243
# If we can't resolve the type var yet, then just assume it is the right value
244-
except TypeConstraintError:
245-
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {type(arg)}"
244+
except TypeConstraintError as e:
245+
if not isinstance(arg, RuntimeExpr):
246+
raise ConvertError(f"Cannot convert {arg} of type {arg_type} to {tp}") from e
246247
tp_just = arg.__egg_typed_expr__.tp
247248
else:
248-
# If this is a var, it has to be a runtime expession
249+
# If this is a var, it has to be a runtime expression
249250
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
250251
return arg
251252
if tcs:

python/egglog/type_constraint_solver.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,36 @@ def substitute_typevars(self, tp: TypeOrVarRef) -> JustTypeRef:
9696
case TypeRefWithVars(name, args):
9797
return JustTypeRef(name, tuple(self.substitute_typevars(arg) for arg in args))
9898
assert_never(tp)
99+
100+
def substitute_typevars_try_function(
101+
self, tp: TypeOrVarRef, value: Callable, decls: Callable[[], Declarations]
102+
) -> JustTypeRef:
103+
"""
104+
Try to substitute typevars in a type with their inferred types.
105+
106+
If this fails and we have an UnstableFn type and a function value, we can try to infer the typevars by calling
107+
it with the input types, if we can resolve those
108+
"""
109+
from .runtime import RuntimeExpr # noqa: PLC0415
110+
111+
try:
112+
return self.substitute_typevars(tp)
113+
except TypeConstraintError:
114+
if isinstance(tp, TypeVarRef) or tp.ident != Ident.builtin("UnstableFn") or not callable(value):
115+
raise
116+
dummy_args = [
117+
RuntimeExpr.__from_values__(decls(), TypedExprDecl(self.substitute_typevars(arg_tp), DummyDecl()))
118+
for arg_tp in tp.args[1:]
119+
]
120+
try:
121+
result = value(*dummy_args)
122+
except Exception as e:
123+
raise TypeConstraintError(
124+
f"Function {value} raised an exception when called with dummy args to infer return type: {e}"
125+
) from e
126+
if not isinstance(result, RuntimeExpr):
127+
raise TypeConstraintError(
128+
f"Function {value} did not return a RuntimeExpr, got {type(result)}, so cannot infer return type"
129+
)
130+
self.infer_typevars(tp.args[0], result.__egg_typed_expr__.tp)
131+
return self.substitute_typevars(tp)

0 commit comments

Comments
 (0)