@@ -96,3 +96,36 @@ def substitute_typevars(self, tp: TypeOrVarRef) -> JustTypeRef:
9696 case TypeRefWithVars (name , args ):
9797 return JustTypeRef (name , tuple (self .substitute_typevars (arg ) for arg in args ))
9898 assert_never (tp )
99+
100+ def substitute_typevars_try_function (
101+ self , tp : TypeOrVarRef , value : Callable , decls : Callable [[], Declarations ]
102+ ) -> JustTypeRef :
103+ """
104+ Try to substitute typevars in a type with their inferred types.
105+
106+ If this fails and we have an UnstableFn type and a function value, we can try to infer the typevars by calling
107+ it with the input types, if we can resolve those
108+ """
109+ from .runtime import RuntimeExpr # noqa: PLC0415
110+
111+ try :
112+ return self .substitute_typevars (tp )
113+ except TypeConstraintError :
114+ if isinstance (tp , TypeVarRef ) or tp .ident != Ident .builtin ("UnstableFn" ) or not callable (value ):
115+ raise
116+ dummy_args = [
117+ RuntimeExpr .__from_values__ (decls (), TypedExprDecl (self .substitute_typevars (arg_tp ), DummyDecl ()))
118+ for arg_tp in tp .args [1 :]
119+ ]
120+ try :
121+ result = value (* dummy_args )
122+ except Exception as e :
123+ raise TypeConstraintError (
124+ f"Function { value } raised an exception when called with dummy args to infer return type: { e } "
125+ ) from e
126+ if not isinstance (result , RuntimeExpr ):
127+ raise TypeConstraintError (
128+ f"Function { value } did not return a RuntimeExpr, got { type (result )} , so cannot infer return type"
129+ )
130+ self .infer_typevars (tp .args [0 ], result .__egg_typed_expr__ .tp )
131+ return self .substitute_typevars (tp )
0 commit comments