Skip to content

Commit 528e1c1

Browse files
Transurgeonclaude
andcommitted
Unify Constant and Parameter into single parameter type
Merge the separate Constant and Parameter leaf nodes into a unified parameter_expr with PARAM_FIXED sentinel (-1) for constants. This eliminates duplicate code paths and consolidates 7 bivariate constructors into 3 unified ones: - new_const_scalar_mult / new_param_scalar_mult -> new_scalar_mult - new_const_vector_mult / new_param_vector_mult -> new_vector_mult - new_left_matmul (CSR) / new_left_param_matmul -> new_left_matmul (param node) Key changes: - Add PARAM_FIXED define and extend new_parameter() to accept initial values - Delete constant.c (absorbed by parameter.c) - Remove direct value storage (double a, double *a) from scalar/vector mult structs; always read from param_source - left_matmul builds sparse CSR for fixed params (preserving sparsity) and dense CSR for updatable params - right_matmul internally creates a fixed parameter node from transposed A - problem_update_params skips PARAM_FIXED nodes - Update all test callers to use new_parameter with PARAM_FIXED Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bf8a55c commit 528e1c1

26 files changed

+193
-357
lines changed

include/affine.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars);
3232
expr *new_promote(expr *child, int d1, int d2);
3333
expr *new_trace(expr *child);
3434

35-
expr *new_constant(int d1, int d2, int n_vars, const double *values);
3635
expr *new_variable(int d1, int d2, int var_id, int n_vars);
37-
expr *new_parameter(int d1, int d2, int param_id, int n_vars);
36+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values);
3837

3938
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
4039
expr *new_reshape(expr *child, int d1, int d2);

include/bivariate.h

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,16 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
3030
/* Matrix multiplication: Z = X @ Y */
3131
expr *new_matmul(expr *x, expr *y);
3232

33-
/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
34-
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
33+
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */
34+
expr *new_left_matmul(expr *param_node, expr *child);
3535

3636
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
3737
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
3838

39-
/* Constant scalar multiplication: a * f(x) where a is a constant double */
40-
expr *new_const_scalar_mult(double a, expr *child);
39+
/* Scalar multiplication: a * f(x) where a comes from a parameter node */
40+
expr *new_scalar_mult(expr *param_node, expr *child);
4141

42-
/* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */
43-
expr *new_const_vector_mult(const double *a, expr *child);
44-
45-
/* Left matrix multiplication: P @ f(x) where P is a parameter */
46-
expr *new_left_param_matmul(expr *param_node, expr *child);
47-
48-
/* Parameter scalar multiplication: p * f(x) where p is a parameter */
49-
expr *new_param_scalar_mult(expr *param_node, expr *child);
50-
51-
/* Parameter vector elementwise multiplication: p ∘ f(x) where p is a parameter */
52-
expr *new_param_vector_mult(expr *param_node, expr *child);
42+
/* Vector elementwise multiplication: a ∘ f(x) where a comes from a parameter node */
43+
expr *new_vector_mult(expr *param_node, expr *child);
5344

5445
#endif /* BIVARIATE_H */

include/subexpr.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@
2525
/* Forward declaration */
2626
struct int_double_pair;
2727

28-
/* Parameter node: like constant but with updatable values via problem_update_params
28+
/* param_id value for fixed (constant) parameters */
29+
#define PARAM_FIXED -1
30+
31+
/* Parameter node: unified leaf for constants and updatable parameters.
32+
* Constants use param_id == PARAM_FIXED and have values set at creation.
33+
* Updatable parameters have param_id >= 0 and are updated via problem_update_params.
2934
*/
3035
typedef struct parameter_expr
3136
{
3237
expr base;
33-
int param_id; /* offset into global theta vector */
38+
int param_id; /* offset into global theta vector, or PARAM_FIXED */
3439
} parameter_expr;
3540

3641
/* Type-specific expression structures that "inherit" from expr */
@@ -136,20 +141,18 @@ typedef struct right_matmul_expr
136141
CSC_Matrix *CSC_work;
137142
} right_matmul_expr;
138143

139-
/* Constant scalar multiplication: y = a * child where a is a constant double */
144+
/* Scalar multiplication: y = a * child where a comes from a parameter node */
140145
typedef struct const_scalar_mult_expr
141146
{
142147
expr base;
143-
double a;
144-
expr *param_source; /* if non-NULL, read a from param_source->value[0] */
148+
expr *param_source; /* always set; read a from param_source->value[0] */
145149
} const_scalar_mult_expr;
146150

147-
/* Constant vector elementwise multiplication: y = a \circ child for constant a */
151+
/* Vector elementwise multiplication: y = a \circ child where a comes from a parameter node */
148152
typedef struct const_vector_mult_expr
149153
{
150154
expr base;
151-
double *a; /* length equals node->size */
152-
expr *param_source; /* if non-NULL, use param_source->value instead of a */
155+
expr *param_source; /* always set; read a from param_source->value */
153156
} const_vector_mult_expr;
154157

155158
/* Index/slicing: y = child[indices] where indices is a list of flat positions */

src/affine/constant.c

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/affine/parameter.c

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
* limitations under the License.
1717
*/
1818

19-
/* Parameter leaf node: behaviorally identical to constant (zero derivatives
20-
w.r.t. variables), but its values are updatable via problem_update_params.
21-
This allows re-solving with different parameter values without rebuilding
22-
the expression tree. */
19+
/* Unified parameter/constant leaf node.
20+
*
21+
* When param_id == PARAM_FIXED, this is a constant whose values are set at
22+
* creation and never change. When param_id >= 0, values are updated via
23+
* problem_update_params.
24+
*
25+
* In both cases the derivative behavior is identical: zero Jacobian and
26+
* Hessian with respect to variables (always affine). */
2327

2428
#include "affine.h"
2529
#include "subexpr.h"
@@ -28,26 +32,26 @@
2832

2933
static void forward(expr *node, const double *u)
3034
{
31-
/* Values are set by problem_update_params, not by forward pass */
35+
/* Values are set at creation (constants) or by problem_update_params */
3236
(void) node;
3337
(void) u;
3438
}
3539

3640
static void jacobian_init(expr *node)
3741
{
38-
/* Parameter jacobian is all zeros: size x n_vars with 0 nonzeros */
42+
/* Parameter/constant jacobian is all zeros: size x n_vars with 0 nonzeros */
3943
node->jacobian = new_csr_matrix(node->size, node->n_vars, 0);
4044
}
4145

4246
static void eval_jacobian(expr *node)
4347
{
44-
/* Parameter jacobian never changes */
48+
/* Jacobian never changes */
4549
(void) node;
4650
}
4751

4852
static void wsum_hess_init(expr *node)
4953
{
50-
/* Parameter Hessian is all zeros */
54+
/* Hessian is all zeros */
5155
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 0);
5256
}
5357

@@ -63,12 +67,19 @@ static bool is_affine(const expr *node)
6367
return true;
6468
}
6569

66-
expr *new_parameter(int d1, int d2, int param_id, int n_vars)
70+
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
6771
{
6872
parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr));
6973
init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian,
7074
is_affine, wsum_hess_init, eval_wsum_hess, NULL);
7175
pnode->param_id = param_id;
72-
/* values will be populated by problem_update_params */
76+
77+
/* If values provided (fixed constant), copy them now */
78+
if (values != NULL)
79+
{
80+
memcpy(pnode->base.value, values, pnode->base.size * sizeof(double));
81+
}
82+
/* Otherwise values will be populated by problem_update_params */
83+
7384
return &pnode->base;
7485
}

src/bivariate/const_scalar_mult.c

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <stdlib.h>
2323
#include <string.h>
2424

25-
/* Constant scalar multiplication: y = a * child where a is a constant double */
25+
/* Scalar multiplication: y = a * child where a comes from a parameter node */
2626

2727
static void forward(expr *node, const double *u)
2828
{
@@ -33,7 +33,7 @@ static void forward(expr *node, const double *u)
3333

3434
/* local forward pass: multiply each element by scalar a */
3535
const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node;
36-
double a = sn->param_source ? sn->param_source->value[0] : sn->a;
36+
double a = sn->param_source->value[0];
3737
for (int i = 0; i < node->size; i++)
3838
{
3939
node->value[i] = a * child->value[i];
@@ -57,7 +57,7 @@ static void eval_jacobian(expr *node)
5757
{
5858
expr *child = node->left;
5959
const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node;
60-
double a = sn->param_source ? sn->param_source->value[0] : sn->a;
60+
double a = sn->param_source->value[0];
6161

6262
/* evaluate child */
6363
child->eval_jacobian(child);
@@ -88,7 +88,7 @@ static void eval_wsum_hess(expr *node, const double *w)
8888
x->eval_wsum_hess(x, w);
8989

9090
const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node;
91-
double a = sn->param_source ? sn->param_source->value[0] : sn->a;
91+
double a = sn->param_source->value[0];
9292
for (int j = 0; j < x->wsum_hess->nnz; j++)
9393
{
9494
node->wsum_hess->x[j] = a * x->wsum_hess->x[j];
@@ -101,23 +101,7 @@ static bool is_affine(const expr *node)
101101
return node->left->is_affine(node->left);
102102
}
103103

104-
expr *new_const_scalar_mult(double a, expr *child)
105-
{
106-
const_scalar_mult_expr *mult_node =
107-
(const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr));
108-
expr *node = &mult_node->base;
109-
110-
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init,
111-
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL);
112-
node->left = child;
113-
mult_node->a = a;
114-
mult_node->param_source = NULL;
115-
expr_retain(child);
116-
117-
return node;
118-
}
119-
120-
static void free_param_type_data(expr *node)
104+
static void free_type_data(expr *node)
121105
{
122106
const_scalar_mult_expr *sn = (const_scalar_mult_expr *) node;
123107
if (sn->param_source)
@@ -126,17 +110,16 @@ static void free_param_type_data(expr *node)
126110
}
127111
}
128112

129-
expr *new_param_scalar_mult(expr *param_node, expr *child)
113+
expr *new_scalar_mult(expr *param_node, expr *child)
130114
{
131115
const_scalar_mult_expr *mult_node =
132116
(const_scalar_mult_expr *) calloc(1, sizeof(const_scalar_mult_expr));
133117
expr *node = &mult_node->base;
134118

135119
init_expr(node, child->d1, child->d2, child->n_vars, forward, jacobian_init,
136120
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess,
137-
free_param_type_data);
121+
free_type_data);
138122
node->left = child;
139-
mult_node->a = param_node->value[0]; /* initial value */
140123
mult_node->param_source = param_node;
141124
expr_retain(child);
142125
expr_retain(param_node);

0 commit comments

Comments
 (0)