@@ -208,20 +208,17 @@ static void eval_wsum_hess(expr *node, const double *w)
208208 }
209209 }
210210
211- /* both are linear operators */
212211 CSC_Matrix * Jg1 = x -> work -> jacobian_csc ;
213212 CSC_Matrix * Jg2 = y -> work -> jacobian_csc ;
213+
214+ // -----------------------------------------------------------------------
215+ // compute C = Jg2^T diag(w) Jg1, CT = C^T, and sum to get hess = C + CT
216+ // -----------------------------------------------------------------------
214217 CSR_Matrix * C = ((elementwise_mult_expr * ) node )-> CSR_work1 ;
215218 CSR_Matrix * CT = ((elementwise_mult_expr * ) node )-> CSR_work2 ;
216-
217- /* Compute C = B^T diag(w) A */
218- BTDA_fill_values (Jg1 , Jg2 , w , C );
219-
220- /* Compute CT = C^T = A^T diag(w) B */
221- AT_fill_values (C , CT , node -> work -> iwork );
222-
223- /* Hessian = C + CT = B^T diag(w) A + A^T diag(w) B */
224- sum_csr_matrices_fill_values (C , CT , node -> wsum_hess );
219+ BTDA_fill_values (Jg1 , Jg2 , w , C ); /* compute C */
220+ AT_fill_values (C , CT , node -> work -> iwork ); /* compute CT */
221+ sum_csr_matrices_fill_values (C , CT , node -> wsum_hess ); /* hess = C + CT */
225222 }
226223}
227224
0 commit comments