Skip to content

Commit e32030a

Browse files
Transurgeonclaudedance858
authored
[WIP] Unify constants and parameters into single parameter_expr type (#53)
* Unify constants and parameters into single parameter_expr type Constants and updatable parameters are now represented by a single parameter_expr node. Constants use param_id == PARAM_FIXED (-1), while updatable parameters use param_id >= 0. Bivariate ops (left_matmul, right_matmul, scalar_mult, vector_mult) accept an optional param_source node and refresh their internal data on forward/jacobian/hessian evaluation. Adds problem_register_params and problem_update_params for problem-level parameter management. Also adds update_values to the Matrix interface for parameter refresh. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Address review comments on parameter-support-v2 - Remove redundant refresh_param_values calls from eval_jacobian and eval_wsum_hess in left_matmul (forward always runs first) - Use memcpy in problem_register_params for pointer array copy - Add PARAM_FIXED guard in problem_update_params to skip fixed constants - Remove unused right_matmul_expr struct from subexpr.h - Add test_param_fixed_skip_in_update covering mixed fixed/updatable params - Add CLAUDE.md for Claude Code guidance Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Run clang-format and remove CLAUDE.md Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Set has_been_refreshed to false in parameter constructor Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Remove CLAUDE.md from PR Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * remove unused constructors * run formatter and push it up * Update includes in test_param_prob.h to use atoms/ prefix Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix column-major parameter ordering in parameterized matmul (#72) * Fix column-major parameter ordering in parameterized matmul CVXPY sends parameter values in Fortran (column-major) order, but the matmul refresh functions assumed row-major/CSR order via raw memcpy. This produced incorrect matrix values for non-symmetric matrices. For sparse matrices, iterate the CSR pattern and index into the column-major source array. For dense matrices, exploit the fact that column-major A is row-major A^T to memcpy directly into AT, then transpose to get A. Also fixes a latent bug where sparse update_values would blindly copy the first nnz values from the full d1*d2 parameter array, which is wrong for matrices with structural zeros. Adds tests for rectangular (3x2) and sparse (3x3 with zeros) cases. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Run clang-format on changed files Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * introduce explicit transpose function for dense matrix * clean up refresh dense right * clean up tests... * one more test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: dance858 <danielcederberg1@gmail.com> * add test that fails with current appraoch * new parameter abstraction that fixess the fialing test * remove dead code * reset csc * add raise error to catch bugs * add initial attempt for fixing parameters and broadcasting (#73) * add initial attempt for fixing parameters and broadcasting * add test for params with broadcast * cleanup test to fit more with other tests style * some progress on supporting backwards compatible constants * add some parameter broadcast tests as well * cleanup left matmul as well * some very minor cleanups * some error checks and numerical diff to tests * we don't always have to run forward of parameter in left matmul * we don't always have to call forward for parameter node in vector mult * comment out forward parameter pass in scalar mult because it is not needed, I think * add test for scalar case to be consistent with vector mult --------- Co-authored-by: dance858 <danielcederberg1@gmail.com> * error message and format --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: dance858 <danielcederberg1@gmail.com>
1 parent deb7cab commit e32030a

42 files changed

Lines changed: 1308 additions & 239 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/atoms/affine.h

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ expr *new_vstack(expr **args, int n_args, int n_vars);
3131
expr *new_promote(expr *child, int d1, int d2);
3232
expr *new_trace(expr *child);
3333

34-
expr *new_constant(int d1, int d2, int n_vars, const double *values);
34+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);
3535
expr *new_variable(int d1, int d2, int var_id, int n_vars);
3636

3737
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
@@ -40,26 +40,27 @@ expr *new_broadcast(expr *child, int target_d1, int target_d2);
4040
expr *new_diag_vec(expr *child);
4141
expr *new_transpose(expr *child);
4242

43-
/* Left matrix multiplication: A @ f(x) where A is a constant sparse
44-
* matrix */
45-
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
43+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
44+
* sparse matrix. param_node is NULL for fixed constants. */
45+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
4646

47-
/* Left matrix multiplication: A @ f(x) where A is a constant dense
48-
* matrix (row-major, m x n). Uses CBLAS for efficient computation. */
49-
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
47+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
48+
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
49+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
50+
const double *data);
5051

51-
/* Right matrix multiplication: f(x) @ A where A is a constant
52-
* matrix */
53-
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
52+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
53+
* matrix. */
54+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
5455

55-
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
56+
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
57+
const double *data);
5658

57-
/* Constant scalar multiplication: a * f(x) where a is a constant
58-
* double */
59-
expr *new_const_scalar_mult(double a, expr *child);
59+
/* Scalar multiplication: a * f(x) where a comes from param_node */
60+
expr *new_scalar_mult(expr *param_node, expr *child);
6061

61-
/* Constant vector elementwise multiplication: a . f(x) where a is
62-
* constant */
63-
expr *new_const_vector_mult(const double *a, expr *child);
62+
/* Vector elementwise multiplication: a . f(x) where a comes from
63+
* param_node */
64+
expr *new_vector_mult(expr *param_node, expr *child);
6465

6566
#endif /* AFFINE_H */

include/expr.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ typedef struct expr
8686
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
8787
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
8888
Expr_Work *work; /* derivative workspace */
89+
/* Set to true on all nodes by problem_update_params() via
90+
expr_set_needs_refresh(). Atoms that cache parameter data
91+
(e.g. left_matmul_dense) check this flag before their forward
92+
pass: if true, they refresh their cached matrices from
93+
param_source->value and clear the flag to false. */
94+
bool needs_parameter_refresh;
8995

9096
// name of node just for debugging - should be removed later
9197
char name[32];
@@ -108,6 +114,9 @@ void wsum_hess_init(expr *node);
108114
* Must be called after jacobian_init. */
109115
void jacobian_csc_init(expr *node);
110116

117+
/* Recursively set needs_parameter_refresh on node and all children */
118+
void expr_set_needs_refresh(expr *node);
119+
111120
/* Reference counting helpers */
112121
void expr_retain(expr *node);
113122

include/problem.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ typedef struct problem
4949
int n_vars;
5050
int total_constraint_size;
5151

52+
/* parameter support */
53+
expr **param_nodes;
54+
int n_param_nodes;
55+
int total_parameter_size;
56+
5257
/* allocated by new_problem */
5358
double *constraint_values;
5459
double *gradient_values;
@@ -79,6 +84,9 @@ void problem_init_jacobian_coo(problem *prob);
7984
void problem_init_hessian_coo_lower_triangular(problem *prob);
8085
void free_problem(problem *prob);
8186

87+
void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes);
88+
void problem_update_params(problem *prob, const double *theta);
89+
8290
double problem_objective_forward(problem *prob, const double *u);
8391
void problem_constraint_forward(problem *prob, const double *u);
8492
void problem_gradient(problem *prob);

include/subexpr.h

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,19 @@
2626
/* Forward declaration */
2727
struct int_double_pair;
2828

29+
/* Parameter ID for fixed constants (not updatable) */
30+
#define PARAM_FIXED -1
31+
2932
/* Type-specific expression structures that "inherit" from expr */
3033

34+
/* Unified constant/parameter node. Constants use param_id == PARAM_FIXED.
35+
* Updatable parameters use param_id >= 0 (offset into global theta). */
36+
typedef struct parameter_expr
37+
{
38+
expr base;
39+
int param_id;
40+
} parameter_expr;
41+
3142
/* Linear operator: y = A * x + b
3243
* The matrix A is stored as node->jacobian (CSR). */
3344
typedef struct linear_op_expr
@@ -118,18 +129,24 @@ typedef struct left_matmul_expr
118129
CSC_Matrix *Jchild_CSC;
119130
CSC_Matrix *J_CSC;
120131
int *csc_to_csr_work;
132+
expr *param_source;
133+
void (*refresh_param_values)(struct left_matmul_expr *);
121134
} left_matmul_expr;
122135

123-
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.
124-
* f(x) has shape p x n, A has shape n x q, output y has shape p x q.
125-
* Uses vec(y) = B * vec(f(x)) where B = A^T kron I_p. */
126-
typedef struct right_matmul_expr
136+
/* Scalar multiplication: y = a * child where a comes from param_source */
137+
typedef struct scalar_mult_expr
138+
{
139+
expr base;
140+
expr *param_source;
141+
} scalar_mult_expr;
142+
143+
/* Vector elementwise multiplication: y = a \circ child where a comes from
144+
* param_source */
145+
typedef struct vector_mult_expr
127146
{
128147
expr base;
129-
CSR_Matrix *B; /* B = A^T kron I_p */
130-
CSR_Matrix *BT; /* B^T for backpropagating Hessian weights */
131-
CSC_Matrix *CSC_work;
132-
} right_matmul_expr;
148+
expr *param_source;
149+
} vector_mult_expr;
133150

134151
/* Bivariate matrix multiplication: Z = f(u) @ g(u) where both children
135152
* may be composite expressions. */
@@ -153,20 +170,6 @@ typedef struct matmul_expr
153170
int *idx_map_Hg;
154171
} matmul_expr;
155172

156-
/* Constant scalar multiplication: y = a * child where a is a constant double */
157-
typedef struct const_scalar_mult_expr
158-
{
159-
expr base;
160-
double a;
161-
} const_scalar_mult_expr;
162-
163-
/* Constant vector elementwise multiplication: y = a \circ child for constant a */
164-
typedef struct const_vector_mult_expr
165-
{
166-
expr base;
167-
double *a; /* length equals node->size */
168-
} const_vector_mult_expr;
169-
170173
/* Index/slicing: y = child[indices] where indices is a list of flat positions */
171174
typedef struct index_expr
172175
{

include/utils/dense_matrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ Matrix *new_dense_matrix(int m, int n, const double *data);
1717
/* Transpose helper */
1818
Matrix *dense_matrix_trans(const Dense_Matrix *self);
1919

20+
void A_transpose(double *AT, const double *A, int m, int n);
21+
2022
#endif /* DENSE_MATRIX_H */

include/utils/matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ typedef struct Matrix
3131
const CSC_Matrix *J, int p);
3232
void (*block_left_mult_values)(const struct Matrix *self, const CSC_Matrix *J,
3333
CSC_Matrix *C);
34+
void (*update_values)(struct Matrix *self, const double *new_values);
3435
void (*free_fn)(struct Matrix *self);
3536
} Matrix;
3637

src/atoms/affine/left_matmul.c

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,37 @@
4949
#include "utils/tracked_alloc.h"
5050
#include "utils/utils.h"
5151

52+
static void refresh_param_values(left_matmul_expr *lnode)
53+
{
54+
if (lnode->param_source == NULL || !lnode->base.needs_parameter_refresh)
55+
{
56+
return;
57+
}
58+
59+
lnode->base.needs_parameter_refresh = false;
60+
lnode->refresh_param_values(lnode);
61+
}
62+
5263
static void forward(expr *node, const double *u)
5364
{
65+
left_matmul_expr *lnode = (left_matmul_expr *) node;
66+
67+
/* call forward on param_source if it exists and needs refresh */
68+
if (lnode->param_source != NULL && lnode->base.needs_parameter_refresh)
69+
{
70+
lnode->param_source->forward(lnode->param_source, NULL);
71+
}
72+
73+
refresh_param_values(lnode);
74+
5475
expr *x = node->left;
5576

5677
/* child's forward pass */
5778
node->left->forward(node->left, u);
5879

5980
/* y = A_kron @ vec(f(x)) */
60-
Matrix *A = ((left_matmul_expr *) node)->A;
61-
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
81+
Matrix *A = lnode->A;
82+
int n_blocks = lnode->n_blocks;
6283
A->block_left_mult_vec(A, x->value, node->value, n_blocks);
6384
}
6485

@@ -75,11 +96,16 @@ static void free_type_data(expr *node)
7596
free_csc_matrix(lnode->Jchild_CSC);
7697
free_csc_matrix(lnode->J_CSC);
7798
free(lnode->csc_to_csr_work);
99+
if (lnode->param_source != NULL)
100+
{
101+
free_expr(lnode->param_source);
102+
}
78103
lnode->A = NULL;
79104
lnode->AT = NULL;
80105
lnode->Jchild_CSC = NULL;
81106
lnode->J_CSC = NULL;
82107
lnode->csc_to_csr_work = NULL;
108+
lnode->param_source = NULL;
83109
}
84110

85111
static void jacobian_init_impl(expr *node)
@@ -99,8 +125,8 @@ static void jacobian_init_impl(expr *node)
99125

100126
static void eval_jacobian(expr *node)
101127
{
102-
expr *x = node->left;
103128
left_matmul_expr *lnode = (left_matmul_expr *) node;
129+
expr *x = node->left;
104130

105131
CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
106132
CSC_Matrix *J_CSC = lnode->J_CSC;
@@ -131,17 +157,33 @@ static void wsum_hess_init_impl(expr *node)
131157

132158
static void eval_wsum_hess(expr *node, const double *w)
133159
{
160+
left_matmul_expr *lnode = (left_matmul_expr *) node;
161+
134162
/* compute A^T w*/
135-
Matrix *AT = ((left_matmul_expr *) node)->AT;
136-
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
163+
Matrix *AT = lnode->AT;
164+
int n_blocks = lnode->n_blocks;
137165
AT->block_left_mult_vec(AT, w, node->work->dwork, n_blocks);
138166

139167
node->left->eval_wsum_hess(node->left, node->work->dwork);
140168
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
141169
node->wsum_hess->nnz * sizeof(double));
142170
}
143171

144-
expr *new_left_matmul(expr *u, const CSR_Matrix *A)
172+
static void refresh_dense_left(left_matmul_expr *lnode)
173+
{
174+
Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A;
175+
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
176+
int m = dm_A->base.m;
177+
int n = dm_A->base.n;
178+
179+
/* The parameter represents the A in left_matmul_dense(A, x) in column-major.
180+
In this diffengine, we store A in row-major order. Hence, param->vals
181+
actually corresponds to the transpose of A, and we transpose AT to get A. */
182+
memcpy(dm_AT->x, lnode->param_source->value, m * n * sizeof(double));
183+
A_transpose(dm_A->x, dm_AT->x, n, m);
184+
}
185+
186+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
145187
{
146188
/* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
147189
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
@@ -188,10 +230,20 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
188230
lnode->AT =
189231
sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->work->iwork);
190232

233+
/* parameter support */
234+
lnode->param_source = param_node;
235+
if (param_node != NULL)
236+
{
237+
fprintf(stderr, "Error in new_left_matmul: parameter for a sparse matrix "
238+
"not supported \n");
239+
exit(1);
240+
}
241+
191242
return node;
192243
}
193244

194-
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
245+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
246+
const double *data)
195247
{
196248
int d1, d2, n_blocks;
197249
if (u->d1 == n)
@@ -227,5 +279,13 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
227279
lnode->A = new_dense_matrix(m, n, data);
228280
lnode->AT = dense_matrix_trans((const Dense_Matrix *) lnode->A);
229281

282+
/* parameter support */
283+
lnode->param_source = param_node;
284+
if (param_node != NULL)
285+
{
286+
expr_retain(param_node);
287+
lnode->refresh_param_values = refresh_dense_left;
288+
}
289+
230290
return node;
231291
}
Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,38 @@
1616
* limitations under the License.
1717
*/
1818
#include "atoms/affine.h"
19+
#include "subexpr.h"
1920
#include "utils/tracked_alloc.h"
21+
#include <stdio.h>
2022
#include <stdlib.h>
2123
#include <string.h>
2224

2325
static void forward(expr *node, const double *u)
2426
{
25-
/* Constants don't depend on u; values are already set */
27+
/* Parameters/constants don't depend on u; values are already set */
2628
(void) node;
2729
(void) u;
2830
}
2931

3032
static void jacobian_init_impl(expr *node)
3133
{
32-
/* Constant jacobian is all zeros: size x n_vars with 0 nonzeros.
33-
* new_csr_matrix uses calloc for row pointers, so they're already 0. */
34+
/* Zero jacobian: size x n_vars with 0 nonzeros. */
3435
node->jacobian = new_csr_matrix(node->size, node->n_vars, 0);
3536
}
3637

3738
static void eval_jacobian(expr *node)
3839
{
39-
/* Constant jacobian never changes - nothing to evaluate */
4040
(void) node;
4141
}
4242

4343
static void wsum_hess_init_impl(expr *node)
4444
{
45-
/* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */
45+
/* Zero Hessian: n_vars x n_vars with 0 nonzeros. */
4646
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
4747
}
4848

4949
static void eval_wsum_hess(expr *node, const double *w)
5050
{
51-
/* Constant Hessian is always zero - nothing to compute */
5251
(void) node;
5352
(void) w;
5453
}
@@ -59,12 +58,22 @@ static bool is_affine(const expr *node)
5958
return true;
6059
}
6160

62-
expr *new_constant(int d1, int d2, int n_vars, const double *values)
61+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
6362
{
64-
expr *node = (expr *) SP_CALLOC(1, sizeof(expr));
63+
parameter_expr *pnode = (parameter_expr *) SP_CALLOC(1, sizeof(parameter_expr));
64+
expr *node = &pnode->base;
6565
init_expr(node, d1, d2, n_vars, forward, jacobian_init_impl, eval_jacobian,
6666
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
67-
memcpy(node->value, values, node->size * sizeof(double));
6867

68+
// TODO we should assert that the values array has the correct size.
69+
pnode->param_id = param_id;
70+
71+
if (values == NULL)
72+
{
73+
fprintf(stderr, "Parameter values should always be set, this is a bug and"
74+
" should be reported\n");
75+
exit(1);
76+
}
77+
memcpy(node->value, values, node->size * sizeof(double));
6978
return node;
7079
}

0 commit comments

Comments
 (0)