Skip to content

Commit 608e390

Browse files
Transurgeonclaude
andcommitted
Merge origin/main into parameter-support-v2
Resolve conflicts from matmul chain rule PR (#67): - subexpr.h: keep both vector_mult_expr and new matmul_expr/const types - dense_matrix.c: adopt shared I_kron_A functions, keep update_values - Add missing dense_matrix.h include in right_matmul.c - Update new matmul tests to use 3-arg left_matmul signature Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents 7a5e56d + a70fcf8 commit 608e390

24 files changed

Lines changed: 1239 additions & 275 deletions

include/subexpr.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,42 @@ typedef struct vector_mult_expr
149149
expr *param_source;
150150
} vector_mult_expr;
151151

152+
/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children
153+
* may be composite expressions. */
154+
typedef struct matmul_expr
155+
{
156+
expr base;
157+
/* Jacobian workspace */
158+
CSR_Matrix *term1_CSR; /* (Y^T x I_m) @ J_f */
159+
CSR_Matrix *term2_CSR; /* (I_n x X) @ J_g */
160+
161+
/* Hessian workspace (composite only) */
162+
CSR_Matrix *B; /* cross-Hessian B(w), mk x kn */
163+
CSR_Matrix *BJg; /* B @ J_g */
164+
CSC_Matrix *BJg_CSC; /* BJg in CSC */
165+
int *BJg_csc_work; /* CSR-to-CSC workspace */
166+
CSR_Matrix *C; /* J_f^T @ B @ J_g */
167+
CSR_Matrix *CT; /* C^T */
168+
int *idx_map_C;
169+
int *idx_map_CT;
170+
int *idx_map_Hf;
171+
int *idx_map_Hg;
172+
} matmul_expr;
173+
174+
/* Constant scalar multiplication: y = a * child where a is a constant double */
175+
typedef struct const_scalar_mult_expr
176+
{
177+
expr base;
178+
double a;
179+
} const_scalar_mult_expr;
180+
181+
/* Constant vector elementwise multiplication: y = a \circ child for constant a */
182+
typedef struct const_vector_mult_expr
183+
{
184+
expr base;
185+
double *a; /* length equals node->size */
186+
} const_vector_mult_expr;
187+
152188
/* Index/slicing: y = child[indices] where indices is a list of flat positions */
153189
typedef struct index_expr
154190
{

include/utils/CSR_sum.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
1414

1515
/* Compute sparsity pattern of A + B where A, B, C are CSR matrices.
1616
* Fills C->p, C->i, and C->nnz; does not touch C->x. */
17-
void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B,
18-
CSR_Matrix *C);
17+
void sum_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
1918

2019
/* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i)
2120
* is already filled and matches the union of A and B per row. Does not modify
2221
* C->p, C->i, or C->nnz. */
23-
void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
24-
CSR_Matrix *C);
22+
void sum_csr_fill_values(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
2523

2624
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
2725
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,

include/utils/dense_matrix.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef DENSE_MATRIX_H
2+
#define DENSE_MATRIX_H
3+
4+
#include "matrix.h"
5+
6+
/* Dense matrix (row-major) */
7+
typedef struct Dense_Matrix
8+
{
9+
Matrix base;
10+
double *x;
11+
double *work; /* scratch buffer, length n */
12+
} Dense_Matrix;
13+
14+
/* Constructors */
15+
Matrix *new_dense_matrix(int m, int n, const double *data);
16+
17+
/* Transpose helper */
18+
Matrix *dense_matrix_trans(const Dense_Matrix *self);
19+
20+
#endif /* DENSE_MATRIX_H */
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef LINALG_DENSE_SPARSE_H
2+
#define LINALG_DENSE_SPARSE_H
3+
4+
#include "CSC_Matrix.h"
5+
#include "CSR_Matrix.h"
6+
#include "matrix.h"
7+
8+
/* C = (I_p kron A) @ J via the polymorphic Matrix interface.
9+
* A is dense m x n, J is (n*p) x k in CSC, C is (m*p) x k in CSC. */
10+
// TODO: maybe we can replace these with I_kron_X functionality?
11+
CSC_Matrix *I_kron_A_alloc(const Matrix *A, const CSC_Matrix *J, int p);
12+
void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C);
13+
14+
/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p,
15+
and C is (m*n) x p. Y is given in column-major dense format. */
16+
CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J);
17+
void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J,
18+
CSR_Matrix *C);
19+
20+
/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense),
21+
J is (k*n) x p, and C is (m*n) x p. */
22+
CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J);
23+
void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J,
24+
CSR_Matrix *C);
25+
26+
#endif /* LINALG_DENSE_SPARSE_H */
Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#ifndef LINALG_H
22
#define LINALG_H
33

4-
/* Forward declarations */
5-
struct CSR_Matrix;
6-
struct CSC_Matrix;
4+
#include "CSC_Matrix.h"
5+
#include "CSR_Matrix.h"
76

87
/* Compute sparsity pattern and values for the matrix-matrix multiplication
98
C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k,
@@ -15,32 +14,28 @@ struct CSC_Matrix;
1514
* Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp],
1615
where J = [J1; J2; ...; Jp]
1716
*/
18-
struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A,
19-
const struct CSC_Matrix *J,
20-
int p);
17+
CSC_Matrix *block_left_multiply_fill_sparsity(const CSR_Matrix *A,
18+
const CSC_Matrix *J, int p);
2119

22-
void block_left_multiply_fill_values(const struct CSR_Matrix *A,
23-
const struct CSC_Matrix *J,
24-
struct CSC_Matrix *C);
20+
void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J,
21+
CSC_Matrix *C);
2522

2623
/* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector.
2724
The output y is m*p-length vector corresponding to
2825
y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n
2926
elements.
3027
*/
31-
void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y,
32-
int p);
28+
void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p);
3329

3430
/* Fill values of C = A @ B where A is CSR, B is CSC.
3531
* C must have sparsity pattern already computed.
3632
*/
37-
void csr_csc_matmul_fill_values(const struct CSR_Matrix *A,
38-
const struct CSC_Matrix *B, struct CSR_Matrix *C);
33+
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
34+
CSR_Matrix *C);
3935

4036
/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
4137
* Allocates and precomputes sparsity pattern. No workspace required.
4238
*/
43-
struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A,
44-
const struct CSC_Matrix *B);
39+
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);
4540

4641
#endif /* LINALG_H */

include/utils/matrix.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,11 @@ typedef struct Sparse_Matrix
4242
CSR_Matrix *csr;
4343
} Sparse_Matrix;
4444

45-
/* Dense matrix (row-major) */
46-
typedef struct Dense_Matrix
47-
{
48-
Matrix base;
49-
double *x;
50-
double *work; /* scratch buffer, length n */
51-
} Dense_Matrix;
52-
5345
/* Constructors */
5446
Matrix *new_sparse_matrix(const CSR_Matrix *A);
55-
Matrix *new_dense_matrix(int m, int n, const double *data);
5647

57-
/* Transpose helpers */
48+
/* Transpose helper */
5849
Matrix *sparse_matrix_trans(const Sparse_Matrix *self, int *iwork);
59-
Matrix *dense_matrix_trans(const Dense_Matrix *self);
6050

6151
/* Free helper */
6252
static inline void free_matrix(Matrix *m)

include/utils/mini_numpy.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,14 @@ void scaled_ones(double *result, int size, double value);
2323
/* Naive implementation of Z = X @ Y, X is m x k, Y is k x n, Z is m x n */
2424
void mat_mat_mult(const double *X, const double *Y, double *Z, int m, int k, int n);
2525

26+
/* Compute v = (Y kron I_m) @ w where Y is k x n (col-major), w has
27+
length m*n, and v has length m*k. Equivalently, reshape w as the
28+
m x n matrix W (col-major) and compute v = vec(W @ Y^T). */
29+
void Y_kron_I_vec(int m, int k, int n, const double *Y, const double *w, double *v);
30+
31+
/* Compute v = (I_n kron X^T) @ w where X is m x k (col-major), w has
32+
length m*n, and v has length k*n. Equivalently, reshape w as the
33+
m x n matrix W (col-major) and compute v = vec(X^T @ W). */
34+
void I_kron_XT_vec(int m, int k, int n, const double *X, const double *w, double *v);
35+
2636
#endif /* MINI_NUMPY_H */

src/affine/add.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ static void jacobian_init_impl(expr *node)
4545
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);
4646

4747
/* fill sparsity pattern */
48-
sum_csr_matrices_fill_sparsity(node->left->jacobian, node->right->jacobian,
49-
node->jacobian);
48+
sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian);
5049
}
5150

5251
static void eval_jacobian(expr *node)
@@ -56,8 +55,7 @@ static void eval_jacobian(expr *node)
5655
node->right->eval_jacobian(node->right);
5756

5857
/* sum children's jacobians */
59-
sum_csr_matrices_fill_values(node->left->jacobian, node->right->jacobian,
60-
node->jacobian);
58+
sum_csr_fill_values(node->left->jacobian, node->right->jacobian, node->jacobian);
6159
}
6260

6361
static void wsum_hess_init_impl(expr *node)
@@ -71,8 +69,7 @@ static void wsum_hess_init_impl(expr *node)
7169
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
7270

7371
/* fill sparsity pattern of hessian */
74-
sum_csr_matrices_fill_sparsity(node->left->wsum_hess, node->right->wsum_hess,
75-
node->wsum_hess);
72+
sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
7673
}
7774

7875
static void eval_wsum_hess(expr *node, const double *w)
@@ -82,8 +79,8 @@ static void eval_wsum_hess(expr *node, const double *w)
8279
node->right->eval_wsum_hess(node->right, w);
8380

8481
/* sum children's wsum_hess */
85-
sum_csr_matrices_fill_values(node->left->wsum_hess, node->right->wsum_hess,
86-
node->wsum_hess);
82+
sum_csr_fill_values(node->left->wsum_hess, node->right->wsum_hess,
83+
node->wsum_hess);
8784
}
8885

8986
static bool is_affine(const expr *node)

src/affine/hstack.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ static void wsum_hess_init_impl(expr *node)
124124
{
125125
expr *child = hnode->args[i];
126126
copy_csr_matrix(H, hnode->CSR_work);
127-
sum_csr_matrices_fill_sparsity(hnode->CSR_work, child->wsum_hess, H);
127+
sum_csr_alloc(hnode->CSR_work, child->wsum_hess, H);
128128
}
129129
}
130130

@@ -140,7 +140,7 @@ static void wsum_hess_eval(expr *node, const double *w)
140140
expr *child = hnode->args[i];
141141
child->eval_wsum_hess(child, w + row_offset);
142142
copy_csr_matrix(H, hnode->CSR_work);
143-
sum_csr_matrices_fill_values(hnode->CSR_work, child->wsum_hess, H);
143+
sum_csr_fill_values(hnode->CSR_work, child->wsum_hess, H);
144144
row_offset += child->size;
145145
}
146146
}

src/affine/left_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
*/
1818
#include "affine.h"
1919
#include "subexpr.h"
20-
#include "utils/matrix.h"
20+
#include "utils/dense_matrix.h"
2121
#include <assert.h>
2222
#include <stdio.h>
2323
#include <stdlib.h>

0 commit comments

Comments
 (0)