Skip to content

Commit a83b584

Browse files
committed
new parameter abstraction that fixess the fialing test
1 parent f9f4516 commit a83b584

6 files changed

Lines changed: 26 additions & 8 deletions

File tree

include/expr.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ typedef struct expr
8686
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
8787
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
8888
Expr_Work *work; /* derivative workspace */
89+
/* Set to true on all nodes by problem_update_params() via
90+
expr_set_needs_refresh(). Atoms that cache parameter data
91+
(e.g. left_matmul_dense) check this flag before their forward
92+
pass: if true, they refresh their cached matrices from
93+
param_source->value and clear the flag to false. */
94+
bool needs_parameter_refresh;
8995

9096
// name of node just for debugging - should be removed later
9197
char name[32];
@@ -108,6 +114,9 @@ void wsum_hess_init(expr *node);
108114
* Must be called after jacobian_init. */
109115
void jacobian_csc_init(expr *node);
110116

117+
/* Recursively set needs_parameter_refresh on node and all children */
118+
void expr_set_needs_refresh(expr *node);
119+
111120
/* Reference counting helpers */
112121
void expr_retain(expr *node);
113122

include/subexpr.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ typedef struct parameter_expr
3737
{
3838
expr base;
3939
int param_id;
40-
/* Set to true by problem_update_params(), cleared by
41-
refresh_param_values() after propagating new values. */
42-
bool needs_refresh;
4340
} parameter_expr;
4441

4542
/* Linear operator: y = A * x + b

src/atoms/affine/left_matmul.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ static void refresh_param_values(left_matmul_expr *lnode)
5555
{
5656
return;
5757
}
58-
parameter_expr *param = (parameter_expr *) lnode->param_source;
59-
if (!param->needs_refresh)
58+
if (!lnode->base.needs_parameter_refresh)
6059
{
6160
return;
6261
}
63-
param->needs_refresh = false;
62+
lnode->base.needs_parameter_refresh = false;
6463
lnode->refresh_param_values(lnode);
6564
}
6665

src/atoms/affine/parameter.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *valu
6565
is_affine, wsum_hess_init_impl, eval_wsum_hess, NULL);
6666

6767
pnode->param_id = param_id;
68-
pnode->needs_refresh = false;
6968

7069
if (values != NULL)
7170
{

src/expr.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ void wsum_hess_init(expr *node)
108108
node->wsum_hess_init_impl(node);
109109
}
110110

111+
void expr_set_needs_refresh(expr *node)
112+
{
113+
if (node == NULL) return;
114+
node->needs_parameter_refresh = true;
115+
expr_set_needs_refresh(node->left);
116+
expr_set_needs_refresh(node->right);
117+
}
118+
111119
void expr_retain(expr *node)
112120
{
113121
if (node == NULL) return;

src/problem.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,13 @@ void problem_update_params(problem *prob, const double *theta)
382382
if (param->param_id == PARAM_FIXED) continue;
383383
int offset = param->param_id;
384384
memcpy(pnode->value, theta + offset, pnode->size * sizeof(double));
385-
param->needs_refresh = true;
385+
}
386+
387+
/* Propagate needs_parameter_refresh to all expressions */
388+
expr_set_needs_refresh(prob->objective);
389+
for (int i = 0; i < prob->n_constraints; i++)
390+
{
391+
expr_set_needs_refresh(prob->constraints[i]);
386392
}
387393

388394
/* Force re-evaluation of affine Jacobians on next call */

0 commit comments

Comments
 (0)