Skip to content

Commit af1f9a3

Browse files
committed
relative entropy hessian
1 parent 3a23c26 commit af1f9a3

11 files changed

Lines changed: 192 additions & 8 deletions

File tree

src/bivariate/rel_entr.c

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "bivariate.h"
2+
#include <assert.h>
23
#include <math.h>
34
#include <stdlib.h>
45

@@ -28,6 +29,8 @@ static void jacobian_init_vectors_args(expr *node)
2829

2930
expr *x = node->left;
3031
expr *y = node->right;
32+
assert(x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE);
33+
assert(x->var_id != y->var_id);
3134

3235
/* if x has lower variable idx than y it should appear first */
3336
if (x->var_id < y->var_id)
@@ -76,6 +79,97 @@ static void eval_jacobian_vector_args(expr *node)
7679
}
7780
}
7881

82+
static void wsum_hess_init_vector_args(expr *node)
83+
{
84+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1);
85+
expr *x = node->left;
86+
expr *y = node->right;
87+
88+
int i, var1_id, var2_id;
89+
90+
if (x->var_id < y->var_id)
91+
{
92+
var1_id = x->var_id;
93+
var2_id = y->var_id;
94+
}
95+
else
96+
{
97+
var1_id = y->var_id;
98+
var2_id = x->var_id;
99+
}
100+
101+
/* var1 rows of Hessian */
102+
for (i = 0; i < node->d1; i++)
103+
{
104+
node->wsum_hess->p[var1_id + i] = 2 * i;
105+
node->wsum_hess->i[2 * i] = var1_id + i;
106+
node->wsum_hess->i[2 * i + 1] = var2_id + i;
107+
}
108+
109+
int nnz = 2 * node->d1;
110+
111+
/* rows between var1 and var2 */
112+
for (i = var1_id + node->d1; i < var2_id; i++)
113+
{
114+
node->wsum_hess->p[i] = nnz;
115+
}
116+
117+
/* var2 rows of Hessian */
118+
for (i = 0; i < node->d1; i++)
119+
{
120+
node->wsum_hess->p[var2_id + i] = nnz + 2 * i;
121+
}
122+
memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int));
123+
124+
/* remaining rows */
125+
for (i = var2_id + node->d1; i <= node->n_vars; i++)
126+
{
127+
node->wsum_hess->p[i] = 4 * node->d1;
128+
}
129+
}
130+
131+
static void eval_wsum_hess_vector_args(expr *node, const double *w)
132+
{
133+
134+
int i, x_id, y_id;
135+
double *x = node->left->value;
136+
double *y = node->right->value;
137+
double *hess = node->wsum_hess->x;
138+
139+
if (node->left->var_id < node->right->var_id)
140+
{
141+
for (i = 0; i < node->d1; i++)
142+
{
143+
hess[2 * i] = w[i] / x[i];
144+
hess[2 * i + 1] = -w[i] / y[i];
145+
}
146+
147+
hess += 2 * node->d1;
148+
149+
for (i = 0; i < node->d1; i++)
150+
{
151+
hess[2 * i] = -w[i] / y[i];
152+
hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]);
153+
}
154+
}
155+
else
156+
{
157+
for (i = 0; i < node->d1; i++)
158+
{
159+
hess[2 * i] = w[i] * x[i] / (y[i] * y[i]);
160+
hess[2 * i + 1] = -w[i] / y[i];
161+
}
162+
163+
hess += 2 * node->d1;
164+
165+
for (i = 0; i < node->d1; i++)
166+
{
167+
hess[2 * i] = -w[i] / y[i];
168+
hess[2 * i + 1] = w[i] / x[i];
169+
}
170+
}
171+
}
172+
79173
expr *new_rel_entr_vector_args(expr *left, expr *right)
80174
{
81175
expr *node = new_expr(left->d1, 1, left->n_vars);
@@ -86,6 +180,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
86180
node->forward = forward_vector_args;
87181
node->jacobian_init = jacobian_init_vectors_args;
88182
node->eval_jacobian = eval_jacobian_vector_args;
183+
node->wsum_hess_init = wsum_hess_init_vector_args;
184+
node->eval_wsum_hess = eval_wsum_hess_vector_args;
89185
// node->is_affine = is_affine_elementwise;
90186
// node->local_jacobian = local_jacobian;
91187
return node;

tests/all_tests.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@
2121
#include "jacobian_tests/test_sum.h"
2222
#include "utils/test_csc_matrix.h"
2323
#include "utils/test_csr_matrix.h"
24-
#include "wsum_hess/test_entr.h"
25-
#include "wsum_hess/test_exp.h"
26-
#include "wsum_hess/test_hyperbolic.h"
27-
#include "wsum_hess/test_log.h"
28-
#include "wsum_hess/test_logistic.h"
29-
#include "wsum_hess/test_power.h"
24+
#include "wsum_hess/elementwise/test_entr.h"
25+
#include "wsum_hess/elementwise/test_exp.h"
26+
#include "wsum_hess/elementwise/test_hyperbolic.h"
27+
#include "wsum_hess/elementwise/test_log.h"
28+
#include "wsum_hess/elementwise/test_logistic.h"
29+
#include "wsum_hess/elementwise/test_power.h"
30+
#include "wsum_hess/elementwise/test_trig.h"
31+
#include "wsum_hess/elementwise/test_xexp.h"
32+
#include "wsum_hess/test_rel_entr.h"
3033
#include "wsum_hess/test_sum.h"
31-
#include "wsum_hess/test_trig.h"
32-
#include "wsum_hess/test_xexp.h"
3334

3435
int main(void)
3536
{
@@ -95,6 +96,8 @@ int main(void)
9596
mu_run_test(test_wsum_hess_sum_log_linear, tests_run);
9697
mu_run_test(test_wsum_hess_sum_log_axis0, tests_run);
9798
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
99+
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
100+
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
98101

99102
printf("\n--- Utility Tests ---\n");
100103
mu_run_test(test_diag_csr_mult, tests_run);
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)