Skip to content

Commit e391e1e

Browse files
tmp
1 parent 529fbf4 commit e391e1e

8 files changed

Lines changed: 67 additions & 73 deletions

File tree

python/egglog/builtins.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def eval(self) -> dict[T, V]:
409409
@property
410410
def value(self) -> dict[T, V]:
411411
d = {}
412-
while args := get_callable_args(self, Map[T, V].insert):
412+
while args := get_callable_args(self, Map.insert): # type: ignore[var-annotated]
413413
self, k, v = args # noqa: PLW0642
414414
d[k] = v
415415
if get_callable_args(self, Map.empty) is None:
@@ -992,7 +992,7 @@ def __init__(self, *args: T) -> None: ...
992992
def empty(cls) -> Vec[T]: ...
993993

994994
@method(egg_fn="vec-append")
995-
def append(self, *others: Vec[T]) -> Vec[T]: ...
995+
def append(self, *others: VecLike[T, T]) -> Vec[T]: ...
996996

997997
@method(egg_fn="vec-push")
998998
def push(self, value: T) -> Vec[T]: ...
@@ -1033,6 +1033,9 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...
10331033

10341034
VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
10351035

1036+
v = Vec(i64(10))
1037+
v.append([i64(10)])
1038+
10361039

10371040
class PyObject(BuiltinExpr, egg_sort="PyObject"):
10381041
@method(preserve=True)

python/egglog/conversion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def resolve_literal(
221221
arg: object,
222222
decls: Callable[[], Declarations] = retrieve_conversion_decls,
223223
tcs: TypeConstraintSolver | None = None,
224-
cls_ident: Ident | None = None,
225224
) -> RuntimeExpr:
226225
"""
227226
Try to convert an object to a type, raising a ConvertError if it is not possible.
@@ -240,7 +239,7 @@ def resolve_literal(
240239
# args first based on the existing type constraint solver
241240
if tcs:
242241
try:
243-
tp_just = tcs.substitute_typevars(tp, cls_ident)
242+
tp_just = tcs.substitute_typevars(tp)
244243
# If we can't resolve the type var yet, then just assume it is the right value
245244
except TypeConstraintError:
246245
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {type(arg)}"
@@ -250,7 +249,7 @@ def resolve_literal(
250249
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
251250
return arg
252251
if tcs:
253-
tcs.infer_typevars(tp, tp_just, cls_ident)
252+
tcs.infer_typevars(tp, tp_just)
254253
if arg_type == tp_just:
255254
# If the type is an egg type, it has to be a runtime expr
256255
assert isinstance(arg, RuntimeExpr)

python/egglog/egraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def __new__( # type: ignore[misc]
328328
cls_ident = Ident(name, _get_module(prev_frame))
329329
# Pass in an instance of the class so that when we are generating the decls
330330
# we can update them eagerly so that we can access the methods in the class body
331+
# TODO: How should we normalize unparameterized classes? Should they have args of the typevars?
331332
runtime_cls = RuntimeClass(None, TypeRefWithVars(cls_ident)) # type: ignore[arg-type]
332333

333334
# Store frame so that we can get live access to updated locals/globals

python/egglog/egraph_state.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -796,23 +796,22 @@ def _unstable_fn_value_to_expr(
796796
continue
797797
if signature.semantic_return_type.ident != return_tp.ident:
798798
continue
799-
tcs = TypeConstraintSolver(self.__egg_decls__)
799+
tcs = TypeConstraintSolver()
800800

801-
arg_types, bound_tp_params = tcs.infer_arg_types(
802-
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp, None
801+
arg_types = tcs.infer_arg_types(
802+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, return_tp
803803
)
804804

805805
args = tuple(
806806
TypedExprDecl(tp, self.value_to_expr(tp, v)) for tp, v in zip(arg_types, partial_args, strict=False)
807807
)
808-
809-
call_decl = CallDecl(
810-
callable_ref,
811-
args,
812-
# Don't include bound type params if this is just a method, we only needed them for type resolution
813-
# but dont need to store them
814-
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else (),
815-
)
808+
if isinstance(callable_ref, ClassMethodRef | InitRef):
809+
bound_tp_params = tuple(
810+
map(tcs.substitute_typevars, self.__egg_decls__.get_class_decl(callable_ref.ident).type_vars)
811+
)
812+
else:
813+
bound_tp_params = ()
814+
call_decl = CallDecl(callable_ref, args, bound_tp_params)
816815
return PartialCallDecl(call_decl)
817816
raise ValueError(f"Function '{name}' not found")
818817

@@ -909,11 +908,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
909908
assert_never(term)
910909
return TypedExprDecl(tp, expr_decl)
911910

912-
def from_call(
913-
self,
914-
tp: JustTypeRef,
915-
term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
916-
) -> CallDecl:
911+
def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
917912
"""
918913
Convert a call to a CallDecl.
919914
@@ -931,33 +926,32 @@ def from_call(
931926
signature = self.decls.get_callable_decl(callable_ref).signature
932927
assert isinstance(signature, FunctionSignature)
933928
if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
934-
# Need OR in case we have class method whose class whas never added as a sort, which would happen
929+
# Need OR in case we have class method whose class was never added as a sort, which would happen
935930
# if the class method didn't return that type and no other function did. In this case, we don't need
936-
# to care about the type vars and we we don't need to bind any possible type.
931+
# to care about the type vars and we don't need to bind any possible type.
937932
possible_types = self.state._get_possible_types(callable_ref.ident) or [None]
938-
cls_name = callable_ref.ident
939933
else:
940934
possible_types = [None]
941-
cls_name = None
942935
for possible_type in possible_types:
943-
tcs = TypeConstraintSolver(self.decls)
936+
tcs = TypeConstraintSolver()
944937
if possible_type and possible_type.args:
945-
tcs.bind_class(possible_type)
938+
tcs.bind_class(possible_type, self.decls)
946939
try:
947-
arg_types, bound_tp_params = tcs.infer_arg_types(
948-
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
940+
arg_types = tcs.infer_arg_types(
941+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp
949942
)
943+
# Include this in try because of iterable
944+
a_tp = list(zip(term.args, arg_types, strict=False))
950945
except TypeConstraintError:
951946
continue
952-
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
953-
954-
return CallDecl(
955-
callable_ref,
956-
args,
957-
# Don't include bound type params if this is just a method, we only needed them for type resolution
958-
# but dont need to store them
959-
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) and not args else (),
960-
)
947+
args = tuple(self.resolve_term(a, tp) for a, tp in a_tp)
948+
if not args and isinstance(callable_ref, ClassMethodRef | InitRef):
949+
bound_tp_params = tuple(
950+
map(tcs.substitute_typevars, self.decls.get_class_decl(callable_ref.ident).type_vars)
951+
)
952+
else:
953+
bound_tp_params = ()
954+
return CallDecl(callable_ref, args, bound_tp_params)
961955
raise ValueError(
962956
f"Could not find callable ref for call {term}. None of these refs matched the types: {self.state.egg_fn_to_callable_refs[term.name]}"
963957
)

python/egglog/exp/array_api_program_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Valu
144144
yield rewrite(value_program(Value.float(f))).to(float_program(f))
145145
# Could add .item() but we usually dont need it.
146146
yield rewrite(value_program(x.to_value())).to(ndarray_program(x))
147-
yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
147+
yield rewrite(bool_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
148148
yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")")
149149
yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")")
150150
yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")

python/egglog/runtime.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def __getitem__(self, args: object) -> RuntimeClass:
350350
"tuple[tuple[DeclerationsLike, ...], tuple[TypeOrVarRef, ...]]",
351351
zip(*(resolve_type_annotation(arg) for arg in args), strict=False),
352352
)
353-
# if we already have some args bound and some not, then we shold replace all existing args of typevars with new
353+
# if we already have some args bound and some not, then we should replace all existing args of typevars with new
354354
# args
355355
if old_args := self.__egg_tp__.args:
356356
is_typevar = [isinstance(arg, TypeVarRef) for arg in old_args]
@@ -550,7 +550,7 @@ def __call__( # noqa: C901
550550
assert not bound.kwargs
551551
args = bound.args
552552

553-
tcs = TypeConstraintSolver(decls)
553+
tcs = TypeConstraintSolver()
554554
bound_tp = (
555555
None
556556
if self.__egg_bound__ is None
@@ -564,27 +564,24 @@ def __call__( # noqa: C901
564564
# Don't bind class if we have a first class function arg, b/c we don't support that yet
565565
and not function_value
566566
):
567-
tcs.bind_class(bound_tp)
567+
tcs.bind_class(bound_tp, decls)
568568
assert (operator.ge if signature.var_arg_type else operator.eq)(len(args), len(signature.arg_types))
569569
# Hack to allow being explicit on function types when casting. # noqa: FIX004
570-
# TODO: Replace type analysis class binding stuff
571-
# with instead binders of location per call origination
572-
cls_ident = bound_tp.ident if bound_tp else None
573570
for _fn_tp in _egg_function_types or ():
574571
try:
575572
_fn_tp_just = _fn_tp.to_just()
576573
except TypeVarError:
577574
continue
578-
tcs.bind_class(_fn_tp_just)
575+
tcs.bind_class(_fn_tp_just, decls)
579576
if _fn_tp_just.args:
580-
cls_ident = _fn_tp_just.ident
577+
pass
581578
upcasted_args = [
582-
resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs=tcs, cls_ident=cls_ident)
579+
resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs)
583580
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
584581
]
585582
decls.update(*upcasted_args)
586583
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
587-
return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_ident)
584+
return_tp = tcs.substitute_typevars(signature.semantic_return_type)
588585
bound_params = (
589586
cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else ()
590587
)

python/egglog/type_constraint_solver.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def infer_arg_types(
5858
Given a return type, infer the argument types. If there is a variable arg, it returns an infinite iterable.
5959
"""
6060
self.infer_typevars(fn_return, return_)
61-
for fn_arg in fn_args:
62-
yield self.substitute_typevars(fn_arg)
63-
if fn_var_args is not None:
64-
var_arg_type = self.substitute_typevars(fn_var_args)
65-
yield from chain(repeat(var_arg_type))
61+
arg_types = [self.substitute_typevars(fn_arg) for fn_arg in fn_args]
62+
if fn_var_args is None:
63+
return arg_types
64+
var_arg_type = self.substitute_typevars(fn_var_args)
65+
return chain(arg_types, repeat(var_arg_type))
6666

6767
def infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef) -> None:
6868
"""
@@ -88,11 +88,11 @@ def substitute_typevars(self, tp: TypeOrVarRef) -> JustTypeRef:
8888
Substitute typevars in a type with their inferred types, raises TypeConstraintError if a typevar is unresolved.
8989
"""
9090
match tp:
91-
case TypeVarRef(tyepvar_ident):
91+
case TypeVarRef(typevar_ident):
9292
try:
93-
return self._typevar_to_type[tyepvar_ident]
93+
return self._typevar_to_type[typevar_ident]
9494
except KeyError as e:
95-
raise TypeConstraintError(f"Unresolved type variable: {tyepvar_ident}") from e
95+
raise TypeConstraintError(f"Unresolved type variable: {typevar_ident}") from e
9696
case TypeRefWithVars(name, args):
9797
return JustTypeRef(name, tuple(self.substitute_typevars(arg) for arg in args))
9898
assert_never(tp)

python/tests/test_type_constraint_solver.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,40 +14,40 @@
1414

1515

1616
def test_simple() -> None:
17-
tcs = TypeConstraintSolver(Declarations())
17+
tcs = TypeConstraintSolver()
1818
tcs.infer_typevars(i64.to_var(), i64)
1919
assert tcs.substitute_typevars(i64.to_var()) == i64
2020

2121

2222
def test_wrong_arg() -> None:
23-
tcs = TypeConstraintSolver(Declarations())
23+
tcs = TypeConstraintSolver()
2424
with pytest.raises(TypeConstraintError):
2525
tcs.infer_typevars(i64.to_var(), unit)
2626

2727

2828
def test_generic() -> None:
29-
tcs = TypeConstraintSolver(Declarations())
30-
tcs.infer_typevars(map, map_i64_unit, Ident("Map"))
31-
tcs.infer_typevars(K, i64, Ident("Map"))
32-
assert tcs.substitute_typevars(V, Ident("Map")) == unit
29+
tcs = TypeConstraintSolver()
30+
tcs.infer_typevars(map, map_i64_unit)
31+
tcs.infer_typevars(K, i64)
32+
assert tcs.substitute_typevars(V) == unit
3333

3434

3535
def test_generic_wrong() -> None:
36-
tcs = TypeConstraintSolver(Declarations())
37-
tcs.infer_typevars(map, map_i64_unit, Ident("Map"))
36+
tcs = TypeConstraintSolver()
37+
tcs.infer_typevars(map, map_i64_unit)
3838
with pytest.raises(TypeConstraintError):
39-
tcs.infer_typevars(K, unit, Ident("Map"))
39+
tcs.infer_typevars(K, unit)
4040

4141

4242
def test_bound() -> None:
43-
bound_cs = TypeConstraintSolver(decls)
44-
bound_cs.bind_class(map_i64_unit)
45-
bound_cs.infer_typevars(K, i64, Ident("Map"))
46-
assert bound_cs.substitute_typevars(V, Ident("Map")) == unit
43+
bound_cs = TypeConstraintSolver()
44+
bound_cs.bind_class(map_i64_unit, decls)
45+
bound_cs.infer_typevars(K, i64)
46+
assert bound_cs.substitute_typevars(V) == unit
4747

4848

4949
def test_bound_wrong():
50-
bound_cs = TypeConstraintSolver(decls)
51-
bound_cs.bind_class(map_i64_unit)
50+
bound_cs = TypeConstraintSolver()
51+
bound_cs.bind_class(map_i64_unit, decls)
5252
with pytest.raises(TypeConstraintError):
53-
bound_cs.infer_typevars(K, unit, Ident("Map"))
53+
bound_cs.infer_typevars(K, unit)

0 commit comments

Comments
 (0)