11# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22
3- from typing import Optional , Tuple
3+ from typing import Callable , Optional , Tuple
44from .types import FormatInfo , RoundMode , Domain
55
66import numpy .typing as npt
77import array_api_compat
88
99
10- def _ifloor (x , int_type ) :
10+ def _ifloor (x : npt . NDArray , int_type : npt . DTypeLike ) -> npt . NDArray :
1111 xp = array_api_compat .array_namespace (x )
1212 floored = xp .floor (x )
1313 return xp .astype (floored , int_type )
@@ -21,7 +21,7 @@ def _iseven(v: npt.NDArray) -> npt.NDArray:
2121 return v & 0x1 == 0
2222
2323
24- def _rnitp (x , pred , int_type ) :
24+ def _rnitp (x : npt . NDArray , pred : Callable [[ npt . NDArray ], npt . NDArray ], int_type : npt . DTypeLike ) -> npt . NDArray :
2525 """Round to nearest integer, ties to predicate"""
2626 xp = array_api_compat .array_namespace (x )
2727 floored = xp .floor (x )
@@ -31,12 +31,12 @@ def _rnitp(x, pred, int_type):
3131 return ifloored + xp .astype (should_round_away , int_type )
3232
3333
34- def _rnite (x , int_type ) :
34+ def _rnite (x : npt . NDArray , int_type : npt . DTypeLike ) -> npt . NDArray :
3535 """Round to nearest integer, ties to even"""
3636 return _rnitp (x , _iseven , int_type )
3737
3838
39- def _rnito (x , int_type ) :
39+ def _rnito (x : npt . NDArray , int_type : npt . DTypeLike ) -> npt . NDArray :
4040 """Round to nearest integer, ties to odd"""
4141 return _rnitp (x , _isodd , int_type )
4242
0 commit comments