@@ -796,23 +796,22 @@ def _unstable_fn_value_to_expr(
796796 continue
797797 if signature .semantic_return_type .ident != return_tp .ident :
798798 continue
799- tcs = TypeConstraintSolver (self . __egg_decls__ )
799+ tcs = TypeConstraintSolver ()
800800
801- arg_types , bound_tp_params = tcs .infer_arg_types (
802- signature .arg_types , signature .semantic_return_type , signature .var_arg_type , return_tp , None
801+ arg_types = tcs .infer_arg_types (
802+ signature .arg_types , signature .semantic_return_type , signature .var_arg_type , return_tp
803803 )
804804
805805 args = tuple (
806806 TypedExprDecl (tp , self .value_to_expr (tp , v )) for tp , v in zip (arg_types , partial_args , strict = False )
807807 )
808-
809- call_decl = CallDecl (
810- callable_ref ,
811- args ,
812- # Don't include bound type params if this is just a method, we only needed them for type resolution
813- # but dont need to store them
814- bound_tp_params if isinstance (callable_ref , ClassMethodRef | InitRef ) else (),
815- )
808+ if isinstance (callable_ref , ClassMethodRef | InitRef ):
809+ bound_tp_params = tuple (
810+ map (tcs .substitute_typevars , self .__egg_decls__ .get_class_decl (callable_ref .ident ).type_vars )
811+ )
812+ else :
813+ bound_tp_params = ()
814+ call_decl = CallDecl (callable_ref , args , bound_tp_params )
816815 return PartialCallDecl (call_decl )
817816 raise ValueError (f"Function '{ name } ' not found" )
818817
@@ -909,11 +908,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
909908 assert_never (term )
910909 return TypedExprDecl (tp , expr_decl )
911910
912- def from_call (
913- self ,
914- tp : JustTypeRef ,
915- term : bindings .TermApp , # additional_arg_tps: tuple[JustTypeRef, ...]
916- ) -> CallDecl :
911+ def from_call (self , tp : JustTypeRef , term : bindings .TermApp ) -> CallDecl :
917912 """
918913 Convert a call to a CallDecl.
919914
@@ -931,33 +926,32 @@ def from_call(
931926 signature = self .decls .get_callable_decl (callable_ref ).signature
932927 assert isinstance (signature , FunctionSignature )
933928 if isinstance (callable_ref , ClassMethodRef | InitRef | MethodRef ):
934- # Need OR in case we have class method whose class whas never added as a sort, which would happen
929+ # Need OR in case we have class method whose class was never added as a sort, which would happen
935930 # if the class method didn't return that type and no other function did. In this case, we don't need
936- # to care about the type vars and we we don't need to bind any possible type.
931+ # to care about the type vars and we don't need to bind any possible type.
937932 possible_types = self .state ._get_possible_types (callable_ref .ident ) or [None ]
938- cls_name = callable_ref .ident
939933 else :
940934 possible_types = [None ]
941- cls_name = None
942935 for possible_type in possible_types :
943- tcs = TypeConstraintSolver (self . decls )
936+ tcs = TypeConstraintSolver ()
944937 if possible_type and possible_type .args :
945- tcs .bind_class (possible_type )
938+ tcs .bind_class (possible_type , self . decls )
946939 try :
947- arg_types , bound_tp_params = tcs .infer_arg_types (
948- signature .arg_types , signature .semantic_return_type , signature .var_arg_type , tp , cls_name
940+ arg_types = tcs .infer_arg_types (
941+ signature .arg_types , signature .semantic_return_type , signature .var_arg_type , tp
949942 )
943+ # Include this in try because of iterable
944+ a_tp = list (zip (term .args , arg_types , strict = False ))
950945 except TypeConstraintError :
951946 continue
952- args = tuple (self .resolve_term (a , tp ) for a , tp in zip (term .args , arg_types , strict = False ))
953-
954- return CallDecl (
955- callable_ref ,
956- args ,
957- # Don't include bound type params if this is just a method, we only needed them for type resolution
958- # but dont need to store them
959- bound_tp_params if isinstance (callable_ref , ClassMethodRef | InitRef ) and not args else (),
960- )
947+ args = tuple (self .resolve_term (a , tp ) for a , tp in a_tp )
948+ if not args and isinstance (callable_ref , ClassMethodRef | InitRef ):
949+ bound_tp_params = tuple (
950+ map (tcs .substitute_typevars , self .decls .get_class_decl (callable_ref .ident ).type_vars )
951+ )
952+ else :
953+ bound_tp_params = ()
954+ return CallDecl (callable_ref , args , bound_tp_params )
961955 raise ValueError (
962956 f"Could not find callable ref for call { term } . None of these refs matched the types: { self .state .egg_fn_to_callable_refs [term .name ]} "
963957 )
0 commit comments