Skip to content

Commit 1f25f23

Browse files
authored
Merge pull request #4 from dance858/rel_entr_hess
[Ready for review] Rel entr and hstack hessian
2 parents a2ebd92 + c16d663 commit 1f25f23

17 files changed

Lines changed: 471 additions & 15 deletions

File tree

include/subexpr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ typedef struct hstack_expr
4646
expr base;
4747
expr **args;
4848
int n_args;
49+
CSR_Matrix *CSR_work; /* for summing Hessians of children */
4950
} hstack_expr;
5051

5152
#endif /* SUBEXPR_H */

include/utils/CSR_Matrix.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C);
5555

5656
/* Compute C = A + B where A, B, C are CSR matrices
5757
* A and B must have same dimensions
58-
* C must be pre-allocated with sufficient nnz capacity */
58+
* C must be pre-allocated with sufficient nnz capacity.
59+
* C must be different from A and B */
5960
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
6061

6162
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */

src/affine/hstack.c

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void jacobian_init(expr *node)
3737

3838
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);
3939

40-
/* precompute sparsity pattern of this jacobian's node */
40+
/* precompute sparsity pattern of this node's jacobian */
4141
int row_offset = 0;
4242
CSR_Matrix *A = node->jacobian;
4343
A->nnz = 0;
@@ -80,6 +80,40 @@ static void eval_jacobian(expr *node)
8080
}
8181
}
8282

83+
static void wsum_hess_init(expr *node)
84+
{
85+
/* initialize children's hessians */
86+
hstack_expr *hnode = (hstack_expr *) node;
87+
int nnz = 0;
88+
for (int i = 0; i < hnode->n_args; i++)
89+
{
90+
hnode->args[i]->wsum_hess_init(hnode->args[i]);
91+
nnz += hnode->args[i]->wsum_hess->nnz;
92+
}
93+
94+
/* worst-case scenario the nnz of node->wsum_hess is the sum of children's
95+
nnz */
96+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz);
97+
hnode->CSR_work = new_csr_matrix(node->n_vars, node->n_vars, nnz);
98+
}
99+
100+
static void wsum_hess_eval(expr *node, const double *w)
101+
{
102+
hstack_expr *hnode = (hstack_expr *) node;
103+
CSR_Matrix *H = node->wsum_hess;
104+
int row_offset = 0;
105+
H->nnz = 0;
106+
107+
for (int i = 0; i < hnode->n_args; i++)
108+
{
109+
expr *child = hnode->args[i];
110+
child->eval_wsum_hess(child, w + row_offset);
111+
copy_csr_matrix(H, hnode->CSR_work);
112+
sum_csr_matrices(hnode->CSR_work, child->wsum_hess, H);
113+
row_offset += child->size;
114+
}
115+
}
116+
83117
static bool is_affine(const expr *node)
84118
{
85119
const hstack_expr *hnode = (const hstack_expr *) node;
@@ -100,7 +134,11 @@ static void free_type_data(expr *node)
100134
for (int i = 0; i < hnode->n_args; i++)
101135
{
102136
free_expr(hnode->args[i]);
137+
hnode->args[i] = NULL;
103138
}
139+
140+
free_csr_matrix(hnode->CSR_work);
141+
hnode->CSR_work = NULL;
104142
}
105143

106144
expr *new_hstack(expr **args, int n_args, int n_vars)
@@ -121,6 +159,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
121159
/* Set type-specific fields */
122160
hnode->args = args;
123161
hnode->n_args = n_args;
162+
node->wsum_hess_init = wsum_hess_init;
163+
node->eval_wsum_hess = wsum_hess_eval;
124164

125165
for (int i = 0; i < n_args; i++)
126166
{

src/bivariate/multiply.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static void jacobian_init(expr *node)
2727
{
2828
/* if a child is a variable we initialize its jacobian for a
2929
short chain rule implementation */
30-
if (node->left->var_id != -1)
30+
if (node->left->var_id != NOT_A_VARIABLE)
3131
{
3232
node->left->jacobian_init(node->left);
3333
}

src/bivariate/rel_entr.c

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "bivariate.h"
2+
#include <assert.h>
23
#include <math.h>
34
#include <stdlib.h>
5+
#include <string.h>
46

57
// --------------------------------------------------------------------
68
// Implementation of relative entropy when both arguments are vectors.
@@ -28,6 +30,8 @@ static void jacobian_init_vectors_args(expr *node)
2830

2931
expr *x = node->left;
3032
expr *y = node->right;
33+
assert(x->var_id != NOT_A_VARIABLE && y->var_id != NOT_A_VARIABLE);
34+
assert(x->var_id != y->var_id);
3135

3236
/* if x has lower variable idx than y it should appear first */
3337
if (x->var_id < y->var_id)
@@ -76,6 +80,95 @@ static void eval_jacobian_vector_args(expr *node)
7680
}
7781
}
7882

83+
static void wsum_hess_init_vector_args(expr *node)
84+
{
85+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, 4 * node->d1);
86+
expr *x = node->left;
87+
expr *y = node->right;
88+
89+
int i, var1_id, var2_id;
90+
91+
if (x->var_id < y->var_id)
92+
{
93+
var1_id = x->var_id;
94+
var2_id = y->var_id;
95+
}
96+
else
97+
{
98+
var1_id = y->var_id;
99+
var2_id = x->var_id;
100+
}
101+
102+
/* var1 rows of Hessian */
103+
for (i = 0; i < node->d1; i++)
104+
{
105+
node->wsum_hess->p[var1_id + i] = 2 * i;
106+
node->wsum_hess->i[2 * i] = var1_id + i;
107+
node->wsum_hess->i[2 * i + 1] = var2_id + i;
108+
}
109+
110+
int nnz = 2 * node->d1;
111+
112+
/* rows between var1 and var2 */
113+
for (i = var1_id + node->d1; i < var2_id; i++)
114+
{
115+
node->wsum_hess->p[i] = nnz;
116+
}
117+
118+
/* var2 rows of Hessian */
119+
for (i = 0; i < node->d1; i++)
120+
{
121+
node->wsum_hess->p[var2_id + i] = nnz + 2 * i;
122+
}
123+
memcpy(node->wsum_hess->i + nnz, node->wsum_hess->i, nnz * sizeof(int));
124+
125+
/* remaining rows */
126+
for (i = var2_id + node->d1; i <= node->n_vars; i++)
127+
{
128+
node->wsum_hess->p[i] = 4 * node->d1;
129+
}
130+
}
131+
132+
static void eval_wsum_hess_vector_args(expr *node, const double *w)
133+
{
134+
double *x = node->left->value;
135+
double *y = node->right->value;
136+
double *hess = node->wsum_hess->x;
137+
138+
if (node->left->var_id < node->right->var_id)
139+
{
140+
for (int i = 0; i < node->d1; i++)
141+
{
142+
hess[2 * i] = w[i] / x[i];
143+
hess[2 * i + 1] = -w[i] / y[i];
144+
}
145+
146+
hess += 2 * node->d1;
147+
148+
for (int i = 0; i < node->d1; i++)
149+
{
150+
hess[2 * i] = -w[i] / y[i];
151+
hess[2 * i + 1] = w[i] * x[i] / (y[i] * y[i]);
152+
}
153+
}
154+
else
155+
{
156+
for (int i = 0; i < node->d1; i++)
157+
{
158+
hess[2 * i] = w[i] * x[i] / (y[i] * y[i]);
159+
hess[2 * i + 1] = -w[i] / y[i];
160+
}
161+
162+
hess += 2 * node->d1;
163+
164+
for (int i = 0; i < node->d1; i++)
165+
{
166+
hess[2 * i] = -w[i] / y[i];
167+
hess[2 * i + 1] = w[i] / x[i];
168+
}
169+
}
170+
}
171+
79172
expr *new_rel_entr_vector_args(expr *left, expr *right)
80173
{
81174
expr *node = new_expr(left->d1, 1, left->n_vars);
@@ -86,6 +179,8 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
86179
node->forward = forward_vector_args;
87180
node->jacobian_init = jacobian_init_vectors_args;
88181
node->eval_jacobian = eval_jacobian_vector_args;
182+
node->wsum_hess_init = wsum_hess_init_vector_args;
183+
node->eval_wsum_hess = eval_wsum_hess_vector_args;
89184
// node->is_affine = is_affine_elementwise;
90185
// node->local_jacobian = local_jacobian;
91186
return node;

src/utils/CSR_Matrix.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C)
9999

100100
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
101101
{
102+
/* A and B must be different from C */
103+
assert(A != C && B != C);
104+
102105
C->nnz = 0;
103106

104107
for (int row = 0; row < A->m; row++)
@@ -138,17 +141,17 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
138141
if (a_ptr < a_end)
139142
{
140143
int a_remaining = a_end - a_ptr;
141-
memmove(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
142-
memmove(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
144+
memcpy(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
145+
memcpy(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
143146
C->nnz += a_remaining;
144147
}
145148

146149
/* Copy remaining elements from B */
147150
if (b_ptr < b_end)
148151
{
149152
int b_remaining = b_end - b_ptr;
150-
memmove(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
151-
memmove(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
153+
memcpy(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
154+
memcpy(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
152155
C->nnz += b_remaining;
153156
}
154157
}

tests/all_tests.c

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@
2121
#include "jacobian_tests/test_sum.h"
2222
#include "utils/test_csc_matrix.h"
2323
#include "utils/test_csr_matrix.h"
24-
#include "wsum_hess/test_entr.h"
25-
#include "wsum_hess/test_exp.h"
26-
#include "wsum_hess/test_hyperbolic.h"
27-
#include "wsum_hess/test_log.h"
28-
#include "wsum_hess/test_logistic.h"
29-
#include "wsum_hess/test_power.h"
24+
#include "wsum_hess/elementwise/test_entr.h"
25+
#include "wsum_hess/elementwise/test_exp.h"
26+
#include "wsum_hess/elementwise/test_hyperbolic.h"
27+
#include "wsum_hess/elementwise/test_log.h"
28+
#include "wsum_hess/elementwise/test_logistic.h"
29+
#include "wsum_hess/elementwise/test_power.h"
30+
#include "wsum_hess/elementwise/test_trig.h"
31+
#include "wsum_hess/elementwise/test_xexp.h"
32+
#include "wsum_hess/test_hstack.h"
33+
#include "wsum_hess/test_rel_entr.h"
3034
#include "wsum_hess/test_sum.h"
31-
#include "wsum_hess/test_trig.h"
32-
#include "wsum_hess/test_xexp.h"
3335

3436
int main(void)
3537
{
@@ -95,6 +97,10 @@ int main(void)
9597
mu_run_test(test_wsum_hess_sum_log_linear, tests_run);
9698
mu_run_test(test_wsum_hess_sum_log_axis0, tests_run);
9799
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
100+
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
101+
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
102+
mu_run_test(test_wsum_hess_hstack, tests_run);
103+
mu_run_test(test_wsum_hess_hstack_matrix, tests_run);
98104

99105
printf("\n--- Utility Tests ---\n");
100106
mu_run_test(test_diag_csr_mult, tests_run);
File renamed without changes.

0 commit comments

Comments
 (0)