-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlinalg_sparse_matmuls.h
More file actions
42 lines (34 loc) · 1.66 KB
/
linalg_sparse_matmuls.h
File metadata and controls
42 lines (34 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#ifndef LINALG_H
#define LINALG_H
#include "CSC_Matrix.h"
#include "CSR_Matrix.h"
/* Compute sparsity pattern and values for the matrix-matrix multiplication
C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k,
without relying on generic sparse matrix-matrix multiplication. Specialized
logic for this is much faster (50-100x) than generic sparse matmul.
* J is provided in CSC format and is split into p blocks of n rows each
* C is returned in CSC format
* Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp],
where J = [J1; J2; ...; Jp]
*/
CSC_Matrix *block_left_multiply_alloc(const CSR_Matrix *A, const CSC_Matrix *J,
int p, size_t *mem);
void block_left_multiply_fill_values(const CSR_Matrix *A, const CSC_Matrix *J,
CSC_Matrix *C);
/* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector.
The output y is m*p-length vector corresponding to
y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n
elements.
*/
void block_left_multiply_vec(const CSR_Matrix *A, const double *x, double *y, int p);
/* Fill values of C = A @ B where A is CSR, B is CSC.
* C must have sparsity pattern already computed.
*/
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
CSR_Matrix *C);
/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
* Allocates and precomputes sparsity pattern. No workspace required.
*/
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B,
size_t *mem);
#endif /* LINALG_H */