Skip to content

Commit aa010a2

Browse files
committed
mypy
1 parent 88b7470 commit aa010a2

3 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/gfloat/round.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
2+
from typing import Callable
23

34
import math
45

@@ -16,7 +17,7 @@ def _iseven(v: int) -> bool:
1617
return v & 0x1 == 0
1718

1819

19-
def _rnitp(x: float, pred) -> int:
20+
def _rnitp(x: float, pred: Callable) -> int:
2021
"""Round to nearest integer, ties to predicate"""
2122
floored = math.floor(x)
2223
should_round_away = (x > floored + 0.5) | ((x == floored + 0.5) & ~pred(floored))

src/gfloat/round_ndarray.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
22

3-
from typing import Optional, Tuple
3+
from typing import Callable, Optional, Tuple
44
from .types import FormatInfo, RoundMode, Domain
55

66
import numpy.typing as npt
77
import array_api_compat
88

99

10-
def _ifloor(x, int_type):
10+
def _ifloor(x: npt.NDArray, int_type: npt.DTypeLike) -> npt.NDArray:
1111
xp = array_api_compat.array_namespace(x)
1212
floored = xp.floor(x)
1313
return xp.astype(floored, int_type)
@@ -21,7 +21,7 @@ def _iseven(v: npt.NDArray) -> npt.NDArray:
2121
return v & 0x1 == 0
2222

2323

24-
def _rnitp(x, pred, int_type):
24+
def _rnitp(x: npt.NDArray, pred: Callable[[npt.NDArray], npt.NDArray], int_type: npt.DTypeLike) -> npt.NDArray:
2525
"""Round to nearest integer, ties to predicate"""
2626
xp = array_api_compat.array_namespace(x)
2727
floored = xp.floor(x)
@@ -31,12 +31,12 @@ def _rnitp(x, pred, int_type):
3131
return ifloored + xp.astype(should_round_away, int_type)
3232

3333

34-
def _rnite(x, int_type):
34+
def _rnite(x: npt.NDArray, int_type: npt.DTypeLike) -> npt.NDArray:
3535
"""Round to nearest integer, ties to even"""
3636
return _rnitp(x, _iseven, int_type)
3737

3838

39-
def _rnito(x, int_type):
39+
def _rnito(x: npt.NDArray, int_type: npt.DTypeLike) -> npt.NDArray:
4040
"""Round to nearest integer, ties to odd"""
4141
return _rnitp(x, _isodd, int_type)
4242

test/test_round.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import array_api_compat
66
import ml_dtypes
77
import numpy as np
8+
import numpy.typing as npt
89
import pytest
910

1011
from gfloat import RoundMode, decode_float, decode_ndarray, round_float, round_ndarray
@@ -14,7 +15,7 @@
1415

1516

1617
@pytest.mark.parametrize("int_type", [np.int64, np.int16])
17-
def test_rnito_rnite(int_type):
18+
def test_rnito_rnite(int_type : npt.DTypeLike) -> None:
1819

1920
xp = array_api_compat.array_namespace(np.array(0.0))
2021
np.testing.assert_equal(_rnito(xp.array(3.5), int_type), 3.0)

0 commit comments

Comments
 (0)