Skip to content

Commit 9153a88

Browse files
committed
add some parameter broadcast tests as well
1 parent 6176be6 commit 9153a88

2 files changed

Lines changed: 87 additions & 0 deletions

File tree

tests/all_tests.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ int main(void)
362362

363363
printf("\n--- Parameter + Broadcast Tests ---\n");
364364
mu_run_test(test_constant_broadcast_vector_mult, tests_run);
365+
mu_run_test(test_constant_promote_vector_mult, tests_run);
366+
mu_run_test(test_param_broadcast_vector_mult, tests_run);
367+
mu_run_test(test_param_promote_vector_mult, tests_run);
365368
#endif /* PROFILE_ONLY */
366369

367370
#ifdef PROFILE_ONLY

tests/problem/test_param_broadcast.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,88 @@ const char *test_constant_promote_vector_mult(void)
7070
free_problem(prob);
7171
return 0;
7272
}
73+
74+
const char *test_param_broadcast_vector_mult(void)
75+
{
76+
int n = 6;
77+
78+
/* minimize sum(x) subject to broadcast(p) ∘ x, with p parameter */
79+
expr *x = new_variable(2, 3, 0, n);
80+
expr *objective = new_sum(x, -1);
81+
double c_vals[3] = {1.0, 2.0, 3.0};
82+
expr *c = new_parameter(1, 3, 0, n, c_vals);
83+
expr *c_bcast = new_broadcast(c, 2, 3);
84+
expr *constraint = new_vector_mult(c_bcast, x);
85+
expr *constraints[1] = {constraint};
86+
problem *prob = new_problem(objective, constraints, 1, false);
87+
88+
expr *param_nodes[1] = {c};
89+
problem_register_params(prob, param_nodes, 1);
90+
problem_init_derivatives(prob);
91+
92+
/* point for evaluating */
93+
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
94+
95+
problem_constraint_forward(prob, x_vals);
96+
double constrs[6] = {1.0, 2.0, 6.0, 8.0, 15.0, 18.0};
97+
problem_jacobian(prob);
98+
double jac_x[6] = {1.0, 1.0, 2.0, 2.0, 3.0, 3.0};
99+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
100+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
101+
102+
/* second iteration after updating parameter */
103+
double theta[3] = {5.0, 4.0, 3.0};
104+
problem_update_params(prob, theta);
105+
problem_constraint_forward(prob, x_vals);
106+
problem_jacobian(prob);
107+
double updated_constrs[6] = {5.0, 10.0, 12.0, 16.0, 15.0, 18.0};
108+
double updated_jac_x[6] = {5.0, 5.0, 4.0, 4.0, 3.0, 3.0};
109+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, updated_constrs, 6));
110+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, updated_jac_x, 6));
111+
112+
free_problem(prob);
113+
return 0;
114+
}
115+
116+
const char *test_param_promote_vector_mult(void)
117+
{
118+
int n = 6;
119+
120+
/* minimize sum(x) subject to promote(p) ∘ x, with p parameter */
121+
expr *x = new_variable(2, 3, 0, n);
122+
expr *objective = new_sum(x, -1);
123+
double c_vals = 3.0;
124+
expr *c = new_parameter(1, 1, 0, n, &c_vals);
125+
expr *c_bcast = new_promote(c, 2, 3);
126+
expr *constraint = new_vector_mult(c_bcast, x);
127+
expr *constraints[1] = {constraint};
128+
problem *prob = new_problem(objective, constraints, 1, false);
129+
130+
expr *param_nodes[1] = {c};
131+
problem_register_params(prob, param_nodes, 1);
132+
problem_init_derivatives(prob);
133+
134+
/* point for evaluating */
135+
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
136+
137+
problem_constraint_forward(prob, x_vals);
138+
double constrs[6] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0};
139+
problem_jacobian(prob);
140+
double jac_x[6] = {3.0, 3.0, 3.0, 3.0, 3.0, 3.0};
141+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
142+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
143+
144+
/* second iteration after updating parameter */
145+
double theta = 5.0;
146+
problem_update_params(prob, &theta);
147+
problem_constraint_forward(prob, x_vals);
148+
problem_jacobian(prob);
149+
double updated_constrs[6] = {5.0, 10.0, 15.0, 20.0, 25.0, 30.0};
150+
double updated_jac_x[6] = {5.0, 5.0, 5.0, 5.0, 5.0, 5.0};
151+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, updated_constrs, 6));
152+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, updated_jac_x, 6));
153+
154+
free_problem(prob);
155+
return 0;
156+
}
73157
#endif /* TEST_PARAM_BROADCAST_H */

0 commit comments

Comments
 (0)