Skip to content

Commit b985028

Browse files
committed
add parameter support for right matmul
1 parent 0edf928 commit b985028

10 files changed

Lines changed: 151 additions & 25 deletions

File tree

include/bivariate.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ expr *new_matmul(expr *x, expr *y);
3434
Only the forward pass possibly updates the parameter. */
3535
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A);
3636

37-
/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */
38-
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
37+
/* Right matrix multiplication: f(x) @ A where A comes from a parameter node. */
38+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
3939

4040
/* Scalar multiplication: a * f(x) where a comes from a parameter node */
4141
expr *new_scalar_mult(expr *param_node, expr *child);

include/subexpr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ typedef struct left_matmul_expr
129129
int *csc_to_csr_workspace;
130130
int *AT_iwork; /* work for computing AT values from A */
131131
expr *param_source; /* parameter node; A/AT values are refreshed from this */
132+
void (*refresh_param_values)(struct left_matmul_expr *lin_node);
132133
} left_matmul_expr;
133134

134135
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.

src/bivariate/left_matmul.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
*/
4949

5050
#include "utils/utils.h"
51+
#include <assert.h>
5152
#include <string.h>
5253

5354
/* Refresh A and AT values from param_source.
@@ -60,6 +61,8 @@ static void refresh_param_values(left_matmul_expr *lin_node)
6061
if (!param || param->has_been_refreshed) return;
6162
param->has_been_refreshed = true;
6263

64+
assert(param->param_id != PARAM_FIXED);
65+
6366
/* update values of A */
6467
memcpy(lin_node->A->x, lin_node->param_source->value,
6568
lin_node->A->nnz * sizeof(double));
@@ -74,7 +77,7 @@ static void forward(expr *node, const double *u)
7477
left_matmul_expr *lin_node = (left_matmul_expr *) node;
7578

7679
/* possibly refresh A and AT */
77-
refresh_param_values(lin_node);
80+
if (lin_node->refresh_param_values) lin_node->refresh_param_values(lin_node);
7881

7982
/* child's forward pass */
8083
node->left->forward(node->left, u);
@@ -205,6 +208,7 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
205208
lin_node->A = new_csr(A);
206209
lin_node->AT = transpose(lin_node->A, lin_node->AT_iwork);
207210
lin_node->param_source = param_node;
211+
lin_node->refresh_param_values = refresh_param_values;
208212

209213
if (param_node) expr_retain(param_node);
210214

src/bivariate/right_matmul.c

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,49 @@
2020
#include "subexpr.h"
2121
#include "utils/CSR_Matrix.h"
2222
#include "utils/linalg_sparse_matmuls.h"
23+
#include <assert.h>
2324
#include <stdlib.h>
25+
#include <string.h>
26+
27+
/* Refresh AT and A values from param_source for right matmul.
28+
param_source stores values in CSR order for the original A. */
29+
static void refresh_param_values(left_matmul_expr *lin_node)
30+
{
31+
parameter_expr *param = (parameter_expr *) lin_node->param_source;
32+
33+
if (!param || param->has_been_refreshed) return;
34+
param->has_been_refreshed = true;
35+
36+
assert(param->param_id != PARAM_FIXED);
37+
38+
/* update values of original A (stored in lin_node->AT) */
39+
memcpy(lin_node->AT->x, lin_node->param_source->value,
40+
lin_node->AT->nnz * sizeof(double));
41+
42+
/* update values of A^T (stored in lin_node->A) */
43+
AT_fill_values(lin_node->AT, lin_node->A, lin_node->AT_iwork);
44+
}
2445

2546
/* This file implements the atom 'right_matmul' corresponding to the operation y =
26-
f(x) @ A, where A is a given matrix and f(x) is an arbitrary expression.
47+
f(x) @ A, where A is a given matrix and f(x) is an arbitrary expression.
2748
We implement this by expressing right matmul in terms of left matmul and
2849
transpose: f(x) @ A = (A^T @ f(x)^T)^T. */
29-
expr *new_right_matmul(expr *u, const CSR_Matrix *A)
50+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A)
3051
{
3152
/* We can express right matmul using left matmul and transpose:
3253
u @ A = (A^T @ u^T)^T. */
3354
int *work_transpose = (int *) malloc(A->n * sizeof(int));
3455
CSR_Matrix *AT = transpose(A, work_transpose);
35-
36-
/* Parameter stores CSR data order (same as AT->x) */
3756
expr *u_transpose = new_transpose(u);
38-
expr *param_node = new_parameter(AT->nnz, 1, PARAM_FIXED, u->n_vars, AT->x);
3957
expr *left_matmul_node = new_left_matmul(param_node, u_transpose, AT);
4058
expr *node = new_transpose(left_matmul_node);
4159

60+
/* functionality for parameter */
61+
left_matmul_expr *left_matmul_data = (left_matmul_expr *) left_matmul_node;
62+
free(left_matmul_data->AT_iwork);
63+
left_matmul_data->AT_iwork = work_transpose;
64+
left_matmul_data->refresh_param_values = refresh_param_values;
65+
4266
free_csr_matrix(AT);
43-
free(work_transpose);
4467
return node;
4568
}

src/bivariate/scalar_mult.c

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ static void forward(expr *node, const double *u)
3232
child->forward(child, u);
3333

3434
/* local forward pass: multiply each element by scalar a */
35-
scalar_mult_expr *sn = (scalar_mult_expr *) node;
36-
double a = sn->param_source->value[0];
35+
double a = ((scalar_mult_expr *) node)->param_source->value[0];
3736
for (int i = 0; i < node->size; i++)
3837
{
3938
node->value[i] = a * child->value[i];
@@ -56,8 +55,7 @@ static void jacobian_init(expr *node)
5655
static void eval_jacobian(expr *node)
5756
{
5857
expr *child = node->left;
59-
scalar_mult_expr *sn = (scalar_mult_expr *) node;
60-
double a = sn->param_source->value[0];
58+
double a = ((scalar_mult_expr *) node)->param_source->value[0];
6159

6260
/* evaluate child */
6361
child->eval_jacobian(child);
@@ -87,8 +85,7 @@ static void eval_wsum_hess(expr *node, const double *w)
8785
expr *x = node->left;
8886
x->eval_wsum_hess(x, w);
8987

90-
scalar_mult_expr *sn = (scalar_mult_expr *) node;
91-
double a = sn->param_source->value[0];
88+
double a = ((scalar_mult_expr *) node)->param_source->value[0];
9289
for (int j = 0; j < x->wsum_hess->nnz; j++)
9390
{
9491
node->wsum_hess->x[j] = a * x->wsum_hess->x[j];

src/bivariate/vector_mult.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
static void forward(expr *node, const double *u)
2828
{
2929
expr *child = node->left;
30-
vector_mult_expr *vn = (vector_mult_expr *) node;
31-
const double *a = vn->param_source->value;
30+
// vector_mult_expr *vn = (vector_mult_expr *) node;
31+
// const double *a = vn->param_source->value;
32+
33+
const double *a = ((vector_mult_expr *) node)->param_source->value;
3234

3335
/* child's forward pass */
3436
child->forward(child, u);
@@ -56,8 +58,10 @@ static void jacobian_init(expr *node)
5658
static void eval_jacobian(expr *node)
5759
{
5860
expr *x = node->left;
59-
vector_mult_expr *vn = (vector_mult_expr *) node;
60-
const double *a = vn->param_source->value;
61+
// vector_mult_expr *vn = (vector_mult_expr *) node;
62+
// const double *a = vn->param_source->value;
63+
64+
const double *a = ((vector_mult_expr *) node)->param_source->value;
6165

6266
/* evaluate x */
6367
x->eval_jacobian(x);
@@ -90,8 +94,10 @@ static void wsum_hess_init(expr *node)
9094
static void eval_wsum_hess(expr *node, const double *w)
9195
{
9296
expr *x = node->left;
93-
vector_mult_expr *vn = (vector_mult_expr *) node;
94-
const double *a = vn->param_source->value;
97+
// vector_mult_expr *vn = (vector_mult_expr *) node;
98+
// const double *a = vn->param_source->value;
99+
100+
const double *a = ((vector_mult_expr *) node)->param_source->value;
95101

96102
/* scale weights w by a */
97103
for (int i = 0; i < node->size; i++)

tests/all_tests.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ int main(void)
284284
mu_run_test(test_param_scalar_mult_problem, tests_run);
285285
mu_run_test(test_param_vector_mult_problem, tests_run);
286286
mu_run_test(test_param_left_matmul_problem, tests_run);
287+
mu_run_test(test_param_right_matmul_problem, tests_run);
287288
#endif /* PROFILE_ONLY */
288289

289290
#ifdef PROFILE_ONLY

tests/jacobian_tests/test_right_matmul.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const char *test_jacobian_right_matmul_log()
2727
memcpy(A->x, A_x, 4 * sizeof(double));
2828

2929
expr *log_x = new_log(x);
30-
expr *log_x_A = new_right_matmul(log_x, A);
30+
expr *log_x_A = new_right_matmul(NULL, log_x, A);
3131

3232
log_x_A->forward(log_x_A, x_vals);
3333
log_x_A->jacobian_init(log_x_A);
@@ -76,7 +76,7 @@ const char *test_jacobian_right_matmul_log_vector()
7676
memcpy(A->x, A_x, 4 * sizeof(double));
7777

7878
expr *log_x = new_log(x);
79-
expr *log_x_A = new_right_matmul(log_x, A);
79+
expr *log_x_A = new_right_matmul(NULL, log_x, A);
8080

8181
log_x_A->forward(log_x_A, x_vals);
8282
log_x_A->jacobian_init(log_x_A);

tests/problem/test_param_prob.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,98 @@ const char *test_param_left_matmul_problem(void)
254254
return 0;
255255
}
256256

257+
/*
258+
* Test 4: right_param_matmul in constraint
259+
*
260+
* Problem: minimize sum(x), subject to x @ A, x size 1x2, A is 2x2
261+
* A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order)
262+
* A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4]
263+
*
264+
* At x=[1,2]:
265+
* constraint_values = [1*1+2*3, 1*2+2*4] = [7, 10]
266+
* jacobian = [[1,3],[2,4]] = A^T
267+
*
268+
* After update A = [[5,6],[7,8]] → theta = [5,6,7,8]:
269+
* constraint_values = [1*5+2*7, 1*6+2*8] = [19, 22]
270+
* jacobian = [[5,7],[6,8]] = A^T
271+
*/
272+
const char *test_param_right_matmul_problem(void)
273+
{
274+
int n_vars = 2;
275+
276+
/* Objective: sum(x) */
277+
expr *x_obj = new_variable(1, 2, 0, n_vars);
278+
expr *objective = new_sum(x_obj, -1);
279+
280+
/* Constraint: x @ A */
281+
expr *x_con = new_variable(1, 2, 0, n_vars);
282+
expr *A_param = new_parameter(2, 2, 0, n_vars, NULL);
283+
284+
/* Dense 2x2 CSR with placeholder zeros (values refreshed from A_param) */
285+
CSR_Matrix *A = new_csr_matrix(2, 2, 4);
286+
int Ap[3] = {0, 2, 4};
287+
int Ai[4] = {0, 1, 0, 1};
288+
double Ax[4] = {0.0, 0.0, 0.0, 0.0};
289+
memcpy(A->p, Ap, 3 * sizeof(int));
290+
memcpy(A->i, Ai, 4 * sizeof(int));
291+
memcpy(A->x, Ax, 4 * sizeof(double));
292+
293+
expr *constraint = new_right_matmul(A_param, x_con, A);
294+
free_csr_matrix(A);
295+
296+
expr *constraints[1] = {constraint};
297+
298+
/* Create problem */
299+
problem *prob = new_problem(objective, constraints, 1, true);
300+
301+
expr *param_nodes[1] = {A_param};
302+
problem_register_params(prob, param_nodes, 1);
303+
problem_init_derivatives(prob);
304+
305+
/* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */
306+
double theta[4] = {1.0, 2.0, 3.0, 4.0};
307+
problem_update_params(prob, theta);
308+
309+
double u[2] = {1.0, 2.0};
310+
problem_constraint_forward(prob, u);
311+
problem_jacobian(prob);
312+
313+
double expected_cv[2] = {7.0, 10.0};
314+
mu_assert("constraint values wrong (A1)",
315+
cmp_double_array(prob->constraint_values, expected_cv, 2));
316+
317+
CSR_Matrix *jac = prob->jacobian;
318+
mu_assert("jac rows wrong", jac->m == 2);
319+
mu_assert("jac cols wrong", jac->n == 2);
320+
321+
/* Dense jacobian = [[1,3],[2,4]] = A^T, CSR: row 0 → cols 0,1 vals 1,3;
322+
* row 1 → cols 0,1 vals 2,4 */
323+
int expected_p[3] = {0, 2, 4};
324+
mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3));
325+
326+
int expected_i[4] = {0, 1, 0, 1};
327+
mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4));
328+
329+
double expected_x[4] = {1.0, 3.0, 2.0, 4.0};
330+
mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4));
331+
332+
/* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */
333+
double theta2[4] = {5.0, 6.0, 7.0, 8.0};
334+
problem_update_params(prob, theta2);
335+
336+
problem_constraint_forward(prob, u);
337+
problem_jacobian(prob);
338+
339+
double expected_cv2[2] = {19.0, 22.0};
340+
mu_assert("constraint values wrong (A2)",
341+
cmp_double_array(prob->constraint_values, expected_cv2, 2));
342+
343+
double expected_x2[4] = {5.0, 7.0, 6.0, 8.0};
344+
mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4));
345+
346+
free_problem(prob);
347+
348+
return 0;
349+
}
350+
257351
#endif /* TEST_PARAM_PROB_H */

tests/wsum_hess/test_right_matmul.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const char *test_wsum_hess_right_matmul()
3333
memcpy(A->x, A_x, 4 * sizeof(double));
3434

3535
expr *log_x = new_log(x);
36-
expr *log_x_A = new_right_matmul(log_x, A);
36+
expr *log_x_A = new_right_matmul(NULL, log_x, A);
3737

3838
log_x_A->forward(log_x_A, x_vals);
3939
log_x_A->jacobian_init(log_x_A);
@@ -83,7 +83,7 @@ const char *test_wsum_hess_right_matmul_vector()
8383
memcpy(A->x, A_x, 4 * sizeof(double));
8484

8585
expr *log_x = new_log(x);
86-
expr *log_x_A = new_right_matmul(log_x, A);
86+
expr *log_x_A = new_right_matmul(NULL, log_x, A);
8787

8888
log_x_A->forward(log_x_A, x_vals);
8989
log_x_A->jacobian_init(log_x_A);

0 commit comments

Comments
 (0)