11#include "bivariate.h"
2+ #include "subexpr.h"
3+ #include "utils/CSC_Matrix.h"
24#include <assert.h>
35#include <math.h>
46#include <stdlib.h>
@@ -58,14 +60,15 @@ static void jacobian_init(expr *node)
5860 }
5961 }
6062 }
61- else /* left node is not a variable */
63+ else /* left node is not a variable (guaranteed to be a linear operator) */
6264 {
65+ linear_op_expr * lin_x = (linear_op_expr * ) x ;
6366 node -> dwork = (double * ) malloc (x -> d1 * sizeof (double ));
6467
6568 /* compute required allocation and allocate jacobian */
6669 bool * col_nz = (bool * ) calloc (
6770 node -> n_vars , sizeof (bool )); /* TODO: could use iwork here instead*/
68- int nonzero_cols = count_nonzero_cols (x -> jacobian , col_nz );
71+ int nonzero_cols = count_nonzero_cols (lin_x -> base . jacobian , col_nz );
6972 node -> jacobian = new_csr_matrix (1 , node -> n_vars , nonzero_cols + 1 );
7073
7174 /* precompute column indices */
@@ -88,11 +91,8 @@ static void jacobian_init(expr *node)
8891 node -> jacobian -> p [0 ] = 0 ;
8992 node -> jacobian -> p [1 ] = node -> jacobian -> nnz ;
9093
91- /* store A^T of child's A to simplify chain rule computation */
92- node -> iwork = (int * ) malloc (x -> jacobian -> n * sizeof (int ));
93- node -> CSR_work = transpose (x -> jacobian , node -> iwork );
94-
9594 /* find position where y should be inserted */
95+ node -> iwork = (int * ) malloc (sizeof (int ));
9696 for (int j = 0 ; j < node -> jacobian -> nnz ; j ++ )
9797 {
9898 if (node -> jacobian -> i [j ] == y -> var_id )
@@ -132,14 +132,16 @@ static void eval_jacobian(expr *node)
132132 }
133133 else /* x is not a variable */
134134 {
135+ CSC_Matrix * A_csc = ((linear_op_expr * ) x )-> A_csc ;
136+
135137 /* local jacobian */
136138 for (int j = 0 ; j < x -> d1 ; j ++ )
137139 {
138140 node -> dwork [j ] = (2.0 * x -> value [j ]) / y -> value [0 ];
139141 }
140142
141- /* chain rule (no derivative wrt y) */
142- csr_matvec_fill_values ( node -> CSR_work , node -> dwork , node -> jacobian );
143+ /* chain rule (no derivative wrt y) using CSC format */
144+ csc_matvec_fill_values ( A_csc , node -> dwork , node -> jacobian );
143145
144146 /* insert derivative wrt y at right place (for correctness this assumes
145147 that y does not appear in the denominator, but this will always be
0 commit comments