Skip to content

Commit 978aa03

Browse files
Transurgeonclaude
andcommitted
Store param values in CSR data order, simplify refresh to memcpy
- left_matmul: replace col-major loop in refresh_param_values with memcpy of nnz doubles (values now arrive in CSR data order) - right_matmul: pass AT->x directly to new_parameter(nnz, 1, ...), remove col-major round-trip allocation - test_param_prob: update theta arrays to CSR data order Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 36864d4 commit 978aa03

3 files changed

Lines changed: 12 additions & 26 deletions

File tree

src/bivariate/left_matmul.c

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,11 @@ static void refresh_param_values(left_matmul_expr *lin_node)
5757
{
5858
if (!lin_node->param_source) return;
5959

60-
const double *src = lin_node->param_source->value;
61-
int m = lin_node->A->m;
62-
CSR_Matrix *A = lin_node->A;
63-
64-
/* Fill A values from column-major source, following existing sparsity pattern */
65-
for (int row = 0; row < m; row++)
66-
for (int k = A->p[row]; k < A->p[row + 1]; k++)
67-
A->x[k] = src[row + A->i[k] * m];
60+
memcpy(lin_node->A->x, lin_node->param_source->value,
61+
lin_node->A->nnz * sizeof(double));
6862

6963
/* Recompute AT values from updated A */
70-
AT_fill_values(A, lin_node->AT, lin_node->base.iwork);
64+
AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork);
7165
}
7266

7367
static void forward(expr *node, const double *u)

src/bivariate/right_matmul.c

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,12 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A)
3333
int *work_transpose = (int *) malloc(A->n * sizeof(int));
3434
CSR_Matrix *AT = transpose(A, work_transpose);
3535

36-
/* Convert AT (CSR) to dense column-major array for parameter node */
37-
int m = AT->m; /* rows of AT = cols of A */
38-
int n = AT->n; /* cols of AT = rows of A */
39-
double *col_major = (double *) calloc(m * n, sizeof(double));
40-
for (int row = 0; row < m; row++)
41-
for (int k = AT->p[row]; k < AT->p[row + 1]; k++)
42-
col_major[row + AT->i[k] * m] = AT->x[k];
43-
36+
/* Parameter stores CSR data order (same as AT->x) */
4437
expr *u_transpose = new_transpose(u);
45-
expr *param_node = new_parameter(m, n, PARAM_FIXED, u->n_vars, col_major);
38+
expr *param_node = new_parameter(AT->nnz, 1, PARAM_FIXED, u->n_vars, AT->x);
4639
expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT);
4740
expr *node = new_transpose(left_matmul_node);
4841

49-
free(col_major);
5042
free_csr_matrix(AT);
5143
free(work_transpose);
5244
return node;

tests/problem/test_param_prob.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ const char *test_param_vector_mult_problem(void)
164164
* Test 3: left_param_matmul in constraint
165165
*
166166
* Problem: minimize sum(x), subject to A @ x, x size 2, A is 2x2
167-
* A is a 2x2 matrix parameter (param_id=0, size=4, column-major)
168-
* A = [[1,2],[3,4]] → column-major theta = [1,3,2,4]
167+
* A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order)
168+
* A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4]
169169
*
170170
* At x=[1,2]:
171171
* constraint_values = [1*1+2*2, 3*1+4*2] = [5, 11]
172172
* jacobian = [[1,2],[3,4]]
173173
*
174-
* After update A = [[5,6],[7,8]] → theta = [5,7,6,8]:
174+
* After update A = [[5,6],[7,8]] → theta = [5,6,7,8]:
175175
* constraint_values = [5*1+6*2, 7*1+8*2] = [17, 23]
176176
* jacobian = [[5,6],[7,8]]
177177
*/
@@ -208,8 +208,8 @@ const char *test_param_left_matmul_problem(void)
208208
problem_register_params(prob, param_nodes, 1);
209209
problem_init_derivatives(prob);
210210

211-
/* Set A = [[1,2],[3,4]], column-major: [1,3,2,4] */
212-
double theta[4] = {1.0, 3.0, 2.0, 4.0};
211+
/* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */
212+
double theta[4] = {1.0, 2.0, 3.0, 4.0};
213213
problem_update_params(prob, theta);
214214

215215
double u[2] = {1.0, 2.0};
@@ -235,8 +235,8 @@ const char *test_param_left_matmul_problem(void)
235235
double expected_x[4] = {1.0, 2.0, 3.0, 4.0};
236236
mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4));
237237

238-
/* Update A = [[5,6],[7,8]], column-major: [5,7,6,8] */
239-
double theta2[4] = {5.0, 7.0, 6.0, 8.0};
238+
/* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */
239+
double theta2[4] = {5.0, 6.0, 7.0, 8.0};
240240
problem_update_params(prob, theta2);
241241

242242
problem_constraint_forward(prob, u);

0 commit comments

Comments
 (0)