Skip to content

Commit 6176be6

Browse files
committed
some progress on supporting backwards compatible constants
1 parent 3310a91 commit 6176be6

7 files changed

Lines changed: 65 additions & 167 deletions

File tree

src/atoms/affine/parameter.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "atoms/affine.h"
1919
#include "subexpr.h"
2020
#include "utils/tracked_alloc.h"
21+
#include <stdio.h>
2122
#include <stdlib.h>
2223
#include <string.h>
2324

@@ -65,11 +66,13 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
6566
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
6667

6768
pnode->param_id = param_id;
68-
69-
if (values != NULL)
69+
70+
if (values == NULL)
7071
{
71-
memcpy(node->value, values, node->size * sizeof(double));
72+
fprintf(stderr, "Parameter values should always be set, this is a bug and"
73+
" should be reported\n");
74+
exit(1);
7275
}
73-
76+
memcpy(node->value, values, node->size * sizeof(double));
7477
return node;
7578
}

src/atoms/affine/scalar_mult.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ static void forward(expr *node, const double *u)
3030
expr *child = node->left;
3131
scalar_mult_expr *snode = (scalar_mult_expr *) node;
3232

33-
/* Refresh param_source expression tree if parameters changed.*/
34-
if (node->needs_parameter_refresh)
35-
{
36-
/* pass NULL to forward: constant param_source never depends on u */
37-
snode->param_source->forward(snode->param_source, NULL);
38-
node->needs_parameter_refresh = false;
39-
}
33+
/* call forward for param_source expr tree
34+
ex: broadcast(param) or promote(const)*/
35+
snode->param_source->forward(snode->param_source, NULL);
4036

4137
double a = snode->param_source->value[0];
4238

@@ -128,6 +124,7 @@ expr *new_scalar_mult(expr *param_node, expr *child)
128124

129125
mult_node->param_source = param_node;
130126
expr_retain(param_node);
127+
//node->needs_parameter_refresh = true;
131128

132129
return node;
133130
}

src/atoms/affine/vector_mult.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ static void forward(expr *node, const double *u)
3030
expr *child = node->left;
3131
vector_mult_expr *vnode = (vector_mult_expr *) node;
3232

33-
/* Refresh param_source expression tree if parameters changed.*/
34-
if (node->needs_parameter_refresh)
35-
{
36-
/* pass NULL to forward: constant param_source never depends on u */
37-
vnode->param_source->forward(vnode->param_source, NULL);
38-
node->needs_parameter_refresh = false;
39-
}
33+
/* call forward for param_source expr tree
34+
ex: broadcast(param) or promote(const)*/
35+
vnode->param_source->forward(vnode->param_source, NULL);
4036

4137
const double *a = vnode->param_source->value;
4238

@@ -138,6 +134,7 @@ expr *new_vector_mult(expr *param_node, expr *child)
138134

139135
vnode->param_source = param_node;
140136
expr_retain(param_node);
137+
//node->needs_parameter_refresh = true;
141138

142139
return node;
143140
}

src/problem.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node
369369
prob->total_parameter_size = 0;
370370
for (int i = 0; i < n_param_nodes; i++)
371371
{
372+
// TODO do we need to skip fixed params? maybe we adopt the convention
373+
// that we don't ever register fixed params?
374+
if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED) continue;
372375
prob->total_parameter_size += param_nodes[i]->size;
373376
}
374377
}

tests/all_tests.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,7 @@ int main(void)
361361
mu_run_test(test_param_fixed_skip_in_update, tests_run);
362362

363363
printf("\n--- Parameter + Broadcast Tests ---\n");
364-
mu_run_test(test_param_broadcast_vector_mult, tests_run);
365-
mu_run_test(test_param_sum_scalar_mult, tests_run);
366-
mu_run_test(test_param_broadcast_left_matmul, tests_run);
364+
mu_run_test(test_constant_broadcast_vector_mult, tests_run);
367365
#endif /* PROFILE_ONLY */
368366

369367
#ifdef PROFILE_ONLY

tests/problem/test_param_broadcast.h

Lines changed: 18 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -12,171 +12,62 @@
1212
#include "subexpr.h"
1313
#include "test_helpers.h"
1414

15-
const char *test_param_broadcast_vector_mult(void)
15+
const char *test_constant_broadcast_vector_mult(void)
1616
{
1717
int n = 6;
1818

19-
/* minimize sum(x) subject to broadcast(p) ∘ x, with p parameter */
19+
/* minimize sum(x) subject to broadcast(c) ∘ x, with c constant */
2020
expr *x = new_variable(2, 3, 0, n);
2121
expr *objective = new_sum(x, -1);
22-
expr *p_param = new_parameter(1, 3, 0, n, NULL);
23-
expr *p_bcast = new_broadcast(p_param, 2, 3);
24-
expr *constraint = new_vector_mult(p_bcast, x);
22+
double c_vals[3] = {1.0, 2.0, 3.0};
23+
expr *c = new_parameter(1, 3, PARAM_FIXED, n, c_vals);
24+
expr *c_bcast = new_broadcast(c, 2, 3);
25+
expr *constraint = new_vector_mult(c_bcast, x);
2526
expr *constraints[1] = {constraint};
2627
problem *prob = new_problem(objective, constraints, 1, false);
27-
28-
/* register parameters and fill sparsity patterns */
29-
expr *param_nodes[1] = {p_param};
30-
problem_register_params(prob, param_nodes, 1);
3128
problem_init_derivatives(prob);
3229

3330
/* point for evaluating */
3431
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
3532

36-
/* test 1: p=[1,2,3] */
37-
double theta[3] = {1.0, 2.0, 3.0};
38-
problem_update_params(prob, theta);
3933
problem_constraint_forward(prob, x_vals);
4034
double constrs[6] = {1.0, 2.0, 6.0, 8.0, 15.0, 18.0};
4135
problem_jacobian(prob);
4236
double jac_x[6] = {1.0, 1.0, 2.0, 2.0, 3.0, 3.0};
4337
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
4438
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
4539

46-
/* test 2: p=[10,20,30] */
47-
theta[0] = 10.0;
48-
theta[1] = 20.0;
49-
theta[2] = 30.0;
50-
problem_update_params(prob, theta);
51-
problem_constraint_forward(prob, x_vals);
52-
problem_jacobian(prob);
53-
constrs[0] = 10.0;
54-
constrs[1] = 20.0;
55-
constrs[2] = 60.0;
56-
constrs[3] = 80.0;
57-
constrs[4] = 150.0;
58-
constrs[5] = 180.0;
59-
jac_x[0] = 10.0;
60-
jac_x[1] = 10.0;
61-
jac_x[2] = 20.0;
62-
jac_x[3] = 20.0;
63-
jac_x[4] = 30.0;
64-
jac_x[5] = 30.0;
65-
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
66-
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
67-
6840
free_problem(prob);
6941
return 0;
7042
}
7143

72-
const char *test_param_sum_scalar_mult(void)
44+
const char *test_constant_promote_vector_mult(void)
7345
{
74-
int n = 3;
46+
int n = 6;
7547

76-
/* minimize sum(x) subject to sum(p) * x, with p parameter */
77-
expr *x = new_variable(3, 1, 0, n);
48+
/* minimize sum(x) subject to promote(c) ∘ x, with c constant */
49+
expr *x = new_variable(2, 3, 0, n);
7850
expr *objective = new_sum(x, -1);
79-
expr *p_param = new_parameter(2, 1, 0, n, NULL);
80-
expr *p_sum = new_sum(p_param, -1);
81-
expr *constraint = new_scalar_mult(p_sum, x);
51+
double c_vals = 3.0;
52+
expr *c = new_parameter(1, 1, PARAM_FIXED, n, &c_vals);
53+
expr *c_bcast = new_promote(c, 2, 3);
54+
expr *constraint = new_vector_mult(c_bcast, x);
8255
expr *constraints[1] = {constraint};
8356
problem *prob = new_problem(objective, constraints, 1, false);
8457

85-
/* register parameters and fill sparsity patterns */
86-
expr *param_nodes[1] = {p_param};
87-
problem_register_params(prob, param_nodes, 1);
8858
problem_init_derivatives(prob);
8959

9060
/* point for evaluating */
91-
double x_vals[3] = {1.0, 2.0, 3.0};
92-
93-
/* test 1: p=[1,2], sum(p)=3 */
94-
double theta[2] = {1.0, 2.0};
95-
problem_update_params(prob, theta);
96-
problem_constraint_forward(prob, x_vals);
97-
double constrs[3] = {3.0, 6.0, 9.0};
98-
problem_jacobian(prob);
99-
double jac_x[3] = {3.0, 3.0, 3.0};
100-
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 3));
101-
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 3));
102-
103-
/* test 2: p=[5,10], sum(p)=15 */
104-
theta[0] = 5.0;
105-
theta[1] = 10.0;
106-
problem_update_params(prob, theta);
107-
problem_constraint_forward(prob, x_vals);
108-
problem_jacobian(prob);
109-
constrs[0] = 15.0;
110-
constrs[1] = 30.0;
111-
constrs[2] = 45.0;
112-
jac_x[0] = 15.0;
113-
jac_x[1] = 15.0;
114-
jac_x[2] = 15.0;
115-
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 3));
116-
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 3));
117-
118-
free_problem(prob);
119-
return 0;
120-
}
121-
122-
const char *test_param_broadcast_left_matmul(void)
123-
{
124-
int n = 2;
125-
126-
/* minimize sum(x) subject to broadcast(p)@x, with p parameter */
127-
expr *x = new_variable(2, 1, 0, n);
128-
expr *objective = new_sum(x, -1);
129-
expr *p_param = new_parameter(1, 2, 0, n, NULL);
130-
expr *p_bcast = new_broadcast(p_param, 3, 2);
131-
double Ax[6] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
132-
expr *constraint = new_left_matmul_dense(p_bcast, x, 3, 2, Ax);
133-
expr *constraints[1] = {constraint};
134-
problem *prob = new_problem(objective, constraints, 1, false);
135-
136-
/* register parameters and fill sparsity patterns */
137-
expr *param_nodes[1] = {p_param};
138-
problem_register_params(prob, param_nodes, 1);
139-
problem_init_derivatives(prob);
140-
141-
/* point for evaluating and utilities for test */
142-
double x_vals[2] = {3.0, 4.0};
143-
int Ap[4] = {0, 2, 4, 6};
144-
int Ai[6] = {0, 1, 0, 1, 0, 1};
145-
146-
/* test 1: p=[1,2] */
147-
double theta[2] = {1.0, 2.0};
148-
problem_update_params(prob, theta);
149-
problem_constraint_forward(prob, x_vals);
150-
double constrs[3] = {11.0, 11.0, 11.0};
151-
problem_jacobian(prob);
152-
double jac_x[6] = {1.0, 2.0, 1.0, 2.0, 1.0, 2.0};
153-
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 3));
154-
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
155-
mu_assert("rows fail", cmp_int_array(prob->jacobian->p, Ap, 4));
156-
mu_assert("cols fail", cmp_int_array(prob->jacobian->i, Ai, 6));
61+
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
15762

158-
/* test 2: p=[5,10] */
159-
theta[0] = 5.0;
160-
theta[1] = 10.0;
161-
problem_update_params(prob, theta);
16263
problem_constraint_forward(prob, x_vals);
64+
double constrs[6] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0};
16365
problem_jacobian(prob);
164-
constrs[0] = 55.0;
165-
constrs[1] = 55.0;
166-
constrs[2] = 55.0;
167-
jac_x[0] = 5.0;
168-
jac_x[1] = 10.0;
169-
jac_x[2] = 5.0;
170-
jac_x[3] = 10.0;
171-
jac_x[4] = 5.0;
172-
jac_x[5] = 10.0;
173-
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 3));
66+
double jac_x[6] = {3.0, 3.0, 3.0, 3.0, 3.0, 3.0};
67+
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
17468
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));
175-
mu_assert("rows fail", cmp_int_array(prob->jacobian->p, Ap, 4));
176-
mu_assert("cols fail", cmp_int_array(prob->jacobian->i, Ai, 6));
17769

17870
free_problem(prob);
17971
return 0;
18072
}
181-
18273
#endif /* TEST_PARAM_BROADCAST_H */

0 commit comments

Comments
 (0)