1717 */
1818#include "bivariate.h"
1919#include "subexpr.h"
20- #include "utils/Timer.h"
21- #include "utils/linalg_sparse_matmuls.h"
20+ #include "utils/matrix.h"
2221#include <assert.h>
2322#include <stdio.h>
2423#include <stdlib.h>
4544
4645 Working in terms of A_kron unifies the implementation of f(x) being
4746 vector-valued or matrix-valued.
48-
49- I (dance858) think we can get additional big speedups when A is dense by
50- introducing a dense matrix class.
5147*/
5248
5349#include "utils/utils.h"
@@ -60,9 +56,9 @@ static void forward(expr *node, const double *u)
6056 node -> left -> forward (node -> left , u );
6157
6258 /* y = A_kron @ vec(f(x)) */
63- CSR_Matrix * A = ((left_matmul_expr * ) node )-> A ;
59+ Matrix * A = ((left_matmul_expr * ) node )-> A ;
6460 int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
65- block_left_multiply_vec (A , x -> value , node -> value , n_blocks );
61+ A -> block_left_mult_vec (A , x -> value , node -> value , n_blocks );
6662}
6763
6864static bool is_affine (const expr * node )
@@ -72,33 +68,32 @@ static bool is_affine(const expr *node)
7268
7369static void free_type_data (expr * node )
7470{
75- left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
76- free_csr_matrix ( lin_node -> A );
77- free_csr_matrix ( lin_node -> AT );
78- free_csc_matrix (lin_node -> Jchild_CSC );
79- free_csc_matrix (lin_node -> J_CSC );
80- free (lin_node -> csc_to_csr_workspace );
81- lin_node -> A = NULL ;
82- lin_node -> AT = NULL ;
83- lin_node -> Jchild_CSC = NULL ;
84- lin_node -> J_CSC = NULL ;
85- lin_node -> csc_to_csr_workspace = NULL ;
71+ left_matmul_expr * lnode = (left_matmul_expr * ) node ;
72+ free_matrix ( lnode -> A );
73+ free_matrix ( lnode -> AT );
74+ free_csc_matrix (lnode -> Jchild_CSC );
75+ free_csc_matrix (lnode -> J_CSC );
76+ free (lnode -> csc_to_csr_work );
77+ lnode -> A = NULL ;
78+ lnode -> AT = NULL ;
79+ lnode -> Jchild_CSC = NULL ;
80+ lnode -> J_CSC = NULL ;
81+ lnode -> csc_to_csr_work = NULL ;
8682}
8783
8884static void jacobian_init (expr * node )
8985{
9086 expr * x = node -> left ;
91- left_matmul_expr * lin_node = (left_matmul_expr * ) node ;
87+ left_matmul_expr * lnode = (left_matmul_expr * ) node ;
9288
9389 /* initialize child's jacobian and precompute sparsity of its CSC */
9490 x -> jacobian_init (x );
95- lin_node -> Jchild_CSC = csr_to_csc_fill_sparsity (x -> jacobian , node -> iwork );
91+ lnode -> Jchild_CSC = csr_to_csc_fill_sparsity (x -> jacobian , node -> iwork );
9692
9793 /* precompute sparsity of this node's jacobian in CSC and CSR */
98- lin_node -> J_CSC = block_left_multiply_fill_sparsity (
99- lin_node -> A , lin_node -> Jchild_CSC , lin_node -> n_blocks );
100- node -> jacobian =
101- csc_to_csr_fill_sparsity (lin_node -> J_CSC , lin_node -> csc_to_csr_workspace );
94+ lnode -> J_CSC = lnode -> A -> block_left_mult_sparsity (lnode -> A , lnode -> Jchild_CSC ,
95+ lnode -> n_blocks );
96+ node -> jacobian = csc_to_csr_fill_sparsity (lnode -> J_CSC , lnode -> csc_to_csr_work );
10297}
10398
10499static void eval_jacobian (expr * node )
@@ -114,8 +109,8 @@ static void eval_jacobian(expr *node)
114109 csr_to_csc_fill_values (x -> jacobian , Jchild_CSC , node -> iwork );
115110
116111 /* compute this node's jacobian: */
117- block_left_multiply_fill_values (lnode -> A , Jchild_CSC , J_CSC );
118- csc_to_csr_fill_values (J_CSC , node -> jacobian , lnode -> csc_to_csr_workspace );
112+ lnode -> A -> block_left_mult_values (lnode -> A , Jchild_CSC , J_CSC );
113+ csc_to_csr_fill_values (J_CSC , node -> jacobian , lnode -> csc_to_csr_work );
119114}
120115
121116static void wsum_hess_init (expr * node )
@@ -131,16 +126,16 @@ static void wsum_hess_init(expr *node)
131126
132127 /* work for computing A^T w*/
133128 int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
134- int dim = ((left_matmul_expr * ) node )-> A -> n * n_blocks ;
129+ int dim = ((left_matmul_expr * ) node )-> AT -> m * n_blocks ;
135130 node -> dwork = (double * ) malloc (dim * sizeof (double ));
136131}
137132
138133static void eval_wsum_hess (expr * node , const double * w )
139134{
140135 /* compute A^T w*/
141- CSR_Matrix * AT = ((left_matmul_expr * ) node )-> AT ;
136+ Matrix * AT = ((left_matmul_expr * ) node )-> AT ;
142137 int n_blocks = ((left_matmul_expr * ) node )-> n_blocks ;
143- block_left_multiply_vec (AT , w , node -> dwork , n_blocks );
138+ AT -> block_left_mult_vec (AT , w , node -> dwork , n_blocks );
144139
145140 node -> left -> eval_wsum_hess (node -> left , node -> dwork );
146141 memcpy (node -> wsum_hess -> x , node -> left -> wsum_hess -> x ,
@@ -173,25 +168,64 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
173168 }
174169
175170 /* Allocate the type-specific struct */
176- left_matmul_expr * lin_node =
171+ left_matmul_expr * lnode =
177172 (left_matmul_expr * ) calloc (1 , sizeof (left_matmul_expr ));
178- expr * node = & lin_node -> base ;
173+ expr * node = & lnode -> base ;
179174 init_expr (node , d1 , d2 , u -> n_vars , forward , jacobian_init , eval_jacobian ,
180175 is_affine , wsum_hess_init , eval_wsum_hess , free_type_data );
181176 node -> left = u ;
182177 expr_retain (u );
183178
184- /* allocate workspace. iwork is used for transposing A (requiring size A->n)
185- and for converting J_child csr to csc (requring size node->n_vars ).
186- csc_to_csr_workspace is used for converting J_CSC to CSR (requring node->size)
187- */
179+ /* allocate workspace. iwork is used for converting J_child csr to csc
180+ (requiring size node->n_vars) and for transposing A (requiring size A->n ).
181+ csc_to_csr_work is used for converting J_CSC to CSR (requiring
182+ node->size) */
188183 node -> iwork = (int * ) malloc (MAX (A -> n , node -> n_vars ) * sizeof (int ));
189- lin_node -> csc_to_csr_workspace = (int * ) malloc (node -> size * sizeof (int ));
190- lin_node -> n_blocks = n_blocks ;
184+ lnode -> csc_to_csr_work = (int * ) malloc (node -> size * sizeof (int ));
185+ lnode -> n_blocks = n_blocks ;
191186
192187 /* store A and AT */
193- lin_node -> A = new_csr (A );
194- lin_node -> AT = transpose (lin_node -> A , node -> iwork );
188+ lnode -> A = new_sparse_matrix (A );
189+ lnode -> AT = sparse_matrix_trans ((const Sparse_Matrix * ) lnode -> A , node -> iwork );
190+
191+ return node ;
192+ }
193+
194+ expr * new_left_matmul_dense (expr * u , int m , int n , const double * data )
195+ {
196+ int d1 , d2 , n_blocks ;
197+ if (u -> d1 == n )
198+ {
199+ d1 = m ;
200+ d2 = u -> d2 ;
201+ n_blocks = u -> d2 ;
202+ }
203+ else if (u -> d2 == n && u -> d1 == 1 )
204+ {
205+ d1 = 1 ;
206+ d2 = m ;
207+ n_blocks = 1 ;
208+ }
209+ else
210+ {
211+ fprintf (stderr , "Error in new_left_matmul_dense: dimension mismatch\n" );
212+ exit (1 );
213+ }
214+
215+ left_matmul_expr * lnode =
216+ (left_matmul_expr * ) calloc (1 , sizeof (left_matmul_expr ));
217+ expr * node = & lnode -> base ;
218+ init_expr (node , d1 , d2 , u -> n_vars , forward , jacobian_init , eval_jacobian ,
219+ is_affine , wsum_hess_init , eval_wsum_hess , free_type_data );
220+ node -> left = u ;
221+ expr_retain (u );
222+
223+ node -> iwork = (int * ) malloc (MAX (n , node -> n_vars ) * sizeof (int ));
224+ lnode -> csc_to_csr_work = (int * ) malloc (node -> size * sizeof (int ));
225+ lnode -> n_blocks = n_blocks ;
226+
227+ lnode -> A = new_dense_matrix (m , n , data );
228+ lnode -> AT = dense_matrix_trans ((const Dense_Matrix * ) lnode -> A );
195229
196230 return node ;
197231}
0 commit comments