11#include "elementwise_full_dom.h"
22#include "subexpr.h"
3+ #include "utils/CSC_Matrix.h"
34#include "utils/CSR_Matrix.h"
5+ #include "utils/CSR_sum.h"
46#include <stdio.h>
57#include <stdlib.h>
68#include <string.h>
@@ -20,14 +22,14 @@ void jacobian_init_elementwise(expr *node)
2022 }
2123 node -> jacobian -> p [node -> size ] = node -> size ;
2224 }
23- /* otherwise it will be a linear operator */
2425 else
2526 {
2627 /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
2728 child -> jacobian_init (child );
2829 CSR_Matrix * Jg = child -> jacobian ;
2930 node -> jacobian = new_csr_matrix (Jg -> m , Jg -> n , Jg -> nnz );
3031 node -> dwork = (double * ) malloc (node -> size * sizeof (double ));
32+ node -> local_jac_diag = (double * ) malloc (node -> size * sizeof (double ));
3133
3234 /* copy sparsity pattern of child */
3335 memcpy (node -> jacobian -> p , Jg -> p , sizeof (int ) * (Jg -> m + 1 ));
@@ -48,7 +50,8 @@ void eval_jacobian_elementwise(expr *node)
4850 /* jacobian of h(x) = f(g(x)) is Jf @ Jg, and here Jf is diagonal */
4951 child -> eval_jacobian (child );
5052 CSR_Matrix * Jg = child -> jacobian ;
51- node -> local_jacobian (node , node -> dwork );
53+ node -> local_jacobian (node , node -> local_jac_diag );
54+ memcpy (node -> dwork , node -> local_jac_diag , node -> size * sizeof (double ));
5255 diag_csr_mult_fill_values (node -> dwork , Jg , node -> jacobian );
5356 }
5457}
@@ -59,7 +62,7 @@ void wsum_hess_init_elementwise(expr *node)
5962 int id = child -> var_id ;
6063 int i ;
6164
62- /* if the variable is a child*/
65+ /* if the variable is a child */
6366 if (id != NOT_A_VARIABLE )
6467 {
6568 node -> wsum_hess = new_csr_matrix (node -> n_vars , node -> n_vars , node -> size );
@@ -75,11 +78,38 @@ void wsum_hess_init_elementwise(expr *node)
7578 node -> wsum_hess -> p [i ] = node -> size ;
7679 }
7780 }
78- /* otherwise it will be a linear operator */
7981 else
8082 {
81- linear_op_expr * lin_child = (linear_op_expr * ) child ;
82- node -> wsum_hess = ATA_alloc (lin_child -> A_csc );
83+ /* Hessian of h(x) = w^T f(g(x) is term1 + term 2 where
84+ term1 = J_g^T @ D @ J_g with D = sum_i w_i Hf_i,
85+ term2 = sum_i (J_f^T w)_i^T Hg_i.
86+
87+ For elementwise functions, D is diagonal. */
88+ jacobian_csc_init (child );
89+ CSC_Matrix * Jg = child -> jacobian_csc ;
90+
91+ if (child -> is_affine (child ))
92+ {
93+ node -> wsum_hess = ATA_alloc (Jg );
94+ }
95+ else
96+ {
97+ /* term1: Jg^T @ D @ Jg */
98+ node -> hess_term1 = ATA_alloc (Jg );
99+
100+ /* term2: child's Hessian */
101+ child -> wsum_hess_init (child );
102+ CSR_Matrix * Hg = child -> wsum_hess ;
103+ node -> hess_term2 = new_csr_matrix (Hg -> m , Hg -> n , Hg -> nnz );
104+ memcpy (node -> hess_term2 -> p , Hg -> p , (Hg -> m + 1 ) * sizeof (int ));
105+ memcpy (node -> hess_term2 -> i , Hg -> i , Hg -> nnz * sizeof (int ));
106+
107+ /* wsum_hess = term1 + term2 */
108+ int max_nnz = node -> hess_term1 -> nnz + node -> hess_term2 -> nnz ;
109+ node -> wsum_hess = new_csr_matrix (node -> n_vars , node -> n_vars , max_nnz );
110+ sum_csr_matrices_fill_sparsity (node -> hess_term1 , node -> hess_term2 ,
111+ node -> wsum_hess );
112+ }
83113 }
84114}
85115
@@ -93,10 +123,43 @@ void eval_wsum_hess_elementwise(expr *node, const double *w)
93123 }
94124 else
95125 {
96- /* Child will be a linear operator */
97- linear_op_expr * lin_child = (linear_op_expr * ) child ;
98- node -> local_wsum_hess (node , node -> dwork , w );
99- ATDA_fill_values (lin_child -> A_csc , node -> dwork , node -> wsum_hess );
126+ if (child -> is_affine (child ))
127+ {
128+ if (!child -> jacobian_csc_filled )
129+ {
130+ csr_to_csc_fill_values (child -> jacobian , child -> jacobian_csc ,
131+ child -> csc_work );
132+ child -> jacobian_csc_filled = true;
133+ }
134+
135+ node -> local_wsum_hess (node , node -> dwork , w );
136+ ATDA_fill_values (child -> jacobian_csc , node -> dwork , node -> wsum_hess );
137+ }
138+ else
139+ {
140+ /* refresh CSC jacobian values */
141+ csr_to_csc_fill_values (child -> jacobian , child -> jacobian_csc ,
142+ child -> csc_work );
143+
144+ /* term1: Jg^T @ D @ Jg */
145+ node -> local_wsum_hess (node , node -> dwork , w );
146+ ATDA_fill_values (child -> jacobian_csc , node -> dwork , node -> hess_term1 );
147+
148+ /* term2: child Hessian with weight Jf^T w */
149+ memcpy (node -> dwork , node -> local_jac_diag , node -> size * sizeof (double ));
150+ for (int k = 0 ; k < node -> size ; k ++ )
151+ {
152+ node -> dwork [k ] *= w [k ];
153+ }
154+
155+ child -> eval_wsum_hess (child , node -> dwork );
156+ memcpy (node -> hess_term2 -> x , child -> wsum_hess -> x ,
157+ child -> wsum_hess -> nnz * sizeof (double ));
158+
159+ /* wsum_hess = term1 + term2 */
160+ sum_csr_matrices_fill_values (node -> hess_term1 , node -> hess_term2 ,
161+ node -> wsum_hess );
162+ }
100163 }
101164}
102165
0 commit comments