Skip to content

Commit 1acf4c8

Browse files
committed
wsum of add and sum
1 parent 9d6a923 commit 1acf4c8

8 files changed

Lines changed: 266 additions & 2 deletions

File tree

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
1. power should be double
22
2. can we reuse calculations, like in hessian of logistic
3-
3. more tests for chain rule elementwise univariate hessian
3+
3. more tests for chain rule elementwise univariate hessian
4+
4. in the refactor, add consts
5+
5. multiply with one constant vector/scalar argument
6+
6. AX where X is a matrix. Can that happen? How is that canonicalized? Maybe it can't happen.

include/utils/mini_numpy.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef MINI_NUMPY_H
2+
#define MINI_NUMPY_H
3+
4+
/* Repeat each element of array 'a' 'repeats' times
5+
* Example: a = [1, 2], len = 2, repeats = 3
6+
* result = [1, 1, 1, 2, 2, 2]
7+
*/
8+
void repeat(double *result, const double *a, int len, int repeats);
9+
10+
/* Tile array 'a' 'tiles' times
11+
* Example: a = [1, 2], len = 2, tiles = 3
12+
* result = [1, 2, 1, 2, 1, 2]
13+
*/
14+
void tile(double *result, const double *a, int len, int tiles);
15+
16+
/* Fill array with 'size' copies of 'value'
17+
* Example: size = 5, value = 3.0
18+
* result = [3.0, 3.0, 3.0, 3.0, 3.0]
19+
*/
20+
void scaled_ones(double *result, int size, double value);
21+
22+
#endif /* MINI_NUMPY_H */

src/affine/add.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,27 @@ static void eval_jacobian(expr *node)
3434
sum_csr_matrices(node->left->jacobian, node->right->jacobian, node->jacobian);
3535
}
3636

37+
static void wsum_hess_init(expr *node)
38+
{
39+
/* initialize children's wsum_hess */
40+
node->left->wsum_hess_init(node->left);
41+
node->right->wsum_hess_init(node->right);
42+
43+
/* we never have to store more than the sum of children's nnz */
44+
int nnz_max = node->left->wsum_hess->nnz + node->right->wsum_hess->nnz;
45+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
46+
}
47+
48+
static void eval_wsum_hess(expr *node, double *w)
49+
{
50+
/* evaluate children's wsum_hess */
51+
node->left->eval_wsum_hess(node->left, w);
52+
node->right->eval_wsum_hess(node->right, w);
53+
54+
/* sum children's wsum_hess */
55+
sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
56+
}
57+
3758
static bool is_affine(expr *node)
3859
{
3960
return node->left->is_affine(node->left) && node->right->is_affine(node->right);
@@ -56,6 +77,8 @@ expr *new_add(expr *left, expr *right)
5677
node->is_affine = is_affine;
5778
node->jacobian_init = jacobian_init;
5879
node->eval_jacobian = eval_jacobian;
80+
node->wsum_hess_init = wsum_hess_init;
81+
node->eval_wsum_hess = eval_wsum_hess;
5982

6083
return node;
6184
}

src/affine/sum.c

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "affine.h"
22
#include "utils/int_double_pair.h"
3+
#include "utils/mini_numpy.h"
34
#include <assert.h>
45
#include <stdlib.h>
56
#include <string.h>
@@ -91,6 +92,40 @@ static void eval_jacobian(expr *node)
9192
}
9293
}
9394

95+
static void wsum_hess_init(expr *node)
96+
{
97+
expr *x = node->left;
98+
/* initialize child's wsum_hess */
99+
x->wsum_hess_init(x);
100+
101+
/* we never have to store more than the child's nnz */
102+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
103+
node->dwork = malloc(x->size * sizeof(double));
104+
}
105+
106+
static void eval_wsum_hess(expr *node, double *w)
107+
{
108+
expr *x = node->left;
109+
110+
if (node->axis == -1)
111+
{
112+
scaled_ones(node->dwork, x->size, *w);
113+
}
114+
else if (node->axis == 0)
115+
{
116+
repeat(node->dwork, w, x->d2, x->d1);
117+
}
118+
else if (node->axis == 1)
119+
{
120+
tile(node->dwork, w, x->d1, x->d2);
121+
}
122+
123+
x->eval_wsum_hess(x, node->dwork);
124+
125+
/* todo: is this copy necessary or can we just change pointers? */
126+
copy_csr_matrix(x->wsum_hess, node->wsum_hess);
127+
}
128+
94129
static bool is_affine(expr *node)
95130
{
96131
return node->left->is_affine(node->left);
@@ -125,6 +160,8 @@ expr *new_sum(expr *child, int axis)
125160
node->is_affine = is_affine;
126161
node->jacobian_init = jacobian_init;
127162
node->eval_jacobian = eval_jacobian;
163+
node->wsum_hess_init = wsum_hess_init;
164+
node->eval_wsum_hess = eval_wsum_hess;
128165
node->axis = axis;
129166

130167
return node;

src/utils/mini_numpy.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "utils/mini_numpy.h"
2+
3+
void repeat(double *result, const double *a, int len, int repeats)
4+
{
5+
int idx = 0;
6+
for (int i = 0; i < len; i++)
7+
{
8+
for (int j = 0; j < repeats; j++)
9+
{
10+
result[idx++] = a[i];
11+
}
12+
}
13+
}
14+
15+
void tile(double *result, const double *a, int len, int tiles)
16+
{
17+
int idx = 0;
18+
for (int i = 0; i < tiles; i++)
19+
{
20+
for (int j = 0; j < len; j++)
21+
{
22+
result[idx++] = a[j];
23+
}
24+
}
25+
}
26+
27+
void scaled_ones(double *result, int size, double value)
28+
{
29+
for (int i = 0; i < size; i++)
30+
{
31+
result[i] = value;
32+
}
33+
}

tests/all_tests.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "wsum_hess/test_log.h"
2828
#include "wsum_hess/test_logistic.h"
2929
#include "wsum_hess/test_power.h"
30+
#include "wsum_hess/test_sum.h"
3031
#include "wsum_hess/test_trig.h"
3132
#include "wsum_hess/test_xexp.h"
3233

@@ -90,6 +91,9 @@ int main(void)
9091
mu_run_test(test_wsum_hess_tanh, tests_run);
9192
mu_run_test(test_wsum_hess_asinh, tests_run);
9293
mu_run_test(test_wsum_hess_atanh, tests_run);
94+
mu_run_test(test_wsum_hess_sum_log_linear, tests_run);
95+
mu_run_test(test_wsum_hess_sum_log_axis0, tests_run);
96+
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
9397

9498
printf("\n--- Utility Tests ---\n");
9599
mu_run_test(test_diag_csr_mult, tests_run);

tests/test_helpers.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,19 @@
55

66
#define EPSILON 1e-7
77

8+
#define ABS_TOL 1e-6
9+
#define REL_TOL 1e-6
10+
11+
int is_equal_double(double a, double b)
12+
{
13+
return fabs(a - b) <= fmax(ABS_TOL, REL_TOL * fmax(fabs(a), fabs(b)));
14+
}
15+
816
int cmp_double_array(const double *actual, const double *expected, int size)
917
{
1018
for (int i = 0; i < size; i++)
1119
{
12-
if (fabs(actual[i] - expected[i]) > EPSILON)
20+
if (!is_equal_double(actual[i], expected[i]))
1321
{
1422
printf(" FAILED: actual[%d] = %f, expected %f\n", i, actual[i],
1523
expected[i]);

tests/wsum_hess/test_sum.h

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include <math.h>
2+
#include <stdio.h>
3+
#include <stdlib.h>
4+
#include <string.h>
5+
6+
#include "affine.h"
7+
#include "elementwise_univariate.h"
8+
#include "expr.h"
9+
#include "minunit.h"
10+
#include "test_helpers.h"
11+
12+
const char *test_wsum_hess_sum_log_linear()
13+
{
14+
double Ax[6] = {1, 1, 2, 3, 1, -1};
15+
int Ai[6] = {0, 1, 0, 1, 0, 1};
16+
int Ap[4] = {0, 2, 4, 6};
17+
CSR_Matrix *A = new_csr_matrix(3, 2, 6);
18+
memcpy(A->x, Ax, 6 * sizeof(double));
19+
memcpy(A->i, Ai, 6 * sizeof(int));
20+
memcpy(A->p, Ap, 4 * sizeof(int));
21+
double x_vals[2] = {2.0, 1.0};
22+
double w = 1.5;
23+
24+
expr *x = new_variable(2, 1, 0, 2);
25+
expr *Ax_node = new_linear(x, A);
26+
expr *log_node = new_log(Ax_node);
27+
expr *sum_node = new_sum(log_node, -1);
28+
29+
sum_node->forward(sum_node, x_vals);
30+
sum_node->jacobian_init(sum_node);
31+
// sum_node->eval_jacobian(sum_node);
32+
sum_node->wsum_hess_init(sum_node);
33+
sum_node->eval_wsum_hess(sum_node, &w);
34+
35+
double expected_x[4] = {-1.5 * 526.0 / 441.0, 1.5 * 338.0 / 441.0,
36+
1.5 * 338.0 / 441.0, -1.5 * 571.0 / 441.0};
37+
int expected_p[3] = {0, 2, 4};
38+
int expected_i[4] = {0, 1, 0, 1};
39+
40+
mu_assert("vals incorrect",
41+
cmp_double_array(sum_node->wsum_hess->x, expected_x, 4));
42+
mu_assert("rows incorrect",
43+
cmp_int_array(sum_node->wsum_hess->p, expected_p, 3));
44+
mu_assert("cols incorrect",
45+
cmp_int_array(sum_node->wsum_hess->i, expected_i, 4));
46+
47+
free_expr(sum_node);
48+
free_expr(log_node);
49+
free_expr(Ax_node);
50+
free_csr_matrix(A);
51+
free_expr(x);
52+
53+
return 0;
54+
}
55+
56+
const char *test_wsum_hess_sum_log_axis0()
57+
{
58+
/* Test: wsum_hess of sum(log(x), axis=0) where x is 3x2
59+
* x = [[1, 4],
60+
* [2, 5],
61+
* [3, 6]]
62+
*/
63+
64+
double x[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
65+
double w[2] = {-1.0, -2.0};
66+
expr *x_node = new_variable(3, 2, 0, 6);
67+
expr *log_node = new_log(x_node);
68+
expr *sum_node = new_sum(log_node, 0);
69+
70+
sum_node->forward(sum_node, x);
71+
sum_node->jacobian_init(sum_node);
72+
sum_node->wsum_hess_init(sum_node);
73+
sum_node->eval_wsum_hess(sum_node, w);
74+
75+
/* Expected diagonal values */
76+
double expected_x[6] = {-w[0] / (x[0] * x[0]), -w[0] / (x[1] * x[1]),
77+
-w[0] / (x[2] * x[2]), -w[1] / (x[3] * x[3]),
78+
-w[1] / (x[4] * x[4]), -w[1] / (x[5] * x[5])};
79+
int expected_p[7] = {0, 1, 2, 3, 4, 5, 6};
80+
int expected_i[6] = {0, 1, 2, 3, 4, 5};
81+
82+
mu_assert("vals incorrect",
83+
cmp_double_array(sum_node->wsum_hess->x, expected_x, 6));
84+
mu_assert("rows incorrect",
85+
cmp_int_array(sum_node->wsum_hess->p, expected_p, 7));
86+
mu_assert("cols incorrect",
87+
cmp_int_array(sum_node->wsum_hess->i, expected_i, 6));
88+
89+
free_expr(sum_node);
90+
free_expr(log_node);
91+
free_expr(x_node);
92+
93+
return 0;
94+
}
95+
96+
const char *test_wsum_hess_sum_log_axis1()
97+
{
98+
/* Test: wsum_hess of sum(log(x), axis=1) where x is 3x2
99+
* x = [[1, 4],
100+
* [2, 5],
101+
* [3, 6]]
102+
*/
103+
104+
double x[6] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
105+
double w[3] = {-1.0, -2.0, -3.0};
106+
expr *x_node = new_variable(3, 2, 0, 6);
107+
expr *log_node = new_log(x_node);
108+
expr *sum_node = new_sum(log_node, 1);
109+
110+
sum_node->forward(sum_node, x);
111+
sum_node->jacobian_init(sum_node);
112+
sum_node->wsum_hess_init(sum_node);
113+
sum_node->eval_wsum_hess(sum_node, w);
114+
115+
/* Expected diagonal values */
116+
double expected_x[6] = {-w[0] / (x[0] * x[0]), -w[1] / (x[1] * x[1]),
117+
-w[2] / (x[2] * x[2]), -w[0] / (x[3] * x[3]),
118+
-w[1] / (x[4] * x[4]), -w[2] / (x[5] * x[5])};
119+
int expected_p[7] = {0, 1, 2, 3, 4, 5, 6};
120+
int expected_i[6] = {0, 1, 2, 3, 4, 5};
121+
122+
mu_assert("vals incorrect",
123+
cmp_double_array(sum_node->wsum_hess->x, expected_x, 6));
124+
mu_assert("rows incorrect",
125+
cmp_int_array(sum_node->wsum_hess->p, expected_p, 7));
126+
mu_assert("cols incorrect",
127+
cmp_int_array(sum_node->wsum_hess->i, expected_i, 6));
128+
129+
free_expr(sum_node);
130+
free_expr(log_node);
131+
free_expr(x_node);
132+
133+
return 0;
134+
}

0 commit comments

Comments
 (0)