Skip to content

Commit 11dc229

Browse files
author
echo_stone
committed
implement async notification
1 parent e334778 commit 11dc229

20 files changed

Lines changed: 959 additions & 144 deletions

File tree

examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
"""
2-
Golden script for async_completion_demo (dual-mode).
2+
Golden script for async_completion_demo.
33
4-
Computation:
5-
producer: out[i] = in[i] * 2.0 (with deferred completion)
4+
Single-card / sim path keeps the original producer-consumer pipeline:
5+
producer: out[i] = in[i] * 2.0
66
consumer: result[i] = out[i] + 1.0
77
8-
So: result[i] = in[i] * 2.0 + 1.0
9-
With in = 3.0: result = 7.0
10-
11-
Args layout: [ptr_in, ptr_out, ptr_result, ptr_event_handle_output,
12-
size_in, size_out, size_result, size_event_handle_output, SIZE]
13-
14-
event_handle_output: 16 bytes — used by the kernel and scheduler for async
15-
completion signaling. Not compared as test output.
8+
Hardware 2-card path validates `out` and `result`:
9+
each rank TGET_ASYNCs the peer rank's `in` into local `out`, then the
10+
normal consumer computes `result = out + 1`.
1611
"""
1712

1813
import ctypes
@@ -45,7 +40,34 @@ def generate_inputs(params: dict) -> list:
4540
]
4641

4742

43+
def generate_distributed_inputs(rank: int, nranks: int, root: int,
44+
comm_ctx=None) -> list:
45+
del comm_ctx
46+
del nranks
47+
del root
48+
49+
size = 128 * 128
50+
inp = [float(i % 251) / 10.0 for i in range(size)]
51+
out = [0.0] * size
52+
result = [0.0] * size
53+
54+
return [
55+
("in", inp),
56+
("out", out),
57+
("result", result),
58+
]
59+
60+
4861
def compute_golden(tensors: dict, params: dict) -> None:
49-
inp = torch.as_tensor(tensors["in"])
50-
tensors["result"][:] = inp * 2.0 + 1.0
51-
tensors["out"][:] = inp * 2.0
62+
if "in" in tensors:
63+
inp = torch.as_tensor(tensors["in"])
64+
tensors["result"][:] = inp * 2.0 + 1.0
65+
tensors["out"][:] = inp * 2.0
66+
return
67+
68+
out = tensors["out"]
69+
result = tensors["result"]
70+
for i in range(len(out)):
71+
value = float(i % 251) / 10.0
72+
out[i] = value
73+
result[i] = value + 1.0

examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in
8585
#endif
8686

8787
volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr);
88+
pto2_cq_reset(cq);
8889
pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq,
8990
PTO2_CQ_COMPLETION_EVENT_FLAG,
9091
event_flag_addr, 0);
91-
pto2_cq_flush();
92+
pto2_cq_flush(cq);
9293
}
Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,89 @@
11
/**
2-
* Async Completion Demo - Hardware SDMA Producer Kernel (func_id=2)
2+
* Async Completion Demo - Hardware 2P SDMA TGET Producer Kernel (func_id=2)
33
*
4-
* Implements: out[i] = in[i] * 2.0 via TLOAD/TADD/TSTORE, then issues
5-
* an async SDMA request via pto2_send_request_entry().
4+
* Implements:
5+
* 1. Read peer rank's input buffer via TGET_ASYNC into local out
6+
* 2. Register the async event in the CQ
7+
* 3. Return immediately so the runtime completes the task asynchronously
68
*
79
* This kernel is only compiled for real hardware (a2a3), not for simulation.
810
*
911
* Kernel args layout (packed by scheduler):
1012
* args[0] = &Tensor(in) — input tensor struct pointer
1113
* args[1] = &Tensor(out) — output tensor struct pointer
12-
* args[2] = sdma_context_addr — SDMA async context
13-
* args[3] = cq_addr — completion queue (appended by submit_deferred)
14+
* args[2] = CommDeviceContext* — distributed communication context
15+
* args[3] = sdma_context_addr — SDMA async context
16+
* args[4] = cq_addr — completion queue (appended by submit_deferred)
1417
*/
1518

1619
#include <cstdint>
20+
#ifndef __gm__
21+
#define __gm__
22+
#endif
23+
24+
#ifndef __aicore__
25+
#define __aicore__ [aicore]
26+
#endif
27+
1728
#include <pto/pto-inst.hpp>
1829
#include "pto/comm/pto_comm_inst.hpp"
1930
#include "pto/npu/comm/async/sdma/sdma_types.hpp"
2031
#include "pto/common/pto_tile.hpp"
2132

33+
#include "common/comm_context.h"
2234
#include "tensor.h"
2335

2436
using namespace pto;
2537

26-
#ifndef __gm__
27-
#define __gm__
28-
#endif
29-
30-
#ifndef __aicore__
31-
#define __aicore__ [aicore]
32-
#endif
33-
3438
#include "pto_rq_kernel_api.h"
3539

40+
template <typename T>
41+
AICORE inline __gm__ T* CommRemotePtr(__gm__ CommDeviceContext* ctx, __gm__ T* local_ptr,
42+
int peer_rank) {
43+
uint64_t local_base = ctx->windowsIn[ctx->rankId];
44+
uint64_t offset = (uint64_t)local_ptr - local_base;
45+
return (__gm__ T*)(ctx->windowsIn[peer_rank] + offset);
46+
}
47+
3648
extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) {
3749
__gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]);
3850
__gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]);
39-
uint64_t sdma_context = static_cast<uint64_t>(args[2]);
40-
uint64_t cq_addr = static_cast<uint64_t>(args[3]);
51+
__gm__ CommDeviceContext* comm_ctx =
52+
reinterpret_cast<__gm__ CommDeviceContext*>(args[2]);
53+
uint64_t sdma_context = static_cast<uint64_t>(args[3]);
54+
uint64_t cq_addr = static_cast<uint64_t>(args[4]);
4155

4256
__gm__ float* in_data = reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset;
4357
__gm__ float* out_data = reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset;
58+
volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr);
59+
pto2_cq_reset(cq);
4460

45-
constexpr int kTRows = 128;
46-
constexpr int kTCols = 128;
47-
constexpr int kTotalElems = kTRows * kTCols;
61+
int my_rank = static_cast<int>(comm_ctx->rankId);
62+
int nranks = static_cast<int>(comm_ctx->rankNum);
63+
if (nranks != 2) {
64+
pipe_barrier(PIPE_ALL);
65+
return;
66+
}
67+
int peer_rank = 1 - my_rank;
4868

49-
using DynShapeDim5 = Shape<1, 1, 1, kTRows, kTCols>;
50-
using DynStridDim5 = Stride<1, 1, 1, kTCols, 1>;
51-
using GlobalData = GlobalTensor<float, DynShapeDim5, DynStridDim5>;
52-
using TileData = Tile<TileType::Vec, float, kTRows, kTCols, BLayout::RowMajor, -1, -1>;
69+
constexpr int kTotalElems = 128 * 128;
5370

5471
using FlatShape = Shape<1, 1, 1, 1, kTotalElems>;
5572
using FlatStride = Stride<kTotalElems, kTotalElems, kTotalElems, kTotalElems, 1>;
5673
using FlatGlobalData = GlobalTensor<float, FlatShape, FlatStride>;
57-
58-
TileData inTile(kTRows, kTCols);
59-
TileData outTile(kTRows, kTCols);
60-
TASSIGN(inTile, 0x0);
61-
TASSIGN(outTile, 0x10000);
62-
63-
GlobalData inGlobal(in_data);
64-
GlobalData outGlobal(out_data);
6574
FlatGlobalData outGlobalFlat(out_data);
66-
67-
// Compute out = in + in = in * 2.0
68-
TLOAD(inTile, inGlobal);
69-
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
70-
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
71-
72-
TADD(outTile, inTile, inTile);
73-
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
74-
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
75-
76-
TSTORE(outGlobal, outTile);
77-
set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
78-
wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7);
75+
__gm__ float* remote_in_data = CommRemotePtr(comm_ctx, in_data, peer_rank);
76+
FlatGlobalData remoteInGlobalFlat(remote_in_data);
7977

8078
using ScratchTile = pto::Tile<pto::TileType::Vec, uint8_t, 1, pto::comm::sdma::UB_ALIGN_SIZE>;
8179
ScratchTile scratchTile;
8280
TASSIGN(scratchTile, 0x20000);
8381

8482
__gm__ uint8_t* context = reinterpret_cast<__gm__ uint8_t*>(static_cast<uintptr_t>(sdma_context));
85-
volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr);
8683

87-
auto desc = pto2_sdma_descriptor(outGlobalFlat, outGlobalFlat, scratchTile, context);
84+
auto desc = pto2_sdma_tget_descriptor(outGlobalFlat, remoteInGlobalFlat, scratchTile, context);
8885
uint64_t tag = pto2_send_request_entry(PTO2_ENGINE_SDMA, PTO2_RQ_ID_AUTO, desc);
8986
pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag);
9087

91-
pto2_cq_flush();
88+
pto2_cq_flush(cq);
9289
}

examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
"""
22
Async Completion Demo - Kernel and Orchestration Configuration
33
4-
Dual-mode demonstration:
5-
Sim mode (a2a3sim): func_id=0 (simulated producer, direct EVENT_FLAG completion)
6-
HW mode (a2a3): func_id=2 (TPUT_ASYNC producer, EVENT_HANDLE_SLOT completion)
7-
8-
Both modes share func_id=1 (consumer, run-to-completion).
9-
Orchestration dynamically selects mode based on SDMA context availability.
4+
Two hardware cards use the existing deferred-completion producer API to
5+
demonstrate a real 2P TGET_ASYNC remote read. The legacy single-card / sim
6+
path stays available for local debugging.
107
"""
118

129
import os
@@ -34,6 +31,26 @@
3431
RUNTIME_CONFIG = {
3532
"runtime": "tensormap_and_ringbuffer",
3633
"aicpu_thread_num": 4,
34+
"orch_thread_num": 1,
3735
"block_dim": 3,
3836
"rounds": 1,
3937
}
38+
39+
if _platform == "a2a3":
40+
RUNTIME_ENV = {
41+
"PTO2_ENABLE_SDMA": "1",
42+
}
43+
44+
DISTRIBUTED_CONFIG = {
45+
"nranks": 2,
46+
"root": 0,
47+
"win_sync_prefix": 256,
48+
"buffers": [
49+
{"name": "in", "dtype": "float32", "count": 128 * 128, "placement": "window"},
50+
{"name": "out", "dtype": "float32", "count": 128 * 128, "placement": "window"},
51+
{"name": "result", "dtype": "float32", "count": 128 * 128, "placement": "device"},
52+
],
53+
"inputs": ["in"],
54+
"outputs": ["out", "result"],
55+
"args": ["in", "out", "result", "deviceCtx"],
56+
}

examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
11
/**
22
* Async Completion Demo - Device-side orchestration (CQ model)
33
*
4-
* DAG structure:
5-
* t0 (producer): out = in * 2.0 [deferred completion via CQ]
6-
* t1 (consumer): result = out + 1.0 [run-to-completion]
7-
* Dependency: t0 -> t1 (consumer reads producer's output tensor)
4+
* Two execution modes share this file:
5+
*
6+
* 1. Single-card / sim mode (legacy demo):
7+
* t0 (producer): out = in * 2.0 [deferred completion via CQ]
8+
* t1 (consumer): result = out + 1.0 [run-to-completion]
9+
*
10+
* 2. Two-card hardware mode:
11+
* both ranks submit one deferred producer task that TGET_ASYNCs the peer
12+
* rank's input buffer into local out, then run the normal consumer on out.
813
*
914
* CQ model:
1015
* Orchestration marks t0 as complete_in_future and passes a CQ address.
1116
* The producer kernel decides at runtime what completions it needs and writes
1217
* them into the completion queue. The scheduler reads the CQ after the kernel
1318
* returns and registers completions dynamically.
14-
*
15-
* Dual-mode dispatch:
16-
* - Sim mode (no SDMA context): func_id=0
17-
* The sim producer kernel writes 1 to a GM flag, then registers an
18-
* EVENT_FLAG CQ entry pointing to that flag.
19-
* - HW mode (SDMA available): func_id=2
20-
* The HW producer kernel issues TPUT_ASYNC, writes the handle to GM,
21-
* then registers an EVENT_HANDLE_SLOT CQ entry.
22-
*
23-
* Args layout (from golden.py):
24-
* [ptr_in, ptr_out, ptr_result, ptr_event_handle_output,
25-
* size_in, size_out, size_result, size_event_handle_output, SIZE]
26-
* + [gm_heap, heap_size] appended by runtime_maker.cpp
2719
*/
2820

2921
#include <stddef.h>
3022
#include <stdint.h>
3123

24+
#include "common/comm_context.h"
3225
#include "pto_orchestration_api.h"
3326

3427
#define ARG_PTR_IN 0
@@ -48,9 +41,8 @@ extern "C" {
4841
__attribute__((visibility("default")))
4942
PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) {
5043
(void)args;
51-
(void)arg_count;
5244
return PTO2OrchestrationConfig{
53-
.expected_arg_count = 9,
45+
.expected_arg_count = (arg_count >= 9) ? 9 : 4,
5446
};
5547
}
5648

@@ -59,7 +51,43 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
5951
int orch_thread_num, int orch_thread_index) {
6052
(void)arg_count;
6153
(void)orch_thread_num;
62-
(void)orch_thread_index;
54+
if (orch_thread_index != 0) return;
55+
56+
if (arg_count == 4) {
57+
void* in_ptr = (void*)(uintptr_t)args[0];
58+
void* out_ptr = (void*)(uintptr_t)args[1];
59+
void* result_ptr = (void*)(uintptr_t)args[2];
60+
auto* comm_ctx = reinterpret_cast<CommDeviceContext*>((uintptr_t)args[3]);
61+
int my_rank = (int)comm_ctx->rankId;
62+
63+
uint32_t shapes[1] = {128 * 128};
64+
Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32);
65+
Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32);
66+
Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32);
67+
68+
uint64_t sdma_context = pto2_rt_get_sdma_context();
69+
uint64_t cq = pto2_rt_alloc_cq();
70+
if (sdma_context == 0 || cq == 0) {
71+
LOG_ERROR("async_demo 2P: rank %d failed to get SDMA context or CQ (sdma=0x%lx, cq=0x%lx)",
72+
my_rank, sdma_context, cq);
73+
return;
74+
}
75+
76+
PTOParam params_producer;
77+
params_producer.add_input(ext_in);
78+
params_producer.add_output(ext_out);
79+
params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx);
80+
params_producer.add_scalar(sdma_context);
81+
pto2_rt_submit_aiv_task_deferred(2, params_producer, cq);
82+
83+
PTOParam params_consumer;
84+
params_consumer.add_input(ext_out);
85+
params_consumer.add_output(ext_result);
86+
pto2_rt_submit_aiv_task(1, params_consumer);
87+
88+
LOG_INFO("async_demo 2P: rank %d submitted TGET_ASYNC producer with CQ", my_rank);
89+
return;
90+
}
6391

6492
void* in_ptr = (void*)(uintptr_t)args[ARG_PTR_IN];
6593
void* out_ptr = (void*)(uintptr_t)args[ARG_PTR_OUT];
@@ -79,14 +107,14 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
79107
Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32);
80108

81109
if (sdma_context != 0) {
82-
// HW mode: kernel issues TPUT_ASYNC, puts event.handle directly in CQ entry.
110+
// HW mode: kernel issues async SDMA request and puts event.handle directly in CQ entry.
83111
PTOParam params_producer;
84112
params_producer.add_input(ext_in);
85113
params_producer.add_output(ext_out);
86114
params_producer.add_scalar(sdma_context);
87115
pto2_rt_submit_aiv_task_deferred(2, params_producer, cq);
88116

89-
LOG_INFO("async_demo: HW mode - submitted TPUT_ASYNC producer (func_id=2) with CQ");
117+
LOG_INFO("async_demo: HW mode - submitted async SDMA producer (func_id=2) with CQ");
90118
} else {
91119
PTOParam params_producer;
92120
params_producer.add_input(ext_in);

0 commit comments

Comments
 (0)