Skip to content

Commit 7a5e56d

Browse files
Transurgeonclaude
andcommitted
Merge origin/main into parameter-support-v2
Resolve conflicts from quad-form chain rule PRs (#65, #66): - Keep parameter_expr struct from this branch - Update new_left_matmul calls in new tests to use 3-arg signature (NULL param_node) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents ff5c33c + 0016174 commit 7a5e56d

40 files changed

Lines changed: 878 additions & 345 deletions

include/problem.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ typedef struct
3636
int nnz_affine;
3737
int nnz_nonlinear; /* jacobian of nonlinear constraints */
3838
int nnz_hessian;
39+
int n_vars;
40+
int total_constraint_size;
3941
} Diff_engine_stats;
4042

4143
typedef struct problem

include/subexpr.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ typedef struct parameter_expr
4040
bool has_been_refreshed;
4141
} parameter_expr;
4242

43-
/* Linear operator: y = A * x + b */
43+
/* Linear operator: y = A * x + b
44+
* The matrix A is stored as node->jacobian (CSR). */
4445
typedef struct linear_op_expr
4546
{
4647
expr base;
47-
CSC_Matrix *A_csc;
48-
CSR_Matrix *A_csr;
4948
double *b; /* constant offset vector (NULL if no offset) */
5049
} linear_op_expr;
5150

@@ -61,6 +60,7 @@ typedef struct quad_form_expr
6160
{
6261
expr base;
6362
CSR_Matrix *Q;
63+
CSC_Matrix *QJf; /* Q * J_f in CSC (for chain rule hessian) */
6464
} quad_form_expr;
6565

6666
/* Sum reduction along an axis */
@@ -110,8 +110,12 @@ typedef struct hstack_expr
110110
typedef struct elementwise_mult_expr
111111
{
112112
expr base;
113-
CSR_Matrix *CSR_work1;
114-
CSR_Matrix *CSR_work2;
113+
CSR_Matrix *CSR_work1; /* C = Jg2^T diag(w) Jg1 */
114+
CSR_Matrix *CSR_work2; /* CT = C^T */
115+
int *idx_map_C; /* C[j] -> wsum_hess pos */
116+
int *idx_map_CT; /* CT[j] -> wsum_hess pos */
117+
int *idx_map_Hx; /* x->wsum_hess[j] -> pos */
118+
int *idx_map_Hy; /* y->wsum_hess[j] -> pos */
115119
} elementwise_mult_expr;
116120

117121
/* Left matrix multiplication: y = A * f(x) where f(x) is an expression. Note that

include/utils/CSC_Matrix.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,37 @@ CSC_Matrix *new_csc_matrix(int m, int n, int nnz);
2929
/* Free a CSC matrix */
3030
void free_csc_matrix(CSC_Matrix *matrix);
3131

32-
CSC_Matrix *csr_to_csc(const CSR_Matrix *A);
33-
34-
/* Allocate sparsity pattern for C = A^T D A for diagonal D */
32+
/* Fill sparsity of C = A^T D A for diagonal D */
3533
CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3634

37-
/* Allocate sparsity pattern for C = B^T D A for diagonal D */
35+
/* Fill sparsity of C = B^T D A for diagonal D */
3836
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B);
3937

40-
/* Compute values for C = A^T D A. C must have precomputed sparsity pattern */
38+
/* Fill sparsity of C = BA, where B is symmetric. */
39+
CSC_Matrix *symBA_alloc(const CSR_Matrix *B, const CSC_Matrix *A);
40+
41+
/* Compute values for C = A^T D A (null d corresponds to D as identity) */
4142
void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
4243

43-
/* Compute values for C = B^T D A. C must have precomputed sparsity pattern */
44+
/* Compute values for C = B^T D A (null d corresonds to D as identity) */
4445
void BTDA_fill_values(const CSC_Matrix *A, const CSC_Matrix *B, const double *d,
4546
CSR_Matrix *C);
4647

47-
/* C = z^T A where A is in CSC format and C is assumed to have one row.
48-
* C must have column indices pre-computed. Fills in values of C only.
49-
*/
50-
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);
48+
/* Fill values of C = BA. The matrix B does not have to be symmetric */
49+
void BA_fill_values(const CSR_Matrix *B, const CSC_Matrix *A, CSC_Matrix *C);
50+
51+
/* Fill values of C = x^T A. The matrix C must have filled sparsity. */
52+
void yTA_fill_values(const CSC_Matrix *A, const double *x, CSR_Matrix *C);
53+
54+
/* Count nonzero columns of a CSC matrix */
55+
int count_nonzero_cols_csc(const CSC_Matrix *A);
5156

52-
CSC_Matrix *csr_to_csc_fill_sparsity(const CSR_Matrix *A, int *iwork);
57+
/* convert from CSR to CSC format */
58+
CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork);
5359
void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork);
5460

55-
CSR_Matrix *csc_to_csr_fill_sparsity(const CSC_Matrix *A, int *iwork);
61+
/* convert from CSC to CSR format */
62+
CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork);
5663
void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork);
5764

5865
#endif /* CSC_MATRIX_H */

include/utils/CSR_sum.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,15 @@ void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
8686
int spacing, int *iwork,
8787
int *idx_map);
8888

89+
/* 4-way sorted merge of CSR matrices A, B, C, D (same dimensions).
90+
* Allocates and returns the output CSR with the union sparsity pattern.
91+
* Allocates and fills idx_maps[0..3] (one per input, size input->nnz
92+
* each) mapping each input entry to its position in the output.
93+
* Caller owns the returned CSR and all 4 idx_map arrays. */
94+
CSR_Matrix *sum_4_csr_fill_sparsity_and_idx_maps(const CSR_Matrix *A,
95+
const CSR_Matrix *B,
96+
const CSR_Matrix *C,
97+
const CSR_Matrix *D,
98+
int *idx_maps[4]);
99+
89100
#endif /* CSR_SUM_H */

src/affine/left_matmul.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ static void jacobian_init_impl(expr *node)
111111

112112
/* initialize child's jacobian and precompute sparsity of its CSC */
113113
jacobian_init(x);
114-
lnode->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->work->iwork);
114+
lnode->Jchild_CSC = csr_to_csc_alloc(x->jacobian, node->work->iwork);
115115

116116
/* precompute sparsity of this node's jacobian in CSC and CSR */
117117
lnode->J_CSC = lnode->A->block_left_mult_sparsity(lnode->A, lnode->Jchild_CSC,
118118
lnode->n_blocks);
119-
node->jacobian = csc_to_csr_fill_sparsity(lnode->J_CSC, lnode->csc_to_csr_work);
119+
node->jacobian = csc_to_csr_alloc(lnode->J_CSC, lnode->csc_to_csr_work);
120120
}
121121

122122
static void eval_jacobian(expr *node)

src/affine/linear_op.c

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "utils/CSR_Matrix.h"
1920
#include <assert.h>
2021
#include <stdlib.h>
2122
#include <string.h>
@@ -28,8 +29,8 @@ static void forward(expr *node, const double *u)
2829
/* child's forward pass */
2930
node->left->forward(node->left, u);
3031

31-
/* y = A * x */
32-
csr_matvec(lin_node->A_csr, x->value, node->value, x->var_id);
32+
/* y = A * x (A is stored as node->jacobian) */
33+
csr_matvec(node->jacobian, x->value, node->value, x->var_id);
3334

3435
/* y += b (if offset exists) */
3536
if (lin_node->b != NULL)
@@ -49,29 +50,17 @@ static bool is_affine(const expr *node)
4950
static void free_type_data(expr *node)
5051
{
5152
linear_op_expr *lin_node = (linear_op_expr *) node;
52-
/* memory pointing to by A_csr will be freed when the jacobian is freed,
53-
so if the jacobian is not null we must not free A_csr. */
54-
55-
if (!node->jacobian)
56-
{
57-
free_csr_matrix(lin_node->A_csr);
58-
}
59-
60-
free_csc_matrix(lin_node->A_csc);
61-
6253
if (lin_node->b != NULL)
6354
{
6455
free(lin_node->b);
6556
lin_node->b = NULL;
6657
}
67-
68-
lin_node->A_csr = NULL;
69-
lin_node->A_csc = NULL;
7058
}
7159

7260
static void jacobian_init_impl(expr *node)
7361
{
74-
node->jacobian = ((linear_op_expr *) node)->A_csr;
62+
/* jacobian is set at construction time — nothing to do */
63+
(void) node;
7564
}
7665

7766
static void eval_jacobian(expr *node)
@@ -80,21 +69,33 @@ static void eval_jacobian(expr *node)
8069
(void) node;
8170
}
8271

72+
static void wsum_hess_init_impl(expr *node)
73+
{
74+
/* Linear operator Hessian is always zero */
75+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
76+
}
77+
78+
static void eval_wsum_hess(expr *node, const double *w)
79+
{
80+
/* Linear operator Hessian is always zero - nothing to evaluate */
81+
(void) node;
82+
(void) w;
83+
}
84+
8385
expr *new_linear(expr *u, const CSR_Matrix *A, const double *b)
8486
{
8587
assert(u->d2 == 1);
8688
/* Allocate the type-specific struct */
8789
linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr));
8890
expr *node = &lin_node->base;
8991
init_expr(node, A->m, 1, u->n_vars, forward, jacobian_init_impl, eval_jacobian,
90-
is_affine, NULL, NULL, free_type_data);
92+
is_affine, wsum_hess_init_impl, eval_wsum_hess, free_type_data);
9193
node->left = u;
9294
expr_retain(u);
9395

94-
/* Initialize type-specific fields */
95-
lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
96-
copy_csr_matrix(A, lin_node->A_csr);
97-
lin_node->A_csc = csr_to_csc(A);
96+
/* Store A directly as the jacobian (linear op jacobian is constant) */
97+
node->jacobian = new_csr_matrix(A->m, A->n, A->nnz);
98+
copy_csr_matrix(A, node->jacobian);
9899

99100
/* Initialize offset (copy b if provided, otherwise NULL) */
100101
if (b != NULL)

0 commit comments

Comments
 (0)