diff --git a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc index ddda3cc602888..4ee19723d8d4e 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc @@ -2119,6 +2119,13 @@ auto node = graph.add([&](sycl::handler& cgh){ Host-tasks can be updated using <>. +When using `property::graph::enable_native_recording`, host tasks submitted via +`sycl::handler::host_task` are not supported and will throw an exception. Host +tasks submitted via +`sycl::ext::oneapi::experimental::host_task` from the +link:../experimental/sycl_ext_oneapi_enqueue_functions.asciidoc[sycl_ext_oneapi_enqueue_functions] +extension are supported. + === Queue Behavior In Recording Mode @@ -2648,7 +2655,11 @@ if used in application code. . Using reductions in a graph node. . Using sycl streams in a graph node. -. Using host tasks via `sycl::handler::host_task` in a graph node when `property::graph::enable_native_recording` is set. +. Using host tasks via `sycl::handler::host_task` in a graph node when + `property::graph::enable_native_recording` is set. Host tasks submitted via + `sycl::ext::oneapi::experimental::host_task` from + link:../experimental/sycl_ext_oneapi_enqueue_functions.asciidoc[sycl_ext_oneapi_enqueue_functions] + are supported in native recording mode. . Calling `update()` on an executable graph created from a graph with `property::graph::enable_native_recording`. . Using an out-of-order queue with `property::graph::enable_native_recording`. diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 3e9a4125987f9..c1af708d90f49 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -3018,6 +3018,8 @@ class HandlerAccess { Handler.internalProfilingTagImpl(); } + static std::function getHostTaskFunc(detail::HostTask &HT); + template static std::enable_if_t< detail::check_fn_signature, void()>::value> diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 4ea11db8c1d64..315750c6924e0 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -11,7 +11,8 @@ #include "graph_impl.hpp" #include "dynamic_impl.hpp" // for dynamic classes #include "node_impl.hpp" // for node_impl -#include // for CG, CGExecKernel, CGHostTask, ArgDesc, NDRDescT +#include // for CG, CGExecKernel, CGHostTask, ArgDesc, NDRDescT +#include // for EnqueueHostTaskData #include // for event_impl #include // for handler_impl #include // for KernelArgMask @@ -646,6 +647,12 @@ bool graph_impl::isQueueRecording(sycl::detail::queue_impl &Queue) { return MRecordingQueues.count(Queue.weak_from_this()) > 0; } +sycl::detail::EnqueueHostTaskData *graph_impl::addNativeHostTaskCallback( + std::unique_ptr Data) { + MNativeHostTaskCallbacks.push_back(std::move(Data)); + return MNativeHostTaskCallbacks.back().get(); +} + void graph_impl::clearQueues(bool NeedsLock) { graph_impl::RecQueuesStorage SwappedQueues; { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 8151285368fe1..af73a5a29a3ab 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -44,6 +44,7 @@ class queue_impl; class NDRDescT; class ArgDesc; class CG; +struct EnqueueHostTaskData; } // namespace detail namespace ext { @@ -550,6 +551,11 @@ class graph_impl : public std::enable_shared_from_this { /// @return True if the queue is recording to this graph, false otherwise. bool isQueueRecording(sycl::detail::queue_impl &Queue); + /// Take ownership of callback data for a native-recorded host task and + /// returns a non-owning pointer for passing to UR + detail::EnqueueHostTaskData * + addNativeHostTaskCallback(std::unique_ptr Data); + private: /// Common implementation for beginRecording and beginRecordingUnlockedQueue. /// @param[in] Queue The queue to be recorded from. @@ -628,6 +634,10 @@ class graph_impl : public std::enable_shared_from_this { /// @note Native recording requires immediate command lists. ur_exp_graph_handle_t MNativeGraphHandle = nullptr; + /// Callback data for host tasks recorded in native recording mode. + std::vector> + MNativeHostTaskCallbacks; + /// Mapping from queues to barrier nodes. For each queue the last barrier /// node recorded to the graph from the queue is stored. std::map, node_impl *, diff --git a/sycl/source/detail/host_task.hpp b/sycl/source/detail/host_task.hpp index 83c9c503a5d3e..4d0c2bb48c45e 100644 --- a/sycl/source/detail/host_task.hpp +++ b/sycl/source/detail/host_task.hpp @@ -75,8 +75,27 @@ class HostTask { friend class DispatchHostTask; friend class ExecCGCommand; + friend class sycl::detail::HandlerAccess; }; +inline std::function HandlerAccess::getHostTaskFunc(HostTask &HT) { + return std::move(HT.MHostTask); +} + +struct EnqueueHostTaskData { + explicit EnqueueHostTaskData(std::function HostTask) + : Func(std::move(HostTask)) {} + + std::function Func; +}; + +template inline void NativeHostTask(void *Data) { + auto *HostTaskData = static_cast(Data); + HostTaskData->Func(); + if constexpr (OwnsData) + delete HostTaskData; +} + } // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 447ff0225eb03..32894d97c51e1 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -7,6 +7,7 @@ #include "unified-runtime/ur_api.h" #include +#include #include #include @@ -249,20 +250,6 @@ void InteropFreeFunc(ur_queue_handle_t, void *InteropData) { return Data->func(Data->ih); } -struct EnqueueHostTaskData { - explicit EnqueueHostTaskData(std::function HostTask) - : Func(std::move(HostTask)) {} - - std::function Func; -}; - -void NativeHostTask(void *Data) { - // Callback data is heap-allocated at enqueue time and released here once - // the backend invokes the host task callback. - auto HostTaskData = std::unique_ptr( - static_cast(Data)); - HostTaskData->Func(); -} } // namespace class DispatchHostTask { @@ -392,12 +379,13 @@ class DispatchHostTask { UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP, sizeof(NativeHostTaskSupport), &NativeHostTaskSupport, nullptr); if (NativeHostTaskSupport) { - auto NativeHostTaskData = std::make_unique( - std::move(HostTask.MHostTask->MHostTask)); + auto NativeHostTaskData = + std::make_unique( + std::move(HostTask.MHostTask->MHostTask)); ur_event_handle_t HostTaskEvent{}; Queue->getAdapter().call( - Queue->getHandleRef(), NativeHostTask, NativeHostTaskData.get(), - nullptr, 0, nullptr, &HostTaskEvent); + Queue->getHandleRef(), detail::NativeHostTask, + NativeHostTaskData.get(), nullptr, 0, nullptr, &HostTaskEvent); // Ownership is transferred to NativeHostTask callback on success. (void)NativeHostTaskData.release(); diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 7b7ba020cc07f..63db7e50a6185 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -365,6 +365,23 @@ fill_copy_args(detail::handler_impl *impl, DestOffset, DestExtent, CopyExtent); } +static bool checkDeviceSupports(device_impl &DeviceImpl, + ur_device_info_t InfoQuery) { + ur_bool_t SupportsOp = false; + DeviceImpl.getAdapter().call( + DeviceImpl.getHandleRef(), InfoQuery, sizeof(ur_bool_t), &SupportsOp, + nullptr); + return SupportsOp; +} + +static std::shared_ptr +getNativeGraphImpl(queue_impl &Queue) { + ur_exp_graph_handle_t UrGraphHandle = nullptr; + Queue.getAdapter().call(Queue.getHandleRef(), + &UrGraphHandle); + return Queue.getContextImpl().getNativeGraph(UrGraphHandle); +} + } // namespace detail handler::handler(detail::handler_impl &HandlerImpl) : impl(&HandlerImpl) {} @@ -766,10 +783,32 @@ detail::EventImplPtr handler::finalize() { // Native graph recording limitation if (type == detail::CGType::CodeplayHostTask && Queue->isNativeRecording()) { - throw sycl::exception( - make_error_code(errc::feature_not_supported), - "SYCL host_task is not supported in native recording mode. Use " - "zeCommandListAppendHostFunction as a workaround."); + auto *HT = static_cast(CommandGroup.get()); + if (!HT->MHostTask->isCreatedFromEnqueueFunction()) { + throw sycl::exception(make_error_code(errc::feature_not_supported), + "Only restricted host tasks may be captured in " + "native recording mode."); + } + + if (!checkDeviceSupports(*detail::getSyclObjImpl(Queue->get_device()), + UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP)) { + throw sycl::exception(make_error_code(errc::feature_not_supported), + "Recording host tasks in native recording mode " + "requires backend support " + "not available on this device."); + } + + // Store callback in the graph so it is available during replays. + auto GraphImpl = detail::getNativeGraphImpl(*Queue); + auto *CallbackData = GraphImpl->addNativeHostTaskCallback( + std::make_unique( + detail::HandlerAccess::getHostTaskFunc(*HT->MHostTask))); + + Queue->getAdapter().call( + Queue->getHandleRef(), detail::NativeHostTask, CallbackData, + nullptr, 0, nullptr, nullptr); + + return detail::event_impl::create_completed_host_event(); } if (!CommandGroup->getRequirements().empty() && Queue->isNativeRecording()) { throw sycl::exception( diff --git a/sycl/test-e2e/Graph/RecordReplay/NativeRecording/enqueue_func_host_task.cpp b/sycl/test-e2e/Graph/RecordReplay/NativeRecording/enqueue_func_host_task.cpp new file mode 100644 index 0000000000000..7e367319942ce --- /dev/null +++ b/sycl/test-e2e/Graph/RecordReplay/NativeRecording/enqueue_func_host_task.cpp @@ -0,0 +1,86 @@ +// REQUIRES: level_zero_v2_adapter +// REQUIRES: level_zero_dev_kit +// REQUIRES: arch-intel_gpu_bmg_g21 +// UNSUPPORTED: windows && gpu-intel-gen12 +// UNSUPPORTED-INTENDED: UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP is not +// supported on win&gen12. + +// RUN: %{build} %level_zero_options -o %t.out +// RUN: %{run} %t.out +// RUN: %if level_zero %{%{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} + +// Tests that syclex::host_task() can be recorded into a native-recording SYCL +// Graph and executes correctly between two SYCL kernels. + +#include "../../graph_common.hpp" +#include "../../ze_common.hpp" + +#include +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +constexpr size_t N = 1024; + +int main() { + queue Queue{property::queue::in_order{}}; + + const sycl::context Context = Queue.get_context(); + const sycl::device Device = Queue.get_device(); + + uint32_t *Data = malloc_shared(N, Queue); + std::fill(Data, Data + N, 0); + + ze_command_list_handle_t ZeCommandList; + bool success = getCommandListFromQueue(Queue, ZeCommandList); + assert(success); + + exp_ext::command_graph Graph{ + Context, Device, {exp_ext::property::graph::enable_native_recording{}}}; + + CommandListStateVerifier verifier(ZeCommandList); + verifier.verify(EXECUTING); + + Graph.begin_recording(Queue); + verifier.verify(RECORDING); + + Queue.submit([&](handler &CGH) { + CGH.parallel_for(range<1>{N}, [=](id<1> idx) { + Data[idx] += static_cast(idx[0]) + 1; + }); + }); + + syclex::host_task(Queue, [=] { + for (size_t i = 0; i < N; i++) { + Data[i] *= 2; + } + }); + + Queue.submit([&](handler &CGH) { + CGH.parallel_for(range<1>{N}, [=](id<1> idx) { Data[idx] += 10; }); + }); + + Graph.end_recording(Queue); + verifier.verify(EXECUTING); + + auto ExecutableGraph = Graph.finalize(); + + Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecutableGraph); }); + Queue.wait(); + + for (size_t i = 0; i < N; i++) { + uint32_t Expected = static_cast((i + 1) * 2 + 10); + assert(check_value(i, Expected, Data[i], "Data")); + } + + Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecutableGraph); }); + Queue.wait(); + + for (size_t i = 0; i < N; i++) { + uint32_t Expected = static_cast((i + 1) * 6 + 30); + assert(check_value(i, Expected, Data[i], "Data")); + } + + free(Data, Queue); + return 0; +} diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index 82571666081e4..f6cfffc2b942a 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -1418,9 +1418,7 @@ ur_result_t ur_command_list_manager::isGraphCaptureActive(bool *pResult) { } ur_result_t ur_command_list_manager::getGraph(ur_exp_graph_handle_t *phGraph) { - auto zeGetGraph = - hContext.get()->getPlatform()->ZeGraphExt.zeCommandListGetGraphExp; - if (!checkGraphExtensionSupport(hContext.get()) || !zeGetGraph) { + if (!checkGraphExtensionSupport(hContext.get())) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -1431,6 +1429,12 @@ ur_result_t ur_command_list_manager::getGraph(ur_exp_graph_handle_t *phGraph) { } // Fork-join and implicit capture scenarios + auto zeGetGraph = + hContext.get()->getPlatform()->ZeGraphExt.zeCommandListGetGraphExp; + if (!zeGetGraph) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + ze_graph_handle_t hZeGraph = nullptr; ze_result_t ZeResult = ZE_CALL_NOCHECK(zeGetGraph, (getZeCommandList(), &hZeGraph));