Skip to content

Commit 5341802

Browse files
committed
jacobian chain rule implementation
1 parent 0404d0d commit 5341802

5 files changed

Lines changed: 143 additions & 92 deletions

File tree

include/utils/CSC_Matrix.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
4949
*/
5050
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);
5151

52+
/* Count nonzero columns of a CSC matrix */
53+
int count_nonzero_cols_csc(const CSC_Matrix *A);
54+
5255
CSC_Matrix *csr_to_csc_fill_sparsity(const CSR_Matrix *A, int *iwork);
5356
void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork);
5457

src/other/quad_form.c

Lines changed: 68 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,44 @@ static void forward(expr *node, const double *u)
2727

2828
static void jacobian_init_impl(expr *node)
2929
{
30-
assert(node->left->var_id != NOT_A_VARIABLE);
31-
3230
expr *x = node->left;
31+
32+
/* dwork stores the result of Q @ f(x) in the forward pass */
3333
node->work->dwork = (double *) malloc(x->size * sizeof(double));
34-
node->jacobian = new_csr_matrix(1, node->n_vars, x->size);
35-
node->jacobian->p[0] = 0;
36-
node->jacobian->p[1] = x->size;
3734

38-
for (int j = 0; j < x->size; j++)
35+
if (x->var_id != NOT_A_VARIABLE)
36+
{
37+
node->jacobian = new_csr_matrix(1, node->n_vars, x->size);
38+
node->jacobian->p[0] = 0;
39+
node->jacobian->p[1] = x->size;
40+
41+
for (int j = 0; j < x->size; j++)
42+
{
43+
node->jacobian->i[j] = x->var_id + j;
44+
}
45+
}
46+
else
3947
{
40-
node->jacobian->i[j] = x->var_id + j;
48+
/* chain rule: J = 2 * (Q @ f(x))^T * J_f */
49+
jacobian_init(x);
50+
jacobian_csc_init(x);
51+
CSC_Matrix *J_csc = x->work->jacobian_csc;
52+
53+
/* allocate the right number of nnz */
54+
int nnz = count_nonzero_cols_csc(J_csc);
55+
node->jacobian = new_csr_matrix(1, node->n_vars, nnz);
56+
node->jacobian->p[0] = 0;
57+
node->jacobian->p[1] = nnz;
58+
59+
/* fill sparsity pattern */
60+
int idx = 0;
61+
for (int j = 0; j < J_csc->n; j++)
62+
{
63+
if (J_csc->p[j + 1] > J_csc->p[j])
64+
{
65+
node->jacobian->i[idx++] = j;
66+
}
67+
}
4168
}
4269
}
4370

@@ -46,12 +73,41 @@ static void eval_jacobian(expr *node)
4673
expr *x = node->left;
4774
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
4875

49-
// jacobian = 2 * Q * x
50-
csr_matvec(Q, x->value, node->jacobian->x, 0);
76+
if (x->var_id != NOT_A_VARIABLE)
77+
{
78+
/* jacobian = 2 * (Q @ x)^T */
79+
csr_matvec(Q, x->value, node->jacobian->x, 0);
5180

52-
for (int j = 0; j < x->size; j++)
81+
for (int j = 0; j < x->size; j++)
82+
{
83+
node->jacobian->x[j] *= 2.0;
84+
}
85+
}
86+
else
5387
{
54-
node->jacobian->x[j] *= 2.0;
88+
/* jacobian = 2 * (Q @ f(x))^T @ J_f */
89+
x->eval_jacobian(x);
90+
91+
if (!x->work->jacobian_csc_filled)
92+
{
93+
csr_to_csc_fill_values(x->jacobian, x->work->jacobian_csc,
94+
x->work->csc_work);
95+
96+
if (x->is_affine(x))
97+
{
98+
x->work->jacobian_csc_filled = true;
99+
}
100+
}
101+
102+
/* The jacobian has same values as the gradient, which is
103+
J_f^T (Q @ f(x)). Here, dwork stores Q @ f(x) from forward */
104+
csc_matvec_fill_values(x->work->jacobian_csc, node->work->dwork,
105+
node->jacobian);
106+
107+
for (int j = 0; j < node->jacobian->nnz; j++)
108+
{
109+
node->jacobian->x[j] *= 2.0;
110+
}
55111
}
56112
}
57113

@@ -89,87 +145,6 @@ static void eval_wsum_hess(expr *node, const double *w)
89145
}
90146
}
91147

92-
/*
93-
The following two functions are commented out. It supports the jacobian for
94-
quad_form(Ax, Q), but after reconsideration, I think we should treat this as
95-
quad_form(x, A.TQA) in the canonicalization so we don't need to support the chain
96-
rule here.
97-
static void jacobian_init_impl(expr *node)
98-
{
99-
expr *x = node->left;
100-
node->work->dwork = (double *) malloc(x->d1 * sizeof(double));
101-
102-
// if x is a variable
103-
if (x->var_id != NOT_A_VARIABLE)
104-
{
105-
node->jacobian = new_csr_matrix(1, node->n_vars, x->d1);
106-
node->jacobian->p[0] = 0;
107-
node->jacobian->p[1] = x->d1;
108-
109-
for (int j = 0; j < x->d1; j++)
110-
{
111-
node->jacobian->i[j] = x->var_id + j;
112-
}
113-
}
114-
else // x is not a variable
115-
{
116-
// compute required allocation and allocate jacobian
117-
bool *col_nz = (bool *) calloc(node->n_vars, sizeof(bool));
118-
int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz);
119-
node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1);
120-
121-
// precompute column indices
122-
node->jacobian->nnz = 0;
123-
for (int j = 0; j < node->n_vars; j++)
124-
{
125-
if (col_nz[j])
126-
{
127-
node->jacobian->i[node->jacobian->nnz] = j;
128-
node->jacobian->nnz++;
129-
}
130-
}
131-
assert(nonzero_cols == node->jacobian->nnz);
132-
free(col_nz);
133-
134-
node->jacobian->p[0] = 0;
135-
node->jacobian->p[1] = node->jacobian->nnz;
136-
}
137-
}
138-
139-
140-
static void eval_jacobian_old(expr *node)
141-
{
142-
expr *x = node->left;
143-
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
144-
145-
// if x is a variable
146-
if (x->var_id != NOT_A_VARIABLE)
147-
{
148-
csr_matvec(Q, x->value, node->jacobian->x, 0);
149-
150-
for (int j = 0; j < x->d1; j++)
151-
{
152-
node->jacobian->x[j] *= 2.0;
153-
}
154-
}
155-
else // x is not a variable
156-
{
157-
linear_op_expr *lin_x = (linear_op_expr *) x;
158-
159-
// local jacobian
160-
csr_matvec(Q, x->value, node->work->dwork, 0);
161-
162-
for (int j = 0; j < x->d1; j++)
163-
{
164-
node->work->dwork[j] *= 2.0;
165-
}
166-
167-
// chain rule using CSC format
168-
csc_matvec_fill_values(lin_x->A_csc, node->work->dwork, node->jacobian);
169-
}
170-
}
171-
*/
172-
173148
static void free_type_data(expr *node)
174149
{
175150
quad_form_expr *qnode = (quad_form_expr *) node;
@@ -180,6 +155,7 @@ static void free_type_data(expr *node)
180155
static bool is_affine(const expr *node)
181156
{
182157
(void) node;
158+
/* TODO: it is affine if both children are constant */
183159
return false;
184160
}
185161

src/utils/CSC_Matrix.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,3 +451,16 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
451451
}
452452
}
453453
}
454+
455+
int count_nonzero_cols_csc(const CSC_Matrix *A)
456+
{
457+
int count = 0;
458+
for (int j = 0; j < A->n; j++)
459+
{
460+
if (A->p[j + 1] > A->p[j])
461+
{
462+
count++;
463+
}
464+
}
465+
return count;
466+
}

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ int main(void)
137137
mu_run_test(test_jacobian_cos_sin_multiply, tests_run);
138138
mu_run_test(test_jacobian_Ax_Bx_multiply, tests_run);
139139
mu_run_test(test_jacobian_AX_BX_multiply, tests_run);
140+
mu_run_test(test_jacobian_quad_form_Ax, tests_run);
141+
mu_run_test(test_jacobian_quad_form_exp, tests_run);
140142
mu_run_test(test_jacobian_composite_exp_add, tests_run);
141143
mu_run_test(test_jacobian_const_scalar_mult_log_vector, tests_run);
142144
mu_run_test(test_jacobian_const_scalar_mult_log_matrix, tests_run);

tests/jacobian_tests/composite/test_chain_rule_jacobian.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "expr.h"
55
#include "minunit.h"
66
#include "numerical_diff.h"
7+
#include "other.h"
78
#include "test_helpers.h"
89
#include "utils/CSR_Matrix.h"
910

@@ -113,3 +114,59 @@ const char *test_jacobian_AX_BX_multiply(void)
113114
free_csr_matrix(B);
114115
return 0;
115116
}
117+
118+
const char *test_jacobian_quad_form_Ax(void)
119+
{
120+
/* (Ax)^T Q (Ax) where Q is symmetric */
121+
double u_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
122+
123+
CSR_Matrix *A = new_csr_random(3, 4, 1.0);
124+
125+
/* Q = [1 2 0; 2 3 0; 0 0 4] */
126+
CSR_Matrix *Q = new_csr_matrix(3, 3, 5);
127+
double Qx[5] = {1.0, 2.0, 2.0, 3.0, 4.0};
128+
int Qi[5] = {0, 1, 0, 1, 2};
129+
int Qp[4] = {0, 2, 4, 5};
130+
memcpy(Q->x, Qx, 5 * sizeof(double));
131+
memcpy(Q->i, Qi, 5 * sizeof(int));
132+
memcpy(Q->p, Qp, 4 * sizeof(int));
133+
134+
expr *x = new_variable(4, 1, 1, 6);
135+
expr *Ax = new_left_matmul(x, A);
136+
expr *sin_Ax = new_sin(Ax);
137+
expr *node = new_quad_form(sin_Ax, Q);
138+
139+
mu_assert("check_jacobian failed",
140+
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
141+
142+
free_expr(node);
143+
free_csr_matrix(A);
144+
free_csr_matrix(Q);
145+
return 0;
146+
}
147+
148+
const char *test_jacobian_quad_form_exp(void)
149+
{
150+
/* exp(x)^T Q exp(x) where Q is symmetric */
151+
double u_vals[3] = {0.5, 1.0, 1.5};
152+
153+
/* Q = [1 2 0; 2 3 0; 0 0 4] */
154+
CSR_Matrix *Q = new_csr_matrix(3, 3, 5);
155+
double Qx[5] = {1.0, 2.0, 2.0, 3.0, 4.0};
156+
int Qi[5] = {0, 1, 0, 1, 2};
157+
int Qp[4] = {0, 2, 4, 5};
158+
memcpy(Q->x, Qx, 5 * sizeof(double));
159+
memcpy(Q->i, Qi, 5 * sizeof(int));
160+
memcpy(Q->p, Qp, 4 * sizeof(int));
161+
162+
expr *x = new_variable(3, 1, 0, 3);
163+
expr *exp_x = new_exp(x);
164+
expr *node = new_quad_form(exp_x, Q);
165+
166+
mu_assert("check_jacobian failed",
167+
check_jacobian(node, u_vals, NUMERICAL_DIFF_DEFAULT_H));
168+
169+
free_expr(node);
170+
free_csr_matrix(Q);
171+
return 0;
172+
}

0 commit comments

Comments
 (0)