Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2119,6 +2119,13 @@ auto node = graph.add([&](sycl::handler& cgh){

Host-tasks can be updated using <<executable-graph-update, Executable Graph Update>>.

When using `property::graph::enable_native_recording`, host tasks submitted via
`sycl::handler::host_task` are not supported and will throw an exception. Host
tasks submitted via
`sycl::ext::oneapi::experimental::host_task` from the
link:../experimental/sycl_ext_oneapi_enqueue_functions.asciidoc[sycl_ext_oneapi_enqueue_functions]
extension are supported.


=== Queue Behavior In Recording Mode

Expand Down Expand Up @@ -2648,7 +2655,11 @@ if used in application code.

. Using reductions in a graph node.
. Using sycl streams in a graph node.
. Using host tasks via `sycl::handler::host_task` in a graph node when `property::graph::enable_native_recording` is set.
. Using host tasks via `sycl::handler::host_task` in a graph node when
`property::graph::enable_native_recording` is set. Host tasks submitted via
`sycl::ext::oneapi::experimental::host_task` from
link:../experimental/sycl_ext_oneapi_enqueue_functions.asciidoc[sycl_ext_oneapi_enqueue_functions]
are supported in native recording mode.
. Calling `update()` on an executable graph created from a graph with
`property::graph::enable_native_recording`.
. Using an out-of-order queue with `property::graph::enable_native_recording`.
Expand Down
2 changes: 2 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3018,6 +3018,8 @@ class HandlerAccess {
Handler.internalProfilingTagImpl();
}

static std::function<void()> getHostTaskFunc(detail::HostTask &HT);

template <typename FuncT>
static std::enable_if_t<
detail::check_fn_signature<std::remove_reference_t<FuncT>, void()>::value>
Expand Down
9 changes: 8 additions & 1 deletion sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
#include "graph_impl.hpp"
#include "dynamic_impl.hpp" // for dynamic classes
#include "node_impl.hpp" // for node_impl
#include <detail/cg.hpp> // for CG, CGExecKernel, CGHostTask, ArgDesc, NDRDescT
#include <detail/cg.hpp> // for CG, CGExecKernel, CGHostTask, ArgDesc, NDRDescT
#include <detail/host_task.hpp> // for EnqueueHostTaskData
#include <detail/event_impl.hpp> // for event_impl
#include <detail/handler_impl.hpp> // for handler_impl
#include <detail/kernel_arg_mask.hpp> // for KernelArgMask
Expand Down Expand Up @@ -646,6 +647,12 @@ bool graph_impl::isQueueRecording(sycl::detail::queue_impl &Queue) {
return MRecordingQueues.count(Queue.weak_from_this()) > 0;
}

sycl::detail::EnqueueHostTaskData *graph_impl::addNativeHostTaskCallback(
std::unique_ptr<sycl::detail::EnqueueHostTaskData> Data) {
MNativeHostTaskCallbacks.push_back(std::move(Data));
return MNativeHostTaskCallbacks.back().get();
}

void graph_impl::clearQueues(bool NeedsLock) {
graph_impl::RecQueuesStorage SwappedQueues;
{
Expand Down
10 changes: 10 additions & 0 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class queue_impl;
class NDRDescT;
class ArgDesc;
class CG;
struct EnqueueHostTaskData;
} // namespace detail

namespace ext {
Expand Down Expand Up @@ -550,6 +551,11 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @return True if the queue is recording to this graph, false otherwise.
bool isQueueRecording(sycl::detail::queue_impl &Queue);

/// Take ownership of callback data for a native-recorded host task and
/// returns a non-owning pointer for passing to UR
detail::EnqueueHostTaskData *
addNativeHostTaskCallback(std::unique_ptr<detail::EnqueueHostTaskData> Data);

private:
/// Common implementation for beginRecording and beginRecordingUnlockedQueue.
/// @param[in] Queue The queue to be recorded from.
Expand Down Expand Up @@ -628,6 +634,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @note Native recording requires immediate command lists.
ur_exp_graph_handle_t MNativeGraphHandle = nullptr;

/// Callback data for host tasks recorded in native recording mode.
std::vector<std::unique_ptr<detail::EnqueueHostTaskData>>
MNativeHostTaskCallbacks;

/// Mapping from queues to barrier nodes. For each queue the last barrier
/// node recorded to the graph from the queue is stored.
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
Expand Down
19 changes: 19 additions & 0 deletions sycl/source/detail/host_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,27 @@ class HostTask {

friend class DispatchHostTask;
friend class ExecCGCommand;
friend class sycl::detail::HandlerAccess;
};

inline std::function<void()> HandlerAccess::getHostTaskFunc(HostTask &HT) {
return std::move(HT.MHostTask);
}

struct EnqueueHostTaskData {
explicit EnqueueHostTaskData(std::function<void()> HostTask)
: Func(std::move(HostTask)) {}

std::function<void()> Func;
};

template <bool OwnsData> inline void NativeHostTask(void *Data) {
auto *HostTaskData = static_cast<EnqueueHostTaskData *>(Data);
HostTaskData->Func();
if constexpr (OwnsData)
delete HostTaskData;
}

} // namespace detail
} // namespace _V1
} // namespace sycl
24 changes: 6 additions & 18 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "unified-runtime/ur_api.h"
#include <detail/error_handling/error_handling.hpp>
#include <detail/host_task.hpp>

#include <detail/context_impl.hpp>
#include <detail/event_impl.hpp>
Expand Down Expand Up @@ -249,20 +250,6 @@ void InteropFreeFunc(ur_queue_handle_t, void *InteropData) {
return Data->func(Data->ih);
}

struct EnqueueHostTaskData {
explicit EnqueueHostTaskData(std::function<void()> HostTask)
: Func(std::move(HostTask)) {}

std::function<void()> Func;
};

void NativeHostTask(void *Data) {
// Callback data is heap-allocated at enqueue time and released here once
// the backend invokes the host task callback.
auto HostTaskData = std::unique_ptr<EnqueueHostTaskData>(
static_cast<EnqueueHostTaskData *>(Data));
HostTaskData->Func();
}
} // namespace

class DispatchHostTask {
Expand Down Expand Up @@ -392,12 +379,13 @@ class DispatchHostTask {
UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP,
sizeof(NativeHostTaskSupport), &NativeHostTaskSupport, nullptr);
if (NativeHostTaskSupport) {
auto NativeHostTaskData = std::make_unique<EnqueueHostTaskData>(
std::move(HostTask.MHostTask->MHostTask));
auto NativeHostTaskData =
std::make_unique<detail::EnqueueHostTaskData>(
std::move(HostTask.MHostTask->MHostTask));
ur_event_handle_t HostTaskEvent{};
Queue->getAdapter().call<UrApiKind::urEnqueueHostTaskExp>(
Queue->getHandleRef(), NativeHostTask, NativeHostTaskData.get(),
nullptr, 0, nullptr, &HostTaskEvent);
Queue->getHandleRef(), detail::NativeHostTask<true>,
NativeHostTaskData.get(), nullptr, 0, nullptr, &HostTaskEvent);
// Ownership is transferred to NativeHostTask callback on success.
(void)NativeHostTaskData.release();

Expand Down
47 changes: 43 additions & 4 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,23 @@ fill_copy_args(detail::handler_impl *impl,
DestOffset, DestExtent, CopyExtent);
}

static bool checkDeviceSupports(device_impl &DeviceImpl,
ur_device_info_t InfoQuery) {
ur_bool_t SupportsOp = false;
DeviceImpl.getAdapter().call<UrApiKind::urDeviceGetInfo>(
DeviceImpl.getHandleRef(), InfoQuery, sizeof(ur_bool_t), &SupportsOp,
nullptr);
return SupportsOp;
}

static std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
getNativeGraphImpl(queue_impl &Queue) {
ur_exp_graph_handle_t UrGraphHandle = nullptr;
Queue.getAdapter().call<UrApiKind::urQueueGetGraphExp>(Queue.getHandleRef(),
&UrGraphHandle);
return Queue.getContextImpl().getNativeGraph(UrGraphHandle);
}

} // namespace detail

handler::handler(detail::handler_impl &HandlerImpl) : impl(&HandlerImpl) {}
Expand Down Expand Up @@ -766,10 +783,32 @@ detail::EventImplPtr handler::finalize() {

// Native graph recording limitation
if (type == detail::CGType::CodeplayHostTask && Queue->isNativeRecording()) {
throw sycl::exception(
make_error_code(errc::feature_not_supported),
"SYCL host_task is not supported in native recording mode. Use "
"zeCommandListAppendHostFunction as a workaround.");
auto *HT = static_cast<detail::CGHostTask *>(CommandGroup.get());
if (!HT->MHostTask->isCreatedFromEnqueueFunction()) {
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Only restricted host tasks may be captured in "
"native recording mode.");
}

if (!checkDeviceSupports(*detail::getSyclObjImpl(Queue->get_device()),
UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP)) {
throw sycl::exception(make_error_code(errc::feature_not_supported),
"Recording host tasks in native recording mode "
"requires backend support "
"not available on this device.");
}

// Store callback in the graph so it is available during replays.
auto GraphImpl = detail::getNativeGraphImpl(*Queue);
auto *CallbackData = GraphImpl->addNativeHostTaskCallback(
std::make_unique<detail::EnqueueHostTaskData>(
detail::HandlerAccess::getHostTaskFunc(*HT->MHostTask)));

Queue->getAdapter().call<detail::UrApiKind::urEnqueueHostTaskExp>(
Queue->getHandleRef(), detail::NativeHostTask<false>, CallbackData,
nullptr, 0, nullptr, nullptr);

return detail::event_impl::create_completed_host_event();
}
if (!CommandGroup->getRequirements().empty() && Queue->isNativeRecording()) {
throw sycl::exception(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// REQUIRES: level_zero_v2_adapter
// REQUIRES: level_zero_dev_kit
// REQUIRES: arch-intel_gpu_bmg_g21
// UNSUPPORTED: windows && gpu-intel-gen12
// UNSUPPORTED-INTENDED: UR_DEVICE_INFO_ENQUEUE_HOST_TASK_SUPPORT_EXP is not
// supported on win&gen12.

// RUN: %{build} %level_zero_options -o %t.out
// RUN: %{run} %t.out
// RUN: %if level_zero %{%{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}

// Tests that syclex::host_task() can be recorded into a native-recording SYCL
// Graph and executes correctly between two SYCL kernels.

#include "../../graph_common.hpp"
#include "../../ze_common.hpp"

#include <sycl/ext/oneapi/experimental/enqueue_functions.hpp>
#include <sycl/properties/all_properties.hpp>

namespace syclex = sycl::ext::oneapi::experimental;

constexpr size_t N = 1024;

int main() {
queue Queue{property::queue::in_order{}};

const sycl::context Context = Queue.get_context();
const sycl::device Device = Queue.get_device();

uint32_t *Data = malloc_shared<uint32_t>(N, Queue);
std::fill(Data, Data + N, 0);

ze_command_list_handle_t ZeCommandList;
bool success = getCommandListFromQueue(Queue, ZeCommandList);
assert(success);

exp_ext::command_graph Graph{
Context, Device, {exp_ext::property::graph::enable_native_recording{}}};

CommandListStateVerifier verifier(ZeCommandList);
verifier.verify(EXECUTING);

Graph.begin_recording(Queue);
verifier.verify(RECORDING);

Queue.submit([&](handler &CGH) {
CGH.parallel_for(range<1>{N}, [=](id<1> idx) {
Data[idx] += static_cast<uint32_t>(idx[0]) + 1;
});
});

syclex::host_task(Queue, [=] {
for (size_t i = 0; i < N; i++) {
Data[i] *= 2;
}
});

Queue.submit([&](handler &CGH) {
CGH.parallel_for(range<1>{N}, [=](id<1> idx) { Data[idx] += 10; });
});

Graph.end_recording(Queue);
verifier.verify(EXECUTING);

auto ExecutableGraph = Graph.finalize();

Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecutableGraph); });
Queue.wait();

for (size_t i = 0; i < N; i++) {
uint32_t Expected = static_cast<uint32_t>((i + 1) * 2 + 10);
assert(check_value(i, Expected, Data[i], "Data"));
}

Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecutableGraph); });
Queue.wait();

for (size_t i = 0; i < N; i++) {
uint32_t Expected = static_cast<uint32_t>((i + 1) * 6 + 30);
assert(check_value(i, Expected, Data[i], "Data"));
}

free(Data, Queue);
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1418,9 +1418,7 @@ ur_result_t ur_command_list_manager::isGraphCaptureActive(bool *pResult) {
}

ur_result_t ur_command_list_manager::getGraph(ur_exp_graph_handle_t *phGraph) {
auto zeGetGraph =
hContext.get()->getPlatform()->ZeGraphExt.zeCommandListGetGraphExp;
if (!checkGraphExtensionSupport(hContext.get()) || !zeGetGraph) {
if (!checkGraphExtensionSupport(hContext.get())) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand All @@ -1431,6 +1429,12 @@ ur_result_t ur_command_list_manager::getGraph(ur_exp_graph_handle_t *phGraph) {
}

// Fork-join and implicit capture scenarios
auto zeGetGraph =
hContext.get()->getPlatform()->ZeGraphExt.zeCommandListGetGraphExp;
if (!zeGetGraph) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

ze_graph_handle_t hZeGraph = nullptr;
ze_result_t ZeResult =
ZE_CALL_NOCHECK(zeGetGraph, (getZeCommandList(), &hZeGraph));
Expand Down