Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ typedef struct PTO2RuntimeOps {
PTO2Runtime *rt, const Tensor &tensor, uint32_t ndims, const uint32_t indices[], uint64_t value
);
TaskOutputTensors (*alloc_tensors)(PTO2Runtime *rt, const Arg &args);

// Parallel for iteration isolation
void (*parallel_for_begin)(PTO2Runtime *rt);
void (*parallel_scope_begin)(PTO2Runtime *rt);
void (*parallel_scope_end)(PTO2Runtime *rt);
void (*parallel_for_end)(PTO2Runtime *rt);
} PTO2RuntimeOps;

/**
Expand Down Expand Up @@ -255,6 +261,38 @@ static inline void pto2_rt_scope_end() {
rt->ops->scope_end(rt);
}

static inline void pto2_rt_parallel_for_begin() {
PTO2Runtime *rt = pto2_current_runtime();
if (rt->ops->is_fatal(rt)) {
return;
}
rt->ops->parallel_for_begin(rt);
}

static inline void pto2_rt_parallel_scope_begin() {
PTO2Runtime *rt = pto2_current_runtime();
if (rt->ops->is_fatal(rt)) {
return;
}
rt->ops->parallel_scope_begin(rt);
}

static inline void pto2_rt_parallel_scope_end() {
PTO2Runtime *rt = pto2_current_runtime();
if (rt->ops->is_fatal(rt)) {
return;
}
rt->ops->parallel_scope_end(rt);
}

static inline void pto2_rt_parallel_for_end() {
PTO2Runtime *rt = pto2_current_runtime();
if (rt->ops->is_fatal(rt)) {
return;
}
rt->ops->parallel_for_end(rt);
}

static inline void pto2_rt_orchestration_done() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->orchestration_done(rt);
Expand Down Expand Up @@ -381,6 +419,68 @@ class PTO2ScopeGuard {
*/
#define PTO2_SCOPE() if (PTO2_SCOPE_GUARD(); true)

/**
* RAII guard for parallel for region (calls parallel_for_begin/end)
*/
class PTO2ParallelForGuard {
public: // NOLINT(whitespace/indent)
PTO2ParallelForGuard() :
rt_(pto2_current_runtime()) {
if (!rt_->ops->is_fatal(rt_)) {
rt_->ops->parallel_for_begin(rt_);
}
}
~PTO2ParallelForGuard() {
if (!rt_->ops->is_fatal(rt_)) {
rt_->ops->parallel_for_end(rt_);
}
}

private: // NOLINT(whitespace/indent)
PTO2Runtime *rt_;
};

/**
* RAII guard for parallel scope (one iteration; calls parallel_scope_begin/end)
*/
class PTO2ParallelScopeGuard {
public: // NOLINT(whitespace/indent)
PTO2ParallelScopeGuard() :
rt_(pto2_current_runtime()) {
if (!rt_->ops->is_fatal(rt_)) {
rt_->ops->parallel_scope_begin(rt_);
}
}
~PTO2ParallelScopeGuard() {
if (!rt_->ops->is_fatal(rt_)) {
rt_->ops->parallel_scope_end(rt_);
}
}

private: // NOLINT(whitespace/indent)
PTO2Runtime *rt_;
};

#define PTO2_PARALLEL_FOR_GUARD() [[maybe_unused]] PTO2ParallelForGuard _PTO2_CONCATENATE(pf_guard_, __COUNTER__)
#define PTO2_PARALLEL_SCOPE_GUARD() [[maybe_unused]] PTO2ParallelScopeGuard _PTO2_CONCATENATE(ps_guard_, __COUNTER__)

/**
* Parallel for loop with automatic iteration isolation:
* PTO2_PARALLEL_FOR(i, N) {
* submit_iter_tasks(i);
* }
*/
#define PTO2_PARALLEL_FOR(var, count) \
if (PTO2_PARALLEL_FOR_GUARD(); true) \
for (int var = 0; var < (count); ++var) \
if (PTO2_PARALLEL_SCOPE_GUARD(); true)

/**
* Single parallel scope (for manual loop control):
* PTO2_PARALLEL_SCOPE() { submit_iter_tasks(); }
*/
#define PTO2_PARALLEL_SCOPE() if (PTO2_PARALLEL_SCOPE_GUARD(); true)

// =============================================================================
// Orchestration Config
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,49 @@ void pto2_scope_end(PTO2OrchestratorState *orch) {
#endif
}

// =============================================================================
// Parallel For Iteration Isolation
// =============================================================================

void pto2_parallel_for_begin(PTO2OrchestratorState *orch) {
if (orch->fatal) {
return;
}
// Currently a marker; the real work is done per-iteration in
// parallel_scope_begin. Reserved for future diagnostics/assertions.
}

void pto2_parallel_scope_begin(PTO2OrchestratorState *orch) {
if (orch->fatal) {
return;
}
uint8_t outer_ring = orch->current_ring_id();
pto2_scope_begin(orch);
uint8_t inner_ring = orch->current_ring_id();
if (inner_ring != outer_ring) {
// Normal case: a new ring was allocated; set the iteration filter.
int32_t next_id = orch->rings[inner_ring].task_allocator.next_local_id();
orch->tensor_map.iter_start_local_ids[inner_ring] = next_id;
}
// else: depth overflow (clamped) — silently degrade to a plain scope.
}

void pto2_parallel_scope_end(PTO2OrchestratorState *orch) {
// iter_start_local_ids is NOT cleared here; the next iteration's
// parallel_scope_begin will overwrite it. parallel_for_end clears it.
pto2_scope_end(orch);
}

void pto2_parallel_for_end(PTO2OrchestratorState *orch) {
if (orch->fatal) {
return;
}
uint8_t ring_id = orch->current_ring_id();
// Clear the filter; subsequent lookups see all entries on this ring.
// In the depth-overflow case the value is already -1 (idempotent).
orch->tensor_map.iter_start_local_ids[ring_id] = -1;
}

// =============================================================================
// Task Submission
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,34 @@ void pto2_scope_begin(PTO2OrchestratorState *orch);
*/
void pto2_scope_end(PTO2OrchestratorState *orch);

// =============================================================================
// Parallel For Iteration Isolation
// =============================================================================

/**
* Begin a parallel for region.
* Currently a no-op marker; reserved for future diagnostics/assertions.
*/
void pto2_parallel_for_begin(PTO2OrchestratorState *orch);

/**
* Begin a parallel scope (one iteration of a parallel for).
* Combines scope_begin + setting the iteration filter boundary.
*/
void pto2_parallel_scope_begin(PTO2OrchestratorState *orch);

/**
* End a parallel scope (one iteration of a parallel for).
* Calls scope_end; does NOT clear the filter (next iteration overwrites it).
*/
void pto2_parallel_scope_end(PTO2OrchestratorState *orch);

/**
* End a parallel for region.
* Clears the iteration filter so subsequent lookups see all entries.
*/
void pto2_parallel_for_end(PTO2OrchestratorState *orch);

// =============================================================================
// Task Submission
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class PTO2TaskAllocator {

uint64_t heap_top() const { return heap_top_; }
uint64_t heap_capacity() const { return heap_size_; }
int32_t next_local_id() const { return local_task_id_; }

private:
// --- Task Ring ---
Expand Down
12 changes: 12 additions & 0 deletions src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ void pto2_rt_scope_begin(PTO2Runtime *rt) { pto2_scope_begin(&rt->orchestrator);

void pto2_rt_scope_end(PTO2Runtime *rt) { pto2_scope_end(&rt->orchestrator); }

static void pto2_rt_parallel_for_begin(PTO2Runtime *rt) { pto2_parallel_for_begin(&rt->orchestrator); }

static void pto2_rt_parallel_scope_begin(PTO2Runtime *rt) { pto2_parallel_scope_begin(&rt->orchestrator); }

static void pto2_rt_parallel_scope_end(PTO2Runtime *rt) { pto2_parallel_scope_end(&rt->orchestrator); }

static void pto2_rt_parallel_for_end(PTO2Runtime *rt) { pto2_parallel_for_end(&rt->orchestrator); }

void pto2_rt_orchestration_done(PTO2Runtime *rt) { pto2_orchestrator_done(&rt->orchestrator); }

static bool is_fatal_impl(PTO2Runtime *rt) { return rt->orchestrator.fatal; }
Expand Down Expand Up @@ -206,6 +214,10 @@ static const PTO2RuntimeOps s_runtime_ops = {
.get_tensor_data = pto2_get_tensor_data,
.set_tensor_data = pto2_set_tensor_data,
.alloc_tensors = alloc_tensors_impl,
.parallel_for_begin = pto2_rt_parallel_for_begin,
.parallel_scope_begin = pto2_rt_parallel_scope_begin,
.parallel_scope_end = pto2_rt_parallel_scope_end,
.parallel_for_end = pto2_rt_parallel_for_end,
};

// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ struct PTO2RuntimeOps {
PTO2Runtime *rt, const Tensor &tensor, uint32_t ndims, const uint32_t indices[], uint64_t value
);
TaskOutputTensors (*alloc_tensors)(PTO2Runtime *rt, const Arg &args);

// Parallel for iteration isolation
void (*parallel_for_begin)(PTO2Runtime *rt);
void (*parallel_scope_begin)(PTO2Runtime *rt);
void (*parallel_scope_end)(PTO2Runtime *rt);
void (*parallel_for_end)(PTO2Runtime *rt);
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ bool PTO2TensorMap::init(
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
last_task_alives[r] = 0;
last_cleanup[r] = 0;
iter_start_local_ids[r] = -1;
}

return true;
Expand Down
16 changes: 16 additions & 0 deletions src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ struct PTO2TensorMap {
// Per-ring validity threshold (for lazy invalidation)
int32_t last_task_alives[PTO2_MAX_RING_DEPTH]; // Cached from shared memory per ring

// Per-ring iteration isolation for parallel for.
// -1 = normal mode (no filtering); >= 0 = parallel for mode, entries with
// local_id < iter_start on the same ring are filtered out during lookup.
int32_t iter_start_local_ids[PTO2_MAX_RING_DEPTH];

// Per-ring cleanup progress (for periodic cleanup_retired)
int32_t last_cleanup[PTO2_MAX_RING_DEPTH]{};

Expand Down Expand Up @@ -328,6 +333,17 @@ struct PTO2TensorMap {
continue;
}

// Parallel for iteration isolation: skip entries from prior iterations
// on the same ring. Outer-ring entries have iter_start_local_ids == -1
// and pass through unconditionally.
{
int32_t iter_start = iter_start_local_ids[cur_entry->producer_task_id.ring()];
if (iter_start >= 0 && static_cast<int32_t>(cur_entry->producer_task_id.local()) < iter_start) {
cur_entry = next_entry;
continue;
}
}

// Entry is valid - check if regions OVERLAP (not just exact match)
// Since we hash only by base_ptr, all entries in this bucket have
// potential to overlap. We must check actual byte-range overlap.
Expand Down
Loading
Loading