Skip to content

Commit 36864d4

Browse files
Transurgeonclaude
andcommitted
Simplify new_left_matmul: accept CSR directly, remove sparse/dense branching
- Change signature to (expr *param_node, expr *child, const CSR_Matrix *A) - Constructor copies A with new_csr() instead of rebuilding from dense values - Remove src_m/src_n fields from left_matmul_expr (use A->m directly) - Allow param_node=NULL for fixed constants (no-op in refresh_param_values) - Update all tests to pass CSR directly; fixed-constant tests use NULL param Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 71dddf1 commit 36864d4

9 files changed

Lines changed: 116 additions & 84 deletions

File tree

include/bivariate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
3131
expr *new_matmul(expr *x, expr *y);
3232

3333
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */
34-
expr *new_left_matmul(expr *param_node, expr *child);
34+
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A);
3535

3636
/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */
3737
expr *new_right_matmul(expr *u, const CSR_Matrix *A);

include/subexpr.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ typedef struct left_matmul_expr
127127
CSC_Matrix *J_CSC;
128128
int *csc_to_csr_workspace;
129129
expr *param_source; /* parameter node; A/AT values are refreshed from this */
130-
int src_m, src_n; /* original matrix dimensions */
131130
} left_matmul_expr;
132131

133132
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.

src/bivariate/left_matmul.c

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@
5151
#include <string.h>
5252

5353
/* Refresh A and AT values from param_source.
54-
A is the small m x n matrix (NOT block-diagonal). */
54+
A is the small m x n matrix (NOT block-diagonal).
55+
No-op when param_source is NULL (fixed constant — values already in A). */
5556
static void refresh_param_values(left_matmul_expr *lin_node)
5657
{
58+
if (!lin_node->param_source) return;
59+
5760
const double *src = lin_node->param_source->value;
58-
int m = lin_node->src_m;
61+
int m = lin_node->A->m;
5962
CSR_Matrix *A = lin_node->A;
6063

6164
/* Fill A values from column-major source, following existing sparsity pattern */
@@ -163,10 +166,10 @@ static void eval_wsum_hess(expr *node, const double *w)
163166
node->wsum_hess->nnz * sizeof(double));
164167
}
165168

166-
expr *new_left_matmul(expr *param_node, expr *child)
169+
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
167170
{
168-
int A_m = param_node->d1;
169-
int A_n = param_node->d2;
171+
int A_m = A->m;
172+
int A_n = A->n;
170173

171174
/* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */
172175
int d1, d2, n_blocks;
@@ -188,40 +191,6 @@ expr *new_left_matmul(expr *param_node, expr *child)
188191
exit(1);
189192
}
190193

191-
/* Build CSR from param_node's column-major values.
192-
* For fixed parameters (PARAM_FIXED), skip zeros to preserve sparsity.
193-
* For updatable parameters, build dense CSR since sparsity may change. */
194-
parameter_expr *pnode = (parameter_expr *) param_node;
195-
int sparse = (pnode->param_id == PARAM_FIXED);
196-
197-
int nnz = 0;
198-
if (sparse)
199-
{
200-
for (int row = 0; row < A_m; row++)
201-
for (int col = 0; col < A_n; col++)
202-
if (param_node->value[row + col * A_m] != 0.0) nnz++;
203-
}
204-
else
205-
{
206-
nnz = A_m * A_n;
207-
}
208-
209-
CSR_Matrix *A = new_csr_matrix(A_m, A_n, nnz);
210-
int idx = 0;
211-
for (int row = 0; row < A_m; row++)
212-
{
213-
A->p[row] = idx;
214-
for (int col = 0; col < A_n; col++)
215-
{
216-
double val = param_node->value[row + col * A_m];
217-
if (sparse && val == 0.0) continue;
218-
A->i[idx] = col;
219-
A->x[idx] = val;
220-
idx++;
221-
}
222-
}
223-
A->p[A_m] = idx;
224-
225194
/* Allocate the type-specific struct */
226195
left_matmul_expr *lin_node =
227196
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
@@ -235,13 +204,11 @@ expr *new_left_matmul(expr *param_node, expr *child)
235204
node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int));
236205
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
237206
lin_node->n_blocks = n_blocks;
238-
lin_node->A = A; /* transfer ownership */
239-
lin_node->AT = transpose(A, node->iwork);
207+
lin_node->A = new_csr(A);
208+
lin_node->AT = transpose(lin_node->A, node->iwork);
240209

241210
lin_node->param_source = param_node;
242-
lin_node->src_m = A_m;
243-
lin_node->src_n = A_n;
244-
expr_retain(param_node);
211+
if (param_node) expr_retain(param_node);
245212

246213
return node;
247214
}

src/bivariate/right_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A)
4343

4444
expr *u_transpose = new_transpose(u);
4545
expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major);
46-
expr *left_matmul_node = new_left_matmul(param_node, u_transpose);
46+
expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT);
4747
expr *node = new_transpose(left_matmul_node);
4848

4949
free(col_major);

tests/jacobian_tests/test_left_matmul.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <math.h>
22
#include <stdio.h>
3+
#include <string.h>
34

45
#include "bivariate.h"
56
#include "elementwise_univariate.h"
@@ -31,12 +32,18 @@ const char *test_jacobian_left_matmul_log()
3132
double x_vals[3] = {1.0, 2.0, 3.0};
3233
expr *x = new_variable(3, 1, 0, 3);
3334

34-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
35-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
36-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
35+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
36+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
37+
int A_p[5] = {0, 2, 4, 6, 7};
38+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
39+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
40+
memcpy(A->p, A_p, 5 * sizeof(int));
41+
memcpy(A->i, A_i, 7 * sizeof(int));
42+
memcpy(A->x, A_x, 7 * sizeof(double));
3743

3844
expr *log_x = new_log(x);
39-
expr *A_log_x = new_left_matmul(A_param, log_x);
45+
expr *A_log_x = new_left_matmul(NULL, log_x, A);
46+
free_csr_matrix(A);
4047

4148
A_log_x->forward(A_log_x, x_vals);
4249
A_log_x->jacobian_init(A_log_x);
@@ -69,12 +76,18 @@ const char *test_jacobian_left_matmul_log_matrix()
6976
double x_vals[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
7077
expr *x = new_variable(3, 2, 0, 6);
7178

72-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
73-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
74-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals);
79+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
80+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
81+
int A_p[5] = {0, 2, 4, 6, 7};
82+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
83+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
84+
memcpy(A->p, A_p, 5 * sizeof(int));
85+
memcpy(A->i, A_i, 7 * sizeof(int));
86+
memcpy(A->x, A_x, 7 * sizeof(double));
7587

7688
expr *log_x = new_log(x);
77-
expr *A_log_x = new_left_matmul(A_param, log_x);
89+
expr *A_log_x = new_left_matmul(NULL, log_x, A);
90+
free_csr_matrix(A);
7891

7992
A_log_x->forward(A_log_x, x_vals);
8093
A_log_x->jacobian_init(A_log_x);
@@ -134,13 +147,19 @@ const char *test_jacobian_left_matmul_log_composite()
134147
memcpy(B->i, B_i, 9 * sizeof(int));
135148
memcpy(B->x, B_x, 9 * sizeof(double));
136149

137-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
138-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
139-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
150+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
151+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
152+
int A_p[5] = {0, 2, 4, 6, 7};
153+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
154+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
155+
memcpy(A->p, A_p, 5 * sizeof(int));
156+
memcpy(A->i, A_i, 7 * sizeof(int));
157+
memcpy(A->x, A_x, 7 * sizeof(double));
140158

141159
expr *Bx = new_linear(x, B, NULL);
142160
expr *log_Bx = new_log(Bx);
143-
expr *A_log_Bx = new_left_matmul(A_param, log_Bx);
161+
expr *A_log_Bx = new_left_matmul(NULL, log_Bx, A);
162+
free_csr_matrix(A);
144163

145164
A_log_Bx->forward(A_log_Bx, x_vals);
146165
A_log_Bx->jacobian_init(A_log_Bx);

tests/jacobian_tests/test_transpose.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,29 @@
33
#define TEST_TRANSPOSE_H
44

55
#include "affine.h"
6+
#include "bivariate.h"
67
#include "minunit.h"
78
#include "subexpr.h"
89
#include "test_helpers.h"
910
#include <math.h>
1011
#include <stdio.h>
12+
#include <string.h>
1113

1214
const char *test_jacobian_transpose()
1315
{
14-
/* A = [1 2; 3 4] in column-major order: [1, 3, 2, 4] */
15-
double A_vals[4] = {1.0, 3.0, 2.0, 4.0};
16-
expr *A_param = new_parameter(2, 2, PARAM_FIXED, 2, A_vals);
16+
/* A = [1 2; 3 4] as dense 2x2 CSR */
17+
CSR_Matrix *A = new_csr_matrix(2, 2, 4);
18+
int Ap[3] = {0, 2, 4};
19+
int Ai[4] = {0, 1, 0, 1};
20+
double Ax[4] = {1.0, 2.0, 3.0, 4.0};
21+
memcpy(A->p, Ap, 3 * sizeof(int));
22+
memcpy(A->i, Ai, 4 * sizeof(int));
23+
memcpy(A->x, Ax, 4 * sizeof(double));
1724

1825
// X = [1 2; 3 4] (columnwise: x = [1 3 2 4])
1926
expr *X = new_variable(2, 2, 0, 4);
20-
expr *AX = new_left_matmul(A_param, X);
27+
expr *AX = new_left_matmul(NULL, X, A);
28+
free_csr_matrix(A);
2129
expr *transpose_AX = new_transpose(AX);
2230
double u[4] = {1, 3, 2, 4};
2331
transpose_AX->forward(transpose_AX, u);

tests/problem/test_param_prob.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <math.h>
55
#include <stdio.h>
6+
#include <string.h>
67

78
#include "affine.h"
89
#include "bivariate.h"
@@ -185,7 +186,18 @@ const char *test_param_left_matmul_problem(void)
185186
/* Constraint: A @ x */
186187
expr *x_con = new_variable(2, 1, 0, n_vars);
187188
expr *A_param = new_parameter(2, 2, 0, n_vars, NULL);
188-
expr *constraint = new_left_matmul(A_param, x_con);
189+
190+
/* Dense 2x2 CSR with placeholder zeros (values refreshed from A_param) */
191+
CSR_Matrix *A = new_csr_matrix(2, 2, 4);
192+
int Ap[3] = {0, 2, 4};
193+
int Ai[4] = {0, 1, 0, 1};
194+
double Ax[4] = {0.0, 0.0, 0.0, 0.0};
195+
memcpy(A->p, Ap, 3 * sizeof(int));
196+
memcpy(A->i, Ai, 4 * sizeof(int));
197+
memcpy(A->x, Ax, 4 * sizeof(double));
198+
199+
expr *constraint = new_left_matmul(A_param, x_con, A);
200+
free_csr_matrix(A);
189201

190202
expr *constraints[1] = {constraint};
191203

tests/profiling/profile_left_matmul.h

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@ const char *profile_left_matmul()
1818
int n = 100;
1919
expr *X = new_variable(n, n, 0, n * n);
2020

21-
/* Create n x n parameter of all ones (column-major, but all ones so order
22-
* doesn't matter) */
23-
double *A_vals = (double *) malloc(n * n * sizeof(double));
24-
for (int i = 0; i < n * n; i++)
21+
/* Build dense n x n CSR (all ones) */
22+
int nnz = n * n;
23+
CSR_Matrix *A = new_csr_matrix(n, n, nnz);
2524
{
26-
A_vals[i] = 1.0;
25+
int idx = 0;
26+
for (int row = 0; row < n; row++)
27+
{
28+
A->p[row] = idx;
29+
for (int col = 0; col < n; col++)
30+
{
31+
A->i[idx] = col;
32+
A->x[idx] = 1.0;
33+
idx++;
34+
}
35+
}
36+
A->p[n] = idx;
2737
}
28-
expr *A_param = new_parameter(n, n, PARAM_FIXED, n, A_vals);
29-
free(A_vals);
3038

31-
expr *AX = new_left_matmul(A_param, X);
39+
expr *AX = new_left_matmul(NULL, X, A);
40+
free_csr_matrix(A);
3241

3342
double *x_vals = (double *) malloc(n * n * sizeof(double));
3443
for (int i = 0; i < n * n; i++)

tests/wsum_hess/test_left_matmul.h

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,18 @@ const char *test_wsum_hess_left_matmul()
5353

5454
expr *x = new_variable(3, 1, 0, 3);
5555

56-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
57-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
58-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
56+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
57+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
58+
int A_p[5] = {0, 2, 4, 6, 7};
59+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
60+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
61+
memcpy(A->p, A_p, 5 * sizeof(int));
62+
memcpy(A->i, A_i, 7 * sizeof(int));
63+
memcpy(A->x, A_x, 7 * sizeof(double));
5964

6065
expr *log_x = new_log(x);
61-
expr *A_log_x = new_left_matmul(A_param, log_x);
66+
expr *A_log_x = new_left_matmul(NULL, log_x, A);
67+
free_csr_matrix(A);
6268

6369
A_log_x->forward(A_log_x, x_vals);
6470
A_log_x->jacobian_init(A_log_x);
@@ -145,13 +151,19 @@ const char *test_wsum_hess_left_matmul_composite()
145151
memcpy(B->i, B_i, 9 * sizeof(int));
146152
memcpy(B->x, B_x, 9 * sizeof(double));
147153

148-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
149-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
150-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
154+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
155+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
156+
int A_p[5] = {0, 2, 4, 6, 7};
157+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
158+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
159+
memcpy(A->p, A_p, 5 * sizeof(int));
160+
memcpy(A->i, A_i, 7 * sizeof(int));
161+
memcpy(A->x, A_x, 7 * sizeof(double));
151162

152163
expr *Bx = new_linear(x, B, NULL);
153164
expr *log_Bx = new_log(Bx);
154-
expr *A_log_Bx = new_left_matmul(A_param, log_Bx);
165+
expr *A_log_Bx = new_left_matmul(NULL, log_Bx, A);
166+
free_csr_matrix(A);
155167

156168
A_log_Bx->forward(A_log_Bx, x_vals);
157169
A_log_Bx->jacobian_init(A_log_Bx);
@@ -212,12 +224,18 @@ const char *test_wsum_hess_left_matmul_matrix()
212224

213225
expr *x = new_variable(3, 2, 0, 6);
214226

215-
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] in column-major order */
216-
double A_vals[12] = {1.0, 3.0, 5.0, 7.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 6.0, 0.0};
217-
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals);
227+
/* A is 4x3: [1, 0, 2; 3, 0, 4; 5, 0, 6; 7, 0, 0] */
228+
CSR_Matrix *A = new_csr_matrix(4, 3, 7);
229+
int A_p[5] = {0, 2, 4, 6, 7};
230+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
231+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
232+
memcpy(A->p, A_p, 5 * sizeof(int));
233+
memcpy(A->i, A_i, 7 * sizeof(int));
234+
memcpy(A->x, A_x, 7 * sizeof(double));
218235

219236
expr *log_x = new_log(x);
220-
expr *A_log_x = new_left_matmul(A_param, log_x);
237+
expr *A_log_x = new_left_matmul(NULL, log_x, A);
238+
free_csr_matrix(A);
221239

222240
A_log_x->forward(A_log_x, x_vals);
223241
A_log_x->jacobian_init(A_log_x);

0 commit comments

Comments
 (0)