Skip to content

Commit 401fadc

Browse files
committed
fixed out of bounds access on test
1 parent 71e6021 commit 401fadc

7 files changed

Lines changed: 10 additions & 10 deletions

File tree

include/affine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ expr *new_linear(expr *u, const CSR_Matrix *A);
99
expr *new_add(expr *left, expr *right);
1010
expr *new_sum(expr *child, int axis);
1111

12-
expr *new_constant(int d1, int d2, const double *values);
12+
expr *new_constant(int d1, int d2, int n_vars, const double *values);
1313
expr *new_variable(int d1, int d2, int var_id, int n_vars);
1414

1515
#endif /* AFFINE_H */

src/affine/constant.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ static bool is_affine(expr *node)
1414
return true; /* constant is affine */
1515
}
1616

17-
expr *new_constant(int d1, int d2, const double *values)
17+
expr *new_constant(int d1, int d2, int n_vars, const double *values)
1818
{
19-
expr *node = new_expr(d1, d2, node->n_vars);
19+
expr *node = new_expr(d1, d2, n_vars);
2020
if (!node) return NULL;
2121

2222
/* Copy constant values */

tests/forward_pass/affine/test_add.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const char *test_addition()
1212
double u[2] = {3.0, 4.0};
1313
double c[2] = {1.0, 2.0};
1414
expr *var = new_variable(2, 1, 0, 2);
15-
expr *const_node = new_constant(2, 1, c);
15+
expr *const_node = new_constant(2, 1, 0, c);
1616
expr *sum = new_add(var, const_node);
1717
sum->forward(sum, u);
1818
double expected[2] = {4.0, 6.0};

tests/forward_pass/affine/test_sum.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const char *test_sum_axis_neg1()
1717
Stored as: [1, 2, 3, 4, 5, 6]
1818
*/
1919
double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
20-
expr *const_node = new_constant(3, 2, values);
20+
expr *const_node = new_constant(3, 2, 0, values);
2121
expr *log_node = new_log(const_node);
2222
expr *sum_node = new_sum(log_node, -1);
2323
sum_node->forward(sum_node, NULL);
@@ -44,7 +44,7 @@ const char *test_sum_axis_0()
4444
Stored as: [1, 2, 3, 4, 5, 6]
4545
*/
4646
double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
47-
expr *const_node = new_constant(3, 2, values);
47+
expr *const_node = new_constant(3, 2, 0, values);
4848
expr *log_node = new_log(const_node);
4949
expr *sum_node = new_sum(log_node, 0);
5050
sum_node->forward(sum_node, NULL);
@@ -73,7 +73,7 @@ const char *test_sum_axis_1()
7373
Stored as: [1, 2, 3, 4, 5, 6]
7474
*/
7575
double values[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
76-
expr *const_node = new_constant(3, 2, values);
76+
expr *const_node = new_constant(3, 2, 0, values);
7777
expr *log_node = new_log(const_node);
7878
expr *sum_node = new_sum(log_node, 1);
7979
sum_node->forward(sum_node, NULL);

tests/forward_pass/affine/test_variable_constant.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const char *test_constant()
2121
{
2222
double c[2] = {5.0, 10.0};
2323
double u[2] = {0.0, 0.0};
24-
expr *const_node = new_constant(2, 1, c);
24+
expr *const_node = new_constant(2, 1, 0, c);
2525
const_node->forward(const_node, u);
2626
mu_assert("Constant test failed", cmp_double_array(const_node->value, c, 2));
2727
free_expr(const_node);

tests/forward_pass/composite/test_composite.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ const char *test_composite()
1616
/* Build tree: log(exp(x) + c) */
1717
expr *var = new_variable(2, 1, 0, 2);
1818
expr *exp_node = new_exp(var);
19-
expr *const_node = new_constant(2, 1, c);
19+
expr *const_node = new_constant(2, 1, 0, c);
2020
expr *sum = new_add(exp_node, const_node);
2121
expr *log_node = new_log(sum);
2222

tests/utils/test_csr_matrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ const char *test_sum_all_rows_csr()
217217
memcpy(A->i, Ai, 7 * sizeof(int));
218218
memcpy(A->p, Ap, 4 * sizeof(int));
219219
CSR_Matrix *C = new_csr_matrix(1, 4, 4);
220-
int_double_pair *pairs = new_int_double_pair_array(4);
220+
int_double_pair *pairs = new_int_double_pair_array(7);
221221
sum_all_rows_csr(A, C, pairs);
222222
double Cx_correct[4] = {6.0, 5.0, 10.0, 7.0};
223223
int Ci_correct[4] = {0, 1, 2, 3};

0 commit comments

Comments
 (0)