[WebGPU] Fix stale buffer bindings on first graph-capture replay#28325
[WebGPU] Fix stale buffer bindings on first graph-capture replay#28325hariharans29 wants to merge 3 commits into
Conversation
Lazily create the graph-mode BufferManager and gate routing through it on a graph_buffer_mgr_active_ flag set only between OnRunStart/OnRunEnd. Refactor GpuBufferAllocator to take a std::function<const BufferManager&()> getter so the allocator can resolve the active buffer manager dynamically. Pre-existing latent bug: when graph capture is enabled, captured commands held buffer pointers from the capture run. If the allocator freed and reused a slot before the first Replay(), the dispatch read stale memory and produced one wrong logits tensor. Mostly hidden in fixed-shape decode because the allocator typically re-handed out the same addresses; exposed by upcoming MatMulNBits MLP/QKV fusions which churn the allocation pattern. Also adds WebGpuContext::WaitForQueueIdle() and ComputeContextBase::FlushAndWait() helpers used by the buffer-mgr lifetime management.
There was a problem hiding this comment.
Pull request overview
This PR addresses a WebGPU graph-capture correctness issue where the first replay could read stale buffer bindings. It does this by routing allocations through a dedicated “graph-mode” BufferManager only during the capture window, and by letting the allocator resolve the currently-active buffer manager dynamically.
Changes:
- Lazily create the graph-mode
BufferManagerand gateWebGpuExecutionProvider::BufferManager()via a newgraph_buffer_mgr_active_flag (active only betweenOnRunStart/OnRunEnd). - Refactor
GpuBufferAllocatorto optionally take astd::function<const BufferManager&()>getter, enabling dynamic buffer-manager selection. - Add
WebGpuContext::WaitForQueueIdle()(implementation) and aComputeContextBase::FlushAndWait()convenience helper.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Adds graph_buffer_mgr_active_ flag to control graph buffer manager routing. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Lazily instantiates graph buffer manager, toggles active flag during capture, and updates allocator creation to use a getter. |
| onnxruntime/core/providers/webgpu/webgpu_context.cc | Implements WebGpuContext::WaitForQueueIdle(). |
| onnxruntime/core/providers/webgpu/compute_context.h | Adds FlushAndWait() helper and expands BufferManagerAccessor friendship. |
| onnxruntime/core/providers/webgpu/allocator.h | Adds <functional> and a new allocator constructor taking a buffer-manager getter; stores the getter. |
| onnxruntime/core/providers/webgpu/allocator.cc | Implements new constructor and switches alloc/free to use the getter. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Fixes build error introduced when ComputeContextBase::FlushAndWait() was added without the corresponding WebGpuContext public declaration. Spotted by Copilot review on the PR.
| return webgpu_context_.Run(*this, program); | ||
| } | ||
|
|
||
| inline Status FlushAndWait() { |
There was a problem hiding this comment.
What is FlushAndWait() used for?
|
I have a PR #28260 to support multi-captured graph which can be used when multi generators are used in GenAI. In that PR, I also need to dynamically adjust the buffer manager so that the correct buffer manager is used. I thought the incorrect only happens for multi generators. But I'm not clear that why you meet the incorrect issue if you are using one generator. You mentioned that ' If the allocator freed and reused a slot before the first Replay(), the dispatch read stale memory and produced one wrong logits tensor. ', I am confused that why it happens. Can you give me a concrete repro to help me understand the issue here? |
Please ignore this PR for now. I hit this issue when I was using a correctness benchmark harness for the QKV and MLP fused kernels and AI suggested this fix. |
Lazily create the graph-mode BufferManager and gate routing through it on a graph_buffer_mgr_active_ flag set only between OnRunStart/OnRunEnd. Refactor GpuBufferAllocator to take a std::function<const BufferManager&()> getter so the allocator can resolve the active buffer manager dynamically.
Pre-existing latent bug: when graph capture is enabled, captured commands held buffer pointers from the capture run. If the allocator freed and reused a slot before the first Replay(), the dispatch read stale memory and produced one wrong logits tensor. Mostly hidden in fixed-shape decode because the allocator typically re-handed out the same addresses; exposed by MatMulNBits MLP/QKV fusions which churn the allocation pattern (#28280)
Also adds WebGpuContext::WaitForQueueIdle() and ComputeContextBase::FlushAndWait() helpers used by the buffer-mgr lifetime management.
Motivation and Context
Fix correctness on first replay