@@ -191,12 +191,10 @@ static void wsum_hess_init(expr *node)
191191 x -> wsum_hess_init (x );
192192
193193 /* Same sparsity as child - weights get summed */
194- node -> wsum_hess = new_csr_matrix (node -> n_vars , node -> n_vars , x -> wsum_hess -> nnz );
195- memcpy (node -> wsum_hess -> p , x -> wsum_hess -> p , (x -> wsum_hess -> m + 1 ) * sizeof (int ));
196- memcpy (node -> wsum_hess -> i , x -> wsum_hess -> i , x -> wsum_hess -> nnz * sizeof (int ));
194+ node -> wsum_hess = new_csr_copy_sparsity (x -> wsum_hess );
197195
198196 /* allocate space for weight vector */
199- node -> dwork = malloc (node -> size * sizeof (double ));
197+ node -> work -> dwork = malloc (node -> size * sizeof (double ));
200198}
201199
202200static void eval_wsum_hess (expr * node , const double * w )
@@ -205,7 +203,7 @@ static void eval_wsum_hess(expr *node, const double *w)
205203 expr * x = node -> left ;
206204
207205 /* Zero out the work array first */
208- memset (node -> dwork , 0 , x -> size * sizeof (double ));
206+ memset (node -> work -> dwork , 0 , x -> size * sizeof (double ));
209207
210208 if (bcast -> type == BROADCAST_ROW )
211209 {
@@ -214,7 +212,7 @@ static void eval_wsum_hess(expr *node, const double *w)
214212 {
215213 for (int i = 0 ; i < node -> d1 ; i ++ )
216214 {
217- node -> dwork [j ] += w [i + j * node -> d1 ];
215+ node -> work -> dwork [j ] += w [i + j * node -> d1 ];
218216 }
219217 }
220218 }
@@ -225,21 +223,21 @@ static void eval_wsum_hess(expr *node, const double *w)
225223 {
226224 for (int i = 0 ; i < node -> d1 ; i ++ )
227225 {
228- node -> dwork [i ] += w [i + j * node -> d1 ];
226+ node -> work -> dwork [i ] += w [i + j * node -> d1 ];
229227 }
230228 }
231229 }
232230 else
233231 {
234232 /* (1, 1) -> (m, n): scalar has m*n weights to sum */
235- node -> dwork [0 ] = 0.0 ;
233+ node -> work -> dwork [0 ] = 0.0 ;
236234 for (int k = 0 ; k < node -> size ; k ++ )
237235 {
238- node -> dwork [0 ] += w [k ];
236+ node -> work -> dwork [0 ] += w [k ];
239237 }
240238 }
241239
242- x -> eval_wsum_hess (x , node -> dwork );
240+ x -> eval_wsum_hess (x , node -> work -> dwork );
243241 memcpy (node -> wsum_hess -> x , x -> wsum_hess -> x , x -> wsum_hess -> nnz * sizeof (double ));
244242}
245243
0 commit comments