Skip to content

Commit 00f7732

Browse files
Transurgeonclaude
andcommitted
Skip AT recomputation in param refresh; param matmul is always affine
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 27acb7d commit 00f7732

1 file changed

Lines changed: 6 additions & 9 deletions

File tree

src/bivariate/left_matmul.c

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
#include "utils/utils.h"
4949
#include <string.h>
5050

51-
/* Refresh block-diagonal A values from param_source and recompute AT values.
51+
/* Refresh block-diagonal A values from param_source.
5252
The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix.
53-
A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */
53+
A->x is laid out as [src_nnz | src_nnz | ... | src_nnz].
54+
Note: AT is not refreshed because param matmul is always affine, so the
55+
weighted Hessian is always zero regardless of AT values. */
5456
static void refresh_param_values(left_matmul_expr *lin_node)
5557
{
5658
const double *src = lin_node->param_source->value;
@@ -75,8 +77,6 @@ static void refresh_param_values(left_matmul_expr *lin_node)
7577
memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double));
7678
}
7779

78-
/* Recompute AT values from updated A */
79-
AT_fill_values(A, lin_node->AT, lin_node->base.iwork);
8080
}
8181

8282
static void forward(expr *node, const double *u)
@@ -168,11 +168,8 @@ static void eval_wsum_hess(expr *node, const double *w)
168168
{
169169
left_matmul_expr *lin_node = (left_matmul_expr *) node;
170170

171-
/* refresh AT if parameter-sourced */
172-
if (lin_node->param_source)
173-
{
174-
refresh_param_values(lin_node);
175-
}
171+
/* No need to refresh AT for param-sourced nodes: param matmul is always
172+
affine, so the child's weighted Hessian is zero regardless of AT values. */
176173

177174
/* compute A^T w*/
178175
csr_matvec_wo_offset(lin_node->AT, w, node->dwork);

0 commit comments

Comments
 (0)