Skip to content

Commit 6182167

Browse files
authored
Implementing iterator over array instead of numpy fallback (#1671)
1 parent 4f90b05 commit 6182167

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

cupynumeric/_array/array.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import operator
1818
from functools import reduce
1919
from math import prod as builtin_prod
20-
from typing import TYPE_CHECKING, Any, Sequence, cast
20+
from typing import TYPE_CHECKING, Any, Iterator, Sequence, cast
2121

2222
import legate.core.types as ty
2323
import numpy as np
@@ -1124,7 +1124,12 @@ def __iter__(self) -> Any:
11241124
"""a.__iter__(/)"""
11251125
if settings.doctor():
11261126
doctor.diagnose("__iter__", (self,), {})
1127-
return self.__array__().__iter__()
1127+
1128+
def iter_array(arr: ndarray) -> Iterator[Any]:
1129+
for i in range(len(arr)):
1130+
yield arr[i]
1131+
1132+
return iter_array(self)
11281133

11291134
def __isub__(self, rhs: Any) -> ndarray:
11301135
"""a.__isub__(/)

tests/integration/test_index_routines.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def test_axes_none(self) -> None:
602602
@pytest.mark.diff
603603
def test_extra_axes(self):
604604
# NumPy does not have axes arg
605-
axes = num.arange(self.a.ndim + 1, dtype=int)
605+
# Use NumPy vector axes because indexing it returns a hashable scalar
606+
# which is necessary for _diag_helper
607+
axes = np.arange(self.a.ndim + 1, dtype=int)
606608
with pytest.raises(ValueError):
607609
self.a._diag_helper(self.a, axes=axes)
608610

0 commit comments

Comments
 (0)