2727
2828from .. import _ufunc
2929from .._utils .array import max_identity , min_identity , to_core_type
30- from .._utils .coverage import clone_class , is_implemented
3130from .._utils .linalg import dot_modes
3231from ..config import (
3332 FFTDirection ,
4948 broadcast_where ,
5049 check_writeable ,
5150 convert_to_cupynumeric_ndarray ,
52- maybe_convert_to_np_ndarray ,
5351 sanitize_shape ,
5452)
5553
8987
9088from math import prod
9189
92- NDARRAY_INTERNAL = {
93- "__array_finalize__" ,
94- "__array_function__" ,
95- "__array_interface__" ,
96- "__array_prepare__" ,
97- "__array_priority__" ,
98- "__array_struct__" ,
99- "__array_ufunc__" ,
100- "__array_wrap__" ,
101- # Avoid auto-wrapping Array API specifics:
102- "__array_namespace__" ,
103- "device" ,
104- "to_device" ,
105- }
106-
10790
10891def _warn_and_convert (array : ndarray , dtype : np .dtype [Any ]) -> ndarray :
10992 if array .dtype != dtype :
@@ -115,7 +98,6 @@ def _warn_and_convert(array: ndarray, dtype: np.dtype[Any]) -> ndarray:
11598 return array
11699
117100
118- @clone_class (np .ndarray , NDARRAY_INTERNAL , maybe_convert_to_np_ndarray )
119101class ndarray :
120102 _thunk : NumPyThunk
121103 _legate_data : dict [str , Any ] | None
@@ -312,8 +294,6 @@ def __array_function__(
312294 ) -> Any :
313295 import cupynumeric as cn
314296
315- what = func .__name__
316-
317297 for t in types :
318298 # Be strict about which types we support. Accept superclasses
319299 # (for basic subclassing support) and NumPy.
@@ -323,29 +303,13 @@ def __array_function__(
323303 # We are wrapping all NumPy modules, so we can expect to find the implemented
324304 # NumPy API call in cuPyNumeric.
325305 module = reduce (getattr , func .__module__ .split ("." )[1 :], cn )
326- cn_func = getattr (module , func .__name__ )
327-
328- # We can't immediately forward to the corresponding cuPyNumeric
329- # entrypoint. Say that we reached this point because the user code
330- # invoked `np.foo(x, bar=True)` where `x` is a `cupynumeric.ndarray`.
331- # If our implementation of `foo` is not complete, and cannot handle
332- # `bar=True`, then forwarding this call to `cn.foo` would fail. This
333- # goes against the semantics of `__array_function__`, which shouldn't
334- # fail if the custom implementation cannot handle the provided
335- # arguments. Conversely, if the user calls `cn.foo(x, bar=True)`
336- # directly, that means they requested the cuPyNumeric implementation
337- # specifically, and the `NotImplementedError` should not be hidden.
338- if is_implemented (cn_func ):
339- try :
340- return cn_func (* args , ** kwargs )
341- except NotImplementedError :
342- # Inform the user that we support the requested API in general,
343- # but not this specific combination of arguments.
344- what = f"the requested combination of arguments to { what } "
345-
346- # We cannot handle this call - raise an error instead of falling back
306+ cn_func = getattr (module , func .__name__ , None )
307+
308+ if cn_func is not None :
309+ return cn_func (* args , ** kwargs )
310+
347311 raise NotImplementedError (
348- f"cuPyNumeric has not implemented { what } . "
312+ f"cuPyNumeric has not implemented { func . __name__ } . "
349313 f"This function is not available."
350314 )
351315
0 commit comments