11#include "bivariate.h"
2+ #include <assert.h>
23#include <math.h>
34#include <stdlib.h>
5+ #include <string.h>
46
57// --------------------------------------------------------------------
68// Implementation of relative entropy when both arguments are vectors.
@@ -28,6 +30,8 @@ static void jacobian_init_vectors_args(expr *node)
2830
2931 expr * x = node -> left ;
3032 expr * y = node -> right ;
33+ assert (x -> var_id != NOT_A_VARIABLE && y -> var_id != NOT_A_VARIABLE );
34+ assert (x -> var_id != y -> var_id );
3135
3236 /* if x has lower variable idx than y it should appear first */
3337 if (x -> var_id < y -> var_id )
@@ -76,6 +80,95 @@ static void eval_jacobian_vector_args(expr *node)
7680 }
7781}
7882
83+ static void wsum_hess_init_vector_args (expr * node )
84+ {
85+ node -> wsum_hess = new_csr_matrix (node -> n_vars , node -> n_vars , 4 * node -> d1 );
86+ expr * x = node -> left ;
87+ expr * y = node -> right ;
88+
89+ int i , var1_id , var2_id ;
90+
91+ if (x -> var_id < y -> var_id )
92+ {
93+ var1_id = x -> var_id ;
94+ var2_id = y -> var_id ;
95+ }
96+ else
97+ {
98+ var1_id = y -> var_id ;
99+ var2_id = x -> var_id ;
100+ }
101+
102+ /* var1 rows of Hessian */
103+ for (i = 0 ; i < node -> d1 ; i ++ )
104+ {
105+ node -> wsum_hess -> p [var1_id + i ] = 2 * i ;
106+ node -> wsum_hess -> i [2 * i ] = var1_id + i ;
107+ node -> wsum_hess -> i [2 * i + 1 ] = var2_id + i ;
108+ }
109+
110+ int nnz = 2 * node -> d1 ;
111+
112+ /* rows between var1 and var2 */
113+ for (i = var1_id + node -> d1 ; i < var2_id ; i ++ )
114+ {
115+ node -> wsum_hess -> p [i ] = nnz ;
116+ }
117+
118+ /* var2 rows of Hessian */
119+ for (i = 0 ; i < node -> d1 ; i ++ )
120+ {
121+ node -> wsum_hess -> p [var2_id + i ] = nnz + 2 * i ;
122+ }
123+ memcpy (node -> wsum_hess -> i + nnz , node -> wsum_hess -> i , nnz * sizeof (int ));
124+
125+ /* remaining rows */
126+ for (i = var2_id + node -> d1 ; i <= node -> n_vars ; i ++ )
127+ {
128+ node -> wsum_hess -> p [i ] = 4 * node -> d1 ;
129+ }
130+ }
131+
132+ static void eval_wsum_hess_vector_args (expr * node , const double * w )
133+ {
134+ double * x = node -> left -> value ;
135+ double * y = node -> right -> value ;
136+ double * hess = node -> wsum_hess -> x ;
137+
138+ if (node -> left -> var_id < node -> right -> var_id )
139+ {
140+ for (int i = 0 ; i < node -> d1 ; i ++ )
141+ {
142+ hess [2 * i ] = w [i ] / x [i ];
143+ hess [2 * i + 1 ] = - w [i ] / y [i ];
144+ }
145+
146+ hess += 2 * node -> d1 ;
147+
148+ for (int i = 0 ; i < node -> d1 ; i ++ )
149+ {
150+ hess [2 * i ] = - w [i ] / y [i ];
151+ hess [2 * i + 1 ] = w [i ] * x [i ] / (y [i ] * y [i ]);
152+ }
153+ }
154+ else
155+ {
156+ for (int i = 0 ; i < node -> d1 ; i ++ )
157+ {
158+ hess [2 * i ] = w [i ] * x [i ] / (y [i ] * y [i ]);
159+ hess [2 * i + 1 ] = - w [i ] / y [i ];
160+ }
161+
162+ hess += 2 * node -> d1 ;
163+
164+ for (int i = 0 ; i < node -> d1 ; i ++ )
165+ {
166+ hess [2 * i ] = - w [i ] / y [i ];
167+ hess [2 * i + 1 ] = w [i ] / x [i ];
168+ }
169+ }
170+ }
171+
79172expr * new_rel_entr_vector_args (expr * left , expr * right )
80173{
81174 expr * node = new_expr (left -> d1 , 1 , left -> n_vars );
@@ -86,6 +179,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
86179 node -> forward = forward_vector_args ;
87180 node -> jacobian_init = jacobian_init_vectors_args ;
88181 node -> eval_jacobian = eval_jacobian_vector_args ;
182+ node -> wsum_hess_init = wsum_hess_init_vector_args ;
183+ node -> eval_wsum_hess = eval_wsum_hess_vector_args ;
89184 // node->is_affine = is_affine_elementwise;
90185 // node->local_jacobian = local_jacobian;
91186 return node ;
0 commit comments