Skip to content

Commit e13ce04

Browse files
Transurgeonclaude
andcommitted
Simplify new_left_matmul: accept CSR directly, remove sparse/dense branching
Change signature from (param_node, child) to (param_node, child, A) where A is a const CSR_Matrix providing the sparsity pattern and initial values. The constructor copies A with new_csr() instead of rebuilding CSR from dense column-major values. Remove src_m/src_n fields from left_matmul_expr (read dimensions from A->m instead). Update all call sites and tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 71dddf1 commit e13ce04

File tree

9 files changed

+114
-54
lines changed

9 files changed

+114
-54
lines changed

include/bivariate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
3131
expr *new_matmul(expr *x, expr *y);
3232

3333
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */
34-
expr *new_left_matmul(expr *param_node, expr *child);
34+
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A);
3535

3636
/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */
3737
expr *new_right_matmul(expr *u, const CSR_Matrix *A);

include/subexpr.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ typedef struct left_matmul_expr
127127
CSC_Matrix *J_CSC;
128128
int *csc_to_csr_workspace;
129129
expr *param_source; /* parameter node; A/AT values are refreshed from this */
130-
int src_m, src_n; /* original matrix dimensions */
131130
} left_matmul_expr;
132131

133132
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.

src/bivariate/left_matmul.c

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
static void refresh_param_values(left_matmul_expr *lin_node)
5656
{
5757
const double *src = lin_node->param_source->value;
58-
int m = lin_node->src_m;
58+
int m = lin_node->A->m;
5959
CSR_Matrix *A = lin_node->A;
6060

6161
/* Fill A values from column-major source, following existing sparsity pattern */
@@ -163,10 +163,10 @@ static void eval_wsum_hess(expr *node, const double *w)
163163
node->wsum_hess->nnz * sizeof(double));
164164
}
165165

166-
expr *new_left_matmul(expr *param_node, expr *child)
166+
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
167167
{
168-
int A_m = param_node->d1;
169-
int A_n = param_node->d2;
168+
int A_m = A->m;
169+
int A_n = A->n;
170170

171171
/* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */
172172
int d1, d2, n_blocks;
@@ -188,40 +188,6 @@ expr *new_left_matmul(expr *param_node, expr *child)
188188
exit(1);
189189
}
190190

191-
/* Build CSR from param_node's column-major values.
192-
* For fixed parameters (PARAM_FIXED), skip zeros to preserve sparsity.
193-
* For updatable parameters, build dense CSR since sparsity may change. */
194-
parameter_expr *pnode = (parameter_expr *) param_node;
195-
int sparse = (pnode->param_id == PARAM_FIXED);
196-
197-
int nnz = 0;
198-
if (sparse)
199-
{
200-
for (int row = 0; row < A_m; row++)
201-
for (int col = 0; col < A_n; col++)
202-
if (param_node->value[row + col * A_m] != 0.0) nnz++;
203-
}
204-
else
205-
{
206-
nnz = A_m * A_n;
207-
}
208-
209-
CSR_Matrix *A = new_csr_matrix(A_m, A_n, nnz);
210-
int idx = 0;
211-
for (int row = 0; row < A_m; row++)
212-
{
213-
A->p[row] = idx;
214-
for (int col = 0; col < A_n; col++)
215-
{
216-
double val = param_node->value[row + col * A_m];
217-
if (sparse && val == 0.0) continue;
218-
A->i[idx] = col;
219-
A->x[idx] = val;
220-
idx++;
221-
}
222-
}
223-
A->p[A_m] = idx;
224-
225191
/* Allocate the type-specific struct */
226192
left_matmul_expr *lin_node =
227193
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
@@ -235,12 +201,10 @@ expr *new_left_matmul(expr *param_node, expr *child)
235201
node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int));
236202
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
237203
lin_node->n_blocks = n_blocks;
238-
lin_node->A = A; /* transfer ownership */
239-
lin_node->AT = transpose(A, node->iwork);
204+
lin_node->A = new_csr(A);
205+
lin_node->AT = transpose(lin_node->A, node->iwork);
240206

241207
lin_node->param_source = param_node;
242-
lin_node->src_m = A_m;
243-
lin_node->src_n = A_n;
244208
expr_retain(param_node);
245209

246210
return node;

src/bivariate/right_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A)
4343

4444
expr *u_transpose = new_transpose(u);
4545
expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major);
46-
expr *left_matmul_node = new_left_matmul(param_node, u_transpose);
46+
expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT);
4747
expr *node = new_transpose(left_matmul_node);
4848

4949
free(col_major);

tests/jacobian_tests/test_left_matmul.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <math.h>
22
#include <stdio.h>
3+
#include <string.h>
34

45
#include "bivariate.h"
56
#include "elementwise_univariate.h"
@@ -36,7 +37,16 @@ const char *test_jacobian_left_matmul_log()
3637
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
3738

3839
expr *log_x = new_log(x);
39-
expr *A_log_x = new_left_matmul(A_param, log_x);
40+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
41+
int A_p[5] = {0, 2, 4, 6, 7};
42+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
43+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
44+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
45+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
46+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
47+
48+
expr *A_log_x = new_left_matmul(A_param, log_x, A_csr);
49+
free_csr_matrix(A_csr);
4050

4151
A_log_x->forward(A_log_x, x_vals);
4252
A_log_x->jacobian_init(A_log_x);
@@ -74,7 +84,16 @@ const char *test_jacobian_left_matmul_log_matrix()
7484
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals);
7585

7686
expr *log_x = new_log(x);
77-
expr *A_log_x = new_left_matmul(A_param, log_x);
87+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
88+
int A_p[5] = {0, 2, 4, 6, 7};
89+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
90+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
91+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
92+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
93+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
94+
95+
expr *A_log_x = new_left_matmul(A_param, log_x, A_csr);
96+
free_csr_matrix(A_csr);
7897

7998
A_log_x->forward(A_log_x, x_vals);
8099
A_log_x->jacobian_init(A_log_x);
@@ -140,7 +159,16 @@ const char *test_jacobian_left_matmul_log_composite()
140159

141160
expr *Bx = new_linear(x, B, NULL);
142161
expr *log_Bx = new_log(Bx);
143-
expr *A_log_Bx = new_left_matmul(A_param, log_Bx);
162+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
163+
int A_p[5] = {0, 2, 4, 6, 7};
164+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
165+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
166+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
167+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
168+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
169+
170+
expr *A_log_Bx = new_left_matmul(A_param, log_Bx, A_csr);
171+
free_csr_matrix(A_csr);
144172

145173
A_log_Bx->forward(A_log_Bx, x_vals);
146174
A_log_Bx->jacobian_init(A_log_Bx);

tests/jacobian_tests/test_transpose.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
#define TEST_TRANSPOSE_H
44

55
#include "affine.h"
6+
#include "bivariate.h"
67
#include "minunit.h"
78
#include "subexpr.h"
89
#include "test_helpers.h"
910
#include <math.h>
1011
#include <stdio.h>
12+
#include <string.h>
1113

1214
const char *test_jacobian_transpose()
1315
{
@@ -17,7 +19,17 @@ const char *test_jacobian_transpose()
1719

1820
// X = [1 2; 3 4] (columnwise: x = [1 3 2 4])
1921
expr *X = new_variable(2, 2, 0, 4);
20-
expr *AX = new_left_matmul(A_param, X);
22+
/* Build dense 2x2 CSR */
23+
CSR_Matrix *A_csr = new_csr_matrix(2, 2, 4);
24+
int Ap[3] = {0, 2, 4};
25+
int Ai[4] = {0, 1, 0, 1};
26+
double Ax[4] = {1.0, 2.0, 3.0, 4.0};
27+
memcpy(A_csr->p, Ap, 3 * sizeof(int));
28+
memcpy(A_csr->i, Ai, 4 * sizeof(int));
29+
memcpy(A_csr->x, Ax, 4 * sizeof(double));
30+
31+
expr *AX = new_left_matmul(A_param, X, A_csr);
32+
free_csr_matrix(A_csr);
2133
expr *transpose_AX = new_transpose(AX);
2234
double u[4] = {1, 3, 2, 4};
2335
transpose_AX->forward(transpose_AX, u);

tests/problem/test_param_prob.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <math.h>
55
#include <stdio.h>
6+
#include <string.h>
67

78
#include "affine.h"
89
#include "bivariate.h"
@@ -185,7 +186,17 @@ const char *test_param_left_matmul_problem(void)
185186
/* Constraint: A @ x */
186187
expr *x_con = new_variable(2, 1, 0, n_vars);
187188
expr *A_param = new_parameter(2, 2, 0, n_vars, NULL);
188-
expr *constraint = new_left_matmul(A_param, x_con);
189+
190+
CSR_Matrix *A_csr = new_csr_matrix(2, 2, 4);
191+
int Ap[3] = {0, 2, 4};
192+
int Ai[4] = {0, 1, 0, 1};
193+
double Ax[4] = {0.0, 0.0, 0.0, 0.0};
194+
memcpy(A_csr->p, Ap, 3 * sizeof(int));
195+
memcpy(A_csr->i, Ai, 4 * sizeof(int));
196+
memcpy(A_csr->x, Ax, 4 * sizeof(double));
197+
198+
expr *constraint = new_left_matmul(A_param, x_con, A_csr);
199+
free_csr_matrix(A_csr);
189200

190201
expr *constraints[1] = {constraint};
191202

tests/profiling/profile_left_matmul.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,26 @@ const char *profile_left_matmul()
2828
expr *A_param = new_parameter(n, n, PARAM_FIXED, n, A_vals);
2929
free(A_vals);
3030

31-
expr *AX = new_left_matmul(A_param, X);
31+
/* Build dense n x n CSR (all ones) */
32+
int nnz = n * n;
33+
CSR_Matrix *A_csr = new_csr_matrix(n, n, nnz);
34+
{
35+
int idx = 0;
36+
for (int row = 0; row < n; row++)
37+
{
38+
A_csr->p[row] = idx;
39+
for (int col = 0; col < n; col++)
40+
{
41+
A_csr->i[idx] = col;
42+
A_csr->x[idx] = 1.0;
43+
idx++;
44+
}
45+
}
46+
A_csr->p[n] = idx;
47+
}
48+
49+
expr *AX = new_left_matmul(A_param, X, A_csr);
50+
free_csr_matrix(A_csr);
3251

3352
double *x_vals = (double *) malloc(n * n * sizeof(double));
3453
for (int i = 0; i < n * n; i++)

tests/wsum_hess/test_left_matmul.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@ const char *test_wsum_hess_left_matmul()
5858
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 3, A_vals);
5959

6060
expr *log_x = new_log(x);
61-
expr *A_log_x = new_left_matmul(A_param, log_x);
61+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
62+
int A_p[5] = {0, 2, 4, 6, 7};
63+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
64+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
65+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
66+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
67+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
68+
69+
expr *A_log_x = new_left_matmul(A_param, log_x, A_csr);
70+
free_csr_matrix(A_csr);
6271

6372
A_log_x->forward(A_log_x, x_vals);
6473
A_log_x->jacobian_init(A_log_x);
@@ -151,7 +160,16 @@ const char *test_wsum_hess_left_matmul_composite()
151160

152161
expr *Bx = new_linear(x, B, NULL);
153162
expr *log_Bx = new_log(Bx);
154-
expr *A_log_Bx = new_left_matmul(A_param, log_Bx);
163+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
164+
int A_p[5] = {0, 2, 4, 6, 7};
165+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
166+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
167+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
168+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
169+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
170+
171+
expr *A_log_Bx = new_left_matmul(A_param, log_Bx, A_csr);
172+
free_csr_matrix(A_csr);
155173

156174
A_log_Bx->forward(A_log_Bx, x_vals);
157175
A_log_Bx->jacobian_init(A_log_Bx);
@@ -217,7 +235,16 @@ const char *test_wsum_hess_left_matmul_matrix()
217235
expr *A_param = new_parameter(4, 3, PARAM_FIXED, 6, A_vals);
218236

219237
expr *log_x = new_log(x);
220-
expr *A_log_x = new_left_matmul(A_param, log_x);
238+
CSR_Matrix *A_csr = new_csr_matrix(4, 3, 7);
239+
int A_p[5] = {0, 2, 4, 6, 7};
240+
int A_i[7] = {0, 2, 0, 2, 0, 2, 0};
241+
double A_x[7] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
242+
memcpy(A_csr->p, A_p, 5 * sizeof(int));
243+
memcpy(A_csr->i, A_i, 7 * sizeof(int));
244+
memcpy(A_csr->x, A_x, 7 * sizeof(double));
245+
246+
expr *A_log_x = new_left_matmul(A_param, log_x, A_csr);
247+
free_csr_matrix(A_csr);
221248

222249
A_log_x->forward(A_log_x, x_vals);
223250
A_log_x->jacobian_init(A_log_x);

0 commit comments

Comments
 (0)