Skip to content

Commit b21e021

Browse files
authored
split up elementwise atoms (#57)
1 parent 19b1b0a commit b21e021

59 files changed

Lines changed: 475 additions & 697 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/elementwise_full_dom.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef ELEMENTWISE_FULL_DOM_H
2+
#define ELEMENTWISE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Helper function to initialize an elementwise expr
7+
* (can be used with derived types) */
8+
void init_elementwise(expr *node, expr *child);
9+
10+
expr *new_exp(expr *child);
11+
expr *new_sin(expr *child);
12+
expr *new_cos(expr *child);
13+
expr *new_sinh(expr *child);
14+
expr *new_tanh(expr *child);
15+
expr *new_asinh(expr *child);
16+
expr *new_logistic(expr *child);
17+
expr *new_power(expr *child, double p);
18+
expr *new_xexp(expr *child);
19+
expr *new_normal_cdf(expr *child);
20+
21+
/* the jacobian and wsum_hess for elementwise full domain
22+
atoms are always initialized in the same way and
23+
implement the chain rule in the same way */
24+
void jacobian_init_elementwise(expr *node);
25+
void eval_jacobian_elementwise(expr *node);
26+
void wsum_hess_init_elementwise(expr *node);
27+
void eval_wsum_hess_elementwise(expr *node, const double *w);
28+
expr *new_elementwise(expr *child);
29+
30+
/* no elementwise atoms are affine according to our
31+
convention, so we can have a common implementation */
32+
bool is_affine_elementwise(const expr *node);
33+
34+
#endif /* ELEMENTWISE_FULL_DOM_H */
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef ELEMENTWISE_RESTRICTED_DOM_H
2+
#define ELEMENTWISE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
/* Shared init functions for restricted domain atoms
7+
* (variable-child only, no linear operator support) */
8+
void jacobian_init_restricted(expr *node);
9+
void wsum_hess_init_restricted(expr *node);
10+
bool is_affine_restricted(const expr *node);
11+
expr *new_restricted(expr *child);
12+
13+
expr *new_log(expr *child);
14+
expr *new_entr(expr *child);
15+
expr *new_atanh(expr *child);
16+
expr *new_tan(expr *child);
17+
18+
#endif /* ELEMENTWISE_RESTRICTED_DOM_H */

include/elementwise_univariate.h

Lines changed: 0 additions & 54 deletions
This file was deleted.
Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
19-
#include "expr.h"
1+
#include "elementwise_full_dom.h"
202
#include "subexpr.h"
213
#include <stdio.h>
224
#include <stdlib.h>
@@ -121,10 +103,6 @@ bool is_affine_elementwise(const expr *node)
121103
return false;
122104
}
123105

124-
/* Helper function to initialize an already-allocated expr for elementwise operations
125-
* This is called when a power_expr or other type-specific struct is allocated
126-
* and we need to initialize the base expr fields
127-
*/
128106
void init_elementwise(expr *node, expr *child)
129107
{
130108
/* Initialize base fields */
Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
192
#include <math.h>
203
#include <string.h>
214

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
22
#include <assert.h>
33
#include <math.h>
44

@@ -114,41 +114,3 @@ expr *new_asinh(expr *child)
114114
node->local_wsum_hess = asinh_local_wsum_hess;
115115
return node;
116116
}
117-
118-
/* ----------------------- atanh ----------------------- */
119-
static void atanh_forward(expr *node, const double *u)
120-
{
121-
node->left->forward(node->left, u);
122-
for (int i = 0; i < node->size; i++)
123-
{
124-
node->value[i] = atanh(node->left->value[i]);
125-
}
126-
}
127-
128-
static void atanh_local_jacobian(expr *node, double *vals)
129-
{
130-
expr *child = node->left;
131-
for (int j = 0; j < node->size; j++)
132-
{
133-
vals[j] = 1.0 / (1.0 - child->value[j] * child->value[j]);
134-
}
135-
}
136-
137-
static void atanh_local_wsum_hess(expr *node, double *out, const double *w)
138-
{
139-
double *x = node->left->value;
140-
for (int j = 0; j < node->size; j++)
141-
{
142-
double c = 1.0 - x[j] * x[j];
143-
out[j] = w[j] * (2.0 * x[j]) / (c * c);
144-
}
145-
}
146-
147-
expr *new_atanh(expr *child)
148-
{
149-
expr *node = new_elementwise(child);
150-
node->forward = atanh_forward;
151-
node->local_jacobian = atanh_local_jacobian;
152-
node->local_wsum_hess = atanh_local_wsum_hess;
153-
return node;
154-
}
Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
192
#include <math.h>
203

214
static void forward(expr *node, const double *u)
Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
192
#include <math.h>
203

214
#ifndef M_PI
@@ -53,7 +36,6 @@ static void local_wsum_hess(expr *node, double *out, const double *w)
5336
double *x = node->left->value;
5437
for (int j = 0; j < node->size; j++)
5538
{
56-
/* could avoid recomputing this (like in logistic) */
5739
double phi = INV_SQRT_2PI * exp(-0.5 * x[j] * x[j]);
5840
out[j] = w[j] * (-x[j] * phi);
5941
}
Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
192
#include "subexpr.h"
203
#include <math.h>
214
#include <stdlib.h>
Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,4 @@
1-
/*
2-
* Copyright 2026 Daniel Cederberg and William Zhang
3-
*
4-
* This file is part of the DNLP-differentiation-engine project.
5-
*
6-
* Licensed under the Apache License, Version 2.0 (the "License");
7-
* you may not use this file except in compliance with the License.
8-
* You may obtain a copy of the License at
9-
*
10-
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
12-
* Unless required by applicable law or agreed to in writing, software
13-
* distributed under the License is distributed on an "AS IS" BASIS,
14-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15-
* See the License for the specific language governing permissions and
16-
* limitations under the License.
17-
*/
18-
#include "elementwise_univariate.h"
1+
#include "elementwise_full_dom.h"
192
#include <math.h>
203

214
/* ----------------------- sin ----------------------- */
@@ -89,43 +72,3 @@ expr *new_cos(expr *child)
8972
node->local_wsum_hess = cos_local_wsum_hess;
9073
return node;
9174
}
92-
93-
/* ----------------------- tan ----------------------- */
94-
static void tan_forward(expr *node, const double *u)
95-
{
96-
node->left->forward(node->left, u);
97-
for (int i = 0; i < node->size; i++)
98-
{
99-
node->value[i] = tan(node->left->value[i]);
100-
}
101-
}
102-
103-
static void tan_local_jacobian(expr *node, double *vals)
104-
{
105-
expr *child = node->left;
106-
for (int j = 0; j < node->size; j++)
107-
{
108-
double c = cos(child->value[j]);
109-
vals[j] = 1.0 / (c * c);
110-
}
111-
}
112-
113-
static void tan_local_wsum_hess(expr *node, double *out, const double *w)
114-
{
115-
double *x = node->left->value;
116-
117-
for (int j = 0; j < node->size; j++)
118-
{
119-
double c = cos(x[j]);
120-
out[j] = 2.0 * w[j] * node->value[j] / (c * c);
121-
}
122-
}
123-
124-
expr *new_tan(expr *child)
125-
{
126-
expr *node = new_elementwise(child);
127-
node->forward = tan_forward;
128-
node->local_jacobian = tan_local_jacobian;
129-
node->local_wsum_hess = tan_local_wsum_hess;
130-
return node;
131-
}

0 commit comments

Comments
 (0)