Skip to content

Commit e4348eb

Browse files
authored
Refactor: move reclamation state into owning data structures (#315)
- Move dep_pool_mark from PTO2TaskPayload (GM) to PTO2TaskSlotState (local memory) to avoid GM cache line pollution - Move last_reclaimed into PTO2DepListPool and last_cleanup into PTO2TensorMap, eliminating parallel arrays in orchestrator state - Consolidate per-submit sm_last_task_alive read — single atomic load shared by tensormap sync and dep pool reclaim - Simplify sync_tensormap to per-ring interface, removing multi-ring loop and MIN_FREE_NUM pressure heuristic - Defer task descriptor GM writes until after tensor insertion to batch cache line stores and reduce eviction pressure - Narrow ring_id type from int32_t to uint8_t throughout
1 parent 77a81aa commit e4348eb

7 files changed

Lines changed: 80 additions & 81 deletions

File tree

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp

Lines changed: 37 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ bool pto2_orchestrator_init(
143143
pto2_dep_pool_init(&orch->rings[r].dep_pool, dep_entries, dep_pool_capacity);
144144
orch->rings[r].dep_pool.error_code_ptr = &sm_handle->header->orch_error_code;
145145
orch->dep_pool_cur_entries[r] = nullptr;
146-
orch->dep_pool_last_reclaimed[r] = 0;
147146
}
148147

149148
// Initialize TensorMap with per-ring task window sizes
@@ -158,9 +157,6 @@ bool pto2_orchestrator_init(
158157
return false;
159158
}
160159
orch->tensor_map.orch = orch;
161-
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
162-
orch->tensormap_last_cleanup[r] = 0;
163-
}
164160

165161
// Initialize scope stack: one flat buffer for task IDs + one array for begin offsets
166162
uint64_t max_depth = PTO2_MAX_SCOPE_DEPTH;
@@ -203,40 +199,18 @@ void pto2_orchestrator_set_scheduler(PTO2OrchestratorState* orch, PTO2SchedulerS
203199
}
204200

205201

206-
// =============================================================================
207-
// Dep Pool Reclamation
208-
// =============================================================================
209-
210-
/**
211-
* Reclaim dead dep pool entries for a specific ring based on scheduler's last_task_alive.
212-
* Safe to call multiple times — only advances tail forward.
213-
*/
214-
static void pto2_dep_pool_reclaim(PTO2OrchestratorState* orch, int32_t ring_id) {
215-
int32_t last_alive =
216-
orch->sm_handle->header->rings[ring_id].fc.last_task_alive.load(std::memory_order_acquire);
217-
if (last_alive > orch->dep_pool_last_reclaimed[ring_id] && last_alive > 0) {
218-
int32_t newest_consumed = last_alive - 1;
219-
int32_t slot_rc = orch->rings[ring_id].task_ring.get_task_slot(newest_consumed);
220-
int32_t mark = orch->sm_handle->task_payloads[ring_id][slot_rc].dep_pool_mark;
221-
if (mark > 0) {
222-
orch->rings[ring_id].dep_pool.advance_tail(mark);
223-
}
224-
orch->dep_pool_last_reclaimed[ring_id] = last_alive;
225-
}
226-
}
227-
228202
/**
229203
* Ensure dep pool for a specific ring has at least `needed` entries available.
230204
* Spin-waits for reclamation if under pressure. Detects deadlock if no progress.
231205
*/
232-
static void pto2_dep_pool_ensure_space(PTO2OrchestratorState* orch, int32_t ring_id, int32_t needed) {
206+
static void pto2_dep_pool_ensure_space(PTO2OrchestratorState* orch, uint8_t ring_id, int32_t needed) {
233207
if (pto2_dep_pool_available(&orch->rings[ring_id].dep_pool) >= needed) return;
234208

235209
int spin_count = 0;
236210
int32_t prev_last_alive =
237211
orch->sm_handle->header->rings[ring_id].fc.last_task_alive.load(std::memory_order_acquire);
238212
while (pto2_dep_pool_available(&orch->rings[ring_id].dep_pool) < needed) {
239-
pto2_dep_pool_reclaim(orch, ring_id);
213+
orch->rings[ring_id].dep_pool.reclaim(orch->scheduler, ring_id, prev_last_alive);
240214
if (pto2_dep_pool_available(&orch->rings[ring_id].dep_pool) >= needed) return;
241215

242216
spin_count++;
@@ -334,7 +308,11 @@ void pto2_scope_end(PTO2OrchestratorState* orch) {
334308
void pto2_submit_mixed_task(
335309
PTO2OrchestratorState* orch, const MixedKernels& mixed_kernels, const PTOParam& params) {
336310
// Fast path after fatal error — all subsequent submits are no-ops
337-
if (orch->fatal) { return; }
311+
if (orch->fatal) {
312+
return;
313+
}
314+
315+
PTO2SchedulerState* sched = orch->scheduler;
338316

339317
// Validate PTOParam construction (errors recorded by add_input/add_output/etc.)
340318
if (params.has_error) {
@@ -370,14 +348,20 @@ void pto2_submit_mixed_task(
370348
}
371349

372350
// === STEP 0: Sync TensorMap validity and optional cleanup ===
373-
orch->tensor_map.sync_tensormap();
374351

375352
// Determine which ring this task belongs to
376-
int32_t ring_id = orch->current_ring_id();
353+
uint8_t ring_id = orch->current_ring_id();
377354
auto& task_ring = orch->rings[ring_id].task_ring;
378355

379-
// Reclaim dead dep pool entries based on scheduler's last_task_alive
380-
pto2_dep_pool_reclaim(orch, ring_id);
356+
// Read current last_task_alive from shared memory for this ring
357+
int32_t sm_last_task_alive =
358+
orch->sm_handle->header->rings[ring_id].fc.last_task_alive.load(std::memory_order_acquire);
359+
360+
orch->tensor_map.sync_tensormap(ring_id, sm_last_task_alive);
361+
362+
if (sched) {
363+
orch->rings[ring_id].dep_pool.reclaim(sched, ring_id, sm_last_task_alive);
364+
}
381365

382366
CYCLE_COUNT_LAP_RECORD(g_orch_sync_cycle, AicpuPhaseId::ORCH_SYNC, -1);
383367

@@ -427,8 +411,7 @@ void pto2_submit_mixed_task(
427411
int32_t local_id = task_ring.pto2_task_ring_alloc();
428412
if (local_id < 0) { orch->fatal = true; return; }
429413
int32_t slot = task_ring.get_task_slot(local_id);
430-
PTO2TaskId mixed_task_id =
431-
pto2_make_task_id(static_cast<uint8_t>(ring_id), static_cast<uint32_t>(local_id));
414+
PTO2TaskId mixed_task_id = pto2_make_task_id(ring_id, static_cast<uint32_t>(local_id));
432415

433416
PTO2TaskDescriptor& task = task_ring.get_task_by_slot(slot);
434417
PTO2TaskPayload* payload = &orch->sm_handle->task_payloads[ring_id][slot];
@@ -444,21 +427,11 @@ void pto2_submit_mixed_task(
444427
for (int32_t j = 0; j < params.scalar_count; j += 8) {
445428
__builtin_prefetch(&payload->scalars[j], 1, 3);
446429
}
447-
// Metadata area: tensor_count, scalar_count, fanin_slot_states[] — all in first 3 CLs
448430
__builtin_prefetch(payload, 1, 3);
449431
__builtin_prefetch(reinterpret_cast<char*>(payload) + 64, 1, 3);
450432
__builtin_prefetch(reinterpret_cast<char*>(payload) + 128, 1, 3);
451433

452-
// Initialize mixed-task descriptor
453-
task.mixed_task_id = mixed_task_id;
454-
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIC)] = normalized.aic_kernel_id;
455-
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIV0)] = normalized.aiv0_kernel_id;
456-
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIV1)] = normalized.aiv1_kernel_id;
457-
task.packed_buffer_base = NULL;
458-
task.packed_buffer_end = NULL;
459-
460434
// Initialize slot state (scheduler-private)
461-
PTO2SchedulerState* sched = orch->scheduler;
462435
if (sched) {
463436
auto& rs = sched->ring_sched_states[ring_id];
464437
PTO2TaskSlotState& slot_state = rs.get_slot_state_by_slot(slot);
@@ -473,7 +446,7 @@ void pto2_submit_mixed_task(
473446
slot_state.task = &task;
474447
slot_state.active_mask = active_mask;
475448
slot_state.subtask_done_mask.store(0, std::memory_order_relaxed);
476-
slot_state.ring_id = static_cast<uint8_t>(ring_id);
449+
slot_state.ring_id = ring_id;
477450
scope_tasks_push(orch, &slot_state);
478451
} else {
479452
scope_tasks_push(orch, nullptr);
@@ -496,10 +469,12 @@ void pto2_submit_mixed_task(
496469
}
497470
}
498471

472+
void* local_packed_base = nullptr;
473+
void* local_packed_end = nullptr;
499474
if (total_output_size > 0) {
500-
task.packed_buffer_base = orch->pto2_alloc_packed_buffer(total_output_size);
501-
if (!task.packed_buffer_base) { orch->fatal = true; return; }
502-
task.packed_buffer_end = (char*)task.packed_buffer_base + total_output_size;
475+
local_packed_base = orch->pto2_alloc_packed_buffer(total_output_size);
476+
if (!local_packed_base) { orch->fatal = true; return; }
477+
local_packed_end = (char*)local_packed_base + total_output_size;
503478
}
504479
CYCLE_COUNT_LAP_RECORD(g_orch_heap_cycle, AicpuPhaseId::ORCH_HEAP, local_id);
505480
#if PTO2_ORCH_PROFILING
@@ -559,7 +534,7 @@ void pto2_submit_mixed_task(
559534
case PTOParamType::OUTPUT: {
560535
Tensor& tensor = *params.tensors[i];
561536
if (tensor.buffer.addr == 0) {
562-
uint64_t alloc_addr = reinterpret_cast<uint64_t>((char*)task.packed_buffer_base + offset);
537+
uint64_t alloc_addr = reinterpret_cast<uint64_t>((char*)local_packed_base + offset);
563538
tensor.buffer.addr = alloc_addr;
564539
offset += PTO2_ALIGN_UP(tensor.buffer.size, PTO2_PACKED_OUTPUT_ALIGN);
565540
}
@@ -582,6 +557,16 @@ void pto2_submit_mixed_task(
582557

583558
CYCLE_COUNT_LAP_RECORD(g_orch_insert_cycle, AicpuPhaseId::ORCH_INSERT, local_id);
584559

560+
// === Batch-write task descriptor to GM (single cache line burst) ===
561+
// Deferred from allocation phase to avoid scattered GM writes that get
562+
// evicted by TensorMap lookup/insert cache pressure.
563+
__builtin_prefetch(&task, 1, 1);
564+
task.mixed_task_id = mixed_task_id;
565+
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIC)] = normalized.aic_kernel_id;
566+
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIV0)] = normalized.aiv0_kernel_id;
567+
task.kernel_id[static_cast<int>(PTO2SubtaskSlot::AIV1)] = normalized.aiv1_kernel_id;
568+
task.packed_buffer_base = local_packed_base;
569+
task.packed_buffer_end = local_packed_end;
585570

586571
// Prefetch producer slot_states and cur_slot_state (written at init but likely
587572
// evicted by lookup/insert/heap). param_copy below provides hide time.
@@ -657,6 +642,8 @@ void pto2_submit_mixed_task(
657642
PTO2ResourceShape shape = pto2_active_mask_to_shape(active_mask);
658643
sched->ready_queues[static_cast<int32_t>(shape)].push(&cur_slot_state);
659644
}
645+
// Record dep pool watermark in local slot state (used by tail reclamation)
646+
cur_slot_state.dep_pool_mark = orch->rings[ring_id].dep_pool.top;
660647
#if PTO2_ORCH_PROFILING
661648
// Per producer: fetch_add(fanout_count) + load(task_state) + store(unlock) = 3 atomics
662649
// Lock atomics (loads + CAS) are counted inside pto2_fanout_lock
@@ -667,9 +654,6 @@ void pto2_submit_mixed_task(
667654
#endif
668655
}
669656

670-
// Record dep pool watermark for this task (used by tail reclamation)
671-
payload->dep_pool_mark = orch->rings[ring_id].dep_pool.top;
672-
673657
CYCLE_COUNT_LAP_RECORD(g_orch_fanin_cycle, AicpuPhaseId::ORCH_FANIN, local_id);
674658

675659
#if PTO2_PROFILING

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,9 @@ struct PTO2OrchestratorState {
4242
// === PER-RING RESOURCES ===
4343
PTO2RingSet rings[PTO2_MAX_RING_DEPTH];
4444
PTO2DepListEntry* dep_pool_cur_entries[PTO2_MAX_RING_DEPTH];
45-
int32_t dep_pool_last_reclaimed[PTO2_MAX_RING_DEPTH];
4645

4746
// === TENSOR MAP (Private) ===
4847
PTO2TensorMap tensor_map; // Producer lookup
49-
int32_t tensormap_last_cleanup[PTO2_MAX_RING_DEPTH];
5048

5149
// === SCOPE STACK (Private) ===
5250
// Single contiguous buffer of task IDs, partitioned by scope level.
@@ -88,10 +86,10 @@ struct PTO2OrchestratorState {
8886
* Get current ring index from scope depth.
8987
* Maps scope depth to ring_id: min(scope_depth, PTO2_MAX_RING_DEPTH - 1)
9088
*/
91-
int32_t current_ring_id() const {
89+
uint8_t current_ring_id() const {
9290
int32_t depth = scope_stack_top;
9391
if (depth < 0) depth = 0;
94-
return depth < PTO2_MAX_RING_DEPTH ? depth : PTO2_MAX_RING_DEPTH - 1;
92+
return depth < PTO2_MAX_RING_DEPTH ? static_cast<uint8_t>(depth) : PTO2_MAX_RING_DEPTH - 1;
9593
}
9694

9795
/**
@@ -102,7 +100,7 @@ struct PTO2OrchestratorState {
102100
return NULL;
103101
}
104102

105-
int32_t rid = current_ring_id();
103+
uint8_t rid = current_ring_id();
106104
void* buffer = rings[rid].heap_ring.pto2_heap_ring_alloc(total_size);
107105

108106
#if PTO2_PROFILING

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <string.h>
1313
#include <stdlib.h> // for exit()
1414
#include "common/unified_log.h"
15+
#include "pto_scheduler.h"
1516

1617
// =============================================================================
1718
// Heap Ring Buffer Implementation
@@ -49,12 +50,23 @@ void pto2_dep_pool_init(PTO2DepListPool* pool, PTO2DepListEntry* base, int32_t c
4950
pool->top = 1; // Start from 1, 0 means NULL/empty
5051
pool->tail = 1; // Match initial top (no reclaimable entries yet)
5152
pool->high_water = 0;
53+
pool->last_reclaimed = 0;
5254

5355
// Initialize entry 0 as NULL marker
5456
pool->base[0].slot_state = nullptr;
5557
pool->base[0].next = nullptr;
5658
}
5759

60+
void PTO2DepListPool::reclaim(PTO2SchedulerState* sched, uint8_t ring_id, int32_t sm_last_task_alive) {
61+
if (sm_last_task_alive >= last_reclaimed + PTO2_DEP_POOL_CLEANUP_INTERVAL && sm_last_task_alive > 0) {
62+
int32_t mark = sched->ring_sched_states[ring_id].get_slot_state_by_task_id(sm_last_task_alive - 1).dep_pool_mark;
63+
if (mark > 0) {
64+
advance_tail(mark);
65+
}
66+
last_reclaimed = sm_last_task_alive;
67+
}
68+
}
69+
5870
int32_t pto2_dep_pool_used(PTO2DepListPool* pool) {
5971
return pool->top - pool->tail;
6072
}

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_ring_buffer.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#include "pto_shared_memory.h"
3333
#include "common/unified_log.h"
3434

35+
struct PTO2SchedulerState; // Forward declaration for dep_pool reclaim
36+
3537
// Set to 1 to enable periodic BLOCKED/Unblocked messages during spin-wait.
3638
#ifndef PTO2_SPIN_VERBOSE_LOGGING
3739
#define PTO2_SPIN_VERBOSE_LOGGING 1
@@ -468,10 +470,21 @@ struct PTO2DepListPool {
468470
int32_t top; // Linear next-allocation counter (starts from 1)
469471
int32_t tail; // Linear first-alive counter (entries before this are dead)
470472
int32_t high_water; // Peak concurrent usage (top - tail)
473+
int32_t last_reclaimed{0}; // last_task_alive at last successful reclamation
471474

472475
// Error code pointer for fatal error reporting (→ sm_header->orch_error_code)
473476
std::atomic<int32_t>* error_code_ptr = nullptr;
474477

478+
/**
479+
* Reclaim dead entries based on scheduler's slot state dep_pool_mark.
480+
* Safe to call multiple times — only advances tail forward.
481+
*
482+
* @param sched Scheduler state (for reading slot dep_pool_mark)
483+
* @param ring_id Ring layer index
484+
* @param sm_last_task_alive Current last_task_alive from shared memory
485+
*/
486+
void reclaim(PTO2SchedulerState* sched, uint8_t ring_id, int32_t sm_last_task_alive);
487+
475488
/**
476489
* Allocate a single entry from the pool (single-thread per pool instance)
477490
*

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103

104104
// TensorMap cleanup interval
105105
#define PTO2_TENSORMAP_CLEANUP_INTERVAL 64 // Cleanup every N retired tasks
106+
#define PTO2_DEP_POOL_CLEANUP_INTERVAL 64 // Cleanup every N retired tasks
106107

107108
// =============================================================================
108109
// Multi-Ring task_id Encoding
@@ -366,7 +367,7 @@ struct PTO2TaskPayload {
366367
int32_t tensor_count{0};
367368
int32_t scalar_count{0};
368369
int32_t fanin_actual_count{0}; // Actual fanin count (without the +1 redundance)
369-
int32_t dep_pool_mark{0}; // Dep pool top after this task's submission (for reclamation)
370+
int32_t _reserved{0}; // Reserved (dep_pool_mark moved to SlotState for local access)
370371
PTO2TaskSlotState* fanin_slot_states[PTO2_MAX_INPUTS]; // Producer slot states (used by on_task_release)
371372
// === Cache lines 3-34 (2048B) — tensors (alignas(64) forces alignment) ===
372373
Tensor tensors[PTO2_MAX_TENSOR_PARAMS];
@@ -425,6 +426,7 @@ struct alignas(64) PTO2TaskSlotState {
425426
uint8_t active_mask; // Bitmask of active subtask slots (set once)
426427
std::atomic<uint8_t> subtask_done_mask; // Each subtask sets its done bit on completion
427428
uint8_t ring_id; // Ring layer this task belongs to (for per-ring reclamation)
429+
int32_t dep_pool_mark{0}; // Dep pool top after this task's submission (orchestrator-only, local memory)
428430
};
429431

430432
static_assert(sizeof(PTO2TaskSlotState) == 64);

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.cpp

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ bool PTO2TensorMap::init(int32_t new_num_buckets, int32_t new_pool_size, const i
115115

116116
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
117117
last_task_alives[r] = 0;
118+
last_cleanup[r] = 0;
118119
}
119120

120121
return true;
@@ -220,27 +221,13 @@ int32_t PTO2TensorMap::valid_count() {
220221
return count;
221222
}
222223

223-
void PTO2TensorMap::sync_tensormap() {
224-
constexpr int MIN_FREE_NUM = 1024;
225-
always_assert(orch != nullptr);
226-
while(true) {
227-
bool did_cleanup = false;
228-
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
229-
// Read current last_task_alive from shared memory for this ring
230-
int32_t new_last_task_alive =
231-
orch->sm_handle->header->rings[r].fc.last_task_alive.load(std::memory_order_acquire);
232-
sync_validity(r, new_last_task_alive);
233-
// Only attempt cleanup when last_task_alive has actually advanced;
234-
// otherwise cleanup_retired would empty-loop and we'd spin forever.
235-
if (new_last_task_alive <= orch->tensormap_last_cleanup[r]) continue;
236-
if ((pool_size - next_entry_idx + free_num < MIN_FREE_NUM) ||
237-
new_last_task_alive - orch->tensormap_last_cleanup[r] >= PTO2_TENSORMAP_CLEANUP_INTERVAL) {
238-
cleanup_retired(r, orch->tensormap_last_cleanup[r], new_last_task_alive);
239-
orch->tensormap_last_cleanup[r] = new_last_task_alive;
240-
did_cleanup = true;
241-
}
242-
}
243-
if (!did_cleanup) break;
224+
void PTO2TensorMap::sync_tensormap(uint8_t ring_id, int32_t sm_last_task_alive) {
225+
sync_validity(ring_id, sm_last_task_alive);
226+
// Only attempt cleanup when last_task_alive has actually advanced;
227+
// otherwise cleanup_retired would empty-loop and we'd spin forever.
228+
if (sm_last_task_alive - last_cleanup[ring_id] >= PTO2_TENSORMAP_CLEANUP_INTERVAL) {
229+
cleanup_retired(ring_id, last_cleanup[ring_id], sm_last_task_alive);
230+
last_cleanup[ring_id] = sm_last_task_alive;
244231
}
245232
}
246233

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ struct PTO2TensorMap {
198198
// Per-ring validity threshold (for lazy invalidation)
199199
int32_t last_task_alives[PTO2_MAX_RING_DEPTH]; // Cached from shared memory per ring
200200

201+
// Per-ring cleanup progress (for periodic cleanup_retired)
202+
int32_t last_cleanup[PTO2_MAX_RING_DEPTH]{};
203+
201204
PTO2OrchestratorState* orch{nullptr};
202205

203206
// new_entry目前不负责分配属性,仅分配内存
@@ -500,7 +503,7 @@ struct PTO2TensorMap {
500503
* Called periodically to refresh the lazy invalidation threshold.
501504
* Also triggers cleanup if threshold has advanced significantly.
502505
*/
503-
void sync_tensormap();
506+
void sync_tensormap(uint8_t ring_id, int32_t sm_last_task_alive);
504507
};
505508

506509
#if PTO2_TENSORMAP_PROFILING

0 commit comments

Comments
 (0)