Skip to content

Commit 0e804a0

Browse files
authored
Left matmul 100x performance improvements (#44)
* test for profiling * 90 times faster sparsity pattern * fill values without forming A kron * update forward pass * fix test * ran formatter * profile forward pass * ran formatter * removed kroenecker product from hessian * minor changes * broadcast fix * ran formatter * improved documentation of block_left_mult
1 parent fa65481 commit 0e804a0

17 files changed

Lines changed: 1244 additions & 166 deletions

File tree

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ set_property(TARGET dnlp_diff PROPERTY POSITION_INDEPENDENT_CODE ON)
8181
# =============================================================================
8282
# C tests (only for standalone builds)
8383
# =============================================================================
84+
option(PROFILE_ONLY "Build only profiling tests" OFF)
85+
8486
if(NOT SKBUILD)
8587
include_directories(${PROJECT_SOURCE_DIR}/tests)
8688
enable_testing()
@@ -90,5 +92,11 @@ if(NOT SKBUILD)
9092
tests/test_helpers.c
9193
)
9294
target_link_libraries(all_tests dnlp_diff)
95+
96+
if(PROFILE_ONLY)
97+
target_compile_definitions(all_tests PRIVATE PROFILE_ONLY)
98+
message(STATUS "Building ONLY profiling tests")
99+
endif()
100+
93101
add_test(NAME AllTests COMMAND all_tests)
94102
endif()

include/subexpr.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ typedef struct left_matmul_expr
109109
expr base;
110110
CSR_Matrix *A;
111111
CSR_Matrix *AT;
112-
CSC_Matrix *CSC_work;
112+
int n_blocks;
113+
CSC_Matrix *Jchild_CSC;
114+
CSC_Matrix *J_CSC;
115+
int *csc_to_csr_workspace;
113116
} left_matmul_expr;
114117

115118
/* Right matrix multiplication: y = f(x) * A where f(x) is an expression.

include/utils/CSC_Matrix.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,7 @@ void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C)
5252
CSC_Matrix *csr_to_csc_fill_sparsity(const CSR_Matrix *A, int *iwork);
5353
void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork);
5454

55-
/* Allocate CSR matrix for C = A @ B where A is CSR, B is CSC
56-
* Precomputes sparsity pattern. No workspace required.
57-
*/
58-
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);
59-
60-
/* Fill values of C = A @ B where A is CSR, B is CSC
61-
* C must have sparsity pattern already computed
62-
*/
63-
void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
64-
CSR_Matrix *C);
55+
CSR_Matrix *csc_to_csr_fill_sparsity(const CSC_Matrix *A, int *iwork);
56+
void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork);
6557

6658
#endif /* CSC_MATRIX_H */

include/utils/CSR_Matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ typedef struct CSR_Matrix
2727

2828
/* Allocate a new CSR matrix with given dimensions and nnz */
2929
CSR_Matrix *new_csr_matrix(int m, int n, int nnz);
30+
CSR_Matrix *new_csr(const CSR_Matrix *A);
3031

3132
/* Free a CSR matrix */
3233
void free_csr_matrix(CSR_Matrix *matrix);

include/utils/linalg.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef LINALG_H
2+
#define LINALG_H
3+
4+
/* Forward declarations */
5+
struct CSR_Matrix;
6+
struct CSC_Matrix;
7+
8+
/* Compute sparsity pattern and values for the matrix-matrix multiplication
9+
C = (I_p kron A) @ J where A is m x n, J is (n*p) x k, and C is (m*p) x k,
10+
without relying on generic sparse matrix-matrix multiplication. Specialized
11+
logic for this is much faster (50-100x) than generic sparse matmul.
12+
13+
* J is provided in CSC format and is split into p blocks of n rows each
14+
* C is returned in CSC format
15+
* Mathematically it corresponds to C = [A @ J1; A @ J2; ...; A @ Jp],
16+
where J = [J1; J2; ...; Jp]
17+
*/
18+
struct CSC_Matrix *block_left_multiply_fill_sparsity(const struct CSR_Matrix *A,
19+
const struct CSC_Matrix *J,
20+
int p);
21+
22+
void block_left_multiply_fill_values(const struct CSR_Matrix *A,
23+
const struct CSC_Matrix *J,
24+
struct CSC_Matrix *C);
25+
26+
/* Compute y = kron(I_p, A) @ x where A is m x n and x is(n*p)-length vector.
27+
The output y is m*p-length vector corresponding to
28+
y = [A @ x1; A @ x2; ...; A @ xp] where x is divided into p blocks of n
29+
elements.
30+
*/
31+
void block_left_multiply_vec(const struct CSR_Matrix *A, const double *x, double *y,
32+
int p);
33+
34+
/* Fill values of C = A @ B where A is CSR, B is CSC.
35+
* C must have sparsity pattern already computed.
36+
*/
37+
void csr_csc_matmul_fill_values(const struct CSR_Matrix *A,
38+
const struct CSC_Matrix *B, struct CSR_Matrix *C);
39+
40+
/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
41+
* Allocates and precomputes sparsity pattern. No workspace required.
42+
*/
43+
struct CSR_Matrix *csr_csc_matmul_alloc(const struct CSR_Matrix *A,
44+
const struct CSC_Matrix *B);
45+
46+
#endif /* LINALG_H */

src/affine/broadcast.c

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static void jacobian_init(expr *node)
8989
else
9090
{
9191
/* Scalar broadcast: (1, 1) -> (m, n) */
92-
total_nnz = x->jacobian->nnz * node->d1 * node->d2;
92+
total_nnz = x->jacobian->nnz * node->size;
9393
}
9494

9595
node->jacobian = new_csr_matrix(node->size, node->n_vars, total_nnz);
@@ -99,10 +99,10 @@ static void jacobian_init(expr *node)
9999
// ---------------------------------------------------------------------
100100
CSR_Matrix *Jx = x->jacobian;
101101
CSR_Matrix *J = node->jacobian;
102-
J->nnz = 0;
103102

104103
if (bcast->type == BROADCAST_ROW)
105104
{
105+
J->nnz = 0;
106106
for (int i = 0; i < node->d2; i++)
107107
{
108108
int nnz_in_row = Jx->p[i + 1] - Jx->p[i];
@@ -117,22 +117,23 @@ static void jacobian_init(expr *node)
117117
J->nnz += nnz_in_row;
118118
}
119119
}
120+
assert(J->nnz == total_nnz);
120121
J->p[node->size] = total_nnz;
121122
}
122123
else if (bcast->type == BROADCAST_COL)
123124
{
124-
125125
/* copy column indices */
126126
tile_int(J->i, Jx->i, Jx->nnz, node->d2);
127127

128128
/* set row pointers */
129129
int offset = 0;
130130
for (int i = 0; i < node->d2; i++)
131131
{
132+
int nnz_in_row = Jx->p[i + 1] - Jx->p[i];
132133
for (int j = 0; j < node->d1; j++)
133134
{
134135
J->p[i * node->d1 + j] = offset;
135-
offset += Jx->p[1] - Jx->p[0];
136+
offset += nnz_in_row;
136137
}
137138
}
138139
assert(offset == total_nnz);
@@ -141,12 +142,12 @@ static void jacobian_init(expr *node)
141142
else
142143
{
143144
/* copy column indices */
144-
tile_int(J->i, Jx->i, Jx->nnz, node->d1 * node->d2);
145+
tile_int(J->i, Jx->i, Jx->nnz, node->size);
145146

146147
/* set row pointers */
147148
int offset = 0;
148149
int nnz = Jx->p[1] - Jx->p[0];
149-
for (int i = 0; i < node->d1 * node->d2; i++)
150+
for (int i = 0; i < node->size; i++)
150151
{
151152
J->p[i] = offset;
152153
offset += nnz;
@@ -163,10 +164,10 @@ static void eval_jacobian(expr *node)
163164
broadcast_expr *bcast = (broadcast_expr *) node;
164165
CSR_Matrix *Jx = node->left->jacobian;
165166
CSR_Matrix *J = node->jacobian;
166-
J->nnz = 0;
167167

168168
if (bcast->type == BROADCAST_ROW)
169169
{
170+
J->nnz = 0;
170171
for (int i = 0; i < node->d2; i++)
171172
{
172173
int nnz_in_row = Jx->p[i + 1] - Jx->p[i];
@@ -180,7 +181,7 @@ static void eval_jacobian(expr *node)
180181
}
181182
else
182183
{
183-
tile_double(J->x, Jx->x, Jx->nnz, node->d1 * node->d2);
184+
tile_double(J->x, Jx->x, Jx->nnz, node->size);
184185
}
185186
}
186187

@@ -268,9 +269,9 @@ expr *new_broadcast(expr *child, int d1, int d2)
268269
}
269270
else
270271
{
271-
fprintf(
272-
stderr,
273-
"ERROR: inconsistency of broadcasting between DNLP-diff and CVXPY. \n");
272+
fprintf(stderr,
273+
"ERROR: inconsistency of broadcasting between SparseDifferentiation"
274+
" and CVXPY. \n");
274275
exit(1);
275276
}
276277

src/bivariate/left_matmul.c

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
*/
1818
#include "bivariate.h"
1919
#include "subexpr.h"
20+
#include "utils/Timer.h"
21+
#include "utils/linalg.h"
2022
#include <assert.h>
2123
#include <stdio.h>
2224
#include <stdlib.h>
@@ -31,7 +33,9 @@
3133
* To compute the forward pass: vec(y) = A_kron @ vec(f(x)),
3234
where A_kron = I_p kron A is a Kronecker product of size (m*p) x (n*p),
3335
or more specificely, a block-diagonal matrix with p blocks of A along the
34-
diagonal.
36+
diagonal. In the refactored implementation we don't form A_kron explicitly,
37+
only conceptually. This led to a 100x speedup in the initialization of the
38+
Jacobian sparsity pattern.
3539
3640
* To compute the Jacobian: J_y = A_kron @ J_f(x), where J_f(x) is the
3741
Jacobian of f(x) of size (n*p) x n_vars.
@@ -42,7 +46,8 @@
4246
Working in terms of A_kron unifies the implementation of f(x) being
4347
vector-valued or matrix-valued.
4448
45-
49+
I (dance858) think we can get additional big speedups when A is dense by
50+
introducing a dense matrix class.
4651
*/
4752

4853
#include "utils/utils.h"
@@ -55,7 +60,9 @@ static void forward(expr *node, const double *u)
5560
node->left->forward(node->left, u);
5661

5762
/* y = A_kron @ vec(f(x)) */
58-
csr_matvec_wo_offset(((left_matmul_expr *) node)->A, x->value, node->value);
63+
CSR_Matrix *A = ((left_matmul_expr *) node)->A;
64+
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
65+
block_left_multiply_vec(A, x->value, node->value, n_blocks);
5966
}
6067

6168
static bool is_affine(const expr *node)
@@ -68,39 +75,47 @@ static void free_type_data(expr *node)
6875
left_matmul_expr *lin_node = (left_matmul_expr *) node;
6976
free_csr_matrix(lin_node->A);
7077
free_csr_matrix(lin_node->AT);
71-
if (lin_node->CSC_work)
72-
{
73-
free_csc_matrix(lin_node->CSC_work);
74-
}
78+
free_csc_matrix(lin_node->Jchild_CSC);
79+
free_csc_matrix(lin_node->J_CSC);
80+
free(lin_node->csc_to_csr_workspace);
7581
lin_node->A = NULL;
7682
lin_node->AT = NULL;
77-
lin_node->CSC_work = NULL;
83+
lin_node->Jchild_CSC = NULL;
84+
lin_node->J_CSC = NULL;
85+
lin_node->csc_to_csr_workspace = NULL;
7886
}
7987

8088
static void jacobian_init(expr *node)
8189
{
8290
expr *x = node->left;
8391
left_matmul_expr *lin_node = (left_matmul_expr *) node;
8492

85-
/* initialize child's jacobian and precompute sparsity of its transpose */
93+
/* initialize child's jacobian and precompute sparsity of its CSC */
8694
x->jacobian_init(x);
87-
lin_node->CSC_work = csr_to_csc_fill_sparsity(x->jacobian, node->iwork);
95+
lin_node->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->iwork);
8896

89-
/* precompute sparsity of this node's jacobian */
90-
node->jacobian = csr_csc_matmul_alloc(lin_node->A, lin_node->CSC_work);
97+
/* precompute sparsity of this node's jacobian in CSC and CSR */
98+
lin_node->J_CSC = block_left_multiply_fill_sparsity(
99+
lin_node->A, lin_node->Jchild_CSC, lin_node->n_blocks);
100+
node->jacobian =
101+
csc_to_csr_fill_sparsity(lin_node->J_CSC, lin_node->csc_to_csr_workspace);
91102
}
92103

93104
static void eval_jacobian(expr *node)
94105
{
95106
expr *x = node->left;
96-
left_matmul_expr *lin_node = (left_matmul_expr *) node;
107+
left_matmul_expr *lnode = (left_matmul_expr *) node;
108+
109+
CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
110+
CSC_Matrix *J_CSC = lnode->J_CSC;
97111

98112
/* evaluate child's jacobian and convert to CSC */
99113
x->eval_jacobian(x);
100-
csr_to_csc_fill_values(x->jacobian, lin_node->CSC_work, node->iwork);
114+
csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->iwork);
101115

102-
/* compute this node's jacobian */
103-
csr_csc_matmul_fill_values(lin_node->A, lin_node->CSC_work, node->jacobian);
116+
/* compute this node's jacobian: */
117+
block_left_multiply_fill_values(lnode->A, Jchild_CSC, J_CSC);
118+
csc_to_csr_fill_values(J_CSC, node->jacobian, lnode->csc_to_csr_workspace);
104119
}
105120

106121
static void wsum_hess_init(expr *node)
@@ -115,15 +130,17 @@ static void wsum_hess_init(expr *node)
115130
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
116131

117132
/* work for computing A^T w*/
118-
int A_n = ((left_matmul_expr *) node)->A->n;
119-
node->dwork = (double *) malloc(A_n * sizeof(double));
133+
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
134+
int dim = ((left_matmul_expr *) node)->A->n * n_blocks;
135+
node->dwork = (double *) malloc(dim * sizeof(double));
120136
}
121137

122138
static void eval_wsum_hess(expr *node, const double *w)
123139
{
124140
/* compute A^T w*/
125-
left_matmul_expr *lin_node = (left_matmul_expr *) node;
126-
csr_matvec_wo_offset(lin_node->AT, w, node->dwork);
141+
CSR_Matrix *AT = ((left_matmul_expr *) node)->AT;
142+
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
143+
block_left_multiply_vec(AT, w, node->dwork, n_blocks);
127144

128145
node->left->eval_wsum_hess(node->left, node->dwork);
129146
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
@@ -132,10 +149,10 @@ static void eval_wsum_hess(expr *node, const double *w)
132149

133150
expr *new_left_matmul(expr *u, const CSR_Matrix *A)
134151
{
135-
/* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users to do
136-
A @ u where u is (n, ) which in C is actually (1, n). In that case the result
137-
of A @ u is (m, ), which is (1, m) according to broadcasting rules. We
138-
therefore check if this is the case. */
152+
/* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
153+
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
154+
the result of A @ u is (m, ), which is (1, m) according to broadcasting
155+
rules. We therefore check if this is the case. */
139156
int d1, d2, n_blocks;
140157
if (u->d1 == A->n)
141158
{
@@ -164,12 +181,17 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
164181
node->left = u;
165182
expr_retain(u);
166183

167-
/* Initialize type-specific fields */
168-
lin_node->A = block_diag_repeat_csr(A, n_blocks);
169-
int alloc = MAX(lin_node->A->n, node->n_vars);
170-
node->iwork = (int *) malloc(alloc * sizeof(int));
184+
/* allocate workspace. iwork is used for transposing A (requiring size A->n)
185+
and for converting J_child csr to csc (requring size node->n_vars).
186+
csc_to_csr_workspace is used for converting J_CSC to CSR (requring node->size)
187+
*/
188+
node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int));
189+
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
190+
lin_node->n_blocks = n_blocks;
191+
192+
/* store A and AT */
193+
lin_node->A = new_csr(A);
171194
lin_node->AT = transpose(lin_node->A, node->iwork);
172-
lin_node->CSC_work = NULL;
173195

174196
return node;
175197
}

src/bivariate/right_matmul.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
#include "bivariate.h"
1919
#include "subexpr.h"
20+
#include "utils/linalg.h"
2021
#include <stdlib.h>
2122

2223
/* This file implements the atom 'right_matmul' corresponding to the operation y =

src/problem.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ static inline void print_end_message(const Diff_engine_stats *stats)
246246
{
247247
printf("\n"
248248
"============================================================\n"
249-
" DNLP Differentiation Engine v%s\n"
249+
" SparseDifferentiation v%s\n"
250250
" (c) D. Cederberg and W. Zhang, Stanford University, 2026\n"
251251
"============================================================\n",
252252
DIFF_ENGINE_VERSION);

0 commit comments

Comments
 (0)