Skip to content

Commit 4ec36c8

Browse files
committed
task: review + refactors
1 parent c9e6eef commit 4ec36c8

4 files changed

Lines changed: 330 additions & 92 deletions

File tree

dpctl/_sycl_queue.pxd

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ cdef public api class SyclQueue (_SyclQueue) [
103103
cdef DPCTLSyclQueueRef get_queue_ref(self)
104104
cpdef memcpy(self, dest, src, size_t count)
105105
cpdef SyclEvent memcpy_async(self, dest, src, size_t count, list dEvents=*)
106-
cpdef copy(self, dest, src, size_t count)
107-
cpdef SyclEvent copy_async(self, dest, src, size_t count, list dEvents=*)
106+
cpdef copy(self, dest, src, size_t count, str dtype=*)
107+
cpdef SyclEvent copy_async(
108+
self, dest, src, size_t count, list dEvents=*, str dtype=*
109+
)
108110
cpdef prefetch(self, ptr, size_t count=*)
109111
cpdef mem_advise(self, ptr, size_t count, int mem)
110112
cpdef SyclEvent submit_barrier(self, dependent_events=*)

dpctl/_sycl_queue.pyx

Lines changed: 103 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,46 @@ cdef bint _is_buffer(object o):
461461
return PyObject_CheckBuffer(o)
462462

463463

464-
cdef DPCTLSyclEventRef _memcpy_impl(
464+
# Function pointer typedefs for the C API queue copy functions
465+
ctypedef DPCTLSyclEventRef (*queue_copy_fn)(
466+
const DPCTLSyclQueueRef, void*, const void*, size_t
467+
)
468+
469+
ctypedef DPCTLSyclEventRef (*queue_copy_with_events_fn)(
470+
const DPCTLSyclQueueRef, void*, const void*, size_t,
471+
const DPCTLSyclEventRef*, size_t
472+
)
473+
474+
475+
cdef size_t _get_dtype_size(str dtype) except *:
476+
"""
477+
Parse numpy-style dtype string and return element size in bytes.
478+
Supports: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8
479+
"""
480+
if dtype == "i1" or dtype == "u1":
481+
return 1
482+
elif dtype == "i2" or dtype == "u2":
483+
return 2
484+
elif dtype == "i4" or dtype == "u4" or dtype == "f4":
485+
return 4
486+
elif dtype == "i8" or dtype == "u8" or dtype == "f8":
487+
return 8
488+
else:
489+
raise ValueError(
490+
f"Unrecognized dtype '{dtype}'. "
491+
"Expected one of: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8"
492+
)
493+
494+
495+
cdef DPCTLSyclEventRef _copy_memcpy_impl(
465496
SyclQueue q,
466497
object dst,
467498
object src,
468499
size_t byte_count,
469500
DPCTLSyclEventRef *dep_events,
470-
size_t dep_events_count
501+
size_t dep_events_count,
502+
queue_copy_fn copy_fn,
503+
queue_copy_with_events_fn copy_with_events_fn
471504
) except *:
472505
cdef void *c_dst_ptr = NULL
473506
cdef void *c_src_ptr = NULL
@@ -514,9 +547,9 @@ cdef DPCTLSyclEventRef _memcpy_impl(
514547
)
515548

516549
if dep_events_count == 0 or dep_events is NULL:
517-
ERef = DPCTLQueue_Memcpy(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
550+
ERef = copy_fn(q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
518551
else:
519-
ERef = DPCTLQueue_MemcpyWithEvents(
552+
ERef = copy_with_events_fn(
520553
q._queue_ref,
521554
c_dst_ptr,
522555
c_src_ptr,
@@ -533,80 +566,37 @@ cdef DPCTLSyclEventRef _memcpy_impl(
533566
return ERef
534567

535568

536-
cdef DPCTLSyclEventRef _copy_impl(
569+
cdef DPCTLSyclEventRef _memcpy_impl(
537570
SyclQueue q,
538571
object dst,
539572
object src,
540573
size_t byte_count,
541574
DPCTLSyclEventRef *dep_events,
542575
size_t dep_events_count
543576
) except *:
544-
cdef void *c_dst_ptr = NULL
545-
cdef void *c_src_ptr = NULL
546-
cdef DPCTLSyclEventRef ERef = NULL
547-
cdef Py_buffer src_buf_view
548-
cdef Py_buffer dst_buf_view
549-
cdef bint src_is_buf = False
550-
cdef bint dst_is_buf = False
551-
cdef int ret_code = 0
577+
return _copy_memcpy_impl(
578+
q, dst, src, byte_count, dep_events, dep_events_count,
579+
DPCTLQueue_Memcpy, DPCTLQueue_MemcpyWithEvents
580+
)
552581

553-
if isinstance(src, _Memory):
554-
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
555-
elif _is_buffer(src):
556-
ret_code = PyObject_GetBuffer(
557-
src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
558-
)
559-
if ret_code != 0: # pragma: no cover
560-
raise RuntimeError("Could not access buffer")
561-
c_src_ptr = src_buf_view.buf
562-
src_is_buf = True
563-
else:
564-
raise TypeError(
565-
"Parameter `src` should have either type "
566-
"`dpctl.memory._Memory` or a type that "
567-
"supports Python buffer protocol"
568-
)
569582

570-
if isinstance(dst, _Memory):
571-
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
572-
elif _is_buffer(dst):
573-
ret_code = PyObject_GetBuffer(
574-
dst, &dst_buf_view,
575-
PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE
576-
)
577-
if ret_code != 0: # pragma: no cover
578-
if src_is_buf:
579-
PyBuffer_Release(&src_buf_view)
580-
raise RuntimeError("Could not access buffer")
581-
c_dst_ptr = dst_buf_view.buf
582-
dst_is_buf = True
583-
else:
584-
raise TypeError(
585-
"Parameter `dst` should have either type "
586-
"`dpctl.memory._Memory` or a type that "
587-
"supports Python buffer protocol"
588-
)
589-
590-
if dep_events_count == 0 or dep_events is NULL:
591-
ERef = DPCTLQueue_CopyData(
592-
q._queue_ref, c_dst_ptr, c_src_ptr, byte_count
593-
)
594-
else:
595-
ERef = DPCTLQueue_CopyDataWithEvents(
596-
q._queue_ref,
597-
c_dst_ptr,
598-
c_src_ptr,
599-
byte_count,
600-
dep_events,
601-
dep_events_count
602-
)
603-
604-
if src_is_buf:
605-
PyBuffer_Release(&src_buf_view)
606-
if dst_is_buf:
607-
PyBuffer_Release(&dst_buf_view)
583+
cdef DPCTLSyclEventRef _copy_impl(
584+
SyclQueue q,
585+
object dst,
586+
object src,
587+
size_t count,
588+
DPCTLSyclEventRef *dep_events,
589+
size_t dep_events_count,
590+
str dtype="u1"
591+
) except *:
592+
# ``count`` is in elements of ``dtype`` (default "u1" => bytes).
593+
cdef size_t element_size = _get_dtype_size(dtype)
594+
cdef size_t byte_count = count * element_size
608595

609-
return ERef
596+
return _copy_memcpy_impl(
597+
q, dst, src, byte_count, dep_events, dep_events_count,
598+
DPCTLQueue_CopyData, DPCTLQueue_CopyDataWithEvents
599+
)
610600

611601

612602
cdef class _SyclQueue:
@@ -1480,17 +1470,17 @@ cdef class SyclQueue(_SyclQueue):
14801470
)
14811471
if depEvents is NULL:
14821472
raise MemoryError()
1483-
else:
1473+
try:
14841474
for idx, de in enumerate(dEvents):
14851475
if isinstance(de, SyclEvent):
14861476
depEvents[idx] = (<SyclEvent>de).get_event_ref()
14871477
else:
1488-
free(depEvents)
14891478
raise TypeError(
14901479
"A sequence of dpctl.SyclEvent is expected"
14911480
)
1492-
ERef = _memcpy_impl(self, dest, src, count, depEvents, nDE)
1493-
free(depEvents)
1481+
ERef = _memcpy_impl(self, dest, src, count, depEvents, nDE)
1482+
finally:
1483+
free(depEvents)
14941484

14951485
if (ERef is NULL):
14961486
raise RuntimeError(
@@ -1499,18 +1489,36 @@ cdef class SyclQueue(_SyclQueue):
14991489

15001490
return SyclEvent._create(ERef)
15011491

1502-
cpdef copy(self, dest, src, size_t count):
1503-
"""Copy ``count`` bytes from ``src`` to ``dest`` and wait.
1492+
cpdef copy(self, dest, src, size_t count, str dtype="u1"):
1493+
"""Copy ``count`` elements of type ``dtype`` from ``src`` to
1494+
``dest`` and wait.
15041495
1505-
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1506-
byte-sized elements.
1496+
Internally, this dispatches ``sycl::queue::copy``. The number of
1497+
bytes transferred is ``count`` multiplied by the size of ``dtype``.
1498+
The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1499+
a byte-wise copy.
15071500
15081501
This is a synchronizing variant corresponding to
15091502
:meth:`dpctl.SyclQueue.copy_async`.
1503+
1504+
Args:
1505+
dest:
1506+
Destination USM object or Python object supporting
1507+
writable buffer protocol.
1508+
src:
1509+
Source USM object or Python object supporting buffer
1510+
protocol.
1511+
count (int):
1512+
Number of elements to copy.
1513+
dtype (str, optional):
1514+
Data type string of the elements to copy. Determines the
1515+
element size used to convert ``count`` into a byte count.
1516+
Defaults to ``"u1"`` (one byte per element).
1517+
Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15101518
"""
15111519
cdef DPCTLSyclEventRef ERef = NULL
15121520

1513-
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1521+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0, dtype)
15141522
if (ERef is NULL):
15151523
raise RuntimeError(
15161524
"SyclQueue.copy operation encountered an error"
@@ -1520,12 +1528,15 @@ cdef class SyclQueue(_SyclQueue):
15201528
DPCTLEvent_Delete(ERef)
15211529

15221530
cpdef SyclEvent copy_async(
1523-
self, dest, src, size_t count, list dEvents=None
1531+
self, dest, src, size_t count, list dEvents=None, str dtype="u1"
15241532
):
1525-
"""Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
1533+
"""Copy ``count`` elements of type ``dtype`` from ``src`` to
1534+
``dest`` asynchronously.
15261535
1527-
Internally, this dispatches ``sycl::queue::copy`` instantiated for
1528-
byte-sized elements.
1536+
Internally, this dispatches ``sycl::queue::copy``. The number of
1537+
bytes transferred is ``count`` multiplied by the size of ``dtype``.
1538+
The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1539+
a byte-wise copy.
15291540
15301541
Args:
15311542
dest:
@@ -1535,9 +1546,14 @@ cdef class SyclQueue(_SyclQueue):
15351546
Source USM object or Python object supporting buffer
15361547
protocol.
15371548
count (int):
1538-
Number of bytes to copy.
1549+
Number of elements to copy.
15391550
dEvents (List[dpctl.SyclEvent], optional):
15401551
Events that this copy depends on.
1552+
dtype (str, optional):
1553+
Data type string of the elements to copy. Determines the
1554+
element size used to convert ``count`` into a byte count.
1555+
Defaults to ``"u1"`` (one byte per element).
1556+
Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15411557
15421558
Returns:
15431559
dpctl.SyclEvent:
@@ -1548,25 +1564,25 @@ cdef class SyclQueue(_SyclQueue):
15481564
cdef size_t nDE = 0
15491565

15501566
if dEvents is None:
1551-
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0)
1567+
ERef = _copy_impl(<SyclQueue>self, dest, src, count, NULL, 0, dtype)
15521568
else:
15531569
nDE = len(dEvents)
15541570
depEvents = (
15551571
<DPCTLSyclEventRef*>malloc(nDE*sizeof(DPCTLSyclEventRef))
15561572
)
15571573
if depEvents is NULL:
15581574
raise MemoryError()
1559-
else:
1575+
try:
15601576
for idx, de in enumerate(dEvents):
15611577
if isinstance(de, SyclEvent):
15621578
depEvents[idx] = (<SyclEvent>de).get_event_ref()
15631579
else:
1564-
free(depEvents)
15651580
raise TypeError(
15661581
"A sequence of dpctl.SyclEvent is expected"
15671582
)
1568-
ERef = _copy_impl(self, dest, src, count, depEvents, nDE)
1569-
free(depEvents)
1583+
ERef = _copy_impl(self, dest, src, count, depEvents, nDE, dtype)
1584+
finally:
1585+
free(depEvents)
15701586

15711587
if (ERef is NULL):
15721588
raise RuntimeError(

0 commit comments

Comments
 (0)