Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/atoms/affine/left_matmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ static void refresh_param_values(left_matmul_expr *lnode)
static void forward(expr *node, const double *u)
{
left_matmul_expr *lnode = (left_matmul_expr *) node;

/* Always call forward on param_source if it exists */
/* Should we also adopt a convention that left_matmul always
points to a param_source, even if its constant? */
if (lnode->param_source != NULL)
{
lnode->param_source->forward(lnode->param_source, NULL);
}

refresh_param_values(lnode);

expr *x = node->left;
Expand Down
11 changes: 7 additions & 4 deletions src/atoms/affine/parameter.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "atoms/affine.h"
#include "subexpr.h"
#include "utils/tracked_alloc.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

Expand Down Expand Up @@ -65,11 +66,13 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);

pnode->param_id = param_id;

if (values != NULL)
if (values == NULL)
{
memcpy(node->value, values, node->size * sizeof(double));
fprintf(stderr, "Parameter values should always be set, this is a bug and"
" should be reported\n");
exit(1);
}

memcpy(node->value, values, node->size * sizeof(double));
return node;
}
9 changes: 8 additions & 1 deletion src/atoms/affine/scalar_mult.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
static void forward(expr *node, const double *u)
{
expr *child = node->left;
double a = ((scalar_mult_expr *) node)->param_source->value[0];
scalar_mult_expr *snode = (scalar_mult_expr *) node;

/* call forward for param_source expr tree
ex: broadcast(param) or promote(const)*/
snode->param_source->forward(snode->param_source, NULL);
Comment thread
dance858 marked this conversation as resolved.
Outdated

double a = snode->param_source->value[0];

/* child's forward pass */
child->forward(child, u);
Expand Down Expand Up @@ -118,6 +124,7 @@ expr *new_scalar_mult(expr *param_node, expr *child)

mult_node->param_source = param_node;
expr_retain(param_node);
//node->needs_parameter_refresh = true;

return node;
}
9 changes: 8 additions & 1 deletion src/atoms/affine/vector_mult.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
static void forward(expr *node, const double *u)
{
expr *child = node->left;
const double *a = ((vector_mult_expr *) node)->param_source->value;
vector_mult_expr *vnode = (vector_mult_expr *) node;

/* call forward for param_source expr tree
ex: broadcast(param) or promote(const)*/
vnode->param_source->forward(vnode->param_source, NULL);

const double *a = vnode->param_source->value;

/* child's forward pass */
child->forward(child, u);
Expand Down Expand Up @@ -128,6 +134,7 @@ expr *new_vector_mult(expr *param_node, expr *child)

vnode->param_source = param_node;
expr_retain(param_node);
//node->needs_parameter_refresh = true;

return node;
}
3 changes: 3 additions & 0 deletions src/problem.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node
prob->total_parameter_size = 0;
for (int i = 0; i < n_param_nodes; i++)
{
// TODO do we need to skip fixed params? maybe we adopt the convention
// that we don't ever register fixed params?
if (((parameter_expr *) param_nodes[i])->param_id == PARAM_FIXED) continue;
Comment thread
dance858 marked this conversation as resolved.
prob->total_parameter_size += param_nodes[i]->size;
}
}
Expand Down
7 changes: 7 additions & 0 deletions tests/all_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "jacobian_tests/other/test_prod_axis_zero.h"
#include "jacobian_tests/other/test_quad_form.h"
#include "numerical_diff/test_numerical_diff.h"
#include "problem/test_param_broadcast.h"
#include "problem/test_param_prob.h"
#include "problem/test_problem.h"
#include "utils/test_cblas.h"
Expand Down Expand Up @@ -358,6 +359,12 @@ int main(void)
mu_run_test(test_param_right_matmul_rectangular, tests_run);
mu_run_test(test_param_shared_left_matmul_problem, tests_run);
mu_run_test(test_param_fixed_skip_in_update, tests_run);

printf("\n--- Parameter + Broadcast Tests ---\n");
mu_run_test(test_constant_broadcast_vector_mult, tests_run);
mu_run_test(test_constant_promote_vector_mult, tests_run);
mu_run_test(test_param_broadcast_vector_mult, tests_run);
mu_run_test(test_param_promote_vector_mult, tests_run);
#endif /* PROFILE_ONLY */

#ifdef PROFILE_ONLY
Expand Down
157 changes: 157 additions & 0 deletions tests/problem/test_param_broadcast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#ifndef TEST_PARAM_BROADCAST_H
#define TEST_PARAM_BROADCAST_H

#include <math.h>
#include <stdio.h>
#include <string.h>

#include "atoms/affine.h"
#include "expr.h"
#include "minunit.h"
#include "problem.h"
#include "subexpr.h"
#include "test_helpers.h"

const char *test_constant_broadcast_vector_mult(void)
{
int n = 6;

/* minimize sum(x) subject to broadcast(c) ∘ x, with c constant */
expr *x = new_variable(2, 3, 0, n);
expr *objective = new_sum(x, -1);
double c_vals[3] = {1.0, 2.0, 3.0};
expr *c = new_parameter(1, 3, PARAM_FIXED, n, c_vals);
expr *c_bcast = new_broadcast(c, 2, 3);
expr *constraint = new_vector_mult(c_bcast, x);
expr *constraints[1] = {constraint};
problem *prob = new_problem(objective, constraints, 1, false);
problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

problem_constraint_forward(prob, x_vals);
double constrs[6] = {1.0, 2.0, 6.0, 8.0, 15.0, 18.0};
problem_jacobian(prob);
double jac_x[6] = {1.0, 1.0, 2.0, 2.0, 3.0, 3.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));

free_problem(prob);
return 0;
}

const char *test_constant_promote_vector_mult(void)
{
int n = 6;

/* minimize sum(x) subject to promote(c) ∘ x, with c constant */
expr *x = new_variable(2, 3, 0, n);
expr *objective = new_sum(x, -1);
double c_vals = 3.0;
expr *c = new_parameter(1, 1, PARAM_FIXED, n, &c_vals);
expr *c_bcast = new_promote(c, 2, 3);
expr *constraint = new_vector_mult(c_bcast, x);
expr *constraints[1] = {constraint};
problem *prob = new_problem(objective, constraints, 1, false);

problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

problem_constraint_forward(prob, x_vals);
double constrs[6] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0};
problem_jacobian(prob);
double jac_x[6] = {3.0, 3.0, 3.0, 3.0, 3.0, 3.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));

free_problem(prob);
return 0;
}

const char *test_param_broadcast_vector_mult(void)
{
int n = 6;

/* minimize sum(x) subject to broadcast(p) ∘ x, with p parameter */
expr *x = new_variable(2, 3, 0, n);
expr *objective = new_sum(x, -1);
double c_vals[3] = {1.0, 2.0, 3.0};
expr *c = new_parameter(1, 3, 0, n, c_vals);
expr *c_bcast = new_broadcast(c, 2, 3);
expr *constraint = new_vector_mult(c_bcast, x);
expr *constraints[1] = {constraint};
problem *prob = new_problem(objective, constraints, 1, false);

expr *param_nodes[1] = {c};
problem_register_params(prob, param_nodes, 1);
problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

problem_constraint_forward(prob, x_vals);
double constrs[6] = {1.0, 2.0, 6.0, 8.0, 15.0, 18.0};
problem_jacobian(prob);
double jac_x[6] = {1.0, 1.0, 2.0, 2.0, 3.0, 3.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));

/* second iteration after updating parameter */
double theta[3] = {5.0, 4.0, 3.0};
problem_update_params(prob, theta);
problem_constraint_forward(prob, x_vals);
problem_jacobian(prob);
double updated_constrs[6] = {5.0, 10.0, 12.0, 16.0, 15.0, 18.0};
double updated_jac_x[6] = {5.0, 5.0, 4.0, 4.0, 3.0, 3.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, updated_constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, updated_jac_x, 6));

free_problem(prob);
return 0;
}

const char *test_param_promote_vector_mult(void)
{
int n = 6;

/* minimize sum(x) subject to promote(p) ∘ x, with p parameter */
expr *x = new_variable(2, 3, 0, n);
expr *objective = new_sum(x, -1);
double c_vals = 3.0;
expr *c = new_parameter(1, 1, 0, n, &c_vals);
expr *c_bcast = new_promote(c, 2, 3);
expr *constraint = new_vector_mult(c_bcast, x);
expr *constraints[1] = {constraint};
problem *prob = new_problem(objective, constraints, 1, false);

expr *param_nodes[1] = {c};
problem_register_params(prob, param_nodes, 1);
problem_init_derivatives(prob);

/* point for evaluating */
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

problem_constraint_forward(prob, x_vals);
double constrs[6] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0};
problem_jacobian(prob);
double jac_x[6] = {3.0, 3.0, 3.0, 3.0, 3.0, 3.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, jac_x, 6));

/* second iteration after updating parameter */
double theta = 5.0;
problem_update_params(prob, &theta);
problem_constraint_forward(prob, x_vals);
problem_jacobian(prob);
double updated_constrs[6] = {5.0, 10.0, 15.0, 20.0, 25.0, 30.0};
double updated_jac_x[6] = {5.0, 5.0, 5.0, 5.0, 5.0, 5.0};
mu_assert("vals fail", cmp_double_array(prob->constraint_values, updated_constrs, 6));
mu_assert("vals fail", cmp_double_array(prob->jacobian->x, updated_jac_x, 6));

free_problem(prob);
return 0;
}
#endif /* TEST_PARAM_BROADCAST_H */
Loading
Loading