11#include "bivariate.h"
2+ #include <assert.h>
23#include <math.h>
34#include <stdlib.h>
45
@@ -28,6 +29,8 @@ static void jacobian_init_vectors_args(expr *node)
2829
2930 expr * x = node -> left ;
3031 expr * y = node -> right ;
32+ assert (x -> var_id != NOT_A_VARIABLE && y -> var_id != NOT_A_VARIABLE );
33+ assert (x -> var_id != y -> var_id );
3134
3235 /* if x has lower variable idx than y it should appear first */
3336 if (x -> var_id < y -> var_id )
@@ -76,6 +79,97 @@ static void eval_jacobian_vector_args(expr *node)
7679 }
7780}
7881
82+ static void wsum_hess_init_vector_args (expr * node )
83+ {
84+ node -> wsum_hess = new_csr_matrix (node -> n_vars , node -> n_vars , 4 * node -> d1 );
85+ expr * x = node -> left ;
86+ expr * y = node -> right ;
87+
88+ int i , var1_id , var2_id ;
89+
90+ if (x -> var_id < y -> var_id )
91+ {
92+ var1_id = x -> var_id ;
93+ var2_id = y -> var_id ;
94+ }
95+ else
96+ {
97+ var1_id = y -> var_id ;
98+ var2_id = x -> var_id ;
99+ }
100+
101+ /* var1 rows of Hessian */
102+ for (i = 0 ; i < node -> d1 ; i ++ )
103+ {
104+ node -> wsum_hess -> p [var1_id + i ] = 2 * i ;
105+ node -> wsum_hess -> i [2 * i ] = var1_id + i ;
106+ node -> wsum_hess -> i [2 * i + 1 ] = var2_id + i ;
107+ }
108+
109+ int nnz = 2 * node -> d1 ;
110+
111+ /* rows between var1 and var2 */
112+ for (i = var1_id + node -> d1 ; i < var2_id ; i ++ )
113+ {
114+ node -> wsum_hess -> p [i ] = nnz ;
115+ }
116+
117+ /* var2 rows of Hessian */
118+ for (i = 0 ; i < node -> d1 ; i ++ )
119+ {
120+ node -> wsum_hess -> p [var2_id + i ] = nnz + 2 * i ;
121+ }
122+ memcpy (node -> wsum_hess -> i + nnz , node -> wsum_hess -> i , nnz * sizeof (int ));
123+
124+ /* remaining rows */
125+ for (i = var2_id + node -> d1 ; i <= node -> n_vars ; i ++ )
126+ {
127+ node -> wsum_hess -> p [i ] = 4 * node -> d1 ;
128+ }
129+ }
130+
131+ static void eval_wsum_hess_vector_args (expr * node , const double * w )
132+ {
133+
134+ int i , x_id , y_id ;
135+ double * x = node -> left -> value ;
136+ double * y = node -> right -> value ;
137+ double * hess = node -> wsum_hess -> x ;
138+
139+ if (node -> left -> var_id < node -> right -> var_id )
140+ {
141+ for (i = 0 ; i < node -> d1 ; i ++ )
142+ {
143+ hess [2 * i ] = w [i ] / x [i ];
144+ hess [2 * i + 1 ] = - w [i ] / y [i ];
145+ }
146+
147+ hess += 2 * node -> d1 ;
148+
149+ for (i = 0 ; i < node -> d1 ; i ++ )
150+ {
151+ hess [2 * i ] = - w [i ] / y [i ];
152+ hess [2 * i + 1 ] = w [i ] * x [i ] / (y [i ] * y [i ]);
153+ }
154+ }
155+ else
156+ {
157+ for (i = 0 ; i < node -> d1 ; i ++ )
158+ {
159+ hess [2 * i ] = w [i ] * x [i ] / (y [i ] * y [i ]);
160+ hess [2 * i + 1 ] = - w [i ] / y [i ];
161+ }
162+
163+ hess += 2 * node -> d1 ;
164+
165+ for (i = 0 ; i < node -> d1 ; i ++ )
166+ {
167+ hess [2 * i ] = - w [i ] / y [i ];
168+ hess [2 * i + 1 ] = w [i ] / x [i ];
169+ }
170+ }
171+ }
172+
79173expr * new_rel_entr_vector_args (expr * left , expr * right )
80174{
81175 expr * node = new_expr (left -> d1 , 1 , left -> n_vars );
@@ -86,6 +180,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
86180 node -> forward = forward_vector_args ;
87181 node -> jacobian_init = jacobian_init_vectors_args ;
88182 node -> eval_jacobian = eval_jacobian_vector_args ;
183+ node -> wsum_hess_init = wsum_hess_init_vector_args ;
184+ node -> eval_wsum_hess = eval_wsum_hess_vector_args ;
89185 // node->is_affine = is_affine_elementwise;
90186 // node->local_jacobian = local_jacobian;
91187 return node ;
0 commit comments