11#include "elementwise_full_dom.h"
22#include "subexpr.h"
3+ #include "utils/CSR_Matrix.h"
34#include <stdio.h>
45#include <stdlib.h>
56#include <string.h>
@@ -22,14 +23,15 @@ void jacobian_init_elementwise(expr *node)
2223 /* otherwise it will be a linear operator */
2324 else
2425 {
25- CSR_Matrix * J = ((linear_op_expr * ) child )-> A_csr ;
26- node -> jacobian = new_csr_matrix (J -> m , J -> n , J -> nnz );
27-
26+ /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
27+ child -> jacobian_init (child );
28+ CSR_Matrix * Jg = child -> jacobian ;
29+ node -> jacobian = new_csr_matrix (Jg -> m , Jg -> n , Jg -> nnz );
2830 node -> dwork = (double * ) malloc (node -> size * sizeof (double ));
2931
3032 /* copy sparsity pattern of child */
31- memcpy (node -> jacobian -> p , J -> p , sizeof (int ) * (J -> m + 1 ));
32- memcpy (node -> jacobian -> i , J -> i , sizeof (int ) * J -> nnz );
33+ memcpy (node -> jacobian -> p , Jg -> p , sizeof (int ) * (Jg -> m + 1 ));
34+ memcpy (node -> jacobian -> i , Jg -> i , sizeof (int ) * Jg -> nnz );
3335 }
3436}
3537
@@ -43,10 +45,11 @@ void eval_jacobian_elementwise(expr *node)
4345 }
4446 else
4547 {
46- /* Child will be a linear operator */
47- linear_op_expr * lin_child = (linear_op_expr * ) child ;
48+ /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
49+ child -> eval_jacobian (child );
50+ CSR_Matrix * Jg = child -> jacobian ;
4851 node -> local_jacobian (node , node -> dwork );
49- diag_csr_mult_fill_values (node -> dwork , lin_child -> A_csr , node -> jacobian );
52+ diag_csr_mult_fill_values (node -> dwork , Jg , node -> jacobian );
5053 }
5154}
5255
0 commit comments