Skip to content

Commit e613c15

Browse files
authored
chain rule wsumm_hess (#59)
* some infrastructure * first draft * test with matrix arguments * remove commentt that;'s wrong
1 parent 5f7931c commit e613c15

5 files changed

Lines changed: 205 additions & 10 deletions

File tree

include/expr.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@ typedef struct expr
5656
// ------------------------------------------------------------------------
5757
double *value;
5858
CSR_Matrix *jacobian;
59+
CSC_Matrix *jacobian_csc;
60+
int *csc_work; /* workspace for CSR-CSC conversion */
61+
62+
/* jacobian_csc_filled is only used for affine functions to avoid redundant
63+
conversions. Could become relevant for non-affine functions if we start
64+
supporting common subexpressions on the Python side. */
65+
bool jacobian_csc_filled;
5966
CSR_Matrix *wsum_hess;
67+
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
68+
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
6069
forward_fn forward;
6170
jacobian_init_fn jacobian_init;
6271
wsum_hess_init_fn wsum_hess_init;
@@ -67,6 +76,7 @@ typedef struct expr
6776
// other things
6877
// ------------------------------------------------------------------------
6978
is_affine_fn is_affine;
79+
double *local_jac_diag; /* cached f'(g(x)) diagonal */
7080
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
7181
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
7282
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
@@ -83,6 +93,10 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
8393

8494
void free_expr(expr *node);
8595

96+
/* Initialize CSC form of the Jacobian from the CSR Jacobian.
97+
* Must be called after jacobian_init. */
98+
void jacobian_csc_init(expr *node);
99+
86100
/* Reference counting helpers */
87101
void expr_retain(expr *node);
88102

src/elementwise_full_dom/common.c

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "elementwise_full_dom.h"
22
#include "subexpr.h"
3+
#include "utils/CSC_Matrix.h"
34
#include "utils/CSR_Matrix.h"
5+
#include "utils/CSR_sum.h"
46
#include <stdio.h>
57
#include <stdlib.h>
68
#include <string.h>
@@ -20,14 +22,14 @@ void jacobian_init_elementwise(expr *node)
2022
}
2123
node->jacobian->p[node->size] = node->size;
2224
}
23-
/* otherwise it will be a linear operator */
2425
else
2526
{
2627
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
2728
child->jacobian_init(child);
2829
CSR_Matrix *Jg = child->jacobian;
2930
node->jacobian = new_csr_matrix(Jg->m, Jg->n, Jg->nnz);
3031
node->dwork = (double *) malloc(node->size * sizeof(double));
32+
node->local_jac_diag = (double *) malloc(node->size * sizeof(double));
3133

3234
/* copy sparsity pattern of child */
3335
memcpy(node->jacobian->p, Jg->p, sizeof(int) * (Jg->m + 1));
@@ -48,7 +50,8 @@ void eval_jacobian_elementwise(expr *node)
4850
/* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
4951
child->eval_jacobian(child);
5052
CSR_Matrix *Jg = child->jacobian;
51-
node->local_jacobian(node, node->dwork);
53+
node->local_jacobian(node, node->local_jac_diag);
54+
memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double));
5255
diag_csr_mult_fill_values(node->dwork, Jg, node->jacobian);
5356
}
5457
}
@@ -59,7 +62,7 @@ void wsum_hess_init_elementwise(expr *node)
5962
int id = child->var_id;
6063
int i;
6164

62-
/* if the variable is a child*/
65+
/* if the variable is a child */
6366
if (id != NOT_A_VARIABLE)
6467
{
6568
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size);
@@ -75,11 +78,38 @@ void wsum_hess_init_elementwise(expr *node)
7578
node->wsum_hess->p[i] = node->size;
7679
}
7780
}
78-
/* otherwise it will be a linear operator */
7981
else
8082
{
81-
linear_op_expr *lin_child = (linear_op_expr *) child;
82-
node->wsum_hess = ATA_alloc(lin_child->A_csc);
83+
/* Hessian of h(x) = w^T f(g(x) is term1 + term 2 where
84+
term1 = J_g^T @ D @ J_g with D = sum_i w_i Hf_i,
85+
term2 = sum_i (J_f^T w)_i^T Hg_i.
86+
87+
For elementwise functions, D is diagonal. */
88+
jacobian_csc_init(child);
89+
CSC_Matrix *Jg = child->jacobian_csc;
90+
91+
if (child->is_affine(child))
92+
{
93+
node->wsum_hess = ATA_alloc(Jg);
94+
}
95+
else
96+
{
97+
/* term1: Jg^T @ D @ Jg */
98+
node->hess_term1 = ATA_alloc(Jg);
99+
100+
/* term2: child's Hessian */
101+
child->wsum_hess_init(child);
102+
CSR_Matrix *Hg = child->wsum_hess;
103+
node->hess_term2 = new_csr_matrix(Hg->m, Hg->n, Hg->nnz);
104+
memcpy(node->hess_term2->p, Hg->p, (Hg->m + 1) * sizeof(int));
105+
memcpy(node->hess_term2->i, Hg->i, Hg->nnz * sizeof(int));
106+
107+
/* wsum_hess = term1 + term2 */
108+
int max_nnz = node->hess_term1->nnz + node->hess_term2->nnz;
109+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz);
110+
sum_csr_matrices_fill_sparsity(node->hess_term1, node->hess_term2,
111+
node->wsum_hess);
112+
}
83113
}
84114
}
85115

@@ -93,10 +123,43 @@ void eval_wsum_hess_elementwise(expr *node, const double *w)
93123
}
94124
else
95125
{
96-
/* Child will be a linear operator */
97-
linear_op_expr *lin_child = (linear_op_expr *) child;
98-
node->local_wsum_hess(node, node->dwork, w);
99-
ATDA_fill_values(lin_child->A_csc, node->dwork, node->wsum_hess);
126+
if (child->is_affine(child))
127+
{
128+
if (!child->jacobian_csc_filled)
129+
{
130+
csr_to_csc_fill_values(child->jacobian, child->jacobian_csc,
131+
child->csc_work);
132+
child->jacobian_csc_filled = true;
133+
}
134+
135+
node->local_wsum_hess(node, node->dwork, w);
136+
ATDA_fill_values(child->jacobian_csc, node->dwork, node->wsum_hess);
137+
}
138+
else
139+
{
140+
/* refresh CSC jacobian values */
141+
csr_to_csc_fill_values(child->jacobian, child->jacobian_csc,
142+
child->csc_work);
143+
144+
/* term1: Jg^T @ D @ Jg */
145+
node->local_wsum_hess(node, node->dwork, w);
146+
ATDA_fill_values(child->jacobian_csc, node->dwork, node->hess_term1);
147+
148+
/* term2: child Hessian with weight Jf^T w */
149+
memcpy(node->dwork, node->local_jac_diag, node->size * sizeof(double));
150+
for (int k = 0; k < node->size; k++)
151+
{
152+
node->dwork[k] *= w[k];
153+
}
154+
155+
child->eval_wsum_hess(child, node->dwork);
156+
memcpy(node->hess_term2->x, child->wsum_hess->x,
157+
child->wsum_hess->nnz * sizeof(double));
158+
159+
/* wsum_hess = term1 + term2 */
160+
sum_csr_matrices_fill_values(node->hess_term1, node->hess_term2,
161+
node->wsum_hess);
162+
}
100163
}
101164
}
102165

src/expr.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "expr.h"
19+
#include "utils/CSC_Matrix.h"
1920
#include "utils/int_double_pair.h"
2021
#include <stdlib.h>
2122
#include <string.h>
@@ -41,6 +42,12 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
4142
node->free_type_data = free_type_data;
4243
}
4344

45+
void jacobian_csc_init(expr *node)
46+
{
47+
node->csc_work = (int *) malloc(node->n_vars * sizeof(int));
48+
node->jacobian_csc = csr_to_csc_fill_sparsity(node->jacobian, node->csc_work);
49+
}
50+
4451
void free_expr(expr *node)
4552
{
4653
if (node == NULL) return;
@@ -63,8 +70,13 @@ void free_expr(expr *node)
6370
/* free value array and jacobian */
6471
free(node->value);
6572
free_csr_matrix(node->jacobian);
73+
free_csc_matrix(node->jacobian_csc);
74+
free(node->csc_work);
6675
free_csr_matrix(node->wsum_hess);
76+
free_csr_matrix(node->hess_term1);
77+
free_csr_matrix(node->hess_term2);
6778
free(node->dwork);
79+
free(node->local_jac_diag);
6880
free(node->iwork);
6981
node->value = NULL;
7082
node->jacobian = NULL;

tests/all_tests.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
#include "wsum_hess/elementwise/test_trig.h"
6666
#include "wsum_hess/elementwise/test_xexp.h"
6767
#include "wsum_hess/test_broadcast.h"
68+
#include "wsum_hess/test_chain_rule_wsum_hess.h"
6869
#include "wsum_hess/test_const_scalar_mult.h"
6970
#include "wsum_hess/test_const_vector_mult.h"
7071
#include "wsum_hess/test_hstack.h"
@@ -259,6 +260,11 @@ int main(void)
259260
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
260261
mu_run_test(test_wsum_hess_trace_composite, tests_run);
261262
mu_run_test(test_wsum_hess_transpose, tests_run);
263+
mu_run_test(test_wsum_hess_exp_sum, tests_run);
264+
mu_run_test(test_wsum_hess_exp_sum_mult, tests_run);
265+
mu_run_test(test_wsum_hess_exp_sum_matmul, tests_run);
266+
mu_run_test(test_wsum_hess_sin_sum_axis0_matmul, tests_run);
267+
mu_run_test(test_wsum_hess_logistic_sum_axis0_matmul, tests_run);
262268

263269
printf("\n--- Utility Tests ---\n");
264270
mu_run_test(test_cblas_ddot, tests_run);
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include "affine.h"
2+
#include "bivariate.h"
3+
#include "elementwise_full_dom.h"
4+
#include "minunit.h"
5+
#include "numerical_diff.h"
6+
7+
const char *test_wsum_hess_exp_sum(void)
8+
{
9+
double u_vals[3] = {1.0, 2.0, 3.0};
10+
double w = 1.0;
11+
12+
expr *x = new_variable(3, 1, 0, 3);
13+
expr *sum_x = new_sum(x, -1);
14+
expr *exp_sum_x = new_exp(sum_x);
15+
16+
mu_assert("check_wsum_hess failed",
17+
check_wsum_hess(exp_sum_x, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));
18+
19+
free_expr(exp_sum_x);
20+
return 0;
21+
}
22+
23+
const char *test_wsum_hess_exp_sum_mult(void)
24+
{
25+
double u_vals[4] = {1.0, 2.0, 3.0, 4.0};
26+
double w = 1.0;
27+
28+
expr *x = new_variable(2, 1, 0, 4);
29+
expr *y = new_variable(2, 1, 2, 4);
30+
expr *xy = new_elementwise_mult(x, y);
31+
expr *sum_xy = new_sum(xy, -1);
32+
expr *exp_sum_xy = new_exp(sum_xy);
33+
34+
mu_assert("check_wsum_hess failed",
35+
check_wsum_hess(exp_sum_xy, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));
36+
37+
free_expr(exp_sum_xy);
38+
return 0;
39+
}
40+
41+
const char *test_wsum_hess_exp_sum_matmul(void)
42+
{
43+
/* exp(sum(X @ Y)) where X is 2x3, Y is 3x2
44+
* n_vars = 6 + 6 = 12 */
45+
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
46+
double w = 1.0;
47+
48+
expr *X = new_variable(2, 3, 0, 12);
49+
expr *Y = new_variable(3, 2, 6, 12);
50+
expr *XY = new_matmul(X, Y);
51+
expr *sum_XY = new_sum(XY, -1);
52+
expr *exp_sum_XY = new_exp(sum_XY);
53+
54+
mu_assert("check_wsum_hess failed",
55+
check_wsum_hess(exp_sum_XY, u_vals, &w, NUMERICAL_DIFF_DEFAULT_H));
56+
57+
free_expr(exp_sum_XY);
58+
return 0;
59+
}
60+
61+
const char *test_wsum_hess_sin_sum_axis0_matmul(void)
62+
{
63+
/* sin(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2
64+
* X@Y is 2x2, sum(axis=0) gives 1x2, sin gives 1x2
65+
* n_vars = 6 + 6 = 12 */
66+
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
67+
double w[2] = {1.0, 1.0};
68+
69+
expr *X = new_variable(2, 3, 0, 12);
70+
expr *Y = new_variable(3, 2, 6, 12);
71+
expr *XY = new_matmul(X, Y);
72+
expr *sum_XY = new_sum(XY, 0);
73+
expr *sin_sum_XY = new_sin(sum_XY);
74+
75+
mu_assert("check_wsum_hess failed",
76+
check_wsum_hess(sin_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H));
77+
78+
free_expr(sin_sum_XY);
79+
return 0;
80+
}
81+
82+
const char *test_wsum_hess_logistic_sum_axis0_matmul(void)
83+
{
84+
/* logistic(sum(X @ Y, axis=0)) where X is 2x3, Y is 3x2
85+
* n_vars = 6 + 6 = 12 */
86+
double u_vals[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, 2.5, 0.1, 0.2, 0.3};
87+
double w[2] = {1.0, 1.0};
88+
89+
expr *X = new_variable(2, 3, 0, 12);
90+
expr *Y = new_variable(3, 2, 6, 12);
91+
expr *XY = new_matmul(X, Y);
92+
expr *sum_XY = new_sum(XY, 0);
93+
expr *logistic_sum_XY = new_logistic(sum_XY);
94+
95+
mu_assert("check_wsum_hess failed",
96+
check_wsum_hess(logistic_sum_XY, u_vals, w, NUMERICAL_DIFF_DEFAULT_H));
97+
98+
free_expr(logistic_sum_XY);
99+
return 0;
100+
}

0 commit comments

Comments
 (0)