Skip to content

Commit 891f127

Browse files
committed
refactored elementwise univariate to resuse already computed values
1 parent 2ed525b commit 891f127

12 files changed

Lines changed: 29 additions & 53 deletions

File tree

include/expr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stddef.h>
88

99
#define JAC_IDXS_NOT_SET -1
10+
#define NOT_A_VARIABLE -1
1011

1112
/* Function pointer types */
1213
struct expr;

src/bivariate/multiply.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ static void jacobian_init(expr *node)
3232
node->left->jacobian_init(node->left);
3333
}
3434

35-
if (node->right->var_id != -1)
35+
if (node->right->var_id != NOT_A_VARIABLE)
3636
{
3737
node->right->jacobian_init(node->right);
3838
}

src/bivariate/quad_over_lin.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static void jacobian_init(expr *node)
3636
expr *y = node->right;
3737

3838
/* if left node is a variable */
39-
if (x->var_id != -1)
39+
if (x->var_id != NOT_A_VARIABLE)
4040
{
4141
node->jacobian = new_csr_matrix(1, node->n_vars, x->d1 + 1);
4242
node->jacobian->p[0] = 0;
@@ -110,7 +110,7 @@ static void eval_jacobian(expr *node)
110110
expr *y = node->right;
111111

112112
/* if x is a variable */
113-
if (x->var_id != -1)
113+
if (x->var_id != NOT_A_VARIABLE)
114114
{
115115
/* if x has lower idx than y*/
116116
if (x->var_id < y->var_id)

src/elementwise_univariate/common.c

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ void jacobian_init_elementwise(expr *node)
88
expr *child = node->left;
99

1010
/* if the variable is a child */
11-
if (child->var_id != -1)
11+
if (child->var_id != NOT_A_VARIABLE)
1212
{
1313
node->jacobian = new_csr_matrix(node->size, node->n_vars, node->size);
1414
for (int j = 0; j < node->size; j++)
@@ -18,7 +18,7 @@ void jacobian_init_elementwise(expr *node)
1818
}
1919
node->jacobian->p[node->size] = node->size;
2020
}
21-
/* otherwise it should be a linear operator */
21+
/* otherwise it will be a linear operator */
2222
else
2323
{
2424
node->jacobian = new_csr_matrix(child->jacobian->m, child->jacobian->n,
@@ -31,13 +31,13 @@ void eval_jacobian_elementwise(expr *node)
3131
{
3232
expr *child = node->left;
3333

34-
if (child->var_id != -1)
34+
if (child->var_id != NOT_A_VARIABLE)
3535
{
3636
node->local_jacobian(node, node->jacobian->x);
3737
}
3838
else
3939
{
40-
/* Child must be a linear operator */
40+
/* Child will be a linear operator */
4141
linear_op_expr *lin_child = (linear_op_expr *) child;
4242
node->local_jacobian(node, node->dwork);
4343
diag_csr_mult(node->dwork, lin_child->A_csr, node->jacobian);
@@ -51,7 +51,7 @@ void wsum_hess_init_elementwise(expr *node)
5151
int i;
5252

5353
/* if the variable is a child*/
54-
if (id != -1)
54+
if (id != NOT_A_VARIABLE)
5555
{
5656
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, node->size);
5757

@@ -78,13 +78,13 @@ void eval_wsum_hess_elementwise(expr *node, const double *w)
7878
{
7979
expr *child = node->left;
8080

81-
if (child->var_id != -1)
81+
if (child->var_id != NOT_A_VARIABLE)
8282
{
8383
node->local_wsum_hess(node, node->wsum_hess->x, w);
8484
}
8585
else
8686
{
87-
/* Child must be a linear operator */
87+
/* Child will be a linear operator */
8888
linear_op_expr *lin_child = (linear_op_expr *) child;
8989
node->local_wsum_hess(node, node->dwork, w);
9090
ATDA_values(lin_child->A_csc, node->dwork, node->wsum_hess);

src/elementwise_univariate/exp.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ static void local_jacobian(expr *node, double *vals)
2121

2222
static void local_wsum_hess(expr *node, double *out, const double *w)
2323
{
24-
double *x = node->left->value;
2524
for (int j = 0; j < node->size; j++)
2625
{
27-
out[j] = w[j] * exp(x[j]);
26+
out[j] = w[j] * node->value[j];
2827
}
2928
}
3029

src/elementwise_univariate/hyperbolic.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ static void sinh_local_jacobian(expr *node, double *vals)
2323

2424
static void sinh_local_wsum_hess(expr *node, double *out, const double *w)
2525
{
26-
double *x = node->left->value;
2726
for (int j = 0; j < node->size; j++)
2827
{
29-
out[j] = w[j] * sinh(x[j]);
28+
out[j] = w[j] * node->value[j];
3029
}
3130
}
3231

@@ -51,10 +50,10 @@ static void tanh_forward(expr *node, const double *u)
5150

5251
static void tanh_local_jacobian(expr *node, double *vals)
5352
{
54-
expr *child = node->left;
53+
double *x = node->left->value;
5554
for (int j = 0; j < node->size; j++)
5655
{
57-
double c = cosh(child->value[j]);
56+
double c = cosh(x[j]);
5857
vals[j] = 1.0 / (c * c);
5958
}
6059
}
@@ -65,7 +64,7 @@ static void tanh_local_wsum_hess(expr *node, double *out, const double *w)
6564
for (int j = 0; j < node->size; j++)
6665
{
6766
double c = cosh(x[j]);
68-
out[j] = w[j] * (-2.0 * tanh(x[j]) / (c * c));
67+
out[j] = w[j] * (-2.0 * node->value[j] / (c * c));
6968
}
7069
}
7170

src/elementwise_univariate/logistic.c

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ static void local_jacobian(expr *node, double *vals)
4242

4343
static void local_wsum_hess(expr *node, double *out, const double *w)
4444
{
45-
double *x = node->left->value;
4645
double *sigmas;
4746

48-
if (node->left->var_id != -1)
47+
if (node->left->var_id != NOT_A_VARIABLE)
4948
{
5049
sigmas = node->jacobian->x;
5150
}
@@ -54,23 +53,8 @@ static void local_wsum_hess(expr *node, double *out, const double *w)
5453
sigmas = node->dwork;
5554
}
5655

57-
// double sigma;
58-
5956
for (int j = 0; j < node->size; j++)
6057
{
61-
/*
62-
if (x[j] >= 0)
63-
{
64-
sigma = 1.0 / (1.0 + exp(-x[j]));
65-
}
66-
else
67-
{
68-
double exp_x = exp(x[j]);
69-
sigma = exp_x / (1.0 + exp_x);
70-
}
71-
72-
out[j] = w[j] * sigma * (1.0 - sigma);
73-
*/
7458
out[j] = sigmas[j] * (1.0 - sigmas[j]) * w[j];
7559
}
7660
}

src/elementwise_univariate/power.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ expr *new_power(expr *child, int p)
4545
/* Allocate the type-specific struct */
4646
power_expr *pnode = (power_expr *) calloc(1, sizeof(power_expr));
4747
expr *node = &pnode->base;
48-
49-
/* Initialize base elementwise fields */
5048
init_elementwise(node, child);
51-
5249
node->forward = forward;
5350
node->local_jacobian = local_jacobian;
5451
node->local_wsum_hess = local_wsum_hess;

src/elementwise_univariate/trig.c

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,18 @@ static void sin_forward(expr *node, const double *u)
1313

1414
static void sin_local_jacobian(expr *node, double *vals)
1515
{
16-
expr *child = node->left;
16+
double *x = node->left->value;
1717
for (int j = 0; j < node->size; j++)
1818
{
19-
vals[j] = cos(child->value[j]);
19+
vals[j] = cos(x[j]);
2020
}
2121
}
2222

2323
static void sin_local_wsum_hess(expr *node, double *out, const double *w)
2424
{
25-
double *x = node->left->value;
26-
2725
for (int j = 0; j < node->size; j++)
2826
{
29-
out[j] = -w[j] * sin(x[j]);
27+
out[j] = -w[j] * node->value[j];
3028
}
3129
}
3230

@@ -51,20 +49,18 @@ static void cos_forward(expr *node, const double *u)
5149

5250
static void cos_local_jacobian(expr *node, double *vals)
5351
{
54-
expr *child = node->left;
52+
double *x = node->left->value;
5553
for (int j = 0; j < node->size; j++)
5654
{
57-
vals[j] = -sin(child->value[j]);
55+
vals[j] = -sin(x[j]);
5856
}
5957
}
6058

6159
static void cos_local_wsum_hess(expr *node, double *out, const double *w)
6260
{
63-
double *x = node->left->value;
64-
6561
for (int j = 0; j < node->size; j++)
6662
{
67-
out[j] = -w[j] * cos(x[j]);
63+
out[j] = -w[j] * node->value[j];
6864
}
6965
}
7066

@@ -104,7 +100,7 @@ static void tan_local_wsum_hess(expr *node, double *out, const double *w)
104100
for (int j = 0; j < node->size; j++)
105101
{
106102
double c = cos(x[j]);
107-
out[j] = 2.0 * w[j] * tan(x[j]) / (c * c);
103+
out[j] = 2.0 * w[j] * node->value[j] / (c * c);
108104
}
109105
}
110106

src/elementwise_univariate/xexp.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ static void forward(expr *node, const double *u)
1616

1717
static void local_jacobian(expr *node, double *vals)
1818
{
19-
expr *child = node->left;
19+
double *x = node->left->value;
2020
for (int j = 0; j < node->size; j++)
2121
{
22-
vals[j] = (1.0 + child->value[j]) * exp(child->value[j]);
22+
vals[j] = (1.0 + x[j]) * exp(x[j]);
2323
}
2424
}
2525

0 commit comments

Comments
 (0)