1717 */
1818#include "bivariate.h"
1919#include "subexpr.h"
20+ #include "utils/Timer.h"
21+ #include "utils/linalg.h"
2022#include <assert.h>
2123#include <stdio.h>
2224#include <stdlib.h>
3133 * To compute the forward pass: vec(y) = A_kron @ vec(f(x)),
3234 where A_kron = I_p kron A is a Kronecker product of size (m*p) x (n*p),
3335 or more specificely, a block-diagonal matrix with p blocks of A along the
34- diagonal.
36+ diagonal. In the refactored implementation we don't form A_kron explicitly,
37+ only conceptually. This led to a 100x speedup in the initialization of the
38+ Jacobian sparsity pattern.
3539
3640 * To compute the Jacobian: J_y = A_kron @ J_f(x), where J_f(x) is the
3741 Jacobian of f(x) of size (n*p) x n_vars.
4246 Working in terms of A_kron unifies the implementation of f(x) being
4347 vector-valued or matrix-valued.
4448
45-
49+ I (dance858) think we can get additional big speedups when A is dense by
50+ introducing a dense matrix class.
4651*/
4752
4853#include "utils/utils.h"
@@ -55,7 +60,9 @@ static void forward(expr *node, const double *u)
5560 node -> left -> forward (node -> left , u );
5661
5762 /* y = A_kron @ vec(f(x)) */
58- csr_matvec_wo_offset (((left_matmul_expr * ) node )-> A , x -> value , node -> value );
63+ CSR_Matrix * A = ((left_matmul_expr * ) node )-> A ;
64+ int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
65+ block_left_multiply_vec (A , x -> value , node -> value , n_blocks );
5966}
6067
6168static bool is_affine (const expr * node )
@@ -68,39 +75,47 @@ static void free_type_data(expr *node)
6875 left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
6976 free_csr_matrix (lin_node -> A );
7077 free_csr_matrix (lin_node -> AT );
71- if (lin_node -> CSC_work )
72- {
73- free_csc_matrix (lin_node -> CSC_work );
74- }
78+ free_csc_matrix (lin_node -> Jchild_CSC );
79+ free_csc_matrix (lin_node -> J_CSC );
80+ free (lin_node -> csc_to_csr_workspace );
7581 lin_node -> A = NULL ;
7682 lin_node -> AT = NULL ;
77- lin_node -> CSC_work = NULL ;
83+ lin_node -> Jchild_CSC = NULL ;
84+ lin_node -> J_CSC = NULL ;
85+ lin_node -> csc_to_csr_workspace = NULL ;
7886}
7987
8088static void jacobian_init (expr * node )
8189{
8290 expr * x = node -> left ;
8391 left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
8492
85- /* initialize child's jacobian and precompute sparsity of its transpose */
93+ /* initialize child's jacobian and precompute sparsity of its CSC */
8694 x -> jacobian_init (x );
87- lin_node -> CSC_work = csr_to_csc_fill_sparsity (x -> jacobian , node -> iwork );
95+ lin_node -> Jchild_CSC = csr_to_csc_fill_sparsity (x -> jacobian , node -> iwork );
8896
89- /* precompute sparsity of this node's jacobian */
90- node -> jacobian = csr_csc_matmul_alloc (lin_node -> A , lin_node -> CSC_work );
97+ /* precompute sparsity of this node's jacobian in CSC and CSR */
98+ lin_node -> J_CSC = block_left_multiply_fill_sparsity (
99+ lin_node -> A , lin_node -> Jchild_CSC , lin_node -> n_blocks );
100+ node -> jacobian =
101+ csc_to_csr_fill_sparsity (lin_node -> J_CSC , lin_node -> csc_to_csr_workspace );
91102}
92103
93104static void eval_jacobian (expr * node )
94105{
95106 expr * x = node -> left ;
96- left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
107+ left_matmul_expr * lnode = (left_matmul_expr * ) node ;
108+
109+ CSC_Matrix * Jchild_CSC = lnode -> Jchild_CSC ;
110+ CSC_Matrix * J_CSC = lnode -> J_CSC ;
97111
98112 /* evaluate child's jacobian and convert to CSC */
99113 x -> eval_jacobian (x );
100- csr_to_csc_fill_values (x -> jacobian , lin_node -> CSC_work , node -> iwork );
114+ csr_to_csc_fill_values (x -> jacobian , Jchild_CSC , node -> iwork );
101115
102- /* compute this node's jacobian */
103- csr_csc_matmul_fill_values (lin_node -> A , lin_node -> CSC_work , node -> jacobian );
116+ /* compute this node's jacobian: */
117+ block_left_multiply_fill_values (lnode -> A , Jchild_CSC , J_CSC );
118+ csc_to_csr_fill_values (J_CSC , node -> jacobian , lnode -> csc_to_csr_workspace );
104119}
105120
106121static void wsum_hess_init (expr * node )
@@ -115,15 +130,17 @@ static void wsum_hess_init(expr *node)
115130 memcpy (node -> wsum_hess -> i , x -> wsum_hess -> i , x -> wsum_hess -> nnz * sizeof (int ));
116131
117132 /* work for computing A^T w*/
118- int A_n = ((left_matmul_expr * ) node )-> A -> n ;
119- node -> dwork = (double * ) malloc (A_n * sizeof (double ));
133+ int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
134+ int dim = ((left_matmul_expr * ) node )-> A -> n * n_blocks ;
135+ node -> dwork = (double * ) malloc (dim * sizeof (double ));
120136}
121137
122138static void eval_wsum_hess (expr * node , const double * w )
123139{
124140 /* compute A^T w*/
125- left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
126- csr_matvec_wo_offset (lin_node -> AT , w , node -> dwork );
141+ CSR_Matrix * AT = ((left_matmul_expr * ) node )-> AT ;
142+ int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
143+ block_left_multiply_vec (AT , w , node -> dwork , n_blocks );
127144
128145 node -> left -> eval_wsum_hess (node -> left , node -> dwork );
129146 memcpy (node -> wsum_hess -> x , node -> left -> wsum_hess -> x ,
@@ -132,10 +149,10 @@ static void eval_wsum_hess(expr *node, const double *w)
132149
133150expr * new_left_matmul (expr * u , const CSR_Matrix * A )
134151{
135- /* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users to do
136- A @ u where u is (n, ) which in C is actually (1, n). In that case the result
137- of A @ u is (m, ), which is (1, m) according to broadcasting rules. We
138- therefore check if this is the case. */
152+ /* We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
153+ to do A @ u where u is (n, ) which in C is actually (1, n). In that case
154+ the result of A @ u is (m, ), which is (1, m) according to broadcasting
155+ rules. We therefore check if this is the case. */
139156 int d1 , d2 , n_blocks ;
140157 if (u -> d1 == A -> n )
141158 {
@@ -164,12 +181,17 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
164181 node -> left = u ;
165182 expr_retain (u );
166183
167- /* Initialize type-specific fields */
168- lin_node -> A = block_diag_repeat_csr (A , n_blocks );
169- int alloc = MAX (lin_node -> A -> n , node -> n_vars );
170- node -> iwork = (int * ) malloc (alloc * sizeof (int ));
184+ /* allocate workspace. iwork is used for transposing A (requiring size A->n)
185+ and for converting J_child csr to csc (requring size node->n_vars).
186+ csc_to_csr_workspace is used for converting J_CSC to CSR (requring node->size)
187+ */
188+ node -> iwork = (int * ) malloc (MAX (A -> n , node -> n_vars ) * sizeof (int ));
189+ lin_node -> csc_to_csr_workspace = (int * ) malloc (node -> size * sizeof (int ));
190+ lin_node -> n_blocks = n_blocks ;
191+
192+ /* store A and AT */
193+ lin_node -> A = new_csr (A );
171194 lin_node -> AT = transpose (lin_node -> A , node -> iwork );
172- lin_node -> CSC_work = NULL ;
173195
174196 return node ;
175197}
0 commit comments