4848
4949#include "utils/utils.h"
5050
51+ static void refresh_param_values (left_matmul_expr * lnode )
52+ {
53+ if (lnode -> param_source == NULL )
54+ {
55+ return ;
56+ }
57+ parameter_expr * param = (parameter_expr * ) lnode -> param_source ;
58+ if (param -> has_been_refreshed )
59+ {
60+ return ;
61+ }
62+ param -> has_been_refreshed = true;
63+ lnode -> refresh_param_values (lnode );
64+ }
65+
5166static void forward (expr * node , const double * u )
5267{
68+ left_matmul_expr * lnode = (left_matmul_expr * ) node ;
69+ refresh_param_values (lnode );
70+
5371 expr * x = node -> left ;
5472
5573 /* child's forward pass */
5674 node -> left -> forward (node -> left , u );
5775
5876 /* y = A_kron @ vec(f(x)) */
59- Matrix * A = (( left_matmul_expr * ) node ) -> A ;
60- int n_blocks = (( left_matmul_expr * ) node ) -> n_blocks ;
77+ Matrix * A = lnode -> A ;
78+ int n_blocks = lnode -> n_blocks ;
6179 A -> block_left_mult_vec (A , x -> value , node -> value , n_blocks );
6280}
6381
@@ -74,11 +92,16 @@ static void free_type_data(expr *node)
7492 free_csc_matrix (lnode -> Jchild_CSC );
7593 free_csc_matrix (lnode -> J_CSC );
7694 free (lnode -> csc_to_csr_work );
95+ if (lnode -> param_source != NULL )
96+ {
97+ free_expr (lnode -> param_source );
98+ }
7799 lnode -> A = NULL ;
78100 lnode -> AT = NULL ;
79101 lnode -> Jchild_CSC = NULL ;
80102 lnode -> J_CSC = NULL ;
81103 lnode -> csc_to_csr_work = NULL ;
104+ lnode -> param_source = NULL ;
82105}
83106
84107static void jacobian_init (expr * node )
@@ -98,8 +121,9 @@ static void jacobian_init(expr *node)
98121
99122static void eval_jacobian (expr * node )
100123{
101- expr * x = node -> left ;
102124 left_matmul_expr * lnode = (left_matmul_expr * ) node ;
125+ refresh_param_values (lnode );
126+ expr * x = node -> left ;
103127
104128 CSC_Matrix * Jchild_CSC = lnode -> Jchild_CSC ;
105129 CSC_Matrix * J_CSC = lnode -> J_CSC ;
@@ -132,17 +156,46 @@ static void wsum_hess_init(expr *node)
132156
133157static void eval_wsum_hess (expr * node , const double * w )
134158{
159+ left_matmul_expr * lnode = (left_matmul_expr * ) node ;
160+ refresh_param_values (lnode );
161+
135162 /* compute A^T w*/
136- Matrix * AT = (( left_matmul_expr * ) node ) -> AT ;
137- int n_blocks = (( left_matmul_expr * ) node ) -> n_blocks ;
163+ Matrix * AT = lnode -> AT ;
164+ int n_blocks = lnode -> n_blocks ;
138165 AT -> block_left_mult_vec (AT , w , node -> dwork , n_blocks );
139166
140167 node -> left -> eval_wsum_hess (node -> left , node -> dwork );
141168 memcpy (node -> wsum_hess -> x , node -> left -> wsum_hess -> x ,
142169 node -> wsum_hess -> nnz * sizeof (double ));
143170}
144171
145- expr * new_left_matmul (expr * u , const CSR_Matrix * A )
172+ static void refresh_sparse_left (left_matmul_expr * lnode )
173+ {
174+ Sparse_Matrix * sm_A = (Sparse_Matrix * ) lnode -> A ;
175+ Sparse_Matrix * sm_AT = (Sparse_Matrix * ) lnode -> AT ;
176+ lnode -> A -> update_values (lnode -> A , lnode -> param_source -> value );
177+ /* Recompute AT values from A */
178+ AT_fill_values (sm_A -> csr , sm_AT -> csr , lnode -> base .iwork );
179+ }
180+
181+ static void refresh_dense_left (left_matmul_expr * lnode )
182+ {
183+ Dense_Matrix * dm_A = (Dense_Matrix * ) lnode -> A ;
184+ int m = dm_A -> base .m ;
185+ int n = dm_A -> base .n ;
186+ lnode -> A -> update_values (lnode -> A , lnode -> param_source -> value );
187+ /* Recompute AT data (transpose of row-major A) */
188+ Dense_Matrix * dm_AT = (Dense_Matrix * ) lnode -> AT ;
189+ for (int i = 0 ; i < m ; i ++ )
190+ {
191+ for (int j = 0 ; j < n ; j ++ )
192+ {
193+ dm_AT -> x [j * m + i ] = dm_A -> x [i * n + j ];
194+ }
195+ }
196+ }
197+
198+ expr * new_left_matmul (expr * param_node , expr * u , const CSR_Matrix * A )
146199{
147200 /* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
148201 to do A @ u where u is (n, ) which in C is actually (1, n). In that case
@@ -188,10 +241,19 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
188241 lnode -> A = new_sparse_matrix (A );
189242 lnode -> AT = sparse_matrix_trans ((const Sparse_Matrix * ) lnode -> A , node -> iwork );
190243
244+ /* parameter support */
245+ lnode -> param_source = param_node ;
246+ if (param_node != NULL )
247+ {
248+ expr_retain (param_node );
249+ lnode -> refresh_param_values = refresh_sparse_left ;
250+ }
251+
191252 return node ;
192253}
193254
194- expr * new_left_matmul_dense (expr * u , int m , int n , const double * data )
255+ expr * new_left_matmul_dense (expr * param_node , expr * u , int m , int n ,
256+ const double * data )
195257{
196258 int d1 , d2 , n_blocks ;
197259 if (u -> d1 == n )
@@ -227,5 +289,13 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
227289 lnode -> A = new_dense_matrix (m , n , data );
228290 lnode -> AT = dense_matrix_trans ((const Dense_Matrix * ) lnode -> A );
229291
292+ /* parameter support */
293+ lnode -> param_source = param_node ;
294+ if (param_node != NULL )
295+ {
296+ expr_retain (param_node );
297+ lnode -> refresh_param_values = refresh_dense_left ;
298+ }
299+
230300 return node ;
231301}
0 commit comments