1717from collections .abc import Callable
1818from dataclasses import InitVar , dataclass , replace
1919from inspect import Parameter , Signature
20- from itertools import zip_longest
2120from typing import TYPE_CHECKING , Any , TypeVar , Union , assert_never , cast , get_args , get_origin
2221
2322import cloudpickle
@@ -398,23 +397,18 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
398397 self .__egg_decls_thunk__ ,
399398 Thunk .value (TypedExprDecl (return_tp .type_ref , CallDecl (ClassVariableRef (self .__egg_tp__ .ident , name )))),
400399 )
400+ bound = self .__egg_tp__ .to_just () if self .__egg_tp__ .args else None
401401 if name in cls_decl .class_methods :
402402 return RuntimeFunction (
403- self .__egg_decls_thunk__ ,
404- Thunk .value (ClassMethodRef (self .__egg_tp__ .ident , name )),
405- self .__egg_tp__ .to_just (),
403+ self .__egg_decls_thunk__ , Thunk .value (ClassMethodRef (self .__egg_tp__ .ident , name )), bound
406404 )
407405 # allow referencing properties and methods as class variables as well
408406 if name in cls_decl .properties :
409407 return RuntimeFunction (
410- self .__egg_decls_thunk__ ,
411- Thunk .value (PropertyRef (self .__egg_tp__ .ident , name )),
412- self .__egg_tp__ .to_just (),
408+ self .__egg_decls_thunk__ , Thunk .value (PropertyRef (self .__egg_tp__ .ident , name )), bound
413409 )
414410 if name in cls_decl .methods :
415- return RuntimeFunction (
416- self .__egg_decls_thunk__ , Thunk .value (MethodRef (self .__egg_tp__ .ident , name )), self .__egg_tp__ .to_just ()
417- )
411+ return RuntimeFunction (self .__egg_decls_thunk__ , Thunk .value (MethodRef (self .__egg_tp__ .ident , name )), bound )
418412
419413 msg = f"Class { self .__egg_tp__ .ident } has no method { name } "
420414 raise AttributeError (msg ) from None
@@ -504,7 +498,7 @@ def __hash__(self) -> int:
504498 def __egg_ref__ (self ) -> CallableRef :
505499 return self .__egg_ref_thunk__ ()
506500
507- def __call__ ( # noqa: C901
501+ def __call__ ( # noqa: C901,PLR0912
508502 self , * args : object , _egg_function_types : tuple [TypeOrVarRef , ...] | None = None , ** kwargs : object
509503 ) -> RuntimeExpr | None :
510504 from .conversion import resolve_literal # noqa: PLC0415
@@ -551,20 +545,14 @@ def __call__( # noqa: C901
551545 args = bound .args
552546
553547 tcs = TypeConstraintSolver ()
554- bound_tp = (
555- None
556- if self .__egg_bound__ is None
557- else self .__egg_bound__ .__egg_typed_expr__ .tp
558- if isinstance (self .__egg_bound__ , RuntimeExpr )
559- else self .__egg_bound__
560- )
561- if (
562- bound_tp
563- and bound_tp .args
564- # Don't bind class if we have a first class function arg, b/c we don't support that yet
565- and not function_value
566- ):
567- tcs .bind_class (bound_tp , decls )
548+ if isinstance (self .__egg_bound__ , JustTypeRef ) and self .__egg_bound__ .args :
549+ if function_value :
550+ msg = "Cannot have both bound type params and function value"
551+ raise ValueError (msg )
552+ tcs .bind_class (self .__egg_bound__ , decls )
553+ bound_tp_params = self .__egg_bound__ .args
554+ else :
555+ bound_tp_params = ()
568556 assert (operator .ge if signature .var_arg_type else operator .eq )(len (args ), len (signature .arg_types ))
569557 # Hack to allow being explicit on function types when casting. # noqa: FIX004
570558 for _fn_tp in _egg_function_types or ():
@@ -575,20 +563,29 @@ def __call__( # noqa: C901
575563 tcs .bind_class (_fn_tp_just , decls )
576564 if _fn_tp_just .args :
577565 pass
566+ # Try using any runtime expressions passed in to help infer typevars
567+ for arg , tp in zip (args , signature .all_args , strict = False ):
568+ if not isinstance (arg , RuntimeExpr ):
569+ continue
570+ try :
571+ tcs .infer_typevars (tp , arg .__egg_typed_expr__ .tp )
572+ # If this leads to an incompatibility, just skip it, since it could need to be upcasted
573+ except TypeConstraintError :
574+ continue
575+ # Now at this point we should be able to resolve all the typevars
578576 upcasted_args = [
579- resolve_literal (cast ("TypeOrVarRef" , tp ), arg , Thunk .value (decls ), tcs )
580- for arg , tp in zip_longest (args , signature .arg_types , fillvalue = signature .var_arg_type )
577+ resolve_literal (
578+ tcs .substitute_typevars_try_function (tp , arg , Thunk .value (decls )).to_var (), arg , Thunk .value (decls )
579+ )
580+ for arg , tp in zip (args , signature .all_args , strict = False )
581581 ]
582582 decls .update (* upcasted_args )
583583 arg_exprs = tuple (arg .__egg_typed_expr__ for arg in upcasted_args )
584584 return_tp = tcs .substitute_typevars (signature .semantic_return_type )
585- bound_params = (
586- cast ("JustTypeRef" , bound_tp ).args if isinstance (self .__egg_ref__ , ClassMethodRef | InitRef ) else ()
587- )
588585 # If we were using unstable-app to call a function, add that function back as the first arg.
589586 if function_value :
590587 arg_exprs = (function_value , * arg_exprs )
591- expr_decl = CallDecl (self .__egg_ref__ , arg_exprs , bound_params )
588+ expr_decl = CallDecl (self .__egg_ref__ , arg_exprs , bound_tp_params )
592589 typed_expr_decl = TypedExprDecl (return_tp , expr_decl )
593590 # If there is not return type, we are mutating the first arg
594591 if not signature .return_type :
@@ -901,8 +898,7 @@ def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | Run
901898 case InitRef (name ):
902899 return RuntimeClass (Thunk .value (decls ), TypeRefWithVars (name ))
903900 case FunctionRef () | MethodRef () | ClassMethodRef () | PropertyRef () | UnnamedFunctionRef ():
904- bound = JustTypeRef (ref .ident ) if isinstance (ref , ClassMethodRef ) else None
905- return RuntimeFunction (Thunk .value (decls ), Thunk .value (ref ), bound )
901+ return RuntimeFunction (Thunk .value (decls ), Thunk .value (ref ), None )
906902 case ConstantRef (name ):
907903 tp = decls ._constants [name ].type_ref
908904 case ClassVariableRef (cls_name , var_name ):
0 commit comments