Skip to content

Commit fe8cf7b

Browse files
Transurgeonclaude
andcommitted
Unify constants and parameters into single parameter_expr type
Constants and updatable parameters are now represented by a single parameter_expr node. Constants use param_id == PARAM_FIXED (-1), while updatable parameters use param_id >= 0. Bivariate ops (left_matmul, right_matmul, scalar_mult, vector_mult) accept an optional param_source node and refresh their internal data on forward/jacobian/hessian evaluation. Adds problem_register_params and problem_update_params for problem-level parameter management. Also adds update_values to the Matrix interface for parameter refresh. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7e7875d commit fe8cf7b

33 files changed

Lines changed: 723 additions & 160 deletions

include/affine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ expr *new_hstack(expr **args, int n_args, int n_vars);
3232
expr *new_promote(expr *child, int d1, int d2);
3333
expr *new_trace(expr *child);
3434

35-
expr *new_constant(int d1, int d2, int n_vars, const double *values);
35+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);
3636
expr *new_variable(int d1, int d2, int var_id, int n_vars);
3737

3838
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);

include/bivariate.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,26 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
3030
/* Matrix multiplication: Z = X @ Y */
3131
expr *new_matmul(expr *x, expr *y);
3232

33-
/* Left matrix multiplication: A @ f(x) where A is a constant sparse matrix */
34-
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
33+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
34+
* sparse matrix. param_node is NULL for fixed constants. */
35+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
3536

36-
/* Left matrix multiplication: A @ f(x) where A is a constant dense matrix
37-
* (row-major, m x n). Uses CBLAS for efficient computation. */
38-
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
37+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
38+
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
39+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
40+
const double *data);
3941

40-
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
41-
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
42+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
43+
* matrix. */
44+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
4245

43-
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
46+
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
47+
const double *data);
4448

45-
/* Constant scalar multiplication: a * f(x) where a is a constant double */
46-
expr *new_const_scalar_mult(double a, expr *child);
49+
/* Scalar multiplication: a * f(x) where a comes from param_node */
50+
expr *new_scalar_mult(expr *param_node, expr *child);
4751

48-
/* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */
49-
expr *new_const_vector_mult(const double *a, expr *child);
52+
/* Vector elementwise multiplication: a ∘ f(x) where a comes from param_node */
53+
expr *new_vector_mult(expr *param_node, expr *child);
5054

5155
#endif /* BIVARIATE_H */

include/problem.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ typedef struct problem
4646
int n_vars;
4747
int total_constraint_size;
4848

49+
/* parameter support */
50+
expr **param_nodes;
51+
int n_param_nodes;
52+
int total_parameter_size;
53+
4954
/* allocated by new_problem */
5055
double *constraint_values;
5156
double *gradient_values;
@@ -76,6 +81,9 @@ void problem_init_jacobian_coo(problem *prob);
7681
void problem_init_hessian_coo_lower_triangular(problem *prob);
7782
void free_problem(problem *prob);
7883

84+
void problem_register_params(problem *prob, expr **param_nodes, int n_param_nodes);
85+
void problem_update_params(problem *prob, const double *theta);
86+
7987
double problem_objective_forward(problem *prob, const double *u);
8088
void problem_constraint_forward(problem *prob, const double *u);
8189
void problem_gradient(problem *prob);

include/subexpr.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,20 @@
2626
/* Forward declaration */
2727
struct int_double_pair;
2828

29+
/* Parameter ID for fixed constants (not updatable) */
30+
#define PARAM_FIXED -1
31+
2932
/* Type-specific expression structures that "inherit" from expr */
3033

34+
/* Unified constant/parameter node. Constants use param_id == PARAM_FIXED.
35+
* Updatable parameters use param_id >= 0 (offset into global theta). */
36+
typedef struct parameter_expr
37+
{
38+
expr base;
39+
int param_id;
40+
bool has_been_refreshed;
41+
} parameter_expr;
42+
3143
/* Linear operator: y = A * x + b */
3244
typedef struct linear_op_expr
3345
{
@@ -114,6 +126,8 @@ typedef struct left_matmul_expr
114126
CSC_Matrix *Jchild_CSC;
115127
CSC_Matrix *J_CSC;
116128
int *csc_to_csr_work;
129+
expr *param_source;
130+
void (*refresh_param_values)(struct left_matmul_expr *);
117131
} left_matmul_expr;
118132

119133
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.
@@ -127,19 +141,20 @@ typedef struct right_matmul_expr
127141
CSC_Matrix *CSC_work;
128142
} right_matmul_expr;
129143

130-
/* Constant scalar multiplication: y = a * child where a is a constant double */
131-
typedef struct const_scalar_mult_expr
144+
/* Scalar multiplication: y = a * child where a comes from param_source */
145+
typedef struct scalar_mult_expr
132146
{
133147
expr base;
134-
double a;
135-
} const_scalar_mult_expr;
148+
expr *param_source;
149+
} scalar_mult_expr;
136150

137-
/* Constant vector elementwise multiplication: y = a \circ child for constant a */
138-
typedef struct const_vector_mult_expr
151+
/* Vector elementwise multiplication: y = a \circ child where a comes from
152+
* param_source */
153+
typedef struct vector_mult_expr
139154
{
140155
expr base;
141-
double *a; /* length equals node->size */
142-
} const_vector_mult_expr;
156+
expr *param_source;
157+
} vector_mult_expr;
143158

144159
/* Index/slicing: y = child[indices] where indices is a list of flat positions */
145160
typedef struct index_expr

include/utils/matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ typedef struct Matrix
3131
const CSC_Matrix *J, int p);
3232
void (*block_left_mult_values)(const struct Matrix *self, const CSC_Matrix *J,
3333
CSC_Matrix *C);
34+
void (*update_values)(struct Matrix *self, const double *new_values);
3435
void (*free_fn)(struct Matrix *self);
3536
} Matrix;
3637

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,36 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "subexpr.h"
1920
#include <stdlib.h>
2021
#include <string.h>
2122

2223
static void forward(expr *node, const double *u)
2324
{
24-
/* Constants don't depend on u; values are already set */
25+
/* Parameters/constants don't depend on u; values are already set */
2526
(void) node;
2627
(void) u;
2728
}
2829

2930
static void jacobian_init(expr *node)
3031
{
31-
/* Constant jacobian is all zeros: size x n_vars with 0 nonzeros.
32-
* new_csr_matrix uses calloc for row pointers, so they're already 0. */
32+
/* Zero jacobian: size x n_vars with 0 nonzeros. */
3333
node->jacobian = new_csr_matrix(node->size, node->n_vars, 0);
3434
}
3535

3636
static void eval_jacobian(expr *node)
3737
{
38-
/* Constant jacobian never changes - nothing to evaluate */
3938
(void) node;
4039
}
4140

4241
static void wsum_hess_init(expr *node)
4342
{
44-
/* Constant Hessian is all zeros: n_vars x n_vars with 0 nonzeros. */
43+
/* Zero Hessian: n_vars x n_vars with 0 nonzeros. */
4544
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
4645
}
4746

4847
static void eval_wsum_hess(expr *node, const double *w)
4948
{
50-
/* Constant Hessian is always zero - nothing to compute */
5149
(void) node;
5250
(void) w;
5351
}
@@ -58,12 +56,20 @@ static bool is_affine(const expr *node)
5856
return true;
5957
}
6058

61-
expr *new_constant(int d1, int d2, int n_vars, const double *values)
59+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
6260
{
63-
expr *node = (expr *) calloc(1, sizeof(expr));
61+
parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr));
62+
expr *node = &pnode->base;
6463
init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine,
6564
wsum_hess_init, eval_wsum_hess, NULL);
66-
memcpy(node->value, values, node->size * sizeof(double));
65+
66+
pnode->param_id = param_id;
67+
pnode->has_been_refreshed = true;
68+
69+
if (values != NULL)
70+
{
71+
memcpy(node->value, values, node->size * sizeof(double));
72+
}
6773

6874
return node;
6975
}

src/bivariate/left_matmul.c

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,34 @@
4848

4949
#include "utils/utils.h"
5050

51+
static void refresh_param_values(left_matmul_expr *lnode)
52+
{
53+
if (lnode->param_source == NULL)
54+
{
55+
return;
56+
}
57+
parameter_expr *param = (parameter_expr *) lnode->param_source;
58+
if (param->has_been_refreshed)
59+
{
60+
return;
61+
}
62+
param->has_been_refreshed = true;
63+
lnode->refresh_param_values(lnode);
64+
}
65+
5166
static void forward(expr *node, const double *u)
5267
{
68+
left_matmul_expr *lnode = (left_matmul_expr *) node;
69+
refresh_param_values(lnode);
70+
5371
expr *x = node->left;
5472

5573
/* child's forward pass */
5674
node->left->forward(node->left, u);
5775

5876
/* y = A_kron @ vec(f(x)) */
59-
Matrix *A = ((left_matmul_expr *) node)->A;
60-
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
77+
Matrix *A = lnode->A;
78+
int n_blocks = lnode->n_blocks;
6179
A->block_left_mult_vec(A, x->value, node->value, n_blocks);
6280
}
6381

@@ -74,11 +92,16 @@ static void free_type_data(expr *node)
7492
free_csc_matrix(lnode->Jchild_CSC);
7593
free_csc_matrix(lnode->J_CSC);
7694
free(lnode->csc_to_csr_work);
95+
if (lnode->param_source != NULL)
96+
{
97+
free_expr(lnode->param_source);
98+
}
7799
lnode->A = NULL;
78100
lnode->AT = NULL;
79101
lnode->Jchild_CSC = NULL;
80102
lnode->J_CSC = NULL;
81103
lnode->csc_to_csr_work = NULL;
104+
lnode->param_source = NULL;
82105
}
83106

84107
static void jacobian_init(expr *node)
@@ -98,8 +121,9 @@ static void jacobian_init(expr *node)
98121

99122
static void eval_jacobian(expr *node)
100123
{
101-
expr *x = node->left;
102124
left_matmul_expr *lnode = (left_matmul_expr *) node;
125+
refresh_param_values(lnode);
126+
expr *x = node->left;
103127

104128
CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
105129
CSC_Matrix *J_CSC = lnode->J_CSC;
@@ -132,17 +156,46 @@ static void wsum_hess_init(expr *node)
132156

133157
static void eval_wsum_hess(expr *node, const double *w)
134158
{
159+
left_matmul_expr *lnode = (left_matmul_expr *) node;
160+
refresh_param_values(lnode);
161+
135162
/* compute A^T w*/
136-
Matrix *AT = ((left_matmul_expr *) node)->AT;
137-
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
163+
Matrix *AT = lnode->AT;
164+
int n_blocks = lnode->n_blocks;
138165
AT->block_left_mult_vec(AT, w, node->dwork, n_blocks);
139166

140167
node->left->eval_wsum_hess(node->left, node->dwork);
141168
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
142169
node->wsum_hess->nnz * sizeof(double));
143170
}
144171

145-
expr *new_left_matmul(expr *u, const CSR_Matrix *A)
172+
static void refresh_sparse_left(left_matmul_expr *lnode)
173+
{
174+
Sparse_Matrix *sm_A = (Sparse_Matrix *) lnode->A;
175+
Sparse_Matrix *sm_AT = (Sparse_Matrix *) lnode->AT;
176+
lnode->A->update_values(lnode->A, lnode->param_source->value);
177+
/* Recompute AT values from A */
178+
AT_fill_values(sm_A->csr, sm_AT->csr, lnode->base.iwork);
179+
}
180+
181+
static void refresh_dense_left(left_matmul_expr *lnode)
182+
{
183+
Dense_Matrix *dm_A = (Dense_Matrix *) lnode->A;
184+
int m = dm_A->base.m;
185+
int n = dm_A->base.n;
186+
lnode->A->update_values(lnode->A, lnode->param_source->value);
187+
/* Recompute AT data (transpose of row-major A) */
188+
Dense_Matrix *dm_AT = (Dense_Matrix *) lnode->AT;
189+
for (int i = 0; i < m; i++)
190+
{
191+
for (int j = 0; j < n; j++)
192+
{
193+
dm_AT->x[j * m + i] = dm_A->x[i * n + j];
194+
}
195+
}
196+
}
197+
198+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
146199
{
147200
/* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
148201
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
@@ -188,10 +241,19 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
188241
lnode->A = new_sparse_matrix(A);
189242
lnode->AT = sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->iwork);
190243

244+
/* parameter support */
245+
lnode->param_source = param_node;
246+
if (param_node != NULL)
247+
{
248+
expr_retain(param_node);
249+
lnode->refresh_param_values = refresh_sparse_left;
250+
}
251+
191252
return node;
192253
}
193254

194-
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
255+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
256+
const double *data)
195257
{
196258
int d1, d2, n_blocks;
197259
if (u->d1 == n)
@@ -227,5 +289,13 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
227289
lnode->A = new_dense_matrix(m, n, data);
228290
lnode->AT = dense_matrix_trans((const Dense_Matrix *) lnode->A);
229291

292+
/* parameter support */
293+
lnode->param_source = param_node;
294+
if (param_node != NULL)
295+
{
296+
expr_retain(param_node);
297+
lnode->refresh_param_values = refresh_dense_left;
298+
}
299+
230300
return node;
231301
}

0 commit comments

Comments
 (0)