@@ -39,7 +39,24 @@ typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double
3939typedef bool (* is_affine_fn )(const struct expr * node );
4040typedef void (* free_type_data_fn )(struct expr * node );
4141
42- /* Base expression node structure - contains only common fields */
42+ /* Workspace for derivative computation */
43+ typedef struct
44+ {
45+ double * dwork ;
46+ int * iwork ;
47+ CSC_Matrix * jacobian_csc ;
48+ int * csc_work ; /* for CSR-CSC conversion */
49+
50+ /* jacobian_csc_filled is only used for affine functions to avoid redundant
51+ conversions. Could become relevant for non-affine functions if we start
52+ supporting common subexpressions on the Python side. */
53+ bool jacobian_csc_filled ;
54+ double * local_jac_diag ; /* cached f'(g(x)) diagonal */
55+ CSR_Matrix * hess_term1 ; /* Jg^T D Jg workspace */
56+ CSR_Matrix * hess_term2 ; /* child wsum_hess workspace */
57+ } Expr_Work ;
58+
59+ /* Base expression node structure */
4360typedef struct expr
4461{
4562 // ------------------------------------------------------------------------
@@ -48,8 +65,6 @@ typedef struct expr
4865 int d1 , d2 , size , n_vars , refcount , var_id ;
4966 struct expr * left ;
5067 struct expr * right ;
51- double * dwork ;
52- int * iwork ;
5368
5469 // ------------------------------------------------------------------------
5570 // oracle related quantities
@@ -58,8 +73,8 @@ typedef struct expr
5873 CSR_Matrix * jacobian ;
5974 CSR_Matrix * wsum_hess ;
6075 forward_fn forward ;
61- jacobian_init_fn jacobian_init ;
62- wsum_hess_init_fn wsum_hess_init ;
76+ jacobian_init_fn jacobian_init_impl ;
77+ wsum_hess_init_fn wsum_hess_init_impl ;
6378 eval_jacobian_fn eval_jacobian ;
6479 wsum_hess_fn eval_wsum_hess ;
6580
@@ -70,6 +85,13 @@ typedef struct expr
7085 local_jacobian_fn local_jacobian ; /* used by elementwise univariate atoms*/
7186 local_wsum_hess_fn local_wsum_hess ; /* used by elementwise univariate atoms*/
7287 free_type_data_fn free_type_data ; /* Cleanup for type-specific fields */
88+ Expr_Work * work ; /* derivative workspace */
89+ /* Set to true on all nodes by problem_update_params() via
90+ expr_set_needs_refresh(). Atoms that cache parameter data
91+ (e.g. left_matmul_dense) check this flag before their forward
92+ pass: if true, they refresh their cached matrices from
93+ param_source->value and clear the flag to false. */
94+ bool needs_parameter_refresh ;
7395
7496 // name of node just for debugging - should be removed later
7597 char name [32 ];
@@ -83,6 +105,18 @@ void init_expr(expr *node, int d1, int d2, int n_vars, forward_fn forward,
83105
84106void free_expr (expr * node );
85107
108+ /* Guarded init: skips if already initialized (safe for DAGs
109+ * where a node may be visited through multiple parents). */
110+ void jacobian_init (expr * node );
111+ void wsum_hess_init (expr * node );
112+
113+ /* Initialize CSC form of the Jacobian from the CSR Jacobian.
114+ * Must be called after jacobian_init. */
115+ void jacobian_csc_init (expr * node );
116+
117+ /* Recursively set needs_parameter_refresh on node and all children */
118+ void expr_set_needs_refresh (expr * node );
119+
86120/* Reference counting helpers */
87121void expr_retain (expr * node );
88122
0 commit comments