5555 No-op when param_source is NULL (fixed constant — values already in A). */
5656static void refresh_param_values (left_matmul_expr * lin_node )
5757{
58- if (! lin_node -> param_source ) return ;
58+ parameter_expr * param = ( parameter_expr * ) lin_node -> param_source ;
5959
60+ if (!param || param -> has_been_refreshed ) return ;
61+ param -> has_been_refreshed = true;
62+
63+ /* update values of A */
6064 memcpy (lin_node -> A -> x , lin_node -> param_source -> value ,
6165 lin_node -> A -> nnz * sizeof (double ));
6266
63- /* Recompute AT values from updated A */
64- AT_fill_values (lin_node -> A , lin_node -> AT , lin_node -> base . iwork );
67+ /* update values of AT */
68+ AT_fill_values (lin_node -> A , lin_node -> AT , lin_node -> AT_iwork );
6569}
6670
6771static void forward (expr * node , const double * u )
6872{
6973 expr * x = node -> left ;
7074 left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
7175
72- /* refresh A/AT from parameter source */
76+ /* possibly refresh A and AT */
7377 refresh_param_values (lin_node );
7478
7579 /* child's forward pass */
@@ -92,6 +96,7 @@ static void free_type_data(expr *node)
9296 free_csc_matrix (lin_node -> Jchild_CSC );
9397 free_csc_matrix (lin_node -> J_CSC );
9498 free (lin_node -> csc_to_csr_workspace );
99+ free (lin_node -> AT_iwork );
95100 free_expr (lin_node -> param_source );
96101}
97102
@@ -119,9 +124,6 @@ static void eval_jacobian(expr *node)
119124 CSC_Matrix * Jchild_CSC = lnode -> Jchild_CSC ;
120125 CSC_Matrix * J_CSC = lnode -> J_CSC ;
121126
122- /* refresh A from parameter source */
123- refresh_param_values (lnode );
124-
125127 /* evaluate child's jacobian and convert to CSC */
126128 x -> eval_jacobian (x );
127129 csr_to_csc_fill_values (x -> jacobian , Jchild_CSC , node -> iwork );
@@ -167,7 +169,6 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
167169 to do A @ u where u is (n, ) which in C is actually (1, n). In that case
168170 the result of A @ u is (m, ), which is (1, m) according to broadcasting
169171 rules. We therefore check if this is the case. */
170-
171172 int d1 , d2 , n_blocks ;
172173 if (child -> d1 == A -> n )
173174 {
@@ -197,13 +198,14 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
197198 expr_retain (child );
198199
199200 /* Store small A (NOT block-diagonal) — block functions handle the rest */
200- node -> iwork = (int * ) malloc (MAX (A -> n , node -> n_vars ) * sizeof (int ));
201+ node -> iwork = (int * ) malloc (node -> n_vars * sizeof (int ));
202+ lin_node -> AT_iwork = (int * ) malloc (A -> n * sizeof (int ));
201203 lin_node -> csc_to_csr_workspace = (int * ) malloc (node -> size * sizeof (int ));
202204 lin_node -> n_blocks = n_blocks ;
203205 lin_node -> A = new_csr (A );
204- lin_node -> AT = transpose (lin_node -> A , node -> iwork );
205-
206+ lin_node -> AT = transpose (lin_node -> A , lin_node -> AT_iwork );
206207 lin_node -> param_source = param_node ;
208+
207209 if (param_node ) expr_retain (param_node );
208210
209211 return node ;
0 commit comments