Skip to content

Commit 5f7931c

Browse files
authored
jacobian chain rule (#58)
* jacobian chain rule * run formatter
1 parent b21e021 commit 5f7931c

4 files changed

Lines changed: 59 additions & 10 deletions

File tree

src/affine/linear_op.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,20 @@ static void jacobian_init(expr *node)
7474
node->jacobian = ((linear_op_expr *) node)->A_csr;
7575
}
7676

77+
static void eval_jacobian(expr *node)
78+
{
79+
/* Linear operator jacobian never changes - nothing to evaluate */
80+
(void) node;
81+
}
82+
7783
expr *new_linear(expr *u, const CSR_Matrix *A, const double *b)
7884
{
7985
assert(u->d2 == 1);
8086
/* Allocate the type-specific struct */
8187
linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr));
8288
expr *node = &lin_node->base;
83-
init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init, NULL, is_affine,
84-
NULL, NULL, free_type_data);
89+
init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init, eval_jacobian,
90+
is_affine, NULL, NULL, free_type_data);
8591
node->left = u;
8692
expr_retain(u);
8793

src/elementwise_full_dom/common.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "elementwise_full_dom.h"
22
#include "subexpr.h"
3+
#include "utils/CSR_Matrix.h"
34
#include <stdio.h>
45
#include <stdlib.h>
56
#include <string.h>
@@ -22,14 +23,15 @@ void jacobian_init_elementwise(expr *node)
2223
/* otherwise it will be a linear operator */
2324
else
2425
{
25-
CSR_Matrix *J = ((linear_op_expr *) child)->A_csr;
26-
node->jacobian = new_csr_matrix(J->m, J->n, J->nnz);
27-
26+
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
27+
child->jacobian_init(child);
28+
CSR_Matrix *Jg = child->jacobian;
29+
node->jacobian = new_csr_matrix(Jg->m, Jg->n, Jg->nnz);
2830
node->dwork = (double *) malloc(node->size * sizeof(double));
2931

3032
/* copy sparsity pattern of child */
31-
memcpy(node->jacobian->p, J->p, sizeof(int) * (J->m + 1));
32-
memcpy(node->jacobian->i, J->i, sizeof(int) * J->nnz);
33+
memcpy(node->jacobian->p, Jg->p, sizeof(int) * (Jg->m + 1));
34+
memcpy(node->jacobian->i, Jg->i, sizeof(int) * Jg->nnz);
3335
}
3436
}
3537

@@ -43,10 +45,11 @@ void eval_jacobian_elementwise(expr *node)
4345
}
4446
else
4547
{
46-
/* Child will be a linear operator */
47-
linear_op_expr *lin_child = (linear_op_expr *) child;
48+
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
49+
child->eval_jacobian(child);
50+
CSR_Matrix *Jg = child->jacobian;
4851
node->local_jacobian(node, node->dwork);
49-
diag_csr_mult_fill_values(node->dwork, lin_child->A_csr, node->jacobian);
52+
diag_csr_mult_fill_values(node->dwork, Jg, node->jacobian);
5053
}
5154
}
5255

tests/all_tests.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "forward_pass/test_prod_axis_one.h"
2323
#include "forward_pass/test_prod_axis_zero.h"
2424
#include "jacobian_tests/test_broadcast.h"
25+
#include "jacobian_tests/test_chain_rule_jacobian.h"
2526
#include "jacobian_tests/test_composite_exp.h"
2627
#include "jacobian_tests/test_const_scalar_mult.h"
2728
#include "jacobian_tests/test_const_vector_mult.h"
@@ -129,6 +130,8 @@ int main(void)
129130
mu_run_test(test_jacobian_log, tests_run);
130131
mu_run_test(test_jacobian_log_matrix, tests_run);
131132
mu_run_test(test_jacobian_composite_exp, tests_run);
133+
mu_run_test(test_jacobian_exp_sum, tests_run);
134+
mu_run_test(test_jacobian_exp_sum_mult, tests_run);
132135
mu_run_test(test_jacobian_composite_exp_add, tests_run);
133136
mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run);
134137
mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "affine.h"
2+
#include "bivariate.h"
3+
#include "elementwise_full_dom.h"
4+
#include "minunit.h"
5+
#include "numerical_diff.h"
6+
7+
const char *test_jacobian_exp_sum(void)
8+
{
9+
double u_vals[3] = {1.0, 2.0, 3.0};
10+
11+
expr *x = new_variable(3, 1, 0, 3);
12+
expr *sum_x = new_sum(x, -1);
13+
expr *exp_sum_x = new_exp(sum_x);
14+
15+
mu_assert("check_jacobian failed",
16+
check_jacobian(exp_sum_x, u_vals, NUMERICAL_DIFF_DEFAULT_H));
17+
18+
free_expr(exp_sum_x);
19+
return 0;
20+
}
21+
22+
const char *test_jacobian_exp_sum_mult(void)
23+
{
24+
double u_vals[4] = {1.0, 2.0, 3.0, 4.0};
25+
26+
expr *x = new_variable(2, 1, 0, 4);
27+
expr *y = new_variable(2, 1, 2, 4);
28+
expr *xy = new_elementwise_mult(x, y);
29+
expr *sum_xy = new_sum(xy, -1);
30+
expr *exp_sum_xy = new_exp(sum_xy);
31+
32+
mu_assert("check_jacobian failed",
33+
check_jacobian(exp_sum_xy, u_vals, NUMERICAL_DIFF_DEFAULT_H));
34+
35+
free_expr(exp_sum_xy);
36+
return 0;
37+
}

0 commit comments

Comments
 (0)