Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <functional> // for function
#include <memory> // for shared_ptr
#include <tuple> // for tuple, apply
#include <vector> // for vector

namespace sycl {
Expand Down Expand Up @@ -162,6 +163,18 @@ class __SYCL_EXPORT modifiable_command_graph
/// Returns true if the graph contains no nodes.
bool empty() const;

/// Register a callback to be invoked when the graph object is destroyed.
/// @param Callback Callable to invoke on destruction.
/// @param CbArgs Arguments to forward to the callback.
template <typename Func, typename... ArgTs>
void set_destruction_callback(Func &&Callback, ArgTs &&...CbArgs) {
setDestructionCallbackImpl(
[Cb = std::forward<Func>(Callback),
Args = std::tuple(std::forward<ArgTs>(CbArgs)...)]() mutable {
std::apply(Cb, Args);
});
}

/// Common Reference Semantics
friend bool operator==(const modifiable_command_graph &LHS,
const modifiable_command_graph &RHS) {
Expand Down Expand Up @@ -203,6 +216,8 @@ class __SYCL_EXPORT modifiable_command_graph

void print_graph(sycl::detail::string_view path, bool verbose = false) const;

void setDestructionCallbackImpl(std::function<void()> Callback);

std::shared_ptr<detail::graph_impl> impl;

static void checkNodePropertiesAndThrow(const property_list &Properties);
Expand Down
32 changes: 32 additions & 0 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ graph_impl::~graph_impl() {
}
MNativeGraphHandle = nullptr;
}
for (auto &Cb : MDestructionCallbacks) {
Cb();
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~graph_impl", e);
}
Expand Down Expand Up @@ -649,6 +652,30 @@ bool graph_impl::isQueueRecording(sycl::detail::queue_impl &Queue) {
return MRecordingQueues.count(Queue.weak_from_this()) > 0;
}

void graph_impl::setDestructionCallback(std::function<void()> Callback) {
if (MNativeGraphHandle) {
auto Data = std::make_unique<std::function<void()>>(std::move(Callback));
context_impl &ContextImpl = *sycl::detail::getSyclObjImpl(MContext);
sycl::detail::adapter_impl &Adapter = ContextImpl.getAdapter();
ur_result_t Result = Adapter.call_nocheck<
sycl::detail::UrApiKind::urGraphSetDestructionCallbackExp>(
MNativeGraphHandle,
[](void *UserData) {
auto *Fn = static_cast<std::function<void()> *>(UserData);
(*Fn)();
delete Fn;
},
Data.get());
if (Result != UR_RESULT_SUCCESS) {
throw sycl::exception(sycl::make_error_code(errc::runtime),
"Failed to register graph destruction callback");
}
Data.release();
} else {
MDestructionCallbacks.push_back(std::move(Callback));
}
}

void graph_impl::clearQueues(bool NeedsLock) {
graph_impl::RecQueuesStorage SwappedQueues;
{
Expand Down Expand Up @@ -2359,6 +2386,11 @@ void modifiable_command_graph::checkNodePropertiesAndThrow(
Properties, CheckDataLessProperties, CheckPropertiesWithData);
}

void modifiable_command_graph::setDestructionCallbackImpl(
std::function<void()> Callback) {
impl->setDestructionCallback(std::move(Callback));
}

executable_command_graph::executable_command_graph(
const std::shared_ptr<detail::graph_impl> &Graph, const sycl::context &Ctx,
const property_list &PropList)
Expand Down
10 changes: 10 additions & 0 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @return True if the queue is recording to this graph, false otherwise.
bool isQueueRecording(sycl::detail::queue_impl &Queue);

/// Register a destruction callback to be invoked when the graph is destroyed.
/// Uses the native UR callback if a native graph handle exists, otherwise
/// stores locally for invocation in ~graph_impl().
/// @param Callback Callable to invoke on graph destruction.
void setDestructionCallback(std::function<void()> Callback);

private:
/// Common implementation for beginRecording and beginRecordingUnlockedQueue.
/// @param[in] Queue The queue to be recorded from.
Expand Down Expand Up @@ -644,6 +650,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
// The number of live executable graphs that have been created from this
// modifiable graph
std::atomic<size_t> MExecGraphCount = 0;

/// Destruction callbacks registered for the command buffer path.
/// Invoked in ~graph_impl() when native recording is not enabled.
std::vector<std::function<void()>> MDestructionCallbacks;
};

/// Get whether native recording is enabled for this graph.
Expand Down
108 changes: 108 additions & 0 deletions sycl/test-e2e/Graph/Inputs/destruction_callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Tests destruction callbacks: resource cleanup on graph destruction, and
// copy/move constraint validation (CopyConstructible required,
// MoveConstructible as rvalue optimization, assignment not required).

#include "../graph_common.hpp"

#include <sycl/properties/all_properties.hpp>

struct CopyMoveTracker {
bool *CopiedFlag;
bool *MovedFlag;
int Value;

CopyMoveTracker(int V, bool *Copied, bool *Moved)
: CopiedFlag(Copied), MovedFlag(Moved), Value(V) {}
CopyMoveTracker(const CopyMoveTracker &Other)
: CopiedFlag(Other.CopiedFlag), MovedFlag(Other.MovedFlag),
Value(Other.Value) {
*CopiedFlag = true;
}
CopyMoveTracker(CopyMoveTracker &&Other)
: CopiedFlag(Other.CopiedFlag), MovedFlag(Other.MovedFlag),
Value(Other.Value) {
*MovedFlag = true;
}

// Assignment operators not required by spec
CopyMoveTracker &operator=(const CopyMoveTracker &) = delete;
CopyMoveTracker &operator=(CopyMoveTracker &&) = delete;
};

int main() {
queue Queue{property::queue::in_order{}};

int ObservedLvalue = 0;
int ObservedRvalue = 0;
size_t ObservedSize = 0;
bool CleanupCallbackInvoked = false;

const size_t N = 64;
int *Data = malloc_device<int>(N, Queue);

{
#ifdef GRAPH_E2E_NATIVE_RECORDING
exp_ext::command_graph Graph{
Queue.get_context(),
Queue.get_device(),
{exp_ext::property::graph::enable_native_recording{}}};
#else
exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()};
#endif

add_node(Graph, Queue, [&](handler &CGH) {
CGH.parallel_for(range<1>{N},
[=](id<1> idx) { Data[idx] = static_cast<int>(idx); });
});

// Resource cleanup: free device memory on graph destruction.
// Also verifies lvalue reference arguments are copied into internal storage
// and not referenced after registration.
size_t SizeCopy = N;
Graph.set_destruction_callback(
[](int *Ptr, queue Q, size_t &Size, size_t *OutSize, bool *OutInvoked) {
sycl::free(Ptr, Q);
*OutSize = Size;
*OutInvoked = true;
},
Data, Queue, SizeCopy, &ObservedSize, &CleanupCallbackInvoked);
SizeCopy = 0;
assert(!CleanupCallbackInvoked && "Cleanup callback should not fire yet");

// Lvalue arg: must be copied into the graph's stored tuple.
bool LvalueCopied = false, LvalueMoved = false;
CopyMoveTracker LvalueTracker{42, &LvalueCopied, &LvalueMoved};
Graph.set_destruction_callback(
[](CopyMoveTracker T, int *Out) { *Out = T.Value; }, LvalueTracker,
&ObservedLvalue);
assert(LvalueCopied && "Lvalue arg should be copied");

// Rvalue arg: move-constructible optimization should kick in.
bool RvalueCopied = false, RvalueMoved = false;
CopyMoveTracker RvalueTracker{99, &RvalueCopied, &RvalueMoved};
Graph.set_destruction_callback(
[](CopyMoveTracker T, int *Out) { *Out = T.Value; },
std::move(RvalueTracker), &ObservedRvalue);
assert(RvalueMoved && !RvalueCopied &&
"Rvalue arg should be moved, not copied");

auto ExecGraph = Graph.finalize();

std::vector<int> HostData(N);
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(ExecGraph); });
Queue.memcpy(HostData.data(), Data, N * sizeof(int));
Queue.wait();
for (size_t i = 0; i < N; i++) {
assert(check_value(i, static_cast<int>(i), HostData[i], "Data"));
}
}

assert(CleanupCallbackInvoked && "Cleanup callback was not invoked");
assert(ObservedLvalue == 42 &&
"Lvalue callback should observe original value");
assert(ObservedRvalue == 99 &&
"Rvalue callback should observe original value");
assert(ObservedSize == N &&
"Cleanup callback should observe pre-mutation size");
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// REQUIRES: level_zero_v2_adapter && arch-intel_gpu_bmg_g21
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#define GRAPH_E2E_RECORD_REPLAY
#define GRAPH_E2E_NATIVE_RECORDING

#include "../../Inputs/destruction_callback.cpp"
6 changes: 6 additions & 0 deletions sycl/test-e2e/Graph/RecordReplay/destruction_callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#define GRAPH_E2E_RECORD_REPLAY

#include "../Inputs/destruction_callback.cpp"
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3114,6 +3114,7 @@ _ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph13end_reco
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph15begin_recordingERKSt6vectorINS0_5queueESaIS7_EERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph15begin_recordingERNS0_5queueERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph24addGraphLeafDependenciesENS3_4nodeE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph26setDestructionCallbackImplESt8functionIFvvEE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph27checkNodePropertiesAndThrowERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph7addImplERKSt6vectorINS3_4nodeESaIS7_EE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph7addImplERNS3_21dynamic_command_groupERKSt6vectorINS3_4nodeESaIS9_EE
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4318,6 +4318,7 @@
?setArgHelper@handler@_V1@sycl@@AEAAXHAEAVwork_group_memory_impl@detail@23@@Z
?setArgsHelper@handler@_V1@sycl@@AEAAXH@Z
?setArgsToAssociatedAccessors@handler@_V1@sycl@@AEAAXXZ
?setDestructionCallbackImpl@modifiable_command_graph@detail@experimental@oneapi@ext@_V1@sycl@@IEAAXV?$function@$$A6AXXZ@std@@@Z
?setDevice@HostProfilingInfo@detail@_V1@sycl@@QEAAXPEAVdevice_impl@234@@Z
?setDeviceKernelInfo@handler@_V1@sycl@@AEAAX$$QEAVkernel@23@@Z
?setDeviceKernelInfoPtr@handler@_V1@sycl@@AEAAXPEAVDeviceKernelInfo@detail@23@@Z
Expand Down
Loading