Skip to content

Commit a70fcf8

Browse files
authored
Matmul chain rule (#67)
* new infrastructure * add chain rule for jacobian + tests * add comment * chain rule for hessian first draft * split up into several functions * change order of two functions * better infrastructure * even better infrastructure * free' * refactor accumulator * minor * redo name change * run formaterre
1 parent 0016174 commit a70fcf8

23 files changed

Lines changed: 1224 additions & 275 deletions

include/subexpr.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,28 @@ typedef struct right_matmul_expr
131131
CSC_Matrix *CSC_work;
132132
} right_matmul_expr;
133133

134+
/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children
135+
* may be composite expressions. */
136+
typedef struct matmul_expr
137+
{
138+
expr base;
139+
/* Jacobian workspace */
140+
CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_f */
141+
CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */
142+
143+
/* Hessian workspace (composite only) */
144+
CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */
145+
CSR_Matrix *BJg; /* B @ J_g */
146+
CSC_Matrix *BJg_CSC; /* BJg in CSC */
147+
int *BJg_csc_work; /* CSR-to-CSC workspace */
148+
CSR_Matrix *C; /* J_f^T @ B @ J_g */
149+
CSR_Matrix *CT; /* C^T */
150+
int *idx_map_C;
151+
int *idx_map_CT;
152+
int *idx_map_Hf;
153+
int *idx_map_Hg;
154+
} matmul_expr;
155+
134156
/* Constant scalar multiplication: y = a * child where a is a constant double */
135157
typedef struct const_scalar_mult_expr
136158
{

include/utils/CSR_sum.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
1414

1515
/* Compute sparsity pattern of A + B where A, B, C are CSR matrices.
1616
* Fills C->p, C->i, and C->nnz; does not touch C->x. */
17-
void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B,
18-
CSR_Matrix *C);
17+
void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
1918

2019
/* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i)
2120
* is already filled and matches the union of A and B per row. Does not modify
2221
* C->p, C->i, or C->nnz. */
23-
void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
24-
CSR_Matrix *C);
22+
void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
2523

2624
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
2725
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,

include/utils/dense_matrix.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef DENSE_MATRIX_H
2+
#define DENSE_MATRIX_H
3+
4+
#include "matrix.h"
5+
6+
/* Dense matrix (row-major) */
7+
typedef struct Dense_Matrix
8+
{
9+
Matrix base;
10+
double *x;
11+
double *work; /* scratch buffer, length n */
12+
} Dense_Matrix;
13+
14+
/* Constructors */
15+
Matrix *new_dense_matrix(int m, int n, const double *data);
16+
17+
/* Transpose helper */
18+
Matrix *dense_matrix_trans(const Dense_Matrix *self);
19+
20+
#endif /* DENSE_MATRIX_H */
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef LINALG_DENSE_SPARSE_H
2+
#define LINALG_DENSE_SPARSE_H
3+
4+
#include "CSC_Matrix.h"
5+
#include "CSR_Matrix.h"
6+
#include "matrix.h"
7+
8+
/* C = (I_p kron A) @ J via the polymorphic Matrix interface.
9+
* A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */
10+
// TODO: maybe we can replace these with I_kron_X functionality?
11+
CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p);
12+
void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C);
13+
14+
/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p,
15+
and C is (m*n) x p. Y is given in column-major dense format. */
16+
CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J);
17+
void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J,
18+
CSR_Matrix *C);
19+
20+
/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense),
21+
J is (k*n) x p, and C is (m*n) x p. */
22+
CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J);
23+
void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J,
24+
CSR_Matrix *C);
25+
26+
#endif /* LINALG_DENSE_SPARSE_H */
Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#ifndef LINALG_H
22
#define LINALG_H
33

4-
/* Forward declarations */
5-
struct CSR_Matrix;
6-
struct CSC_Matrix;
4+
#include "CSC_Matrix.h"
5+
#include "CSR_Matrix.h"
76

87
/* Compute sparsity pattern and values for the matrix-matrix multiplication
98
C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k,
@@ -15,32 +14,28 @@ struct CSC_Matrix;
1514
* Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp],
1615
where J = [J1; J2; ...; Jp]
1716
*/
18-
struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A,
19-
const struct CSC_Matrix *J,
20-
int p);
17+
CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A,
18+
const CSC_Matrix *J, int p);
2119

22-
void block_left_multiply_fill_values(const struct CSR_Matrix *A,
23-
const struct CSC_Matrix *J,
24-
struct CSC_Matrix *C);
20+
void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J,
21+
CSC_Matrix *C);
2522

2623
/* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector.
2724
The output y is m*p-length vector corresponding to
2825
y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n
2926
elements.
3027
*/
31-
void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y,
32-
int p);
28+
void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p);
3329

3430
/* Fill values of C = A @ B where A is CSR, B is CSC.
3531
* C must have sparsity pattern already computed.
3632
*/
37-
void csr_csc_matmul_fill_values(const struct CSR_Matrix *A,
38-
const struct CSC_Matrix *B, struct CSR_Matrix *C);
33+
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
34+
CSR_Matrix *C);
3935

4036
/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
4137
* Allocates and precomputes sparsity pattern. No workspace required.
4238
*/
43-
struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A,
44-
const struct CSC_Matrix *B);
39+
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);
4540

4641
#endif /* LINALG_H */

include/utils/matrix.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,11 @@ typedef struct Sparse_Matrix
4141
CSR_Matrix *csr;
4242
} Sparse_Matrix;
4343

44-
/* Dense matrix (row-major) */
45-
typedef struct Dense_Matrix
46-
{
47-
Matrix base;
48-
double *x;
49-
double *work; /* scratch buffer, length n */
50-
} Dense_Matrix;
51-
5244
/* Constructors */
5345
Matrix *new_sparse_matrix(const CSR_Matrix *A);
54-
Matrix *new_dense_matrix(int m, int n, const double *data);
5546

56-
/* Transpose helpers */
47+
/* Transpose helper */
5748
Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork);
58-
Matrix *dense_matrix_trans(const Dense_Matrix *self);
5949

6050
/* Free helper */
6151
static inline void free_matrix(Matrix *m)

include/utils/mini_numpy.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,14 @@ void scaled_ones(double *result, int size, double value);
2323
/* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */
2424
void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n);
2525

26+
/* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has
27+
length m*n, and v has length m*k. Equivalently, reshape w as the
28+
m x n matrix W (col-major) and compute v = vec(W @ Y^T). */
29+
void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v);
30+
31+
/* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has
32+
length m*n, and v has length k*n. Equivalently, reshape w as the
33+
m x n matrix W (col-major) and compute v = vec(X^T @ W). */
34+
void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v);
35+
2636
#endif /* MINI_NUMPY_H */

src/affine/add.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ static void jacobian_init_impl(expr *node)
4545
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);
4646

4747
/* fill sparsity pattern */
48-
sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian,
49-
node->jacobian);
48+
sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian);
5049
}
5150

5251
static void eval_jacobian(expr *node)
@@ -56,8 +55,7 @@ static void eval_jacobian(expr *node)
5655
node->right->eval_jacobian(node->right);
5756

5857
/* sum children's jacobians */
59-
sum_csr_matrices_fill_values(node->left->jacobian, node->right->jacobian,
60-
node->jacobian);
58+
sum_csr_fill_values(node->left->jacobian, node->right->jacobian, node->jacobian);
6159
}
6260

6361
static void wsum_hess_init_impl(expr *node)
@@ -71,8 +69,7 @@ static void wsum_hess_init_impl(expr *node)
7169
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
7270

7371
/* fill sparsity pattern of hessian */
74-
sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess,
75-
node->wsum_hess);
72+
sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
7673
}
7774

7875
static void eval_wsum_hess(expr *node, const double *w)
@@ -82,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w)
8279
node->right->eval_wsum_hess(node->right, w);
8380

8481
/* sum children's wsum_hess */
85-
sum_csr_matrices_fill_values(node->left->wsum_hess, node->right->wsum_hess,
86-
node->wsum_hess);
82+
sum_csr_fill_values(node->left->wsum_hess, node->right->wsum_hess,
83+
node->wsum_hess);
8784
}
8885

8986
static bool is_affine(const expr *node)

src/affine/hstack.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ static void wsum_hess_init_impl(expr *node)
124124
{
125125
expr *child = hnode->args[i];
126126
copy_csr_matrix(H, hnode->CSR_work);
127-
sum_csr_matrices_fill_sparsity(hnode->CSR_work, child->wsum_hess, H);
127+
sum_csr_alloc(hnode->CSR_work, child->wsum_hess, H);
128128
}
129129
}
130130

@@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w)
140140
expr *child = hnode->args[i];
141141
child->eval_wsum_hess(child, w + row_offset);
142142
copy_csr_matrix(H, hnode->CSR_work);
143-
sum_csr_matrices_fill_values(hnode->CSR_work, child->wsum_hess, H);
143+
sum_csr_fill_values(hnode->CSR_work, child->wsum_hess, H);
144144
row_offset += child->size;
145145
}
146146
}

src/affine/left_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
*/
1818
#include "affine.h"
1919
#include "subexpr.h"
20-
#include "utils/matrix.h"
20+
#include "utils/dense_matrix.h"
2121
#include <assert.h>
2222
#include <stdio.h>
2323
#include <stdlib.h>

0 commit comments

Comments
 (0)