@@ -290,8 +290,8 @@ def check_index(length: IntLike, idx: IntLike) -> Int:
290290 """
291291 Returns the index if 0 <= idx < length, otherwise returns Int.NEVER
292292 """
293- length = cast (Int , length )
294- idx = cast (Int , idx )
293+ length = cast (" Int" , length )
294+ idx = cast (" Int" , idx )
295295 return Int .if_ (((idx >= 0 ) & (idx < length )), idx , Int .NEVER )
296296
297297
@@ -336,7 +336,7 @@ def abs(self) -> Float: ...
336336
337337 @method (cost = 2 )
338338 @classmethod
339- def rational (cls , r : Rational ) -> Float : ...
339+ def rational (cls , r : BigRat ) -> Float : ...
340340
341341 @classmethod
342342 def from_int (cls , i : IntLike ) -> Float : ...
@@ -362,15 +362,15 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
362362
363363
364364@array_api_ruleset .register
365- def _float (fl : Float , f : f64 , f2 : f64 , i : i64 , r : Rational , r1 : Rational ):
365+ def _float (fl : Float , f : f64 , f2 : f64 , i : i64 , r : BigRat , r1 : BigRat ):
366366 return [
367367 rule (eq (fl ).to (Float (f ))).then (set_ (fl .to_f64 ).to (f )),
368368 rewrite (Float (f ).abs ()).to (Float (f ), f >= 0.0 ),
369369 rewrite (Float (f ).abs ()).to (Float (- f ), f < 0.0 ),
370370 # Convert from float to rationl, if its a whole number i.e. can be converted to int
371- rewrite (Float (f )).to (Float .rational (Rational (f .to_i64 (), 1 )), eq (f64 .from_i64 (f .to_i64 ())).to (f )),
371+ rewrite (Float (f )).to (Float .rational (BigRat (f .to_i64 (), 1 )), eq (f64 .from_i64 (f .to_i64 ())).to (f )),
372372 # always convert from int to rational
373- rewrite (Float .from_int (Int (i ))).to (Float .rational (Rational (i , 1 ))),
373+ rewrite (Float .from_int (Int (i ))).to (Float .rational (BigRat (i , 1 ))),
374374 rewrite (Float (f ) + Float (f2 )).to (Float (f + f2 )),
375375 rewrite (Float (f ) - Float (f2 )).to (Float (f - f2 )),
376376 rewrite (Float (f ) * Float (f2 )).to (Float (f * f2 )),
@@ -417,7 +417,7 @@ def range(cls, stop: IntLike) -> TupleInt:
417417 def from_vec (cls , vec : VecLike [Int , IntLike ]) -> TupleInt : ...
418418
419419 def __add__ (self , other : TupleIntLike ) -> TupleInt :
420- other = cast (TupleInt , other )
420+ other = cast (" TupleInt" , other )
421421 return TupleInt (
422422 self .length () + other .length (), lambda i : Int .if_ (i < self .length (), self [i ], other [i - self .length ()])
423423 )
@@ -475,14 +475,14 @@ def select(self, indices: TupleIntLike) -> TupleInt:
475475 """
476476 Return a new tuple with the elements at the given indices
477477 """
478- indices = cast (TupleInt , indices )
478+ indices = cast (" TupleInt" , indices )
479479 return indices .map (lambda i : self [i ])
480480
481481 def deselect (self , indices : TupleIntLike ) -> TupleInt :
482482 """
483483 Return a new tuple with the elements not at the given indices
484484 """
485- indices = cast (TupleInt , indices )
485+ indices = cast (" TupleInt" , indices )
486486 return TupleInt .range (self .length ()).filter (lambda i : ~ indices .contains (i )).map (lambda i : self [i ])
487487
488488
@@ -554,7 +554,7 @@ def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None:
554554 @method (subsume = True )
555555 @classmethod
556556 def single (cls , i : TupleIntLike ) -> TupleTupleInt :
557- i = cast (TupleInt , i )
557+ i = cast (" TupleInt" , i )
558558 return TupleTupleInt (1 , lambda _ : i )
559559
560560 @method (subsume = True )
@@ -564,7 +564,7 @@ def from_vec(cls, vec: Vec[TupleInt]) -> TupleTupleInt: ...
564564 def append (self , i : TupleIntLike ) -> TupleTupleInt : ...
565565
566566 def __add__ (self , other : TupleTupleIntLike ) -> TupleTupleInt :
567- other = cast (TupleTupleInt , other )
567+ other = cast (" TupleTupleInt" , other )
568568 return TupleTupleInt (
569569 self .length () + other .length (),
570570 lambda i : TupleInt .if_ (i < self .length (), self [i ], other [i - self .length ()]),
@@ -840,7 +840,7 @@ def _value(i: Int, f: Float, b: Boolean, v: Value, v1: Value, i1: Int, f1: Float
840840
841841 yield rewrite (Value .float (f ).sqrt ()).to (Value .float (f ** (0.5 )))
842842
843- yield rewrite (Value .float (Float .rational (Rational (0 , 1 ))) + v ).to (v )
843+ yield rewrite (Value .float (Float .rational (BigRat (0 , 1 ))) + v ).to (v )
844844
845845 yield rewrite (Value .if_ (TRUE , v , v1 )).to (v )
846846 yield rewrite (Value .if_ (FALSE , v , v1 )).to (v1 )
@@ -862,7 +862,7 @@ def append(self, i: ValueLike) -> TupleValue: ...
862862 def from_vec (cls , vec : Vec [Value ]) -> TupleValue : ...
863863
864864 def __add__ (self , other : TupleValueLike ) -> TupleValue :
865- other = cast (TupleValue , other )
865+ other = cast (" TupleValue" , other )
866866 return TupleValue (
867867 self .length () + other .length (),
868868 lambda i : Value .if_ (i < self .length (), self [i ], other [i - self .length ()]),
@@ -875,13 +875,13 @@ def __getitem__(self, i: Int) -> Value: ...
875875 def foldl_boolean (self , f : Callable [[Boolean , Value ], Boolean ], init : BooleanLike ) -> Boolean : ...
876876
877877 def contains (self , value : ValueLike ) -> Boolean :
878- value = cast (Value , value )
878+ value = cast (" Value" , value )
879879 return self .foldl_boolean (lambda acc , j : acc | (value == j ), FALSE )
880880
881881 @method (subsume = True )
882882 @classmethod
883883 def from_tuple_int (cls , ti : TupleIntLike ) -> TupleValue :
884- ti = cast (TupleInt , ti )
884+ ti = cast (" TupleInt" , ti )
885885 return TupleValue (ti .length (), lambda i : Value .int (ti [i ]))
886886
887887
@@ -1259,7 +1259,7 @@ def append(self, i: NDArrayLike) -> TupleNDArray: ...
12591259 def from_vec (cls , vec : Vec [NDArray ]) -> TupleNDArray : ...
12601260
12611261 def __add__ (self , other : TupleNDArrayLike ) -> TupleNDArray :
1262- other = cast (TupleNDArray , other )
1262+ other = cast (" TupleNDArray" , other )
12631263 return TupleNDArray (
12641264 self .length () + other .length (),
12651265 lambda i : NDArray .if_ (i < self .length (), self [i ], other [i - self .length ()]),
@@ -1632,7 +1632,7 @@ def ndindex(shape: TupleIntLike) -> TupleTupleInt:
16321632 """
16331633 https://numpy.org/doc/stable/reference/generated/numpy.ndindex.html
16341634 """
1635- shape = cast (TupleInt , shape )
1635+ shape = cast (" TupleInt" , shape )
16361636 return shape .map_tuple_int (TupleInt .range ).product ()
16371637
16381638
0 commit comments