Skip to content

Commit c514425

Browse files
authored
fix eager logic when mixed array types are passed to binary operation (nv-legate#1114)
* fix eager logic when mixed array types are passed to binary operation * reverting previous change and adding fix only for binary/unary functions * addressing PR comments * another approach to fix mixed array types in eager.py * fixing errors with previous approach
1 parent 1122fca commit c514425

1 file changed

Lines changed: 51 additions & 11 deletions

File tree

cupynumeric/_thunk/eager.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,12 @@ def method(
226226
"""
227227
Helper method to apply unary ufunc operations.
228228
"""
229-
self.check_eager_args(out)
229+
from .._array.array import ndarray
230+
231+
# Check if out contains a deferred thunk and convert self if needed
232+
if isinstance(out, ndarray) and runtime.is_deferred_array(out._thunk):
233+
if self.deferred is None:
234+
self.to_deferred_array(read_only=False)
230235

231236
if self.deferred is not None:
232237
return deferred_ufunc._call_full(
@@ -239,9 +244,7 @@ def method(
239244
)
240245

241246
out_array = (
242-
out._thunk.__numpy_array__()
243-
if (out is not None and hasattr(out, "_thunk"))
244-
else out
247+
out._thunk.__numpy_array__() if isinstance(out, ndarray) else out
245248
)
246249
return np_ufunc(
247250
self.array,
@@ -270,7 +273,48 @@ def method(
270273
"""
271274
Helper method to apply binary ufunc operations.
272275
"""
273-
self.check_eager_args(rhs, out)
276+
from .._array.array import ndarray
277+
278+
# Check if rhs or out contain deferred thunks and convert self if needed
279+
# We check the ._thunk attribute directly without extracting to avoid issues
280+
# Only convert if the thunk is truly deferred (not an eager array with .deferred set)
281+
# Also avoid converting if arrays are 0-dimensional (ndim=0) as deferred doesn't support them
282+
if isinstance(rhs, ndarray):
283+
rhs_thunk = rhs._thunk
284+
# Only convert if both self and rhs have at least 1 dimension and non-zero size
285+
can_convert = (
286+
self.ndim > 0
287+
and self.array.size > 0
288+
and rhs.ndim > 0
289+
and rhs.size > 0
290+
)
291+
if runtime.is_deferred_array(rhs_thunk):
292+
if self.deferred is None and can_convert:
293+
self.to_deferred_array(read_only=False)
294+
elif (
295+
runtime.is_eager_array(rhs_thunk)
296+
and rhs_thunk.deferred is not None
297+
):
298+
if self.deferred is None and can_convert:
299+
self.to_deferred_array(read_only=False)
300+
if isinstance(out, ndarray):
301+
out_thunk = out._thunk
302+
# Only convert if both self and out have at least 1 dimension and non-zero size
303+
can_convert = (
304+
self.ndim > 0
305+
and self.array.size > 0
306+
and out.ndim > 0
307+
and out.size > 0
308+
)
309+
if runtime.is_deferred_array(out_thunk):
310+
if self.deferred is None and can_convert:
311+
self.to_deferred_array(read_only=False)
312+
elif (
313+
runtime.is_eager_array(out_thunk)
314+
and out_thunk.deferred is not None
315+
):
316+
if self.deferred is None and can_convert:
317+
self.to_deferred_array(read_only=False)
274318

275319
if self.deferred is not None:
276320
return deferred_ufunc._call_full(
@@ -284,15 +328,11 @@ def method(
284328
)
285329

286330
rhs_array = (
287-
rhs._thunk.__numpy_array__()
288-
if (rhs is not None and hasattr(rhs, "_thunk"))
289-
else rhs
331+
rhs._thunk.__numpy_array__() if isinstance(rhs, ndarray) else rhs
290332
)
291333

292334
out_array = (
293-
out._thunk.__numpy_array__()
294-
if (out is not None and hasattr(out, "_thunk"))
295-
else out
335+
out._thunk.__numpy_array__() if isinstance(out, ndarray) else out
296336
)
297337
return np_ufunc(
298338
self.array,

0 commit comments

Comments
 (0)