Skip to content

Commit 2cd8f34

Browse files
committed
added hessian for hstack
1 parent 315c0a1 commit 2cd8f34

6 files changed

Lines changed: 275 additions & 6 deletions

File tree

include/subexpr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ typedef struct hstack_expr
4646
expr base;
4747
expr **args;
4848
int n_args;
49+
CSR_Matrix *CSR_work; /* for summing Hessians of children */
4950
} hstack_expr;
5051

5152
#endif /* SUBEXPR_H */

include/utils/CSR_Matrix.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C);
5555

5656
/* Compute C = A + B where A, B, C are CSR matrices
5757
* A and B must have same dimensions
58-
* C must be pre-allocated with sufficient nnz capacity */
58+
* C must be pre-allocated with sufficient nnz capacity.
59+
* C must be different from A and B */
5960
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
6061

6162
/* Compute C = diag(d1) * A + diag(d2) * B where A, B, C are CSR matrices */

src/affine/hstack.c

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void jacobian_init(expr *node)
3737

3838
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);
3939

40-
/* precompute sparsity pattern of this jacobian's node */
40+
/* precompute sparsity pattern of this node's jacobian */
4141
int row_offset = 0;
4242
CSR_Matrix *A = node->jacobian;
4343
A->nnz = 0;
@@ -80,6 +80,40 @@ static void eval_jacobian(expr *node)
8080
}
8181
}
8282

83+
static void wsum_hess_init(expr *node)
84+
{
85+
/* initialize children's hessians */
86+
hstack_expr *hnode = (hstack_expr *) node;
87+
int nnz = 0;
88+
for (int i = 0; i < hnode->n_args; i++)
89+
{
90+
hnode->args[i]->wsum_hess_init(hnode->args[i]);
91+
nnz += hnode->args[i]->wsum_hess->nnz;
92+
}
93+
94+
/* worst-case scenario the nnz of node->wsum_hess is the sum of children's
95+
nnz */
96+
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz);
97+
hnode->CSR_work = new_csr_matrix(node->n_vars, node->n_vars, nnz);
98+
}
99+
100+
static void wsum_hess_eval(expr *node, const double *w)
101+
{
102+
hstack_expr *hnode = (hstack_expr *) node;
103+
CSR_Matrix *H = node->wsum_hess;
104+
int row_offset = 0;
105+
H->nnz = 0;
106+
107+
for (int i = 0; i < hnode->n_args; i++)
108+
{
109+
expr *child = hnode->args[i];
110+
child->eval_wsum_hess(child, w + row_offset);
111+
copy_csr_matrix(H, hnode->CSR_work);
112+
sum_csr_matrices(hnode->CSR_work, child->wsum_hess, H);
113+
row_offset += child->size;
114+
}
115+
}
116+
83117
static bool is_affine(const expr *node)
84118
{
85119
const hstack_expr *hnode = (const hstack_expr *) node;
@@ -121,6 +155,8 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
121155
/* Set type-specific fields */
122156
hnode->args = args;
123157
hnode->n_args = n_args;
158+
node->wsum_hess_init = wsum_hess_init;
159+
node->eval_wsum_hess = wsum_hess_eval;
124160

125161
for (int i = 0; i < n_args; i++)
126162
{

src/utils/CSR_Matrix.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ void diag_csr_mult(const double *d, const CSR_Matrix *A, CSR_Matrix *C)
9999

100100
void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
101101
{
102+
/* A and B must be different from C */
103+
assert(A != C && B != C);
104+
102105
C->nnz = 0;
103106

104107
for (int row = 0; row < A->m; row++)
@@ -138,17 +141,17 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C)
138141
if (a_ptr < a_end)
139142
{
140143
int a_remaining = a_end - a_ptr;
141-
memmove(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
142-
memmove(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
144+
memcpy(C->i + C->nnz, A->i + a_ptr, a_remaining * sizeof(int));
145+
memcpy(C->x + C->nnz, A->x + a_ptr, a_remaining * sizeof(double));
143146
C->nnz += a_remaining;
144147
}
145148

146149
/* Copy remaining elements from B */
147150
if (b_ptr < b_end)
148151
{
149152
int b_remaining = b_end - b_ptr;
150-
memmove(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
151-
memmove(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
153+
memcpy(C->i + C->nnz, B->i + b_ptr, b_remaining * sizeof(int));
154+
memcpy(C->x + C->nnz, B->x + b_ptr, b_remaining * sizeof(double));
152155
C->nnz += b_remaining;
153156
}
154157
}

tests/all_tests.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "wsum_hess/elementwise/test_power.h"
3030
#include "wsum_hess/elementwise/test_trig.h"
3131
#include "wsum_hess/elementwise/test_xexp.h"
32+
#include "wsum_hess/test_hstack.h"
3233
#include "wsum_hess/test_rel_entr.h"
3334
#include "wsum_hess/test_sum.h"
3435

@@ -98,6 +99,8 @@ int main(void)
9899
mu_run_test(test_wsum_hess_sum_log_axis1, tests_run);
99100
mu_run_test(test_wsum_hess_rel_entr_1, tests_run);
100101
mu_run_test(test_wsum_hess_rel_entr_2, tests_run);
102+
mu_run_test(test_wsum_hess_hstack, tests_run);
103+
mu_run_test(test_wsum_hess_hstack_matrix, tests_run);
101104

102105
printf("\n--- Utility Tests ---\n");
103106
mu_run_test(test_diag_csr_mult, tests_run);

tests/wsum_hess/test_hstack.h

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#include "affine.h"
2+
#include "elementwise_univariate.h"
3+
#include "expr.h"
4+
#include "minunit.h"
5+
#include "test_helpers.h"
6+
#include <math.h>
7+
#include <stdio.h>
8+
9+
const char *test_wsum_hess_hstack()
10+
{
11+
/* Test: hstack([log(x), log(z), exp(x), sin(y)])
12+
* Variables: x at idx 0, z at idx 3, y at idx 6
13+
* x = [1, 2, 3], z = [4, 5, 6], y = [7, 8, 9]
14+
* Total 9 variables
15+
* hStacked vectorized output is 12x1: [log(x),
16+
* log(z),
17+
* exp(x),
18+
* sin(y)]
19+
* w = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
20+
*/
21+
22+
double u_vals[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
23+
double w[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
24+
25+
expr *x = new_variable(3, 1, 0, 9);
26+
expr *z = new_variable(3, 1, 3, 9);
27+
expr *y = new_variable(3, 1, 6, 9);
28+
29+
expr *log_x = new_log(x);
30+
expr *log_z = new_log(z);
31+
expr *exp_x = new_exp(x);
32+
expr *sin_y = new_sin(y);
33+
34+
expr *args[4] = {log_x, log_z, exp_x, sin_y};
35+
expr *hstack_node = new_hstack(args, 4, 9);
36+
37+
hstack_node->forward(hstack_node, u_vals);
38+
hstack_node->wsum_hess_init(hstack_node);
39+
hstack_node->eval_wsum_hess(hstack_node, w);
40+
41+
/* Expected Hessian:
42+
* log(x): d²/dx² = -1/x²
43+
* w[0] * (-1/1²) = 1 * (-1) = -1 at (0,0)
44+
* w[1] * (-1/2²) = 2 * (-0.25) = -0.5 at (1,1)
45+
* w[2] * (-1/3²) = 3 * (-1/9) = -1/3 at (2,2)
46+
*
47+
* log(z): d²/dz² = -1/z²
48+
* w[3] * (-1/4²) = 4 * (-1/16) = -0.25 at (3,3)
49+
* w[4] * (-1/5²) = 5 * (-1/25) = -0.2 at (4,4)
50+
* w[5] * (-1/6²) = 6 * (-1/36) = -1/6 at (5,5)
51+
*
52+
* exp(x): d²/dx² = exp(x)
53+
* w[6] * exp(1) = 7 * e at (0,0)
54+
* w[7] * exp(2) = 8 * e² at (1,1)
55+
* w[8] * exp(3) = 9 * e³ at (2,2)
56+
*
57+
* sin(y): d²/dy² = -sin(y)
58+
* w[9] * (-sin(7)) at (6,6)
59+
* w[10] * (-sin(8)) at (7,7)
60+
* w[11] * (-sin(9)) at (8,8)
61+
*
62+
* Accumulated:
63+
* (0,0): -1 + 7*e
64+
* (1,1): -0.5 + 8*e²
65+
* (2,2): -1/3 + 9*e³
66+
* (3,3): -0.25
67+
* (4,4): -0.2
68+
* (5,5): -1/6
69+
* (6,6): -10*sin(7)
70+
* (7,7): -11*sin(8)
71+
* (8,8): -12*sin(9)
72+
*/
73+
74+
double e = exp(1.0);
75+
double e2 = exp(2.0);
76+
double e3 = exp(3.0);
77+
78+
double expected_x[9] = {-1.0 + 7.0 * e,
79+
-0.5 + 8.0 * e2,
80+
-1.0 / 3.0 + 9.0 * e3,
81+
-0.25,
82+
-0.2,
83+
-1.0 / 6.0,
84+
-10.0 * sin(7.0),
85+
-11.0 * sin(8.0),
86+
-12.0 * sin(9.0)};
87+
88+
int expected_p[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
89+
int expected_i[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8};
90+
91+
mu_assert("vals incorrect",
92+
cmp_double_array(hstack_node->wsum_hess->x, expected_x, 9));
93+
mu_assert("rows incorrect",
94+
cmp_int_array(hstack_node->wsum_hess->p, expected_p, 10));
95+
mu_assert("cols incorrect",
96+
cmp_int_array(hstack_node->wsum_hess->i, expected_i, 9));
97+
98+
free_expr(hstack_node);
99+
free_expr(sin_y);
100+
free_expr(exp_x);
101+
free_expr(log_z);
102+
free_expr(log_x);
103+
free_expr(y);
104+
free_expr(z);
105+
free_expr(x);
106+
107+
return 0;
108+
}
109+
110+
const char *test_wsum_hess_hstack_matrix()
111+
{
112+
/* Test: hstack([log(x), log(z), exp(x), sin(y)]) with matrix variables
113+
* Variables: x at idx 0, z at idx 6, y at idx 12
114+
* Each is 3x2, so 6 elements per variable
115+
* x = [1 4] z = [7 10] y = [13 16]
116+
* [2 5] [8 11] [14 17]
117+
* [3 6] [9 12] [15 18]
118+
* Vectorized column-wise: x = [1,2,3,4,5,6], z = [7,8,9,10,11,12], y =
119+
* [13,14,15,16,17,18] Total 18 variables Stacked output is 24x1: [log(x),
120+
* log(z), exp(x), sin(y)] each 6x1 w = [1, 2, 3, ..., 24]
121+
*/
122+
123+
double u_vals[18] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
124+
10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0};
125+
double w[24];
126+
for (int i = 0; i < 24; i++)
127+
{
128+
w[i] = i + 1.0;
129+
}
130+
131+
expr *x = new_variable(3, 2, 0, 18);
132+
expr *z = new_variable(3, 2, 6, 18);
133+
expr *y = new_variable(3, 2, 12, 18);
134+
135+
expr *log_x = new_log(x);
136+
expr *log_z = new_log(z);
137+
expr *exp_x = new_exp(x);
138+
expr *sin_y = new_sin(y);
139+
140+
expr *args[4] = {log_x, log_z, exp_x, sin_y};
141+
expr *hstack_node = new_hstack(args, 4, 18);
142+
143+
hstack_node->forward(hstack_node, u_vals);
144+
hstack_node->wsum_hess_init(hstack_node);
145+
hstack_node->eval_wsum_hess(hstack_node, w);
146+
147+
/* Expected Hessian (diagonal):
148+
* log(x): w[0:5] * (-1/x[0:5]²) at indices 0-5
149+
* log(z): w[6:11] * (-1/z[0:5]²) at indices 6-11
150+
* exp(x): w[12:17] * exp(x[0:5]) at indices 0-5 (accumulates with log(x))
151+
* sin(y): w[18:23] * (-sin(y[0:5])) at indices 12-17
152+
*
153+
* For x indices (0-5):
154+
* i=0: -1/1² + 13*e¹ = -1 + 13*e
155+
* i=1: -2/2² + 14*e² = -0.5 + 14*e²
156+
* i=2: -3/3² + 15*e³ = -1/3 + 15*e³
157+
* i=3: -4/4² + 16*e⁴ = -0.25 + 16*e⁴
158+
* i=4: -5/5² + 17*e⁵ = -0.2 + 17*e⁵
159+
* i=5: -6/6² + 18*e⁶ = -1/6 + 18*e⁶
160+
*
161+
* For z indices (6-11):
162+
* i=0: -7/7² = -1/7
163+
* i=1: -8/8² = -1/8
164+
* i=2: -9/9² = -1/9
165+
* i=3: -10/10² = -0.1
166+
* i=4: -11/11² = -11/121
167+
* i=5: -12/12² = -1/12
168+
*
169+
* For y indices (12-17):
170+
* i=0: -19*sin(13)
171+
* i=1: -20*sin(14)
172+
* i=2: -21*sin(15)
173+
* i=3: -22*sin(16)
174+
* i=4: -23*sin(17)
175+
* i=5: -24*sin(18)
176+
*/
177+
178+
double expected_x[18];
179+
// x indices (0-5) - accumulation of log and exp
180+
expected_x[0] = -1.0 + 13.0 * exp(1.0);
181+
expected_x[1] = -0.5 + 14.0 * exp(2.0);
182+
expected_x[2] = -1.0 / 3.0 + 15.0 * exp(3.0);
183+
expected_x[3] = -0.25 + 16.0 * exp(4.0);
184+
expected_x[4] = -0.2 + 17.0 * exp(5.0);
185+
expected_x[5] = -1.0 / 6.0 + 18.0 * exp(6.0);
186+
187+
// z indices (6-11) - only log
188+
expected_x[6] = -1.0 / 7.0;
189+
expected_x[7] = -1.0 / 8.0;
190+
expected_x[8] = -1.0 / 9.0;
191+
expected_x[9] = -0.1;
192+
expected_x[10] = -11.0 / 121.0;
193+
expected_x[11] = -1.0 / 12.0;
194+
195+
// y indices (12-17) - only sin
196+
expected_x[12] = -19.0 * sin(13.0);
197+
expected_x[13] = -20.0 * sin(14.0);
198+
expected_x[14] = -21.0 * sin(15.0);
199+
expected_x[15] = -22.0 * sin(16.0);
200+
expected_x[16] = -23.0 * sin(17.0);
201+
expected_x[17] = -24.0 * sin(18.0);
202+
203+
int expected_p[19] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
204+
10, 11, 12, 13, 14, 15, 16, 17, 18};
205+
int expected_i[18] = {0, 1, 2, 3, 4, 5, 6, 7, 8,
206+
9, 10, 11, 12, 13, 14, 15, 16, 17};
207+
208+
mu_assert("vals incorrect",
209+
cmp_double_array(hstack_node->wsum_hess->x, expected_x, 18));
210+
mu_assert("rows incorrect",
211+
cmp_int_array(hstack_node->wsum_hess->p, expected_p, 19));
212+
mu_assert("cols incorrect",
213+
cmp_int_array(hstack_node->wsum_hess->i, expected_i, 18));
214+
215+
free_expr(hstack_node);
216+
free_expr(sin_y);
217+
free_expr(exp_x);
218+
free_expr(log_z);
219+
free_expr(log_x);
220+
free_expr(y);
221+
free_expr(z);
222+
free_expr(x);
223+
224+
return 0;
225+
}

0 commit comments

Comments
 (0)