Skip to content

Commit 3a6e149

Browse files
authored
Right matmul refactor (#46)
* refactored and moved things from CSR_Matrix * clean up documentation of CSR_Matrix.h * replace right matmul with left matmul
1 parent 0e804a0 commit 3a6e149

17 files changed

Lines changed: 884 additions & 969 deletions

include/utils/CSR_Matrix.h

Lines changed: 11 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,29 @@ typedef struct CSR_Matrix
2525
int nnz;
2626
} CSR_Matrix;
2727

28-
/* Allocate a new CSR matrix with given dimensions and nnz */
28+
/* constructors and destructors */
2929
CSR_Matrix *new_csr_matrix(int m, int n, int nnz);
3030
CSR_Matrix *new_csr(const CSR_Matrix *A);
31-
32-
/* Free a CSR matrix */
3331
void free_csr_matrix(CSR_Matrix *matrix);
34-
35-
/* Copy CSR matrix A to C */
3632
void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C);
3733

38-
/* Build block-diagonal repeat A_blk = I_p kron A. Returns newly allocated CSR
39-
* matrix of size (p*A->m) x (p*A->n) with nnz = p*A->nnz. */
34+
/* transpose functionality (iwork must be of size A->n) */
35+
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
36+
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork);
37+
void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork);
38+
39+
/* Build (I_p kron A) = blkdiag(A, A, ..., A) of size (p*A->m) x (p*A->n) */
4040
CSR_Matrix *block_diag_repeat_csr(const CSR_Matrix *A, int p);
4141

42-
/* Build left-repeated Kronecker A_kron = A kron I_p. Returns newly allocated CSR
43-
* matrix of size (A->m * p) x (A->n * p) with nnz = A->nnz * p. */
42+
/* Build (A kron I_p) of size (A->m * p) x (A->n * p) with nnz = A->nnz * p. */
4443
CSR_Matrix *kron_identity_csr(const CSR_Matrix *A, int p);
4544

46-
/* matvec y = Ax, where A indices minus col_offset gives x indices. Returns y as
47-
* dense. */
45+
/* y = Ax, where y is returned as dense */
4846
void csr_matvec(const CSR_Matrix *A, const double *x, double *y, int col_offset);
4947
void csr_matvec_wo_offset(const CSR_Matrix *A, const double *x, double *y);
5048

51-
/* C = z^T A is assumed to have one row. C must have column indices pre-computed
52-
and transposed matrix AT must be provided. Fills in values of C only.
53-
*/
49+
/* Computes values of the row matrix C = z^T A (column indices must have been
50+
pre-computed) and transposed matrix AT must be provided) */
5451
void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C);
5552

5653
/* Insert value into CSR matrix A with just one row at col_idx. Assumes that A
@@ -64,92 +61,6 @@ void csr_insert_value(CSR_Matrix *A, int col_idx, double value);
6461
void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C);
6562
void diag_csr_mult_fill_values(const double *d, const CSR_Matrix *A, CSR_Matrix *C);
6663

67-
/* Compute C = A + B where A, B, C are CSR matrices
68-
* A and B must have same dimensions
69-
* C must be pre-allocated with sufficient nnz capacity.
70-
* C must be different from A and B */
71-
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
72-
/* Compute sparsity pattern of A + B where A, B, C are CSR matrices.
73-
* Fills C->p, C->i, and C->nnz; does not touch C->x. */
74-
void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B,
75-
CSR_Matrix *C);
76-
77-
/* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i)
78-
* is already filled and matches the union of A and B per row. Does not modify
79-
* C->p, C->i, or C->nnz. */
80-
void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
81-
CSR_Matrix *C);
82-
83-
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
84-
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,
85-
const double *d1, const double *d2);
86-
87-
/* Fill only the values of C = diag(d1) * A + diag(d2) * B, assuming C's sparsity
88-
* pattern (p and i) is already filled and matches the union of A and B per row.
89-
* Does not modify C->p, C->i, or C->nnz. */
90-
void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
91-
CSR_Matrix *C, const double *d1,
92-
const double *d2);
93-
94-
/* Sum all rows of A into a single row matrix C */
95-
void sum_all_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
96-
struct int_double_pair *pairs);
97-
98-
/* iwork must have size max(C->n, A->nnz), and idx_map must have size A->nnz. */
99-
void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix *C,
100-
int *iwork, int *idx_map);
101-
102-
/* Fill values of summed rows using precomputed idx_map and sparsity of C */
103-
// void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C,
104-
// const int *idx_map);
105-
106-
/* Fill accumulator for summing rows using precomputed idx_map for each nnz of A.
107-
Must memset accumulator to zero before calling. */
108-
void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map,
109-
double *accumulator);
110-
void idx_map_accumulator_with_spacing(const CSR_Matrix *A, const int *idx_map,
111-
double *accumulator, int spacing);
112-
113-
/* Sum blocks of rows of A into a matrix C */
114-
void sum_block_of_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
115-
struct int_double_pair *pairs, int row_block_size);
116-
117-
/* Build sparsity and index map for summing blocks of rows.
118-
* iwork must have size max(A->n, A->nnz), and idx_map must have size A->nnz. */
119-
void sum_block_of_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
120-
CSR_Matrix *C,
121-
int row_block_size, int *iwork,
122-
int *idx_map);
123-
124-
/* Fill values for summing blocks of rows using precomputed idx_map */
125-
// void sum_block_of_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C,
126-
// const int *idx_map);
127-
128-
/* Sum evenly spaced rows of A into a matrix C */
129-
void sum_evenly_spaced_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
130-
struct int_double_pair *pairs, int row_spacing);
131-
132-
/* Build sparsity and index map for summing evenly spaced rows.
133-
* iwork must have size max(A->n, A->nnz), and idx_map must have size A->nnz. */
134-
void sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
135-
CSR_Matrix *C,
136-
int row_spacing,
137-
int *iwork, int *idx_map);
138-
139-
/* Fill values for summing evenly spaced rows using precomputed idx_map */
140-
// void sum_evenly_spaced_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C,
141-
// const int *idx_map);
142-
143-
/* Sum evenly spaced rows of A starting at offset into a row matrix C */
144-
void sum_spaced_rows_into_row_csr(const CSR_Matrix *A, CSR_Matrix *C,
145-
struct int_double_pair *pairs, int offset,
146-
int spacing);
147-
/* Fills the sparsity and index map for summing spaced rows into a row matrix */
148-
void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
149-
CSR_Matrix *C,
150-
int spacing, int *iwork,
151-
int *idx_map);
152-
15364
/* Count number of columns with nonzero entries */
15465
int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz);
15566

@@ -158,13 +69,6 @@ void insert_idx(int idx, int *arr, int len);
15869

15970
double csr_get_value(const CSR_Matrix *A, int row, int col);
16071

161-
/* iwork must be of size A->n*/
162-
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
163-
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork);
164-
165-
/* Fill values of A^T given sparsity pattern is already computed */
166-
void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork);
167-
16872
/* Expand symmetric CSR matrix A to full matrix C. A is assumed to store
16973
only upper triangle. C must be pre-allocated with sufficient nnz */
17074
void symmetrize_csr(const int *Ap, const int *Ai, int m, CSR_Matrix *C);

include/utils/CSR_sum.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#ifndef CSR_SUM_H
2+
#define CSR_SUM_H
3+
4+
#include "utils/CSR_Matrix.h"
5+
6+
/* forward declaration */
7+
struct int_double_pair;
8+
9+
/* Compute C = A + B where A, B, C are CSR matrices
10+
* A and B must have same dimensions
11+
* C must be pre-allocated with sufficient nnz capacity.
12+
* C must be different from A and B */
13+
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
14+
15+
/* Compute sparsity pattern of A + B where A, B, C are CSR matrices.
16+
* Fills C->p, C->i, and C->nnz; does not touch C->x. */
17+
void sum_csr_matrices_fill_sparsity(const CSR_Matrix *A, const CSR_Matrix *B,
18+
CSR_Matrix *C);
19+
20+
/* Fill only the values of C = A + B, assuming C's sparsity pattern (p and i)
21+
* is already filled and matches the union of A and B per row. Does not modify
22+
* C->p, C->i, or C->nnz. */
23+
void sum_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
24+
CSR_Matrix *C);
25+
26+
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */
27+
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,
28+
const double *d1, const double *d2);
29+
30+
/* Fill only the values of C = diag(d1) * A + diag(d2) * B, assuming C's sparsity
31+
* pattern (p and i) is already filled and matches the union of A and B per row.
32+
* Does not modify C->p, C->i, or C->nnz. */
33+
void sum_scaled_csr_matrices_fill_values(const CSR_Matrix *A, const CSR_Matrix *B,
34+
CSR_Matrix *C, const double *d1,
35+
const double *d2);
36+
37+
/* Sum all rows of A into a single row matrix C */
38+
void sum_all_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
39+
struct int_double_pair *pairs);
40+
41+
/* iwork must have size max(C->n, A->nnz), and idx_map must have size A->nnz. */
42+
void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix *C,
43+
int *iwork, int *idx_map);
44+
45+
/* Fill values of summed rows using precomputed idx_map and sparsity of C */
46+
// void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C,
47+
// const int *idx_map);
48+
49+
/* Fill accumulator for summing rows using precomputed idx_map for each nnz of A.
50+
Must memset accumulator to zero before calling. */
51+
void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map,
52+
double *accumulator);
53+
void idx_map_accumulator_with_spacing(const CSR_Matrix *A, const int *idx_map,
54+
double *accumulator, int spacing);
55+
56+
/* Sum blocks of rows of A into a matrix C */
57+
void sum_block_of_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
58+
struct int_double_pair *pairs, int row_block_size);
59+
60+
/* Build sparsity and index map for summing blocks of rows.
61+
* iwork must have size max(A->n, A->nnz), and idx_map must have size A->nnz. */
62+
void sum_block_of_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
63+
CSR_Matrix *C,
64+
int row_block_size, int *iwork,
65+
int *idx_map);
66+
67+
/* Sum evenly spaced rows of A into a matrix C */
68+
void sum_evenly_spaced_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
69+
struct int_double_pair *pairs, int row_spacing);
70+
71+
/* Build sparsity and index map for summing evenly spaced rows.
72+
* iwork must have size max(A->n, A->nnz), and idx_map must have size A->nnz. */
73+
void sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
74+
CSR_Matrix *C,
75+
int row_spacing,
76+
int *iwork, int *idx_map);
77+
78+
/* Sum evenly spaced rows of A starting at offset into a row matrix C */
79+
void sum_spaced_rows_into_row_csr(const CSR_Matrix *A, CSR_Matrix *C,
80+
struct int_double_pair *pairs, int offset,
81+
int spacing);
82+
83+
/* Fills the sparsity and index map for summing spaced rows into a row matrix */
84+
void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
85+
CSR_Matrix *C,
86+
int spacing, int *iwork,
87+
int *idx_map);
88+
89+
#endif /* CSR_SUM_H */

src/affine/add.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "utils/CSR_sum.h"
1920
#include <assert.h>
2021
#include <stdio.h>
2122
#include <stdlib.h>

src/affine/hstack.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "utils/CSR_sum.h"
1920
#include <assert.h>
2021
#include <stdio.h>
2122
#include <stdlib.h>

src/affine/sum.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "utils/CSR_sum.h"
1920
#include "utils/int_double_pair.h"
2021
#include "utils/mini_numpy.h"
2122
#include "utils/utils.h"

src/affine/trace.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19+
#include "utils/CSR_sum.h"
1920
#include "utils/int_double_pair.h"
2021
#include "utils/utils.h"
2122
#include <assert.h>

src/bivariate/left_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "bivariate.h"
1919
#include "subexpr.h"
2020
#include "utils/Timer.h"
21-
#include "utils/linalg.h"
21+
#include "utils/linalg_sparse_matmuls.h"
2222
#include <assert.h>
2323
#include <stdio.h>
2424
#include <stdlib.h>

src/bivariate/multiply.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
#include "bivariate.h"
1919
#include "subexpr.h"
20+
#include "utils/CSR_sum.h"
2021
#include <assert.h>
2122
#include <math.h>
2223
#include <stdio.h>

0 commit comments

Comments
 (0)