Skip to content

Commit 4577604

Browse files
committed
add test for scalar case to be consistent with vector mult
1 parent 0e7a58a commit 4577604

4 files changed

Lines changed: 97 additions & 4 deletions

File tree

src/atoms/affine/parameter.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
6464
expr *node = &pnode->base;
6565
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
6666
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
67-
67+
68+
// TODO we should assert that the values array has the correct size.
6869
pnode->param_id = param_id;
6970

7071
if (values == NULL)

src/atoms/affine/scalar_mult.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ static void forward(expr *node, const double *u)
3030
expr *child = node->left;
3131
scalar_mult_expr *snode = (scalar_mult_expr *) node;
3232

33-
/* call forward for param_source expr tree
34-
ex: broadcast(param) or promote(const)*/
35-
// snode->param_source->forward(snode->param_source, NULL);
33+
/* call forward for param_source expr tree (this extra logic is needed
34+
in case the parameter is a broadcast or promote node which needs to refresh
35+
its values) */
36+
if (snode->base.needs_parameter_refresh)
37+
{
38+
snode->param_source->forward(snode->param_source, NULL);
39+
snode->base.needs_parameter_refresh = false;
40+
}
3641

3742
double a = snode->param_source->value[0];
3843

@@ -125,5 +130,8 @@ expr *new_scalar_mult(expr *param_node, expr *child)
125130
mult_node->param_source = param_node;
126131
expr_retain(param_node);
127132

133+
/* special case for handling broadcasting of constants correctly */
134+
mult_node->base.needs_parameter_refresh = true;
135+
128136
return node;
129137
}

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ int main(void)
366366
mu_run_test(test_constant_promote_vector_mult, tests_run);
367367
mu_run_test(test_param_broadcast_vector_mult, tests_run);
368368
mu_run_test(test_param_promote_vector_mult, tests_run);
369+
mu_run_test(test_const_sum_scalar_mult, tests_run);
370+
mu_run_test(test_param_sum_scalar_mult, tests_run);
369371
#endif /* PROFILE_ONLY */
370372

371373
#ifdef PROFILE_ONLY

tests/problem/test_param_broadcast.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,86 @@ const char *test_param_promote_vector_mult(void)
175175
free_problem(prob);
176176
return 0;
177177
}
178+
179+
const char *test_const_sum_scalar_mult(void)
180+
{
181+
int n = 6;
182+
183+
/* minimize sum(x) subject to sum(c) * x, with c constant */
184+
expr *x = new_variable(1, 1, 0, n);
185+
expr *objective = new_sum(x, -1);
186+
double c_vals[3] = {1.0, 2.0, 3.0};
187+
expr *c = new_parameter(1, 3, PARAM_FIXED, n, c_vals);
188+
expr *c_sum = new_sum(c, -1);
189+
expr *constraint = new_scalar_mult(c_sum, x);
190+
expr *constraints[1] = {constraint};
191+
problem *prob = new_problem(objective, constraints, 1, false);
192+
193+
problem_init_derivatives(prob);
194+
195+
/* point for evaluating */
196+
double x_vals[1] = {4.0};
197+
198+
problem_constraint_forward(prob, x_vals);
199+
double constrs[1] = {6.0 * 4.0};
200+
problem_jacobian(prob);
201+
double jac_x[1] = {6.0};
202+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 1));
203+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 1));
204+
205+
mu_assert("check_jacobian failed",
206+
check_jacobian_num(constraint, x_vals, NUMERICAL_DIFF_DEFAULT_H));
207+
free_problem(prob);
208+
return 0;
209+
}
210+
211+
const char *test_param_sum_scalar_mult(void)
212+
{
213+
int n = 6;
214+
215+
/* minimize sum(x) subject to sum(p) * x, with p parameter */
216+
expr *x = new_variable(1, 1, 0, n);
217+
expr *objective = new_sum(x, -1);
218+
double c_vals[3] = {1.0, 2.0, 3.0};
219+
expr *c = new_parameter(1, 3, 0, n, c_vals);
220+
expr *c_sum = new_sum(c, -1);
221+
expr *constraint = new_scalar_mult(c_sum, x);
222+
expr *constraints[1] = {constraint};
223+
problem *prob = new_problem(objective, constraints, 1, false);
224+
225+
expr *param_nodes[1] = {c};
226+
problem_register_params(prob, param_nodes, 1);
227+
problem_init_derivatives(prob);
228+
229+
/* point for evaluating */
230+
double x_vals[1] = {4.0};
231+
232+
problem_constraint_forward(prob, x_vals);
233+
double constrs[1] = {6.0 * 4.0};
234+
problem_jacobian(prob);
235+
double jac_x[1] = {6.0};
236+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 1));
237+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 1));
238+
239+
mu_assert("check_jacobian failed",
240+
check_jacobian_num(constraint, x_vals, NUMERICAL_DIFF_DEFAULT_H));
241+
242+
/* second iteration after updating parameter */
243+
double theta[3] = {5.0, 4.0, 3.0};
244+
problem_update_params(prob, theta);
245+
problem_constraint_forward(prob, x_vals);
246+
problem_jacobian(prob);
247+
double updated_constrs[1] = {12.0 * 4.0};
248+
double updated_jac_x[1] = {12.0};
249+
mu_assert("vals fail",
250+
cmp_double_array(prob->constraint_values, updated_constrs, 1));
251+
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, updated_jac_x, 1));
252+
253+
mu_assert("check_jacobian failed",
254+
check_jacobian_num(constraint, x_vals, NUMERICAL_DIFF_DEFAULT_H));
255+
256+
free_problem(prob);
257+
return 0;
258+
}
259+
178260
#endif /* TEST_PARAM_BROADCAST_H */

0 commit comments

Comments
 (0)