|
49 | 49 | #include <string.h> |
50 | 50 |
|
51 | 51 | /* Refresh block-diagonal A values from param_source and recompute AT values. |
52 | | - The block-diagonal has n_blocks copies of the src_m x src_n source matrix. |
53 | | - block_diag_repeat_csr lays out values as: for each block, copy the entire |
54 | | - source nnz array in order. So A->x is [src_nnz | src_nnz | ... | src_nnz]. */ |
| 52 | + The block-diagonal has n_blocks copies of the dense src_m x src_n source matrix. |
| 53 | + A->x is laid out as [src_nnz | src_nnz | ... | src_nnz]. */ |
55 | 54 | static void refresh_param_values(left_matmul_expr *lin_node) |
56 | 55 | { |
57 | 56 | const double *src = lin_node->param_source->value; |
58 | 57 | CSR_Matrix *A = lin_node->A; |
59 | 58 | int src_m = lin_node->src_m; |
60 | 59 | int src_n = lin_node->src_n; |
61 | | - int total_rows = A->m; |
62 | | - int n_blocks = total_rows / src_m; |
| 60 | + int src_nnz = src_m * src_n; |
| 61 | + int n_blocks = A->m / src_m; |
63 | 62 |
|
64 | | - /* Rebuild A values from column-major source matrix. |
65 | | - For each block, iterate rows of the source matrix and fill CSR values. */ |
66 | | - int nnz_cursor = 0; |
67 | | - for (int block = 0; block < n_blocks; block++) |
| 63 | + /* Build first block: column-major source -> row-major CSR values */ |
| 64 | + for (int row = 0; row < src_m; row++) |
68 | 65 | { |
69 | | - for (int row = 0; row < src_m; row++) |
| 66 | + for (int col = 0; col < src_n; col++) |
70 | 67 | { |
71 | | - int dest_row = block * src_m + row; |
72 | | - for (int j = A->p[dest_row]; j < A->p[dest_row + 1]; j++) |
73 | | - { |
74 | | - /* column index in local block coordinates */ |
75 | | - int col = A->i[j] - block * src_n; |
76 | | - /* source is column-major: src[row + col * src_m] */ |
77 | | - A->x[nnz_cursor] = src[row + col * src_m]; |
78 | | - nnz_cursor++; |
79 | | - } |
| 68 | + A->x[row * src_n + col] = src[row + col * src_m]; |
80 | 69 | } |
81 | 70 | } |
82 | 71 |
|
| 72 | + /* Copy first block to remaining blocks */ |
| 73 | + for (int block = 1; block < n_blocks; block++) |
| 74 | + { |
| 75 | + memcpy(A->x + block * src_nnz, A->x, src_nnz * sizeof(double)); |
| 76 | + } |
| 77 | + |
83 | 78 | /* Recompute AT values from updated A */ |
84 | 79 | AT_fill_values(A, lin_node->AT, lin_node->base.iwork); |
85 | 80 | } |
|
0 commit comments