4848#include "utils/utils.h"
4949#include <string.h>
5050
51- /* Refresh block-diagonal A values from param_source and recompute AT values .
51+ /* Refresh block-diagonal A values from param_source.
5252 The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix.
53- A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */
53+ A->x is laid out as [src_nnz | src_nnz | ... | src_nnz].
54+ Note: AT is not refreshed because param matmul is always affine, so the
55+ weighted Hessian is always zero regardless of AT values. */
5456static void refresh_param_values (left_matmul_expr * lin_node )
5557{
5658 const double * src = lin_node -> param_source -> value ;
@@ -75,8 +77,6 @@ static void refresh_param_values(left_matmul_expr *lin_node)
7577 memcpy (A -> x + block * src_nnz , A -> x , src_nnz * sizeof (double ));
7678 }
7779
78- /* Recompute AT values from updated A */
79- AT_fill_values (A , lin_node -> AT , lin_node -> base .iwork );
8080}
8181
8282static void forward (expr * node , const double * u )
@@ -168,11 +168,8 @@ static void eval_wsum_hess(expr *node, const double *w)
168168{
169169 left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
170170
171- /* refresh AT if parameter-sourced */
172- if (lin_node -> param_source )
173- {
174- refresh_param_values (lin_node );
175- }
171+ /* No need to refresh AT for param-sourced nodes: param matmul is always
172+ affine, so the child's weighted Hessian is zero regardless of AT values. */
176173
177174 /* compute A^T w*/
178175 csr_matvec_wo_offset (lin_node -> AT , w , node -> dwork );
0 commit comments