Skip to content

Commit ca61c00

Browse files
committed
hessian for elementwise univariate atoms
1 parent 56d3434 commit ca61c00

28 files changed

Lines changed: 1213 additions & 124 deletions

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
1. power should be double
2+
2. can we reuse calculations, like in hessian of logistic

include/elementwise_univariate.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ expr *new_asinh(expr *child);
1515
expr *new_atanh(expr *child);
1616
expr *new_logistic(expr *child);
1717
expr *new_power(expr *child, int p);
18+
expr *new_xexp(expr *child);
1819

19-
/* the jacobian for elementwise atoms are always initialized in the
20-
same way and implement the chain rule in the same way */
20+
/* the jacobian and wsum_hess for elementwise univariate atoms are always
21+
initialized in the same way and implement the chain rule in the same way */
2122
void jacobian_init_elementwise(expr *node);
2223
void eval_jacobian_elementwise(expr *node);
24+
void wsum_hess_init_elementwise(expr *node);
25+
void eval_wsum_hess_elementwise(expr *node, double *w);
26+
expr *new_elementwise(expr *child);
2327

2428
/* no elementwise atoms are affine according to our convention,
2529
so we can have a common implementation */

include/expr.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ struct int_double_pair;
1515
/* Function pointer types */
1616
typedef void (*forward_fn)(struct expr *node, const double *u);
1717
typedef void (*jacobian_init_fn)(struct expr *node);
18+
typedef void (*wsum_hess_init_fn)(struct expr *node);
1819
typedef void (*eval_jacobian_fn)(struct expr *node);
19-
typedef void (*eval_local_jacobian_fn)(struct expr *node, double *out);
20+
typedef void (*wsum_hess_fn)(struct expr *node, double *w);
21+
typedef void (*local_jacobian_fn)(struct expr *node, double *out);
22+
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w);
2023
typedef bool (*is_affine_fn)(struct expr *node);
2124

2225
/* TODO: implement proper polymorphism */
@@ -52,10 +55,14 @@ typedef struct expr
5255
// jacobian related quantities
5356
// ------------------------------------------------------------------------
5457
CSR_Matrix *jacobian;
58+
CSR_Matrix *wsum_hess;
5559
CSR_Matrix *CSR_work;
5660
jacobian_init_fn jacobian_init;
61+
wsum_hess_init_fn wsum_hess_init;
5762
eval_jacobian_fn eval_jacobian;
58-
eval_local_jacobian_fn eval_local_jacobian;
63+
wsum_hess_fn eval_wsum_hess;
64+
local_jacobian_fn local_jacobian;
65+
local_wsum_hess_fn local_wsum_hess;
5966
is_affine_fn is_affine;
6067

6168
// for every linear operator we store A in CSR and CSC

src/bivariate/rel_entr.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ expr *new_rel_entr_vector_args(expr *left, expr *right)
8787
node->jacobian_init = jacobian_init_vectors_args;
8888
node->eval_jacobian = eval_jacobian_vector_args;
8989
// node->is_affine = is_affine_elementwise;
90-
// node->eval_local_jacobian = eval_local_jacobian;
90+
// node->local_jacobian = local_jacobian;
9191
return node;
9292
}
9393

src/elementwise_univariate/common.c

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ void jacobian_init_elementwise(expr *node)
55
{
66
expr *child = node->left;
77

8-
// if the variable is a child
8+
/* if the variable is a child */
99
if (child->var_id != -1)
1010
{
1111
node->jacobian = new_csr_matrix(node->size, node->n_vars, node->size);
@@ -16,7 +16,7 @@ void jacobian_init_elementwise(expr *node)
1616
}
1717
node->jacobian->p[node->size] = node->size;
1818
}
19-
// otherwise it should be a linear operator
19+
/* otherwise it should be a linear operator */
2020
else
2121
{
2222
node->jacobian = new_csr_matrix(child->jacobian->m, child->jacobian->n,
@@ -31,17 +31,73 @@ void eval_jacobian_elementwise(expr *node)
3131

3232
if (child->var_id != -1)
3333
{
34-
node->eval_local_jacobian(node, node->jacobian->x);
34+
node->local_jacobian(node, node->jacobian->x);
3535
}
3636
else
3737
{
38-
node->eval_local_jacobian(node, node->dwork);
38+
node->local_jacobian(node, node->dwork);
3939
diag_csr_mult(node->dwork, child->A_csr, node->jacobian);
4040
}
4141
}
4242

43+
void wsum_hess_init_elementwise(expr *node)
44+
{
45+
expr *child = node->left;
46+
int id = child->var_id;
47+
int i;
48+
49+
/* if the variable is a child*/
50+
if (id != -1)
51+
{
52+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size);
53+
54+
for (i = 0; i < node->size; i++)
55+
{
56+
node->wsum_hess->p[id + i] = i;
57+
node->wsum_hess->i[i] = id + i;
58+
}
59+
60+
for (i = id + node->size; i <= node->n_vars; i++)
61+
{
62+
node->wsum_hess->p[i] = node->size;
63+
}
64+
}
65+
/* otherwise it will be a linear operator */
66+
else
67+
{
68+
node->wsum_hess = ATA_alloc(child->A_csc);
69+
}
70+
}
71+
72+
void eval_wsum_hess_elementwise(expr *node, double *w)
73+
{
74+
expr *child = node->left;
75+
76+
if (child->var_id != -1)
77+
{
78+
node->local_wsum_hess(node, node->wsum_hess->x, w);
79+
}
80+
else
81+
{
82+
node->local_wsum_hess(node, node->dwork, w);
83+
ATDA_values(child->A_csc, node->dwork, node->wsum_hess);
84+
}
85+
}
86+
4387
bool is_affine_elementwise(expr *node)
4488
{
4589
(void) node;
4690
return false;
4791
}
92+
93+
expr *new_elementwise(expr *child)
94+
{
95+
expr *node = new_expr(child->d1, child->d2, child->n_vars);
96+
node->left = child;
97+
expr_retain(child);
98+
node->is_affine = is_affine_elementwise;
99+
node->jacobian_init = jacobian_init_elementwise;
100+
node->eval_jacobian = eval_jacobian_elementwise;
101+
node->wsum_hess_init = wsum_hess_init_elementwise;
102+
node->eval_wsum_hess = eval_wsum_hess_elementwise;
103+
}

src/elementwise_univariate/entr.c

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ static void forward(expr *node, const double *u)
1515
}
1616
}
1717

18-
static void eval_local_jacobian(expr *node, double *vals)
18+
static void local_jacobian(expr *node, double *vals)
1919
{
2020
expr *child = node->left;
2121
for (int j = 0; j < node->size; j++)
@@ -24,15 +24,21 @@ static void eval_local_jacobian(expr *node, double *vals)
2424
}
2525
}
2626

27+
static void local_wsum_hess(expr *node, double *out, double *w)
28+
{
29+
double *x = node->left->value;
30+
31+
for (int j = 0; j < node->size; j++)
32+
{
33+
out[j] = -w[j] / x[j];
34+
}
35+
}
36+
2737
expr *new_entr(expr *child)
2838
{
29-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
30-
node->left = child;
31-
expr_retain(child);
39+
expr *node = new_elementwise(child);
3240
node->forward = forward;
33-
node->jacobian_init = jacobian_init_elementwise;
34-
node->eval_jacobian = eval_jacobian_elementwise;
35-
node->is_affine = is_affine_elementwise;
36-
node->eval_local_jacobian = eval_local_jacobian;
41+
node->local_jacobian = local_jacobian;
42+
node->local_wsum_hess = local_wsum_hess;
3743
return node;
3844
}

src/elementwise_univariate/exp.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@ static void forward(expr *node, const double *u)
1414
}
1515
}
1616

17-
static void eval_local_jacobian(expr *node, double *vals)
17+
static void local_jacobian(expr *node, double *vals)
1818
{
1919
memcpy(vals, node->value, node->size * sizeof(double));
2020
}
2121

22-
expr *new_exp(expr *child)
22+
static void local_wsum_hess(expr *node, double *out, double *w)
2323
{
24-
if (!child) return NULL;
25-
26-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
27-
if (!node) return NULL;
24+
double *x = node->left->value;
25+
for (int j = 0; j < node->size; j++)
26+
{
27+
out[j] = w[j] * exp(x[j]);
28+
}
29+
}
2830

29-
node->left = child;
30-
expr_retain(child);
31+
expr *new_exp(expr *child)
32+
{
33+
expr *node = new_elementwise(child);
3134
node->forward = forward;
32-
node->is_affine = is_affine_elementwise;
33-
node->jacobian_init = jacobian_init_elementwise;
34-
node->eval_jacobian = eval_jacobian_elementwise;
35-
node->eval_local_jacobian = eval_local_jacobian;
36-
35+
node->local_jacobian = local_jacobian;
36+
node->local_wsum_hess = local_wsum_hess;
3737
return node;
3838
}
Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "elementwise_univariate.h"
2+
#include <assert.h>
23
#include <math.h>
34

45
/* ----------------------- sinh ----------------------- */
@@ -11,7 +12,7 @@ static void sinh_forward(expr *node, const double *u)
1112
}
1213
}
1314

14-
static void sinh_eval_local_jacobian(expr *node, double *vals)
15+
static void sinh_local_jacobian(expr *node, double *vals)
1516
{
1617
expr *child = node->left;
1718
for (int j = 0; j < node->size; j++)
@@ -20,16 +21,21 @@ static void sinh_eval_local_jacobian(expr *node, double *vals)
2021
}
2122
}
2223

24+
static void sinh_local_wsum_hess(expr *node, double *out, double *w)
25+
{
26+
double *x = node->left->value;
27+
for (int j = 0; j < node->size; j++)
28+
{
29+
out[j] = w[j] * sinh(x[j]);
30+
}
31+
}
32+
2333
expr *new_sinh(expr *child)
2434
{
25-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
26-
node->left = child;
27-
expr_retain(child);
35+
expr *node = new_elementwise(child);
2836
node->forward = sinh_forward;
29-
node->jacobian_init = jacobian_init_elementwise;
30-
node->eval_jacobian = eval_jacobian_elementwise;
31-
node->eval_local_jacobian = sinh_eval_local_jacobian;
32-
node->is_affine = is_affine_elementwise;
37+
node->local_jacobian = sinh_local_jacobian;
38+
node->local_wsum_hess = sinh_local_wsum_hess;
3339
return node;
3440
}
3541

@@ -43,7 +49,7 @@ static void tanh_forward(expr *node, const double *u)
4349
}
4450
}
4551

46-
static void tanh_eval_local_jacobian(expr *node, double *vals)
52+
static void tanh_local_jacobian(expr *node, double *vals)
4753
{
4854
expr *child = node->left;
4955
for (int j = 0; j < node->size; j++)
@@ -53,16 +59,22 @@ static void tanh_eval_local_jacobian(expr *node, double *vals)
5359
}
5460
}
5561

62+
static void tanh_local_wsum_hess(expr *node, double *out, double *w)
63+
{
64+
double *x = node->left->value;
65+
for (int j = 0; j < node->size; j++)
66+
{
67+
double c = cosh(x[j]);
68+
out[j] = w[j] * (-2.0 * tanh(x[j]) / (c * c));
69+
}
70+
}
71+
5672
expr *new_tanh(expr *child)
5773
{
58-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
59-
node->left = child;
60-
expr_retain(child);
74+
expr *node = new_elementwise(child);
6175
node->forward = tanh_forward;
62-
node->jacobian_init = jacobian_init_elementwise;
63-
node->eval_jacobian = eval_jacobian_elementwise;
64-
node->eval_local_jacobian = tanh_eval_local_jacobian;
65-
node->is_affine = is_affine_elementwise;
76+
node->local_jacobian = tanh_local_jacobian;
77+
node->local_wsum_hess = tanh_local_wsum_hess;
6678
return node;
6779
}
6880

@@ -76,7 +88,7 @@ static void asinh_forward(expr *node, const double *u)
7688
}
7789
}
7890

79-
static void asinh_eval_local_jacobian(expr *node, double *vals)
91+
static void asinh_local_jacobian(expr *node, double *vals)
8092
{
8193
expr *child = node->left;
8294
for (int j = 0; j < node->size; j++)
@@ -85,16 +97,22 @@ static void asinh_eval_local_jacobian(expr *node, double *vals)
8597
}
8698
}
8799

100+
static void asinh_local_wsum_hess(expr *node, double *out, double *w)
101+
{
102+
double *x = node->left->value;
103+
for (int j = 0; j < node->size; j++)
104+
{
105+
double c = 1.0 + x[j] * x[j];
106+
out[j] = w[j] * (-x[j]) / pow(c, 1.5);
107+
}
108+
}
109+
88110
expr *new_asinh(expr *child)
89111
{
90-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
91-
node->left = child;
92-
expr_retain(child);
112+
expr *node = new_elementwise(child);
93113
node->forward = asinh_forward;
94-
node->jacobian_init = jacobian_init_elementwise;
95-
node->eval_jacobian = eval_jacobian_elementwise;
96-
node->eval_local_jacobian = asinh_eval_local_jacobian;
97-
node->is_affine = is_affine_elementwise;
114+
node->local_jacobian = asinh_local_jacobian;
115+
node->local_wsum_hess = asinh_local_wsum_hess;
98116
return node;
99117
}
100118

@@ -108,7 +126,7 @@ static void atanh_forward(expr *node, const double *u)
108126
}
109127
}
110128

111-
static void atanh_eval_local_jacobian(expr *node, double *vals)
129+
static void atanh_local_jacobian(expr *node, double *vals)
112130
{
113131
expr *child = node->left;
114132
for (int j = 0; j < node->size; j++)
@@ -117,15 +135,21 @@ static void atanh_eval_local_jacobian(expr *node, double *vals)
117135
}
118136
}
119137

138+
static void atanh_local_wsum_hess(expr *node, double *out, double *w)
139+
{
140+
double *x = node->left->value;
141+
for (int j = 0; j < node->size; j++)
142+
{
143+
double c = 1.0 - x[j] * x[j];
144+
out[j] = w[j] * (2.0 * x[j]) / (c * c);
145+
}
146+
}
147+
120148
expr *new_atanh(expr *child)
121149
{
122-
expr *node = new_expr(child->d1, child->d2, child->n_vars);
123-
node->left = child;
124-
expr_retain(child);
150+
expr *node = new_elementwise(child);
125151
node->forward = atanh_forward;
126-
node->jacobian_init = jacobian_init_elementwise;
127-
node->eval_jacobian = eval_jacobian_elementwise;
128-
node->eval_local_jacobian = atanh_eval_local_jacobian;
129-
node->is_affine = is_affine_elementwise;
152+
node->local_jacobian = atanh_local_jacobian;
153+
node->local_wsum_hess = atanh_local_wsum_hess;
130154
return node;
131155
}

0 commit comments

Comments
 (0)