Skip to content

Commit a7d8dfa

Browse files
committed
first draft of hessian chain rule
1 parent 5341802 commit a7d8dfa

6 files changed

Lines changed: 338 additions & 23 deletions

File tree

include/subexpr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ typedef struct quad_form_expr
4848
{
4949
expr base;
5050
CSR_Matrix *Q;
51+
CSC_Matrix *QJf; /* Q * J_f in CSC (for chain rule hessian) */
5152
} quad_form_expr;
5253

5354
/* Sum reduction along an axis */

include/utils/CSC_Matrix.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3737
/* Allocate sparsity pattern for C = B^T D A for diagonal D */
3838
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B);
3939

40-
/* Compute values for C = A^T D A. C must have precomputed sparsity pattern */
40+
/* Compute values for C = A^T D A. C must have precomputed sparsity pattern.
41+
* If d is NULL, D is treated as the identity (computes A^T A). */
4142
void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
4243

43-
/* Compute values for C = B^T D A. C must have precomputed sparsity pattern */
44+
/* Compute values for C = B^T D A. C must have precomputed sparsity pattern.
45+
* If d is NULL, D is treated as the identity (computes B^T A). */
4446
void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
4547
CSR_Matrix *C);
4648

@@ -49,6 +51,13 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
4951
*/
5052
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);
5153

54+
/* Allocate B = Q * A (sparsity only). Q is CSR, A is CSC, B is CSC. */
55+
CSC_Matrix *csr_csc_multiply_fill_sparsity(const CSR_Matrix *Q, const CSC_Matrix *A);
56+
57+
/* Fill values of B = Q * A. B must have sparsity from above. */
58+
void csr_csc_multiply_fill_values(const CSR_Matrix *Q, const CSC_Matrix *A,
59+
CSC_Matrix *B);
60+
5261
/* Count nonzero columns of a CSC matrix */
5362
int count_nonzero_cols_csc(const CSC_Matrix *A);
5463

src/other/quad_form.c

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "other.h"
22
#include "subexpr.h"
33
#include "utils/CSC_Matrix.h"
4+
#include "utils/CSR_sum.h"
5+
#include "utils/cblas_wrapper.h"
46
#include <assert.h>
57
#include <math.h>
68
#include <stdio.h>
@@ -115,33 +117,106 @@ static void wsum_hess_init_impl(expr *node)
115117
{
116118
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
117119
expr *x = node->left;
118-
CSR_Matrix *H = new_csr_matrix(node->n_vars, node->n_vars, Q->nnz);
119120

120-
/* set global row pointers */
121-
memcpy(H->p + x->var_id, Q->p, (x->size + 1) * sizeof(int));
122-
for (int i = x->var_id + x->size + 1; i <= node->n_vars; i++)
121+
if (x->var_id != NOT_A_VARIABLE)
123122
{
123+
CSR_Matrix *H = new_csr_matrix(node->n_vars, node->n_vars, Q->nnz);
124124

125-
H->p[i] = Q->nnz;
126-
}
125+
/* set global row pointers */
126+
memcpy(H->p + x->var_id, Q->p, (x->size + 1) * sizeof(int));
127+
for (int i = x->var_id + x->size + 1; i <= node->n_vars; i++)
128+
{
129+
H->p[i] = Q->nnz;
130+
}
127131

128-
/* set global column indices */
129-
for (int i = 0; i < Q->nnz; i++)
130-
{
131-
H->i[i] = Q->i[i] + x->var_id;
132+
/* set global column indices */
133+
for (int i = 0; i < Q->nnz; i++)
134+
{
135+
H->i[i] = Q->i[i] + x->var_id;
136+
}
137+
138+
node->wsum_hess = H;
132139
}
140+
else
141+
{
142+
/* The hessian of h(x) = f(x)^T Q f(x) is term1 + term2 where
133143
134-
node->wsum_hess = H;
144+
* term1 = J_f^T Q J_f
145+
* term2 = sum_i (Qf(x))_i nabla^2 f_i.
146+
147+
To compute term1, we first compute B = Q J_f and then compute term1
148+
= J_f^T B.
149+
*/
150+
151+
/* jacobian_csc_init(x) already called in jacobian_init */
152+
quad_form_expr *qnode = (quad_form_expr *) node;
153+
CSC_Matrix *Jf = x->work->jacobian_csc;
154+
155+
/* term1 = Jf^T W Jf = Jf^T B*/
156+
CSC_Matrix *B = csr_csc_multiply_fill_sparsity(Q, Jf);
157+
qnode->QJf = B;
158+
node->work->hess_term1 = BTA_alloc(Jf, B);
159+
160+
/* term2 = sum_i (Qf(x))_i nabla^2 f_i */
161+
wsum_hess_init(x);
162+
node->work->hess_term2 = new_csr_copy_sparsity(x->wsum_hess);
163+
164+
/* hess = term1 + term2 */
165+
int max_nnz = node->work->hess_term1->nnz + node->work->hess_term2->nnz;
166+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, max_nnz);
167+
sum_csr_matrices_fill_sparsity(node->work->hess_term1,
168+
node->work->hess_term2, node->wsum_hess);
169+
}
135170
}
136171

137172
static void eval_wsum_hess(expr *node, const double *w)
138173
{
139174
CSR_Matrix *Q = ((quad_form_expr *) node)->Q;
140-
double *H = node->wsum_hess->x;
175+
expr *x = node->left;
141176
double two_w = 2.0 * w[0];
142-
for (int i = 0; i < Q->nnz; i++)
177+
178+
if (x->var_id != NOT_A_VARIABLE)
143179
{
144-
H[i] = two_w * Q->x[i];
180+
/* TODO: do we want to compute this hessian only once (up to a scaling)?
181+
* Maybe unnecessary optimization. */
182+
double *H = node->wsum_hess->x;
183+
for (int i = 0; i < Q->nnz; i++)
184+
{
185+
H[i] = two_w * Q->x[i];
186+
}
187+
}
188+
else
189+
{
190+
/* fill the CSC representation of the Jacobian of the child */
191+
CSC_Matrix *Jf = x->work->jacobian_csc;
192+
if (!x->work->jacobian_csc_filled)
193+
{
194+
csr_to_csc_fill_values(x->jacobian, Jf, x->work->csc_work);
195+
196+
if (x->is_affine(x))
197+
{
198+
x->work->jacobian_csc_filled = true;
199+
}
200+
}
201+
202+
CSC_Matrix *QJf = ((quad_form_expr *) node)->QJf;
203+
CSR_Matrix *term1 = node->work->hess_term1;
204+
CSR_Matrix *term2 = node->work->hess_term2;
205+
206+
/* term1 = J_f^T Q J_f = J_f^T B */
207+
csr_csc_multiply_fill_values(Q, Jf, QJf);
208+
BTDA_fill_values(Jf, QJf, NULL, term1);
209+
210+
/* term2 */
211+
x->eval_wsum_hess(x, node->work->dwork);
212+
memcpy(term2->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
213+
214+
/* scale both terms by 2w */
215+
cblas_dscal(term1->nnz, two_w, term1->x, 1);
216+
cblas_dscal(term2->nnz, two_w, term2->x, 1);
217+
218+
/* sum the two terms */
219+
sum_csr_matrices_fill_values(term1, term2, node->wsum_hess);
145220
}
146221
}
147222

@@ -150,12 +225,17 @@ static void free_type_data(expr *node)
150225
quad_form_expr *qnode = (quad_form_expr *) node;
151226
free_csr_matrix(qnode->Q);
152227
qnode->Q = NULL;
228+
if (qnode->QJf != NULL)
229+
{
230+
free_csc_matrix(qnode->QJf);
231+
qnode->QJf = NULL;
232+
}
153233
}
154234

155235
static bool is_affine(const expr *node)
156236
{
157237
(void) node;
158-
/* TODO: it is affine if both children are constant */
238+
/* TODO: it is affine (constant) if both children are constant */
159239
return false;
160240
}
161241

src/utils/CSC_Matrix.c

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,34 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A)
111111
return C;
112112
}
113113

114+
static inline double sparse_dot(const double *a_x, const int *a_i, int a_nnz,
115+
const double *b_x, const int *b_i, int b_nnz)
116+
{
117+
int ii = 0;
118+
int jj = 0;
119+
double sum = 0.0;
120+
121+
while (ii < a_nnz && jj < b_nnz)
122+
{
123+
if (a_i[ii] == b_i[jj])
124+
{
125+
sum += a_x[ii] * b_x[jj];
126+
ii++;
127+
jj++;
128+
}
129+
else if (a_i[ii] < b_i[jj])
130+
{
131+
ii++;
132+
}
133+
else
134+
{
135+
jj++;
136+
}
137+
}
138+
139+
return sum;
140+
}
141+
114142
static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz,
115143
const double *b_x, const int *b_i, int b_nnz,
116144
const double *d)
@@ -158,9 +186,17 @@ void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C)
158186
int nnz_ai = A->p[ii + 1] - A->p[ii];
159187
int nnz_aj = A->p[j + 1] - A->p[j];
160188

161-
/* compute Cij = weighted inner product of column i and column j */
162-
double sum = sparse_wdot(A->x + A->p[ii], A->i + A->p[ii], nnz_ai,
163-
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);
189+
double sum;
190+
if (d != NULL)
191+
{
192+
sum = sparse_wdot(A->x + A->p[ii], A->i + A->p[ii], nnz_ai,
193+
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);
194+
}
195+
else
196+
{
197+
sum = sparse_dot(A->x + A->p[ii], A->i + A->p[ii], nnz_ai,
198+
A->x + A->p[j], A->i + A->p[j], nnz_aj);
199+
}
164200

165201
C->x[jj] = sum;
166202
}
@@ -443,15 +479,115 @@ void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
443479
int nnz_bi = B->p[i + 1] - B->p[i];
444480
int nnz_aj = A->p[j + 1] - A->p[j];
445481

446-
/* compute Cij = weighted inner product of col i of B and col j of A */
447-
double sum = sparse_wdot(B->x + B->p[i], B->i + B->p[i], nnz_bi,
448-
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);
482+
double sum;
483+
if (d != NULL)
484+
{
485+
sum = sparse_wdot(B->x + B->p[i], B->i + B->p[i], nnz_bi,
486+
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);
487+
}
488+
else
489+
{
490+
sum = sparse_dot(B->x + B->p[i], B->i + B->p[i], nnz_bi,
491+
A->x + A->p[j], A->i + A->p[j], nnz_aj);
492+
}
449493

450494
C->x[jj] = sum;
451495
}
452496
}
453497
}
454498

499+
CSC_Matrix *csr_csc_multiply_fill_sparsity(const CSR_Matrix *Q, const CSC_Matrix *A)
500+
{
501+
/* Allocate B = Q * A (sparsity only).
502+
* Q is CSR (m x m), A is CSC (m x n), B is CSC (m x n). */
503+
504+
int m = Q->m;
505+
int n = A->n;
506+
507+
int *marker = (int *) malloc(m * sizeof(int));
508+
memset(marker, -1, m * sizeof(int));
509+
510+
int *Bp = (int *) malloc((n + 1) * sizeof(int));
511+
iVec *Bi = iVec_new(A->nnz);
512+
Bp[0] = 0;
513+
514+
for (int j = 0; j < n; j++)
515+
{
516+
int col_nnz = 0;
517+
518+
for (int t = A->p[j]; t < A->p[j + 1]; t++)
519+
{
520+
int k = A->i[t];
521+
522+
for (int s = Q->p[k]; s < Q->p[k + 1]; s++)
523+
{
524+
int row = Q->i[s];
525+
if (marker[row] != j)
526+
{
527+
marker[row] = j;
528+
iVec_append(Bi, row);
529+
col_nnz++;
530+
}
531+
}
532+
}
533+
534+
Bp[j + 1] = Bp[j] + col_nnz;
535+
}
536+
537+
int total_nnz = Bp[n];
538+
CSC_Matrix *B = new_csc_matrix(m, n, total_nnz);
539+
memcpy(B->p, Bp, (n + 1) * sizeof(int));
540+
memcpy(B->i, Bi->data, total_nnz * sizeof(int));
541+
542+
free(marker);
543+
free(Bp);
544+
iVec_free(Bi);
545+
546+
return B;
547+
}
548+
549+
void csr_csc_multiply_fill_values(const CSR_Matrix *Q, const CSC_Matrix *A,
550+
CSC_Matrix *B)
551+
{
552+
/* Fill values of B = Q * A. B must have sparsity from
553+
* csr_csc_multiply_fill_sparsity. */
554+
555+
int m = Q->m;
556+
557+
int *marker = (int *) malloc(m * sizeof(int));
558+
memset(marker, -1, m * sizeof(int));
559+
memset(B->x, 0, B->nnz * sizeof(double));
560+
561+
for (int j = 0; j < B->n; j++)
562+
{
563+
/* map row index -> position in column j of B */
564+
for (int t = B->p[j]; t < B->p[j + 1]; t++)
565+
{
566+
marker[B->i[t]] = t;
567+
}
568+
569+
/* accumulate A_{k,j} * Q[k, :] */
570+
for (int t = A->p[j]; t < A->p[j + 1]; t++)
571+
{
572+
int k = A->i[t];
573+
double a_kj = A->x[t];
574+
575+
for (int s = Q->p[k]; s < Q->p[k + 1]; s++)
576+
{
577+
B->x[marker[Q->i[s]]] += a_kj * Q->x[s];
578+
}
579+
}
580+
581+
/* reset marker */
582+
for (int t = B->p[j]; t < B->p[j + 1]; t++)
583+
{
584+
marker[B->i[t]] = -1;
585+
}
586+
}
587+
588+
free(marker);
589+
}
590+
455591
int count_nonzero_cols_csc(const CSC_Matrix *A)
456592
{
457593
int count = 0;

tests/all_tests.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ int main(void)
276276
mu_run_test(test_wsum_hess_x_x_multiply, tests_run);
277277
mu_run_test(test_wsum_hess_AX_BX_multiply, tests_run);
278278
mu_run_test(test_wsum_hess_multiply_deep_composite, tests_run);
279+
mu_run_test(test_wsum_hess_quad_form_Ax, tests_run);
280+
mu_run_test(test_wsum_hess_quad_form_sin_Ax, tests_run);
281+
mu_run_test(test_wsum_hess_quad_form_exp, tests_run);
279282

280283
printf("\n--- Utility Tests ---\n");
281284
mu_run_test(test_cblas_ddot, tests_run);

0 commit comments

Comments
 (0)