@@ -461,13 +461,46 @@ cdef bint _is_buffer(object o):
461461 return PyObject_CheckBuffer(o)
462462
463463
464- cdef DPCTLSyclEventRef _memcpy_impl(
464+ # Function pointer typedefs for the C API queue copy functions
465+ ctypedef DPCTLSyclEventRef (* queue_copy_fn)(
466+ const DPCTLSyclQueueRef, void * , const void * , size_t
467+ )
468+
469+ ctypedef DPCTLSyclEventRef (* queue_copy_with_events_fn)(
470+ const DPCTLSyclQueueRef, void * , const void * , size_t,
471+ const DPCTLSyclEventRef* , size_t
472+ )
473+
474+
475+ cdef size_t _get_dtype_size(str dtype) except * :
476+ """
477+ Parse numpy-style dtype string and return element size in bytes.
478+ Supports: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8
479+ """
480+ if dtype == " i1" or dtype == " u1" :
481+ return 1
482+ elif dtype == " i2" or dtype == " u2" :
483+ return 2
484+ elif dtype == " i4" or dtype == " u4" or dtype == " f4" :
485+ return 4
486+ elif dtype == " i8" or dtype == " u8" or dtype == " f8" :
487+ return 8
488+ else :
489+ raise ValueError (
490+ f" Unrecognized dtype '{dtype}'. "
491+ " Expected one of: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8"
492+ )
493+
494+
495+ cdef DPCTLSyclEventRef _copy_memcpy_impl(
465496 SyclQueue q,
466497 object dst,
467498 object src,
468499 size_t byte_count,
469500 DPCTLSyclEventRef * dep_events,
470- size_t dep_events_count
501+ size_t dep_events_count,
502+ queue_copy_fn copy_fn,
503+ queue_copy_with_events_fn copy_with_events_fn
471504) except * :
472505 cdef void * c_dst_ptr = NULL
473506 cdef void * c_src_ptr = NULL
@@ -514,9 +547,9 @@ cdef DPCTLSyclEventRef _memcpy_impl(
514547 )
515548
516549 if dep_events_count == 0 or dep_events is NULL :
517- ERef = DPCTLQueue_Memcpy (q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
550+ ERef = copy_fn (q._queue_ref, c_dst_ptr, c_src_ptr, byte_count)
518551 else :
519- ERef = DPCTLQueue_MemcpyWithEvents (
552+ ERef = copy_with_events_fn (
520553 q._queue_ref,
521554 c_dst_ptr,
522555 c_src_ptr,
@@ -533,80 +566,37 @@ cdef DPCTLSyclEventRef _memcpy_impl(
533566 return ERef
534567
535568
536- cdef DPCTLSyclEventRef _copy_impl (
569+ cdef DPCTLSyclEventRef _memcpy_impl (
537570 SyclQueue q,
538571 object dst,
539572 object src,
540573 size_t byte_count,
541574 DPCTLSyclEventRef * dep_events,
542575 size_t dep_events_count
543576) except * :
544- cdef void * c_dst_ptr = NULL
545- cdef void * c_src_ptr = NULL
546- cdef DPCTLSyclEventRef ERef = NULL
547- cdef Py_buffer src_buf_view
548- cdef Py_buffer dst_buf_view
549- cdef bint src_is_buf = False
550- cdef bint dst_is_buf = False
551- cdef int ret_code = 0
577+ return _copy_memcpy_impl(
578+ q, dst, src, byte_count, dep_events, dep_events_count,
579+ DPCTLQueue_Memcpy, DPCTLQueue_MemcpyWithEvents
580+ )
552581
553- if isinstance (src, _Memory):
554- c_src_ptr = < void * > (< _Memory> src).get_data_ptr()
555- elif _is_buffer(src):
556- ret_code = PyObject_GetBuffer(
557- src, & src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
558- )
559- if ret_code != 0 : # pragma: no cover
560- raise RuntimeError (" Could not access buffer" )
561- c_src_ptr = src_buf_view.buf
562- src_is_buf = True
563- else :
564- raise TypeError (
565- " Parameter `src` should have either type "
566- " `dpctl.memory._Memory` or a type that "
567- " supports Python buffer protocol"
568- )
569582
570- if isinstance (dst, _Memory):
571- c_dst_ptr = < void * > (< _Memory> dst).get_data_ptr()
572- elif _is_buffer(dst):
573- ret_code = PyObject_GetBuffer(
574- dst, & dst_buf_view,
575- PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE
576- )
577- if ret_code != 0 : # pragma: no cover
578- if src_is_buf:
579- PyBuffer_Release(& src_buf_view)
580- raise RuntimeError (" Could not access buffer" )
581- c_dst_ptr = dst_buf_view.buf
582- dst_is_buf = True
583- else :
584- raise TypeError (
585- " Parameter `dst` should have either type "
586- " `dpctl.memory._Memory` or a type that "
587- " supports Python buffer protocol"
588- )
589-
590- if dep_events_count == 0 or dep_events is NULL :
591- ERef = DPCTLQueue_CopyData(
592- q._queue_ref, c_dst_ptr, c_src_ptr, byte_count
593- )
594- else :
595- ERef = DPCTLQueue_CopyDataWithEvents(
596- q._queue_ref,
597- c_dst_ptr,
598- c_src_ptr,
599- byte_count,
600- dep_events,
601- dep_events_count
602- )
603-
604- if src_is_buf:
605- PyBuffer_Release(& src_buf_view)
606- if dst_is_buf:
607- PyBuffer_Release(& dst_buf_view)
583+ cdef DPCTLSyclEventRef _copy_impl(
584+ SyclQueue q,
585+ object dst,
586+ object src,
587+ size_t count,
588+ DPCTLSyclEventRef * dep_events,
589+ size_t dep_events_count,
590+ str dtype = " u1"
591+ ) except * :
592+ # ``count`` is in elements of ``dtype`` (default "u1" => bytes).
593+ cdef size_t element_size = _get_dtype_size(dtype)
594+ cdef size_t byte_count = count * element_size
608595
609- return ERef
596+ return _copy_memcpy_impl(
597+ q, dst, src, byte_count, dep_events, dep_events_count,
598+ DPCTLQueue_CopyData, DPCTLQueue_CopyDataWithEvents
599+ )
610600
611601
612602cdef class _SyclQueue:
@@ -1480,17 +1470,17 @@ cdef class SyclQueue(_SyclQueue):
14801470 )
14811471 if depEvents is NULL :
14821472 raise MemoryError ()
1483- else :
1473+ try :
14841474 for idx, de in enumerate (dEvents):
14851475 if isinstance (de, SyclEvent):
14861476 depEvents[idx] = (< SyclEvent> de).get_event_ref()
14871477 else :
1488- free(depEvents)
14891478 raise TypeError (
14901479 " A sequence of dpctl.SyclEvent is expected"
14911480 )
1492- ERef = _memcpy_impl(self , dest, src, count, depEvents, nDE)
1493- free(depEvents)
1481+ ERef = _memcpy_impl(self , dest, src, count, depEvents, nDE)
1482+ finally :
1483+ free(depEvents)
14941484
14951485 if (ERef is NULL ):
14961486 raise RuntimeError (
@@ -1499,18 +1489,36 @@ cdef class SyclQueue(_SyclQueue):
14991489
15001490 return SyclEvent._create(ERef)
15011491
1502- cpdef copy(self , dest, src, size_t count):
1503- """ Copy ``count`` bytes from ``src`` to ``dest`` and wait.
1492+ cpdef copy(self , dest, src, size_t count, str dtype = " u1" ):
1493+ """ Copy ``count`` elements of type ``dtype`` from ``src`` to
1494+ ``dest`` and wait.
15041495
1505- Internally, this dispatches ``sycl::queue::copy`` instantiated for
1506- byte-sized elements.
1496+ Internally, this dispatches ``sycl::queue::copy``. The number of
1497+ bytes transferred is ``count`` multiplied by the size of ``dtype``.
1498+ The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1499+ a byte-wise copy.
15071500
15081501 This is a synchronizing variant corresponding to
15091502 :meth:`dpctl.SyclQueue.copy_async`.
1503+
1504+ Args:
1505+ dest:
1506+ Destination USM object or Python object supporting
1507+ writable buffer protocol.
1508+ src:
1509+ Source USM object or Python object supporting buffer
1510+ protocol.
1511+ count (int):
1512+ Number of elements to copy.
1513+ dtype (str, optional):
1514+ Data type string of the elements to copy. Determines the
1515+ element size used to convert ``count`` into a byte count.
1516+ Defaults to ``"u1"`` (one byte per element).
1517+ Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15101518 """
15111519 cdef DPCTLSyclEventRef ERef = NULL
15121520
1513- ERef = _copy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
1521+ ERef = _copy_impl(< SyclQueue> self , dest, src, count, NULL , 0 , dtype )
15141522 if (ERef is NULL ):
15151523 raise RuntimeError (
15161524 " SyclQueue.copy operation encountered an error"
@@ -1520,12 +1528,15 @@ cdef class SyclQueue(_SyclQueue):
15201528 DPCTLEvent_Delete(ERef)
15211529
15221530 cpdef SyclEvent copy_async(
1523- self , dest, src, size_t count, list dEvents = None
1531+ self , dest, src, size_t count, list dEvents = None , str dtype = " u1 "
15241532 ):
1525- """ Copy ``count`` bytes from ``src`` to ``dest`` asynchronously.
1533+ """ Copy ``count`` elements of type ``dtype`` from ``src`` to
1534+ ``dest`` asynchronously.
15261535
1527- Internally, this dispatches ``sycl::queue::copy`` instantiated for
1528- byte-sized elements.
1536+ Internally, this dispatches ``sycl::queue::copy``. The number of
1537+ bytes transferred is ``count`` multiplied by the size of ``dtype``.
1538+ The default ``dtype`` of ``"u1"`` (a single byte) makes the default
1539+ a byte-wise copy.
15291540
15301541 Args:
15311542 dest:
@@ -1535,9 +1546,14 @@ cdef class SyclQueue(_SyclQueue):
15351546 Source USM object or Python object supporting buffer
15361547 protocol.
15371548 count (int):
1538- Number of bytes to copy.
1549+ Number of elements to copy.
15391550 dEvents (List[dpctl.SyclEvent], optional):
15401551 Events that this copy depends on.
1552+ dtype (str, optional):
1553+ Data type string of the elements to copy. Determines the
1554+ element size used to convert ``count`` into a byte count.
1555+ Defaults to ``"u1"`` (one byte per element).
1556+ Supported types: i1, u1, i2, u2, i4, u4, i8, u8, f4, f8.
15411557
15421558 Returns:
15431559 dpctl.SyclEvent:
@@ -1548,25 +1564,25 @@ cdef class SyclQueue(_SyclQueue):
15481564 cdef size_t nDE = 0
15491565
15501566 if dEvents is None :
1551- ERef = _copy_impl(< SyclQueue> self , dest, src, count, NULL , 0 )
1567+ ERef = _copy_impl(< SyclQueue> self , dest, src, count, NULL , 0 , dtype )
15521568 else :
15531569 nDE = len (dEvents)
15541570 depEvents = (
15551571 < DPCTLSyclEventRef* > malloc(nDE* sizeof(DPCTLSyclEventRef))
15561572 )
15571573 if depEvents is NULL :
15581574 raise MemoryError ()
1559- else :
1575+ try :
15601576 for idx, de in enumerate (dEvents):
15611577 if isinstance (de, SyclEvent):
15621578 depEvents[idx] = (< SyclEvent> de).get_event_ref()
15631579 else :
1564- free(depEvents)
15651580 raise TypeError (
15661581 " A sequence of dpctl.SyclEvent is expected"
15671582 )
1568- ERef = _copy_impl(self , dest, src, count, depEvents, nDE)
1569- free(depEvents)
1583+ ERef = _copy_impl(self , dest, src, count, depEvents, nDE, dtype)
1584+ finally :
1585+ free(depEvents)
15701586
15711587 if (ERef is NULL ):
15721588 raise RuntimeError (
0 commit comments