Skip to content

Commit 0edf928

Browse files
committed
fix AT workspace in left_matmul and add has_been_refreshed
1 parent b3e2304 commit 0edf928

6 files changed

Lines changed: 25 additions & 15 deletions

File tree

include/bivariate.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
3030
/* Matrix multiplication: Z = X @ Y */
3131
expr *new_matmul(expr *x, expr *y);
3232

33-
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node */
33+
/* Left matrix multiplication: A @ f(x) where A comes from a parameter node.
34+
Only the forward pass possibly updates the parameter. */
3435
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A);
3536

3637
/* Right matrix multiplication: f(x) @ A where A is a fixed parameter matrix */

include/problem.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ typedef struct problem
5959
* hessian are called */
6060
bool jacobian_called;
6161

62-
/* Parameter tracking for fast parameter updates */
62+
/* Parameter tracking for fast parameter updates. */
6363
expr **param_nodes; /* weak references to parameter nodes in tree */
6464
int n_param_nodes;
6565
int total_parameter_size;

include/subexpr.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct int_double_pair;
3535
typedef struct parameter_expr
3636
{
3737
expr base;
38-
int param_id; /* offset into global theta vector, or PARAM_FIXED */
38+
int param_id; /* offset into global theta vector, or PARAM_FIXED */
39+
bool has_been_refreshed; /* tracks whether parameter has been refreshed */
3940
} parameter_expr;
4041

4142
/* Type-specific expression structures that "inherit" from expr */
@@ -126,6 +127,7 @@ typedef struct left_matmul_expr
126127
CSC_Matrix *Jchild_CSC;
127128
CSC_Matrix *J_CSC;
128129
int *csc_to_csr_workspace;
130+
int *AT_iwork; /* work for computing AT values from A */
129131
expr *param_source; /* parameter node; A/AT values are refreshed from this */
130132
} left_matmul_expr;
131133

src/affine/parameter.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
7474
init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine,
7575
wsum_hess_init, eval_wsum_hess, NULL);
7676
pnode->param_id = param_id;
77+
pnode->has_been_refreshed = false;
7778

7879
/* If values provided (fixed constant), copy them now.
7980
Otherwise values will be populated by problem_update_params. */

src/bivariate/left_matmul.c

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,25 @@
5555
No-op when param_source is NULL (fixed constant — values already in A). */
5656
static void refresh_param_values(left_matmul_expr *lin_node)
5757
{
58-
if (!lin_node->param_source) return;
58+
parameter_expr *param = (parameter_expr *) lin_node->param_source;
5959

60+
if (!param || param->has_been_refreshed) return;
61+
param->has_been_refreshed = true;
62+
63+
/* update values of A */
6064
memcpy(lin_node->A->x, lin_node->param_source->value,
6165
lin_node->A->nnz * sizeof(double));
6266

63-
/* Recompute AT values from updated A */
64-
AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork);
67+
/* update values of AT */
68+
AT_fill_values(lin_node->A, lin_node->AT, lin_node->AT_iwork);
6569
}
6670

6771
static void forward(expr *node, const double *u)
6872
{
6973
expr *x = node->left;
7074
left_matmul_expr *lin_node = (left_matmul_expr *) node;
7175

72-
/* refresh A/AT from parameter source */
76+
/* possibly refresh A and AT */
7377
refresh_param_values(lin_node);
7478

7579
/* child's forward pass */
@@ -92,6 +96,7 @@ static void free_type_data(expr *node)
9296
free_csc_matrix(lin_node->Jchild_CSC);
9397
free_csc_matrix(lin_node->J_CSC);
9498
free(lin_node->csc_to_csr_workspace);
99+
free(lin_node->AT_iwork);
95100
free_expr(lin_node->param_source);
96101
}
97102

@@ -119,9 +124,6 @@ static void eval_jacobian(expr *node)
119124
CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
120125
CSC_Matrix *J_CSC = lnode->J_CSC;
121126

122-
/* refresh A from parameter source */
123-
refresh_param_values(lnode);
124-
125127
/* evaluate child's jacobian and convert to CSC */
126128
x->eval_jacobian(x);
127129
csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->iwork);
@@ -167,7 +169,6 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
167169
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
168170
the result of A @ u is (m, ), which is (1, m) according to broadcasting
169171
rules. We therefore check if this is the case. */
170-
171172
int d1, d2, n_blocks;
172173
if (child->d1 == A->n)
173174
{
@@ -197,13 +198,14 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
197198
expr_retain(child);
198199

199200
/* Store small A (NOT block-diagonal) — block functions handle the rest */
200-
node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int));
201+
node->iwork = (int *) malloc(node->n_vars * sizeof(int));
202+
lin_node->AT_iwork = (int *) malloc(A->n * sizeof(int));
201203
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
202204
lin_node->n_blocks = n_blocks;
203205
lin_node->A = new_csr(A);
204-
lin_node->AT = transpose(lin_node->A, node->iwork);
205-
206+
lin_node->AT = transpose(lin_node->A, lin_node->AT_iwork);
206207
lin_node->param_source = param_node;
208+
207209
if (param_node) expr_retain(param_node);
208210

209211
return node;

src/problem.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,9 @@ void problem_register_params(problem *prob, expr **param_nodes, int n_param_node
456456

457457
prob->total_parameter_size = 0;
458458
for (int i = 0; i < n_param_nodes; i++)
459+
{
459460
prob->total_parameter_size += param_nodes[i]->size;
461+
}
460462
}
461463

462464
void problem_update_params(problem *prob, const double *theta)
@@ -466,7 +468,9 @@ void problem_update_params(problem *prob, const double *theta)
466468
parameter_expr *p = (parameter_expr *) prob->param_nodes[i];
467469
if (p->param_id == PARAM_FIXED) continue;
468470
memcpy(p->base.value, theta + p->param_id, p->base.size * sizeof(double));
471+
p->has_been_refreshed = false;
469472
}
470-
/* Force re-evaluation of affine Jacobians on next call */
473+
474+
/* force re-evaluation of affine Jacobians on next call */
471475
prob->jacobian_called = false;
472476
}

0 commit comments

Comments
 (0)