Skip to content

Commit 77a81aa

Browse files
zhusy54zhusy54
andauthored
feat(runtime): Replace PTOParam assert with orchestration error handling (#306)
PTOParam validation previously used assert() which is stripped in release builds. Replace all assert() calls with an error-flag mechanism that integrates with the orchestration error path (LOG_ERROR + orch_error_code + fatal + emergency_shutdown), ensuring validation works in all builds. - Add has_error/error_msg fields and set_error() helper to PTOParam - Merge ordering and capacity checks into check_add_tensor_valid() - Convert all assert() in add_input/add_output/add_inout/add_scalar to set_error() with descriptive messages - Add PTO2_ERROR_INVALID_PARAM (5) error code - Validate params.has_error at pto2_submit_mixed_task entry - Update error handling documentation Co-authored-by: zhusy54 <zhusiyu1@hisilicon.com>
1 parent 15e6034 commit 77a81aa

3 files changed

Lines changed: 65 additions & 21 deletions

File tree

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,21 @@ void pto2_submit_mixed_task(
336336
// Fast path after fatal error — all subsequent submits are no-ops
337337
if (orch->fatal) { return; }
338338

339+
// Validate PTOParam construction (errors recorded by add_input/add_output/etc.)
340+
if (params.has_error) {
341+
LOG_ERROR("========================================");
342+
LOG_ERROR("FATAL: Invalid PTOParam Detected!");
343+
LOG_ERROR("========================================");
344+
LOG_ERROR("Error: %s", params.error_msg ? params.error_msg : "(unknown)");
345+
LOG_ERROR(" tensor_count: %d, scalar_count: %d", params.tensor_count, params.scalar_count);
346+
LOG_ERROR("This is a bug in the orchestration code.");
347+
LOG_ERROR("========================================");
348+
orch->sm_handle->header->orch_error_code.store(
349+
PTO2_ERROR_INVALID_PARAM, std::memory_order_release);
350+
orch->fatal = true;
351+
return;
352+
}
353+
339354
CYCLE_COUNT_START();
340355

341356
// === Validate submit inputs ===

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#define PTO2_ERROR_HEAP_RING_DEADLOCK 2
6565
#define PTO2_ERROR_FLOW_CONTROL_DEADLOCK 3
6666
#define PTO2_ERROR_DEP_POOL_OVERFLOW 4
67+
#define PTO2_ERROR_INVALID_PARAM 5 // PTOParam construction error (invalid params)
6768

6869
// Scheduler errors (100+): detected in scheduler threads
6970
#define PTO2_ERROR_SCHEDULER_TIMEOUT 100

src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#define ORCH_BUILD_GRAPH_PTO_TYPES_H
1616

1717
#include <stdint.h>
18-
#include <assert.h>
1918
#include <string.h>
2019

2120
#if defined(__aarch64__)
@@ -68,56 +67,78 @@ struct PTOParam {
6867
uint64_t scalars[PTO2_MAX_SCALAR_PARAMS];
6968
int32_t tensor_count{0};
7069
int32_t scalar_count{0};
70+
bool has_error{false};
71+
const char* error_msg{nullptr};
7172

7273
void reset() {
7374
tensor_count = 0;
7475
scalar_count = 0;
76+
has_error = false;
77+
error_msg = nullptr;
7578
}
7679

77-
bool check_add_tensor_valid() const {
78-
assert(scalar_count == 0 && "scalar must add after all tensor added");
80+
void set_error(const char* msg) {
81+
if (!has_error) {
82+
has_error = true;
83+
error_msg = msg;
84+
}
85+
}
86+
87+
bool check_add_tensor_valid() {
88+
if (scalar_count != 0) {
89+
set_error("add_input/add_output/add_inout called after add_scalar: "
90+
"all tensors must be added before any scalars");
91+
return false;
92+
}
93+
if (tensor_count >= PTO2_MAX_TENSOR_PARAMS) {
94+
set_error("Too many tensor params (exceeds PTO2_MAX_TENSOR_PARAMS=32)");
95+
return false;
96+
}
7997
return true;
8098
}
8199

82100
void add_input(Tensor& t) {
83-
if (!check_add_tensor_valid()) {
101+
if (!check_add_tensor_valid()) { return; }
102+
if (t.buffer.addr == 0) {
103+
set_error("INPUT tensor must have a non-NULL buffer address");
84104
return;
85105
}
86-
assert(t.buffer.addr != 0 && "INPUT param must have a non-NULL buffer address");
87-
assert(tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params");
88106
tensors[tensor_count] = &t;
89107
tensor_types[tensor_count] = PTOParamType::INPUT;
90108
tensor_count++;
91109
}
92110

93111
void add_output(Tensor& t) {
94-
if (!check_add_tensor_valid()) {
95-
return;
96-
}
97-
assert(tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params");
112+
if (!check_add_tensor_valid()) { return; }
98113
tensors[tensor_count] = &t;
99114
tensor_types[tensor_count] = PTOParamType::OUTPUT;
100115
tensor_count++;
101116
}
102117

103118
void add_inout(Tensor& t) {
104-
if (!check_add_tensor_valid()) {
119+
if (!check_add_tensor_valid()) { return; }
120+
if (t.buffer.addr == 0) {
121+
set_error("INOUT tensor must have a non-NULL buffer address");
105122
return;
106123
}
107-
assert(t.buffer.addr != 0 && "INOUT param must have a non-NULL buffer address");
108-
assert(tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params");
109124
tensors[tensor_count] = &t;
110125
tensor_types[tensor_count] = PTOParamType::INOUT;
111126
tensor_count++;
112127
}
113128

114129
void add_scalar(uint64_t v) {
115-
assert(scalar_count < PTO2_MAX_SCALAR_PARAMS && "Too many scalar params");
130+
if (scalar_count >= PTO2_MAX_SCALAR_PARAMS) {
131+
set_error("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)");
132+
return;
133+
}
116134
scalars[scalar_count++] = v;
117135
}
118136

119137
void add_scalars(const uint64_t* values, int count) {
120-
assert(scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params");
138+
if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS) {
139+
set_error("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)");
140+
return;
141+
}
121142
memcpy(&scalars[scalar_count], values, count * sizeof(uint64_t));
122143
scalar_count += count;
123144
}
@@ -129,7 +150,10 @@ struct PTOParam {
129150
* Uses NEON to process 4 elements per iteration on aarch64.
130151
*/
131152
void add_scalars_i32(const int32_t* values, int count) {
132-
assert(scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params");
153+
if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS) {
154+
set_error("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)");
155+
return;
156+
}
133157
uint64_t* dst = &scalars[scalar_count];
134158
#if defined(__aarch64__)
135159
int i = 0;
@@ -154,13 +178,17 @@ struct PTOParam {
154178
/**
155179
* Copy scalars from another PTOParam's scalar array.
156180
* Useful when multiple tasks share the same scalar data (e.g., block indices).
157-
* Rounds up to cache line boundary — both arrays are 1024B so no overrun.
158181
*/
159182
void copy_scalars_from(const PTOParam& src, int src_offset, int count) {
160-
assert(src_offset + count <= src.scalar_count && "Source scalar range out of bounds");
161-
assert(scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params");
162-
size_t bytes = (count * sizeof(uint64_t) + 63) & ~size_t(63);
163-
memcpy(&scalars[scalar_count], &src.scalars[src_offset], bytes);
183+
if (src_offset + count > src.scalar_count) {
184+
set_error("Source scalar range out of bounds in copy_scalars_from");
185+
return;
186+
}
187+
if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS) {
188+
set_error("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)");
189+
return;
190+
}
191+
memcpy(&scalars[scalar_count], &src.scalars[src_offset], count * sizeof(uint64_t));
164192
scalar_count += count;
165193
}
166194
};

0 commit comments

Comments
 (0)