11#include "elementwise_univariate.h"
2+ #include <assert.h>
23#include <math.h>
34
45/* ----------------------- sinh ----------------------- */
@@ -11,7 +12,7 @@ static void sinh_forward(expr *node, const double *u)
1112 }
1213}
1314
14- static void sinh_eval_local_jacobian (expr * node , double * vals )
15+ static void sinh_local_jacobian (expr * node , double * vals )
1516{
1617 expr * child = node -> left ;
1718 for (int j = 0 ; j < node -> size ; j ++ )
@@ -20,16 +21,21 @@ static void sinh_eval_local_jacobian(expr *node, double *vals)
2021 }
2122}
2223
24+ static void sinh_local_wsum_hess (expr * node , double * out , double * w )
25+ {
26+ double * x = node -> left -> value ;
27+ for (int j = 0 ; j < node -> size ; j ++ )
28+ {
29+ out [j ] = w [j ] * sinh (x [j ]);
30+ }
31+ }
32+
2333expr * new_sinh (expr * child )
2434{
25- expr * node = new_expr (child -> d1 , child -> d2 , child -> n_vars );
26- node -> left = child ;
27- expr_retain (child );
35+ expr * node = new_elementwise (child );
2836 node -> forward = sinh_forward ;
29- node -> jacobian_init = jacobian_init_elementwise ;
30- node -> eval_jacobian = eval_jacobian_elementwise ;
31- node -> eval_local_jacobian = sinh_eval_local_jacobian ;
32- node -> is_affine = is_affine_elementwise ;
37+ node -> local_jacobian = sinh_local_jacobian ;
38+ node -> local_wsum_hess = sinh_local_wsum_hess ;
3339 return node ;
3440}
3541
@@ -43,7 +49,7 @@ static void tanh_forward(expr *node, const double *u)
4349 }
4450}
4551
46- static void tanh_eval_local_jacobian (expr * node , double * vals )
52+ static void tanh_local_jacobian (expr * node , double * vals )
4753{
4854 expr * child = node -> left ;
4955 for (int j = 0 ; j < node -> size ; j ++ )
@@ -53,16 +59,22 @@ static void tanh_eval_local_jacobian(expr *node, double *vals)
5359 }
5460}
5561
62+ static void tanh_local_wsum_hess (expr * node , double * out , double * w )
63+ {
64+ double * x = node -> left -> value ;
65+ for (int j = 0 ; j < node -> size ; j ++ )
66+ {
67+ double c = cosh (x [j ]);
68+ out [j ] = w [j ] * (-2.0 * tanh (x [j ]) / (c * c ));
69+ }
70+ }
71+
5672expr * new_tanh (expr * child )
5773{
58- expr * node = new_expr (child -> d1 , child -> d2 , child -> n_vars );
59- node -> left = child ;
60- expr_retain (child );
74+ expr * node = new_elementwise (child );
6175 node -> forward = tanh_forward ;
62- node -> jacobian_init = jacobian_init_elementwise ;
63- node -> eval_jacobian = eval_jacobian_elementwise ;
64- node -> eval_local_jacobian = tanh_eval_local_jacobian ;
65- node -> is_affine = is_affine_elementwise ;
76+ node -> local_jacobian = tanh_local_jacobian ;
77+ node -> local_wsum_hess = tanh_local_wsum_hess ;
6678 return node ;
6779}
6880
@@ -76,7 +88,7 @@ static void asinh_forward(expr *node, const double *u)
7688 }
7789}
7890
79- static void asinh_eval_local_jacobian (expr * node , double * vals )
91+ static void asinh_local_jacobian (expr * node , double * vals )
8092{
8193 expr * child = node -> left ;
8294 for (int j = 0 ; j < node -> size ; j ++ )
@@ -85,16 +97,22 @@ static void asinh_eval_local_jacobian(expr *node, double *vals)
8597 }
8698}
8799
100+ static void asinh_local_wsum_hess (expr * node , double * out , double * w )
101+ {
102+ double * x = node -> left -> value ;
103+ for (int j = 0 ; j < node -> size ; j ++ )
104+ {
105+ double c = 1.0 + x [j ] * x [j ];
106+ out [j ] = w [j ] * (- x [j ]) / pow (c , 1.5 );
107+ }
108+ }
109+
88110expr * new_asinh (expr * child )
89111{
90- expr * node = new_expr (child -> d1 , child -> d2 , child -> n_vars );
91- node -> left = child ;
92- expr_retain (child );
112+ expr * node = new_elementwise (child );
93113 node -> forward = asinh_forward ;
94- node -> jacobian_init = jacobian_init_elementwise ;
95- node -> eval_jacobian = eval_jacobian_elementwise ;
96- node -> eval_local_jacobian = asinh_eval_local_jacobian ;
97- node -> is_affine = is_affine_elementwise ;
114+ node -> local_jacobian = asinh_local_jacobian ;
115+ node -> local_wsum_hess = asinh_local_wsum_hess ;
98116 return node ;
99117}
100118
@@ -108,7 +126,7 @@ static void atanh_forward(expr *node, const double *u)
108126 }
109127}
110128
111- static void atanh_eval_local_jacobian (expr * node , double * vals )
129+ static void atanh_local_jacobian (expr * node , double * vals )
112130{
113131 expr * child = node -> left ;
114132 for (int j = 0 ; j < node -> size ; j ++ )
@@ -117,15 +135,21 @@ static void atanh_eval_local_jacobian(expr *node, double *vals)
117135 }
118136}
119137
138+ static void atanh_local_wsum_hess (expr * node , double * out , double * w )
139+ {
140+ double * x = node -> left -> value ;
141+ for (int j = 0 ; j < node -> size ; j ++ )
142+ {
143+ double c = 1.0 - x [j ] * x [j ];
144+ out [j ] = w [j ] * (2.0 * x [j ]) / (c * c );
145+ }
146+ }
147+
120148expr * new_atanh (expr * child )
121149{
122- expr * node = new_expr (child -> d1 , child -> d2 , child -> n_vars );
123- node -> left = child ;
124- expr_retain (child );
150+ expr * node = new_elementwise (child );
125151 node -> forward = atanh_forward ;
126- node -> jacobian_init = jacobian_init_elementwise ;
127- node -> eval_jacobian = eval_jacobian_elementwise ;
128- node -> eval_local_jacobian = atanh_eval_local_jacobian ;
129- node -> is_affine = is_affine_elementwise ;
152+ node -> local_jacobian = atanh_local_jacobian ;
153+ node -> local_wsum_hess = atanh_local_wsum_hess ;
130154 return node ;
131155}
0 commit comments