Skip to content

Commit 27acb7d

Browse files
Transurgeonclaude
andcommitted
Simplify refresh_param_values: fill one block, memcpy the rest
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bd4f3b0 commit 27acb7d

1 file changed

Lines changed: 14 additions & 19 deletions

File tree

src/bivariate/left_matmul.c

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,37 +49,32 @@
4949
#include <string.h>
5050

5151
/* Refresh block-diagonal A values from param_source and recompute AT values.
52-
The block-diagonal has n_blocks copies of the src_m x src_n source matrix.
53-
block_diag_repeat_csr lays out values as: for each block, copy the entire
54-
source nnz array in order. So A->x is [src_nnz | src_nnz | ... | src_nnz]. */
52+
The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix.
53+
A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */
5554
static void refresh_param_values(left_matmul_expr *lin_node)
5655
{
5756
const double *src = lin_node->param_source->value;
5857
CSR_Matrix *A = lin_node->A;
5958
int src_m = lin_node->src_m;
6059
int src_n = lin_node->src_n;
61-
int total_rows = A->m;
62-
int n_blocks = total_rows / src_m;
60+
int src_nnz = src_m * src_n;
61+
int n_blocks = A->m / src_m;
6362

64-
/* Rebuild A values from column-major source matrix.
65-
For each block, iterate rows of the source matrix and fill CSR values. */
66-
int nnz_cursor = 0;
67-
for (int block = 0; block < n_blocks; block++)
63+
/* Build first block: column-major source -> row-major CSR values */
64+
for (int row = 0; row < src_m; row++)
6865
{
69-
for (int row = 0; row < src_m; row++)
66+
for (int col = 0; col < src_n; col++)
7067
{
71-
int dest_row = block * src_m + row;
72-
for (int j = A->p[dest_row]; j < A->p[dest_row + 1]; j++)
73-
{
74-
/* column index in local block coordinates */
75-
int col = A->i[j] - block * src_n;
76-
/* source is column-major: src[row + col * src_m] */
77-
A->x[nnz_cursor] = src[row + col * src_m];
78-
nnz_cursor++;
79-
}
68+
A->x[row * src_n + col] = src[row + col * src_m];
8069
}
8170
}
8271

72+
/* Copy first block to remaining blocks */
73+
for (int block = 1; block < n_blocks; block++)
74+
{
75+
memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double));
76+
}
77+
8378
/* Recompute AT values from updated A */
8479
AT_fill_values(A, lin_node->AT, lin_node->base.iwork);
8580
}

0 commit comments

Comments
 (0)