Skip to content

Commit 8200078

Browse files
author
echo_stone
committed
Fix(pr): address review comments for #334
- Remove preemptive flush_deferred_releases guard and unused lambda from executor loop; rely on existing inline flush-on-full and idle-batch-flush paths (reviewer: poursoul) - Clarify cache_invalidate_range comment: all current counter writers (SDMA flags, TNOTIFY RDMA atomics) bypass AICPU cache, so invalidation is always required (reviewer: uv-xiao) - Add pto2_rt_submit_notification_wait_task() helper API to pto_orchestration_api.h, reducing NotifyWait boilerplate in orchestration code (reviewer: uv-xiao) - Simplify async_notify_demo and moe_dispatch orchestration to use the new helper API - Remove unused PTO2LocalReadyBuffer forward declaration (reviewer: uv-xiao) Made-with: Cursor
1 parent 5f4e78b commit 8200078

6 files changed

Lines changed: 50 additions & 57 deletions

File tree

examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
5252
Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32);
5353
Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32);
5454

55-
uint64_t cq_notify = pto2_rt_alloc_cq();
56-
if (cq_notify == 0) {
57-
LOG_ERROR("async_notify_demo: rank %d failed CQ alloc", my_rank);
58-
return;
59-
}
60-
6155
// Producer: normal run-to-completion task (sends TNOTIFY to peer)
6256
PTOParam params_producer;
6357
params_producer.add_input(ext_in);
@@ -67,20 +61,13 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
6761
pto2_rt_submit_aiv_task(0, params_producer);
6862

6963
// NotifyWait: deferred task that waits for notification counter >= 1.
70-
// Produces dummy_notify so the consumer can depend on it via TensorMap.
71-
uint32_t dummy_shape[1] = { 1 };
72-
Tensor dummy_notify = make_tensor(dummy_shape, 1, DataType::INT32);
73-
74-
PTOParam params_wait;
75-
params_wait.add_output(dummy_notify);
76-
params_wait.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr);
77-
params_wait.add_scalar((uint64_t)1);
78-
pto2_rt_submit_aiv_task_deferred(2, params_wait, cq_notify);
64+
// Returns a dependency token tensor for downstream tasks.
65+
Tensor notify_token = pto2_rt_submit_notification_wait_task(
66+
2, (uint64_t)(uintptr_t)notify_counter_ptr, 1);
7967

80-
// Consumer: depends on producer (via ext_out) and notify_wait (via dummy_notify).
81-
// Guaranteed notify_counter >= 1 when this task runs.
68+
// Consumer: depends on producer (via ext_out) and notify_wait (via token).
8269
PTOParam params_consumer;
83-
params_consumer.add_input(dummy_notify);
70+
params_consumer.add_input(notify_token);
8471
params_consumer.add_input(ext_out);
8572
params_consumer.add_output(ext_result);
8673
params_consumer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr);

examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/orchestration/moe_dispatch_orchestration.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
104104

105105
uint64_t sdma_context = pto2_rt_get_sdma_context();
106106
uint64_t cq_send = pto2_rt_alloc_cq();
107-
uint64_t cq_notify = pto2_rt_alloc_cq();
108-
if (sdma_context == 0 || cq_send == 0 || cq_notify == 0) {
107+
if (sdma_context == 0 || cq_send == 0) {
109108
LOG_ERROR("moe_dispatch_v2: rank %d failed SDMA context or CQ alloc", my_rank);
110109
return;
111110
}
@@ -132,20 +131,14 @@ void aicpu_orchestration_entry(uint64_t* args, int arg_count,
132131
params_send.add_scalar(sdma_context);
133132
pto2_rt_submit_aiv_task_deferred(1, params_send, cq_send);
134133

135-
// Phase 1.5: NotifyWait — deferred task that waits for notification counter.
136-
// Produces a dummy_notify tensor so RecvAssemble can depend on it via TensorMap.
137-
uint32_t dummy_shape[1] = { 1 };
138-
Tensor dummy_notify = make_tensor(dummy_shape, 1, DataType::INT32);
134+
// Phase 1.5: NotifyWait — deferred wait for notification counter >= NUM_RANKS-1.
135+
// Returns a dependency token for RecvAssemble via TensorMap.
136+
Tensor notify_token = pto2_rt_submit_notification_wait_task(
137+
3, notify_counter_addr, NUM_RANKS - 1);
139138

140-
PTOParam params_wait;
141-
params_wait.add_output(dummy_notify);
142-
params_wait.add_scalar(notify_counter_addr);
143-
params_wait.add_scalar((uint64_t)(NUM_RANKS - 1));
144-
pto2_rt_submit_aiv_task_deferred(3, params_wait, cq_notify);
145-
146-
// Phase 2: RecvAssemble (depends on NotifyWait via dummy_notify)
139+
// Phase 2: RecvAssemble (depends on NotifyWait via notify_token)
147140
PTOParam params_recv;
148-
params_recv.add_input(dummy_notify);
141+
params_recv.add_input(notify_token);
149142
params_recv.add_input(ext_local_counts);
150143
params_recv.add_output(ext_expand_x);
151144
params_recv.add_output(ext_etn);

src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,25 +1083,6 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa
10831083

10841084
PTO2AsyncWaitList async_wait_list;
10851085

1086-
auto flush_deferred_releases = [&]() {
1087-
while (deferred_release_count > 0) {
1088-
#if PTO2_SCHED_PROFILING
1089-
int32_t fe = rt->scheduler.on_task_release(
1090-
*deferred_release_slot_states[--deferred_release_count], thread_idx);
1091-
#else
1092-
int32_t fe = rt->scheduler.on_task_release(
1093-
*deferred_release_slot_states[--deferred_release_count]);
1094-
#endif
1095-
(void)fe;
1096-
#if PTO2_SCHED_PROFILING
1097-
fanin_edges_total += fe;
1098-
if (fe > fanin_max_degree) {
1099-
fanin_max_degree = fe;
1100-
}
1101-
#endif
1102-
}
1103-
};
1104-
11051086
bool cores_released = false;
11061087

11071088
#if PTO2_PROFILING
@@ -1172,9 +1153,6 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa
11721153

11731154
// Phase 0: Poll async completion conditions (deferred-completion tasks)
11741155
int32_t async_completed_this_turn = 0;
1175-
if (deferred_release_count > MAX_DEFERRED_RELEASES - PTO2_MAX_ASYNC_WAITS) {
1176-
flush_deferred_releases();
1177-
}
11781156
if (async_wait_list.count > 0) {
11791157
PTO2AsyncPollResult poll_result = async_wait_list.poll_and_complete<false>(
11801158
&rt->scheduler, local_bufs,

src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ static inline void pto2_rt_submit_task_deferred(const MixedKernels& mixed_kernel
253253
rt->ops->submit_task(rt, mixed_kernels, params);
254254
}
255255

256+
/**
257+
* Submit a notification-wait deferred task and return a dependency token.
258+
*
259+
* Encapsulates the boilerplate for creating a NotifyWait task:
260+
* 1. Allocate a CQ
261+
* 2. Create a 1-element dummy output tensor (dependency token)
262+
* 3. Submit a deferred AIV task with (counter_addr, expected_value, cq_addr)
263+
*
264+
* The returned token tensor should be added as an input to any downstream
265+
* task that depends on the notification completing.
266+
*
267+
* @param kernel_id func_id of the NotifyWait kernel
268+
* @param counter_addr GM address of the notification counter (int32*)
269+
* @param expected_value threshold: task completes when *counter >= expected
270+
* @return dependency token tensor (add as input to downstream tasks)
271+
*/
272+
static inline Tensor pto2_rt_submit_notification_wait_task(
273+
int32_t kernel_id,
274+
uint64_t counter_addr,
275+
uint32_t expected_value) {
276+
uint64_t cq_addr = pto2_rt_alloc_cq();
277+
278+
uint32_t dummy_shape[1] = { 1 };
279+
Tensor token = make_tensor(dummy_shape, 1, DataType::INT32);
280+
281+
PTOParam params;
282+
params.add_output(token);
283+
params.add_scalar(counter_addr);
284+
params.add_scalar(static_cast<uint64_t>(expected_value));
285+
pto2_rt_submit_aiv_task_deferred(kernel_id, params, cq_addr);
286+
287+
return token;
288+
}
289+
256290
static inline void pto2_rt_scope_begin() {
257291
PTO2Runtime* rt = pto2_current_runtime();
258292
rt->ops->scope_begin(rt);

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ struct PTO2AsyncWaitList {
185185
for (int32_t c = 0; c < entry.condition_count; c++) {
186186
PTO2CompletionCondition& cond = entry.conditions[c];
187187
if (!cond.satisfied) {
188-
// RDMA-written counters (e.g. TNOTIFY) bypass AICPU data cache.
189-
// Invalidate before reading to see the true memory value.
188+
// All current counter writers (SDMA engine flags, TNOTIFY
189+
// RDMA atomics) bypass AICPU data cache. Invalidation is
190+
// needed so the poll reads the true GM value. For any
191+
// hypothetical CPU-written counter this is a harmless no-op.
190192
if (cond.counter_addr) {
191193
cache_invalidate_range(
192194
reinterpret_cast<const void*>(const_cast<const uint32_t*>(cond.counter_addr)),

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "common/core_type.h"
2929

3030
struct PTO2SchedulerState;
31-
struct PTO2LocalReadyBuffer;
3231

3332
#if PTO2_SCHED_PROFILING
3433
#include "aicpu/device_time.h"

0 commit comments

Comments
 (0)