1515#define ORCH_BUILD_GRAPH_PTO_TYPES_H
1616
1717#include <stdint.h>
18- #include <assert.h>
1918#include <string.h>
2019
2120#if defined(__aarch64__ )
@@ -68,56 +67,78 @@ struct PTOParam {
6867 uint64_t scalars [PTO2_MAX_SCALAR_PARAMS ];
6968 int32_t tensor_count {0 };
7069 int32_t scalar_count {0 };
70+ bool has_error {false};
71+ const char * error_msg {nullptr };
7172
7273 void reset () {
7374 tensor_count = 0 ;
7475 scalar_count = 0 ;
76+ has_error = false;
77+ error_msg = nullptr ;
7578 }
7679
77- bool check_add_tensor_valid () const {
78- assert (scalar_count == 0 && "scalar must add after all tensor added" );
80+ void set_error (const char * msg ) {
81+ if (!has_error ) {
82+ has_error = true;
83+ error_msg = msg ;
84+ }
85+ }
86+
87+ bool check_add_tensor_valid () {
88+ if (scalar_count != 0 ) {
89+ set_error ("add_input/add_output/add_inout called after add_scalar: "
90+ "all tensors must be added before any scalars" );
91+ return false;
92+ }
93+ if (tensor_count >= PTO2_MAX_TENSOR_PARAMS ) {
94+ set_error ("Too many tensor params (exceeds PTO2_MAX_TENSOR_PARAMS=32)" );
95+ return false;
96+ }
7997 return true;
8098 }
8199
82100 void add_input (Tensor & t ) {
83- if (!check_add_tensor_valid ()) {
101+ if (!check_add_tensor_valid ()) { return ; }
102+ if (t .buffer .addr == 0 ) {
103+ set_error ("INPUT tensor must have a non-NULL buffer address" );
84104 return ;
85105 }
86- assert (t .buffer .addr != 0 && "INPUT param must have a non-NULL buffer address" );
87- assert (tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params" );
88106 tensors [tensor_count ] = & t ;
89107 tensor_types [tensor_count ] = PTOParamType ::INPUT ;
90108 tensor_count ++ ;
91109 }
92110
93111 void add_output (Tensor & t ) {
94- if (!check_add_tensor_valid ()) {
95- return ;
96- }
97- assert (tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params" );
112+ if (!check_add_tensor_valid ()) { return ; }
98113 tensors [tensor_count ] = & t ;
99114 tensor_types [tensor_count ] = PTOParamType ::OUTPUT ;
100115 tensor_count ++ ;
101116 }
102117
103118 void add_inout (Tensor & t ) {
104- if (!check_add_tensor_valid ()) {
119+ if (!check_add_tensor_valid ()) { return ; }
120+ if (t .buffer .addr == 0 ) {
121+ set_error ("INOUT tensor must have a non-NULL buffer address" );
105122 return ;
106123 }
107- assert (t .buffer .addr != 0 && "INOUT param must have a non-NULL buffer address" );
108- assert (tensor_count < PTO2_MAX_TENSOR_PARAMS && "Too many tensor params" );
109124 tensors [tensor_count ] = & t ;
110125 tensor_types [tensor_count ] = PTOParamType ::INOUT ;
111126 tensor_count ++ ;
112127 }
113128
114129 void add_scalar (uint64_t v ) {
115- assert (scalar_count < PTO2_MAX_SCALAR_PARAMS && "Too many scalar params" );
130+ if (scalar_count >= PTO2_MAX_SCALAR_PARAMS ) {
131+ set_error ("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)" );
132+ return ;
133+ }
116134 scalars [scalar_count ++ ] = v ;
117135 }
118136
119137 void add_scalars (const uint64_t * values , int count ) {
120- assert (scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params" );
138+ if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS ) {
139+ set_error ("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)" );
140+ return ;
141+ }
121142 memcpy (& scalars [scalar_count ], values , count * sizeof (uint64_t ));
122143 scalar_count += count ;
123144 }
@@ -129,7 +150,10 @@ struct PTOParam {
129150 * Uses NEON to process 4 elements per iteration on aarch64.
130151 */
131152 void add_scalars_i32 (const int32_t * values , int count ) {
132- assert (scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params" );
153+ if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS ) {
154+ set_error ("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)" );
155+ return ;
156+ }
133157 uint64_t * dst = & scalars [scalar_count ];
134158#if defined(__aarch64__ )
135159 int i = 0 ;
@@ -154,13 +178,17 @@ struct PTOParam {
154178 /**
155179 * Copy scalars from another PTOParam's scalar array.
156180 * Useful when multiple tasks share the same scalar data (e.g., block indices).
157- * Rounds up to cache line boundary — both arrays are 1024B so no overrun.
158181 */
159182 void copy_scalars_from (const PTOParam & src , int src_offset , int count ) {
160- assert (src_offset + count <= src .scalar_count && "Source scalar range out of bounds" );
161- assert (scalar_count + count <= PTO2_MAX_SCALAR_PARAMS && "Too many scalar params" );
162- size_t bytes = (count * sizeof (uint64_t ) + 63 ) & ~size_t (63 );
163- memcpy (& scalars [scalar_count ], & src .scalars [src_offset ], bytes );
183+ if (src_offset + count > src .scalar_count ) {
184+ set_error ("Source scalar range out of bounds in copy_scalars_from" );
185+ return ;
186+ }
187+ if (scalar_count + count > PTO2_MAX_SCALAR_PARAMS ) {
188+ set_error ("Too many scalar params (exceeds PTO2_MAX_SCALAR_PARAMS=128)" );
189+ return ;
190+ }
191+ memcpy (& scalars [scalar_count ], & src .scalars [src_offset ], count * sizeof (uint64_t ));
164192 scalar_count += count ;
165193 }
166194};
0 commit comments