Skip to content

Commit b64e381

Browse files
add initial attempt for fixing parameters and broadcasting (#73)
* add initial attempt for fixing parameters and broadcasting * add test for params with broadcast * cleanup test to fit more with other tests style * some progress on supporting backwards compatible constants * add some parameter broadcast tests as well * cleanup left matmul as well * some very minor cleanups * some error checks and numerical diff to tests * we don't always have to run forward of parameter in left matmul * we don't always have to call forward for parameter node in vector mult * comment out forward parameter pass in scalar mult because it is not needed, I think * add test for scalar case to be consistent with vector mult --------- Co-authored-by: dance858 <danielcederberg1@gmail.com>
1 parent 30a7868 commit b64e381

14 files changed

Lines changed: 439 additions & 43 deletions

src/atoms/affine/left_matmul.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ static void refresh_param_values(left_matmul_expr *lnode)
6363
static void forward(expr *node, const double *u)
6464
{
6565
left_matmul_expr *lnode = (left_matmul_expr *) node;
66+
67+
/* call forward on param_source if it exists and needs refresh */
68+
if (lnode->param_source != NULL && lnode->base.needs_parameter_refresh)
69+
{
70+
lnode->param_source->forward(lnode->param_source, NULL);
71+
}
72+
6673
refresh_param_values(lnode);
6774

6875
expr *x = node->left;

src/atoms/affine/parameter.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "atoms/affine.h"
1919
#include "subexpr.h"
2020
#include "utils/tracked_alloc.h"
21+
#include <stdio.h>
2122
#include <stdlib.h>
2223
#include <string.h>
2324

@@ -63,13 +64,16 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
6364
expr *node = &pnode->base;
6465
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
6566
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
66-
67+
68+
// TODO we should assert that the values array has the correct size.
6769
pnode->param_id = param_id;
6870

69-
if (values != NULL)
71+
if (values == NULL)
7072
{
71-
memcpy(node->value, values, node->size * sizeof(double));
73+
fprintf(stderr, "Parameter values should always be set, this is a bug and"
74+
" should be reported\n");
75+
exit(1);
7276
}
73-
77+
memcpy(node->value, values, node->size * sizeof(double));
7478
return node;
7579
}

src/atoms/affine/scalar_mult.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,18 @@
2828
static void forward(expr *node, const double *u)
2929
{
3030
expr *child = node->left;
31-
double a = ((scalar_mult_expr *) node)->param_source->value[0];
31+
scalar_mult_expr *snode = (scalar_mult_expr *) node;
32+
33+
/* call forward for param_source expr tree (this extra logic is needed
34+
in case the parameter is a broadcast or promote node which needs to refresh
35+
its values) */
36+
if (snode->base.needs_parameter_refresh)
37+
{
38+
snode->param_source->forward(snode->param_source, NULL);
39+
snode->base.needs_parameter_refresh = false;
40+
}
41+
42+
double a = snode->param_source->value[0];
3243

3344
/* child's forward pass */
3445
child->forward(child, u);
@@ -119,5 +130,8 @@ expr *new_scalar_mult(expr *param_node, expr *child)
119130
mult_node->param_source = param_node;
120131
expr_retain(param_node);
121132

133+
/* special case for handling broadcasting of constants correctly */
134+
mult_node->base.needs_parameter_refresh = true;
135+
122136
return node;
123137
}

src/atoms/affine/vector_mult.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,18 @@
2828
static void forward(expr *node, const double *u)
2929
{
3030
expr *child = node->left;
31-
const double *a = ((vector_mult_expr *) node)->param_source->value;
31+
vector_mult_expr *vnode = (vector_mult_expr *) node;
32+
33+
/* call forward for param_source expr tree (this extra logic is needed
34+
in case the parameter is a broadcast or promote node which needs to refresh
35+
its values) */
36+
if (vnode->base.needs_parameter_refresh)
37+
{
38+
vnode->param_source->forward(vnode->param_source, NULL);
39+
vnode->base.needs_parameter_refresh = false;
40+
}
41+
42+
const double *a = vnode->param_source->value;
3243

3344
/* child's forward pass */
3445
child->forward(child, u);
@@ -129,5 +140,8 @@ expr *new_vector_mult(expr *param_node, expr *child)
129140
vnode->param_source = param_node;
130141
expr_retain(param_node);
131142

143+
/* special case for handling broadcasting of constants correctly */
144+
vnode->base.needs_parameter_refresh = true;
145+
132146
return node;
133147
}

src/problem.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,16 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node
369369
prob->total_parameter_size = 0;
370370
for (int i = 0; i < n_param_nodes; i++)
371371
{
372+
373+
if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED)
374+
{
375+
fprintf(stderr, "can this ever happen? \n");
376+
exit(1);
377+
}
378+
379+
// TODO do we need to skip fixed params? maybe we adopt the convention
380+
// that we don't ever register fixed params?
381+
if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED) continue;
372382
prob->total_parameter_size += param_nodes[i]->size;
373383
}
374384
}
@@ -387,6 +397,13 @@ void problem_update_params(problem *prob, const double *theta)
387397
{
388398
expr *pnode = prob->param_nodes[i];
389399
parameter_expr *param = (parameter_expr *) pnode;
400+
401+
if (param->param_id == PARAM_FIXED)
402+
{
403+
fprintf(stderr, "can this ever happen? \n");
404+
exit(1);
405+
}
406+
390407
if (param->param_id == PARAM_FIXED) continue;
391408
int offset = param->param_id;
392409
memcpy(pnode->value, theta + offset, pnode->size * sizeof(double));

tests/all_tests.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "jacobian_tests/other/test_prod_axis_zero.h"
5050
#include "jacobian_tests/other/test_quad_form.h"
5151
#include "numerical_diff/test_numerical_diff.h"
52+
#include "problem/test_param_broadcast.h"
5253
#include "problem/test_param_prob.h"
5354
#include "problem/test_problem.h"
5455
#include "utils/test_cblas.h"
@@ -358,6 +359,15 @@ int main(void)
358359
mu_run_test(test_param_right_matmul_rectangular, tests_run);
359360
mu_run_test(test_param_shared_left_matmul_problem, tests_run);
360361
mu_run_test(test_param_fixed_skip_in_update, tests_run);
362+
mu_run_test(test_param_scalar_mult_problem_with_constant, tests_run);
363+
364+
printf("\n--- Parameter + Broadcast Tests ---\n");
365+
mu_run_test(test_constant_broadcast_vector_mult, tests_run);
366+
mu_run_test(test_constant_promote_vector_mult, tests_run);
367+
mu_run_test(test_param_broadcast_vector_mult, tests_run);
368+
mu_run_test(test_param_promote_vector_mult, tests_run);
369+
mu_run_test(test_const_sum_scalar_mult, tests_run);
370+
mu_run_test(test_param_sum_scalar_mult, tests_run);
361371
#endif /* PROFILE_ONLY */
362372

363373
#ifdef PROFILE_ONLY

tests/jacobian_tests/affine/test_left_matmul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ const char *test_jacobian_left_matmul_exp_composite(void)
138138
expr *A_exp_Bx = new_left_matmul(NULL, exp_Bx, A);
139139

140140
mu_assert("check_jacobian failed",
141-
check_jacobian(A_exp_Bx, x_vals, NUMERICAL_DIFF_DEFAULT_H));
141+
check_jacobian_num(A_exp_Bx, x_vals, NUMERICAL_DIFF_DEFAULT_H));
142142

143143
free_csr_matrix(A);
144144
free_csr_matrix(B);

tests/jacobian_tests/composite/test_chain_rule_jacobian.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ const char *test_jacobian_exp_sum(void)
1717
expr *exp_sum_x = new_exp(sum_x);
1818

1919
mu_assert("check_jacobian failed",
20-
check_jacobian(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
20+
check_jacobian_num(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
2121

2222
free_expr(exp_sum_x);
2323
return 0;
@@ -34,7 +34,7 @@ const char *test_jacobian_exp_sum_mult(void)
3434
expr *exp_sum_xy = new_exp(sum_xy);
3535

3636
mu_assert("check_jacobian failed",
37-
check_jacobian(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H));
37+
check_jacobian_num(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H));
3838

3939
free_expr(exp_sum_xy);
4040
return 0;
@@ -49,7 +49,7 @@ const char *test_jacobian_sin_cos(void)
4949
expr *sin_cos_x = new_sin(cos_x);
5050

5151
mu_assert("check_jacobian failed",
52-
check_jacobian(sin_cos_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
52+
check_jacobian_num(sin_cos_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
5353

5454
free_expr(sin_cos_x);
5555
return 0;
@@ -68,7 +68,7 @@ const char *test_jacobian_cos_sin_multiply(void)
6868
expr *multiply = new_elementwise_mult(sum, sin_y);
6969

7070
mu_assert("check_jacobian failed",
71-
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
71+
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
7272

7373
free_expr(multiply);
7474
return 0;
@@ -87,7 +87,7 @@ const char *test_jacobian_Ax_Bx_multiply(void)
8787
expr *multiply = new_elementwise_mult(Ax, Bx);
8888

8989
mu_assert("check_jacobian failed",
90-
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
90+
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
9191

9292
free_expr(multiply);
9393
free_csr_matrix(A);
@@ -107,7 +107,7 @@ const char *test_jacobian_AX_BX_multiply(void)
107107
expr *multiply = new_elementwise_mult(new_sin(AX), new_cos(BX));
108108

109109
mu_assert("check_jacobian failed",
110-
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
110+
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
111111

112112
free_expr(multiply);
113113
free_csr_matrix(A);
@@ -137,7 +137,7 @@ const char *test_jacobian_quad_form_Ax(void)
137137
expr *node = new_quad_form(sin_Ax, Q);
138138

139139
mu_assert("check_jacobian failed",
140-
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
140+
check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
141141

142142
free_expr(node);
143143
free_csr_matrix(A);
@@ -164,7 +164,7 @@ const char *test_jacobian_quad_form_exp(void)
164164
expr *node = new_quad_form(exp_x, Q);
165165

166166
mu_assert("check_jacobian failed",
167-
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
167+
check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
168168

169169
free_expr(node);
170170
free_csr_matrix(Q);
@@ -183,7 +183,7 @@ const char *test_jacobian_matmul_exp_exp(void)
183183
expr *Z = new_matmul(exp_X, exp_Y);
184184

185185
mu_assert("check_jacobian failed",
186-
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
186+
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
187187

188188
free_expr(Z);
189189
return 0;
@@ -201,7 +201,7 @@ const char *test_jacobian_matmul_sin_cos(void)
201201
expr *Z = new_matmul(sin_X, cos_Y);
202202

203203
mu_assert("check_jacobian failed",
204-
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
204+
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
205205

206206
free_expr(Z);
207207
return 0;
@@ -222,7 +222,7 @@ const char *test_jacobian_matmul_Ax_By(void)
222222
expr *Z = new_matmul(AX, BY); /* 3x2 */
223223

224224
mu_assert("check_jacobian failed",
225-
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
225+
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
226226

227227
free_expr(Z);
228228
free_csr_matrix(A);
@@ -246,7 +246,7 @@ const char *test_jacobian_matmul_sin_Ax_cos_Bx(void)
246246
expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */
247247

248248
mu_assert("check_jacobian failed",
249-
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
249+
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
250250

251251
free_expr(Z);
252252
free_csr_matrix(A);
@@ -263,7 +263,7 @@ const char *test_jacobian_matmul_X_X(void)
263263
expr *Z = new_matmul(X, X); /* 2x2 */
264264

265265
mu_assert("check_jacobian failed",
266-
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
266+
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
267267

268268
free_expr(Z);
269269
return 0;

tests/jacobian_tests/composite/test_composite_exp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ const char *test_jacobian_composite_exp_add(void)
7171
expr *sum = new_add(exp_Ax, exp_By);
7272

7373
mu_assert("check_jacobian failed",
74-
check_jacobian(sum, u_vals, NUMERICAL_DIFF_DEFAULT_H));
74+
check_jacobian_num(sum, u_vals, NUMERICAL_DIFF_DEFAULT_H));
7575

7676
free_expr(sum);
7777
free_csr_matrix(A);

tests/numerical_diff.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ double *numerical_jacobian(expr *node, const double *u, double h)
6161
return J;
6262
}
6363

64-
int check_jacobian(expr *node, const double *u, double h)
64+
int check_jacobian_num(expr *node, const double *u, double h)
6565
{
6666
int m = node->size;
6767
int n = node->n_vars;

0 commit comments

Comments
 (0)