Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/atoms/affine/left_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ static void refresh_param_values(left_matmul_expr *lnode)
static void forward(expr *node, const double *u)
{
left_matmul_expr *lnode = (left_matmul_expr *) node;

/* call forward on param_source if it exists and needs refresh */
if (lnode->param_source != NULL && lnode->base.needs_parameter_refresh)
{
lnode->param_source->forward(lnode->param_source, NULL);
}

refresh_param_values(lnode);

expr *x = node->left;
Expand Down
12 changes: 8 additions & 4 deletions src/atoms/affine/parameter.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "atoms/affine.h"
#include "subexpr.h"
#include "utils/tracked_alloc.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

Expand Down Expand Up @@ -63,13 +64,16 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
expr *node = &pnode->base;
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);


// TODO we should assert that the values array has the correct size.
pnode->param_id = param_id;

if (values != NULL)
if (values == NULL)
{
memcpy(node->value, values, node->size * sizeof(double));
fprintf(stderr, "Parameter values should always be set, this is a bug and"
" should be reported\n");
exit(1);
}

memcpy(node->value, values, node->size * sizeof(double));
return node;
}
16 changes: 15 additions & 1 deletion src/atoms/affine/scalar_mult.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,18 @@
static void forward(expr *node, const double *u)
{
expr *child = node->left;
double a = ((scalar_mult_expr *) node)->param_source->value[0];
scalar_mult_expr *snode = (scalar_mult_expr *) node;

/* call forward for param_source expr tree (this extra logic is needed
in case the parameter is a broadcast or promote node which needs to refresh
its values) */
if (snode->base.needs_parameter_refresh)
{
snode->param_source->forward(snode->param_source, NULL);
snode->base.needs_parameter_refresh = false;
}

double a = snode->param_source->value[0];

/* child's forward pass */
child->forward(child, u);
Expand Down Expand Up @@ -119,5 +130,8 @@ expr *new_scalar_mult(expr *param_node, expr *child)
mult_node->param_source = param_node;
expr_retain(param_node);

/* special case for handling broadcasting of constants correctly */
mult_node->base.needs_parameter_refresh = true;

return node;
}
16 changes: 15 additions & 1 deletion src/atoms/affine/vector_mult.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,18 @@
static void forward(expr *node, const double *u)
{
expr *child = node->left;
const double *a = ((vector_mult_expr *) node)->param_source->value;
vector_mult_expr *vnode = (vector_mult_expr *) node;

/* call forward for param_source expr tree (this extra logic is needed
in case the parameter is a broadcast or promote node which needs to refresh
its values) */
if (vnode->base.needs_parameter_refresh)
{
vnode->param_source->forward(vnode->param_source, NULL);
vnode->base.needs_parameter_refresh = false;
}

const double *a = vnode->param_source->value;

/* child's forward pass */
child->forward(child, u);
Expand Down Expand Up @@ -129,5 +140,8 @@ expr *new_vector_mult(expr *param_node, expr *child)
vnode->param_source = param_node;
expr_retain(param_node);

/* special case for handling broadcasting of constants correctly */
vnode->base.needs_parameter_refresh = true;

return node;
}
17 changes: 17 additions & 0 deletions src/problem.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,16 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node
prob->total_parameter_size = 0;
for (int i = 0; i < n_param_nodes; i++)
{

if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED)
{
fprintf(stderr, "can this ever happen? \n");
exit(1);
}

// TODO do we need to skip fixed params? maybe we adopt the convention
// that we don't ever register fixed params?
if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED) continue;
prob->total_parameter_size += param_nodes[i]->size;
}
}
Expand All @@ -387,6 +397,13 @@ void problem_update_params(problem *prob, const double *theta)
{
expr *pnode = prob->param_nodes[i];
parameter_expr *param = (parameter_expr *) pnode;

if (param->param_id == PARAM_FIXED)
{
fprintf(stderr, "can this ever happen? \n");
exit(1);
}

if (param->param_id == PARAM_FIXED) continue;
int offset = param->param_id;
memcpy(pnode->value, theta + offset, pnode->size * sizeof(double));
Expand Down
10 changes: 10 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "jacobian_tests/other/test_prod_axis_zero.h"
#include "jacobian_tests/other/test_quad_form.h"
#include "numerical_diff/test_numerical_diff.h"
#include "problem/test_param_broadcast.h"
#include "problem/test_param_prob.h"
#include "problem/test_problem.h"
#include "utils/test_cblas.h"
Expand Down Expand Up @@ -358,6 +359,15 @@ int main(void)
mu_run_test(test_param_right_matmul_rectangular, tests_run);
mu_run_test(test_param_shared_left_matmul_problem, tests_run);
mu_run_test(test_param_fixed_skip_in_update, tests_run);
mu_run_test(test_param_scalar_mult_problem_with_constant, tests_run);

printf("\n--- Parameter + Broadcast Tests ---\n");
mu_run_test(test_constant_broadcast_vector_mult, tests_run);
mu_run_test(test_constant_promote_vector_mult, tests_run);
mu_run_test(test_param_broadcast_vector_mult, tests_run);
mu_run_test(test_param_promote_vector_mult, tests_run);
mu_run_test(test_const_sum_scalar_mult, tests_run);
mu_run_test(test_param_sum_scalar_mult, tests_run);
#endif /* PROFILE_ONLY */

#ifdef PROFILE_ONLY
Expand Down
2 changes: 1 addition & 1 deletion tests/jacobian_tests/affine/test_left_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ const char *test_jacobian_left_matmul_exp_composite(void)
expr *A_exp_Bx = new_left_matmul(NULL, exp_Bx, A);

mu_assert("check_jacobian failed",
check_jacobian(A_exp_Bx, x_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(A_exp_Bx, x_vals, NUMERICAL_DIFF_DEFAULT_H));

free_csr_matrix(A);
free_csr_matrix(B);
Expand Down
26 changes: 13 additions & 13 deletions tests/jacobian_tests/composite/test_chain_rule_jacobian.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const char *test_jacobian_exp_sum(void)
expr *exp_sum_x = new_exp(sum_x);

mu_assert("check_jacobian failed",
check_jacobian(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_sum_x);
return 0;
Expand All @@ -34,7 +34,7 @@ const char *test_jacobian_exp_sum_mult(void)
expr *exp_sum_xy = new_exp(sum_xy);

mu_assert("check_jacobian failed",
check_jacobian(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_sum_xy);
return 0;
Expand All @@ -49,7 +49,7 @@ const char *test_jacobian_sin_cos(void)
expr *sin_cos_x = new_sin(cos_x);

mu_assert("check_jacobian failed",
check_jacobian(sin_cos_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(sin_cos_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(sin_cos_x);
return 0;
Expand All @@ -68,7 +68,7 @@ const char *test_jacobian_cos_sin_multiply(void)
expr *multiply = new_elementwise_mult(sum, sin_y);

mu_assert("check_jacobian failed",
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(multiply);
return 0;
Expand All @@ -87,7 +87,7 @@ const char *test_jacobian_Ax_Bx_multiply(void)
expr *multiply = new_elementwise_mult(Ax, Bx);

mu_assert("check_jacobian failed",
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(multiply);
free_csr_matrix(A);
Expand All @@ -107,7 +107,7 @@ const char *test_jacobian_AX_BX_multiply(void)
expr *multiply = new_elementwise_mult(new_sin(AX), new_cos(BX));

mu_assert("check_jacobian failed",
check_jacobian(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(multiply, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(multiply);
free_csr_matrix(A);
Expand Down Expand Up @@ -137,7 +137,7 @@ const char *test_jacobian_quad_form_Ax(void)
expr *node = new_quad_form(sin_Ax, Q);

mu_assert("check_jacobian failed",
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(node);
free_csr_matrix(A);
Expand All @@ -164,7 +164,7 @@ const char *test_jacobian_quad_form_exp(void)
expr *node = new_quad_form(exp_x, Q);

mu_assert("check_jacobian failed",
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(node);
free_csr_matrix(Q);
Expand All @@ -183,7 +183,7 @@ const char *test_jacobian_matmul_exp_exp(void)
expr *Z = new_matmul(exp_X, exp_Y);

mu_assert("check_jacobian failed",
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(Z);
return 0;
Expand All @@ -201,7 +201,7 @@ const char *test_jacobian_matmul_sin_cos(void)
expr *Z = new_matmul(sin_X, cos_Y);

mu_assert("check_jacobian failed",
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(Z);
return 0;
Expand All @@ -222,7 +222,7 @@ const char *test_jacobian_matmul_Ax_By(void)
expr *Z = new_matmul(AX, BY); /* 3x2 */

mu_assert("check_jacobian failed",
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(Z);
free_csr_matrix(A);
Expand All @@ -246,7 +246,7 @@ const char *test_jacobian_matmul_sin_Ax_cos_Bx(void)
expr *Z = new_matmul(sin_AX, cos_BX); /* 2x2 */

mu_assert("check_jacobian failed",
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(Z);
free_csr_matrix(A);
Expand All @@ -263,7 +263,7 @@ const char *test_jacobian_matmul_X_X(void)
expr *Z = new_matmul(X, X); /* 2x2 */

mu_assert("check_jacobian failed",
check_jacobian(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(Z, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(Z);
return 0;
Expand Down
2 changes: 1 addition & 1 deletion tests/jacobian_tests/composite/test_composite_exp.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const char *test_jacobian_composite_exp_add(void)
expr *sum = new_add(exp_Ax, exp_By);

mu_assert("check_jacobian failed",
check_jacobian(sum, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(sum, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(sum);
free_csr_matrix(A);
Expand Down
2 changes: 1 addition & 1 deletion tests/numerical_diff.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ double *numerical_jacobian(expr *node, const double *u, double h)
return J;
}

int check_jacobian(expr *node, const double *u, double h)
int check_jacobian_num(expr *node, const double *u, double h)
{
int m = node->size;
int n = node->n_vars;
Expand Down
2 changes: 1 addition & 1 deletion tests/numerical_diff.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ double *numerical_jacobian(expr *node, const double *u, double h);
/* Evaluate analytical Jacobian, compute numerical Jacobian,
* and compare. Returns 1 on match, 0 on mismatch.
* Prints diagnostic on first failing entry. */
int check_jacobian(expr *node, const double *u, double h);
int check_jacobian_num(expr *node, const double *u, double h);

/* Compute dense numerical weighted-sum Hessian via central
* differences on the gradient g(u) = J(u)^T w.
Expand Down
2 changes: 1 addition & 1 deletion tests/numerical_diff/test_numerical_diff.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const char *test_check_jacobian_composite_exp(void)
expr *exp_node = new_exp(Au);

mu_assert("check_jacobian failed",
check_jacobian(exp_node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
check_jacobian_num(exp_node, u_vals, NUMERICAL_DIFF_DEFAULT_H));

free_expr(exp_node);
free_csr_matrix(A);
Expand Down
Loading
Loading