@@ -1126,3 +1126,140 @@ def id[T](base: T) -> T:
11261126 raise NotHandled
11271127
11281128 assert isinstance (id (A (0 )).x , Term )
1129+
1130+
1131+ # Forward references in types only work on module-level definitions.
1132+ @defop
1133+ def forward_ref_op () -> "A" :
1134+ raise NotHandled
1135+
1136+
1137+ class A : ...
1138+
1139+
1140+ def test_defop_forward_ref ():
1141+ term = forward_ref_op ()
1142+ assert term .op == forward_ref_op
1143+ assert typeof (term ) is A
1144+
1145+ @defop
1146+ def local_forward_ref_op () -> "B" :
1147+ raise NotHandled
1148+
1149+ class B : ...
1150+
1151+ with pytest .raises (NameError ):
1152+ local_forward_ref_op ()
1153+
1154+
1155+ # Forward ref in a parameter annotation.
1156+ @defop
1157+ def _forward_ref_param_op (x : "_ForwardRefParam" ) -> int :
1158+ raise NotHandled
1159+
1160+
1161+ class _ForwardRefParam :
1162+ pass
1163+
1164+
1165+ def test_defop_forward_ref_param ():
1166+ sig = inspect .signature (_forward_ref_param_op )
1167+ assert sig .parameters ["x" ].annotation is _ForwardRefParam
1168+ assert sig .return_annotation is int
1169+
1170+
1171+ # Forward ref through Operation.define on a type.
1172+ class _ForwardRefType :
1173+ pass
1174+
1175+
1176+ _forward_ref_type_op = Operation .define (_ForwardRefType )
1177+
1178+
1179+ def test_define_type_forward_ref ():
1180+ term = _forward_ref_type_op ()
1181+ assert term .op == _forward_ref_type_op
1182+ assert typeof (term ) is _ForwardRefType
1183+
1184+
1185+ # Forward ref on an instance method.
1186+ class _ForwardRefMethodHost :
1187+ @defop
1188+ def my_method (self , x : int ) -> "_ForwardRefMethodResult" :
1189+ raise NotHandled
1190+
1191+
1192+ class _ForwardRefMethodResult :
1193+ pass
1194+
1195+
1196+ def test_defop_forward_ref_method ():
1197+ instance = _ForwardRefMethodHost ()
1198+ term = instance .my_method (5 )
1199+ assert isinstance (term , Term )
1200+ sig = inspect .signature (_ForwardRefMethodHost .my_method )
1201+ assert sig .return_annotation is _ForwardRefMethodResult
1202+
1203+
1204+ # Forward ref on a staticmethod.
1205+ class _ForwardRefStaticHost :
1206+ @defop
1207+ @staticmethod
1208+ def my_static (x : int ) -> "_ForwardRefStaticResult" :
1209+ raise NotHandled
1210+
1211+
1212+ class _ForwardRefStaticResult :
1213+ pass
1214+
1215+
1216+ def test_defop_forward_ref_staticmethod ():
1217+ term = _ForwardRefStaticHost .my_static (5 )
1218+ assert isinstance (term , Term )
1219+ sig = inspect .signature (_ForwardRefStaticHost .my_static )
1220+ assert sig .return_annotation is _ForwardRefStaticResult
1221+
1222+
1223+ # Forward ref on a classmethod.
1224+ class _ForwardRefClassmethodHost :
1225+ @defop
1226+ @classmethod
1227+ def my_classmethod (cls , x : int ) -> "_ForwardRefClassmethodResult" :
1228+ raise NotHandled
1229+
1230+
1231+ class _ForwardRefClassmethodResult :
1232+ pass
1233+
1234+
1235+ def test_defop_forward_ref_classmethod ():
1236+ term = _ForwardRefClassmethodHost .my_classmethod (5 )
1237+ assert isinstance (term , Term )
1238+ sig = inspect .signature (_ForwardRefClassmethodHost .my_classmethod )
1239+ assert sig .return_annotation is _ForwardRefClassmethodResult
1240+
1241+
1242+ # Mutual recursion: two classes with forward refs to each other.
1243+ class _Coordinate :
1244+ @defop
1245+ def log (self ) -> "_CoordinateTangent" :
1246+ raise NotHandled
1247+
1248+
1249+ class _CoordinateTangent :
1250+ @defop
1251+ def exp (self ) -> "_Coordinate" :
1252+ raise NotHandled
1253+
1254+
1255+ def test_defop_forward_ref_mutual_recursion ():
1256+ coord = _Coordinate ()
1257+ tangent = _CoordinateTangent ()
1258+
1259+ log_term = coord .log ()
1260+ assert isinstance (log_term , Term )
1261+ assert typeof (log_term ) is _CoordinateTangent
1262+
1263+ exp_term = tangent .exp ()
1264+ assert isinstance (exp_term , Term )
1265+ assert typeof (exp_term ) is _Coordinate
0 commit comments