Skip to content

Commit 4202921

Browse files
Transurgeonclaudedance858
authored
Fix column-major parameter ordering in parameterized matmul (#72)
* Fix column-major parameter ordering in parameterized matmul CVXPY sends parameter values in Fortran (column-major) order, but the matmul refresh functions assumed row-major/CSR order via raw memcpy. This produced incorrect matrix values for non-symmetric matrices. For sparse matrices, iterate the CSR pattern and index into the column-major source array. For dense matrices, exploit the fact that column-major A is row-major A^T to memcpy directly into AT, then transpose to get A. Also fixes a latent bug where sparse update_values would blindly copy the first nnz values from the full d1*d2 parameter array, which is wrong for matrices with structural zeros. Adds tests for rectangular (3x2) and sparse (3x3 with zeros) cases. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Run clang-format on changed files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * introduce explicit transpose function for dense matrix * clean up refresh dense right * clean up tests... * one more test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: dance858 <danielcederberg1@gmail.com>
1 parent d673b62 commit 4202921

9 files changed

Lines changed: 326 additions & 342 deletions

File tree

include/subexpr.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ typedef struct parameter_expr
3737
{
3838
expr base;
3939
int param_id;
40-
bool has_been_refreshed;
40+
/* Set to true by problem_update_params(), cleared by
41+
refresh_param_values() after propagating new values. */
42+
bool needs_refresh;
4143
} parameter_expr;
4244

4345
/* Linear operator: y = A * x + b

include/utils/dense_matrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ Matrix *new_dense_matrix(int m, int n, const double *data);
1717
/* Transpose helper */
1818
Matrix *dense_matrix_trans(const Dense_Matrix *self);
1919

20+
void A_transpose(double *AT, const double *A, int m, int n);
21+
2022
#endif /* DENSE_MATRIX_H */

src/atoms/affine/left_matmul.c

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ static void refresh_param_values(left_matmul_expr *lnode)
5656
return;
5757
}
5858
parameter_expr *param = (parameter_expr *) lnode->param_source;
59-
if (param->has_been_refreshed)
59+
if (!param->needs_refresh)
6060
{
6161
return;
6262
}
63-
param->has_been_refreshed = true;
63+
param->needs_refresh = false;
6464
lnode->refresh_param_values(lnode);
6565
}
6666

@@ -168,28 +168,25 @@ static void eval_wsum_hess(expr *node, const double *w)
168168

169169
static void refresh_sparse_left(left_matmul_expr *lnode)
170170
{
171-
Sparse_Matrix *sm_A = (Sparse_Matrix *) lnode->A;
172-
Sparse_Matrix *sm_AT = (Sparse_Matrix *) lnode->AT;
173-
lnode->A->update_values(lnode->A, lnode->param_source->value);
174-
/* Recompute AT values from A */
175-
AT_fill_values(sm_A->csr, sm_AT->csr, lnode->base.work->iwork);
171+
(void) lnode;
172+
fprintf(stderr,
173+
"Error in refresh_sparse_left: parameter for a sparse matrix not "
174+
"supported \n");
175+
exit(1);
176176
}
177177

178178
static void refresh_dense_left(left_matmul_expr *lnode)
179179
{
180180
Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A;
181+
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
181182
int m = dm_A->base.m;
182183
int n = dm_A->base.n;
183-
lnode->A->update_values(lnode->A, lnode->param_source->value);
184-
/* Recompute AT data (transpose of row-major A) */
185-
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
186-
for (int i = 0; i < m; i++)
187-
{
188-
for (int j = 0; j < n; j++)
189-
{
190-
dm_AT->x[j * m + i] = dm_A->x[i * n + j];
191-
}
192-
}
184+
185+
/* The parameter represents the A in left_matmul_dense(A, x) in column-major.
186+
In this diffengine, we store A in row-major order. Hence, param->vals
187+
actually corresponds to the transpose of A, and we transpose AT to get A. */
188+
memcpy(dm_AT->x, lnode->param_source->value, m * n * sizeof(double));
189+
A_transpose(dm_A->x, dm_AT->x, n, m);
193190
}
194191

195192
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
@@ -243,6 +240,11 @@ expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
243240
lnode->param_source = param_node;
244241
if (param_node != NULL)
245242
{
243+
244+
fprintf(stderr, "Error in new_left_matmul: parameter for a sparse matrix "
245+
"not supported \n");
246+
exit(1);
247+
246248
expr_retain(param_node);
247249
lnode->refresh_param_values = refresh_sparse_left;
248250
}

src/atoms/affine/parameter.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
6565
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
6666

6767
pnode->param_id = param_id;
68-
pnode->has_been_refreshed = false;
68+
pnode->needs_refresh = false;
6969

7070
if (values != NULL)
7171
{

src/atoms/affine/right_matmul.c

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "utils/CSR_Matrix.h"
2121
#include "utils/dense_matrix.h"
2222
#include "utils/tracked_alloc.h"
23+
#include <stdio.h>
2324
#include <stdlib.h>
2425

2526
/* This file implements the atom 'right_matmul' corresponding to the operation y =
@@ -38,30 +39,27 @@
3839
So: update lnode->AT from param values, then recompute lnode->A. */
3940
static void refresh_sparse_right(left_matmul_expr *lnode)
4041
{
41-
Sparse_Matrix *sm_AT_inner = (Sparse_Matrix *) lnode->A;
42-
Sparse_Matrix *sm_A_inner = (Sparse_Matrix *) lnode->AT;
43-
/* lnode->AT holds the original A; update its values from param */
44-
lnode->AT->update_values(lnode->AT, lnode->param_source->value);
45-
/* Recompute A^T (lnode->A) from A (lnode->AT) */
46-
AT_fill_values(sm_A_inner->csr, sm_AT_inner->csr, lnode->base.work->iwork);
42+
(void) lnode;
43+
fprintf(stderr,
44+
"Error in refresh_sparse_right: parameter for a sparse matrix not "
45+
"supported \n");
46+
exit(1);
4747
}
4848

4949
static void refresh_dense_right(left_matmul_expr *lnode)
5050
{
51-
Dense_Matrix *dm_AT_inner = (Dense_Matrix *) lnode->A;
52-
Dense_Matrix *dm_A_inner = (Dense_Matrix *) lnode->AT;
53-
int m_orig = dm_A_inner->base.m; /* original A is m x n */
54-
int n_orig = dm_A_inner->base.n;
55-
/* Update original A (inner's AT) from param values */
56-
lnode->AT->update_values(lnode->AT, lnode->param_source->value);
57-
/* Recompute A^T (inner's A) from A */
58-
for (int i = 0; i < m_orig; i++)
59-
{
60-
for (int j = 0; j < n_orig; j++)
61-
{
62-
dm_AT_inner->x[j * m_orig + i] = dm_A_inner->x[i * n_orig + j];
63-
}
64-
}
51+
/* This left_matmul_expr node corresponds to left multiplication with B = AT,
52+
where A is the original (m x n) matrix given to the right_matmul function.
53+
Furthermore, lnode->param_source->value corresponds to the column-major
54+
version of A, which is BT (an m x n matrix) */
55+
56+
Dense_Matrix *B = (Dense_Matrix *) lnode->AT;
57+
Dense_Matrix *BT = (Dense_Matrix *) lnode->A;
58+
int m = B->base.n;
59+
int n = B->base.m;
60+
61+
memcpy(BT->x, lnode->param_source->value, m * n * sizeof(double));
62+
A_transpose(B->x, BT->x, m, n);
6563
}
6664

6765
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
@@ -78,6 +76,11 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
7876
left_matmul */
7977
if (param_node != NULL)
8078
{
79+
80+
fprintf(stderr, "Error in new_right_matmul: parameter for a sparse matrix "
81+
"not supported \n");
82+
exit(1);
83+
8184
left_matmul_expr *lnode = (left_matmul_expr *) left_matmul;
8285
lnode->param_source = param_node;
8386
expr_retain(param_node);
@@ -94,16 +97,9 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
9497
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
9598
const double *data)
9699
{
97-
/* We express: u @ A = (A^T @ u^T)^T
98-
A is m x n, so A^T is n x m. */
100+
/* We express: u @ A = (A^T @ u^T)^T. A is m x n, so A^T is n x m. */
99101
double *AT = (double *) SP_MALLOC(n * m * sizeof(double));
100-
for (int i = 0; i < m; i++)
101-
{
102-
for (int j = 0; j < n; j++)
103-
{
104-
AT[j * m + i] = data[i * n + j];
105-
}
106-
}
102+
A_transpose(AT, data, m, n);
107103

108104
expr *u_transpose = new_transpose(u);
109105
expr *left_matmul_node = new_left_matmul_dense(NULL, u_transpose, n, m, AT);

src/problem.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ void problem_update_params(problem *prob, const double *theta)
382382
if (param->param_id == PARAM_FIXED) continue;
383383
int offset = param->param_id;
384384
memcpy(pnode->value, theta + offset, pnode->size * sizeof(double));
385-
param->has_been_refreshed = false;
385+
param->needs_refresh = true;
386386
}
387387

388388
/* Force re-evaluation of affine Jacobians on next call */

src/utils/dense_matrix.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,20 @@ Matrix *dense_matrix_trans(const Dense_Matrix *A)
7878
int n = A->base.n;
7979
double *AT_x = (double *) SP_MALLOC(m * n * sizeof(double));
8080

81+
A_transpose(AT_x, A->x, m, n);
82+
83+
Matrix *result = new_dense_matrix(n, m, AT_x);
84+
free(AT_x);
85+
return result;
86+
}
87+
88+
void A_transpose(double *AT, const double *A, int m, int n)
89+
{
8190
for (int i = 0; i < m; i++)
8291
{
8392
for (int j = 0; j < n; j++)
8493
{
85-
AT_x[j * m + i] = A->x[i * n + j];
94+
AT[j * m + i] = A[i * n + j];
8695
}
8796
}
88-
89-
Matrix *result = new_dense_matrix(n, m, AT_x);
90-
free(AT_x);
91-
return result;
9297
}

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ int main(void)
354354
mu_run_test(test_param_vector_mult_problem, tests_run);
355355
mu_run_test(test_param_left_matmul_problem, tests_run);
356356
mu_run_test(test_param_right_matmul_problem, tests_run);
357+
mu_run_test(test_param_left_matmul_rectangular, tests_run);
358+
mu_run_test(test_param_right_matmul_rectangular, tests_run);
357359
mu_run_test(test_param_fixed_skip_in_update, tests_run);
358360
#endif /* PROFILE_ONLY */
359361

0 commit comments

Comments
 (0)