|
| 1 | +#include "affine.h" |
| 2 | +#include "elementwise_univariate.h" |
| 3 | +#include "expr.h" |
| 4 | +#include "minunit.h" |
| 5 | +#include "test_helpers.h" |
| 6 | +#include <math.h> |
| 7 | +#include <stdio.h> |
| 8 | + |
| 9 | +const char *test_wsum_hess_hstack() |
| 10 | +{ |
| 11 | + /* Test: hstack([log(x), log(z), exp(x), sin(y)]) |
| 12 | + * Variables: x at idx 0, z at idx 3, y at idx 6 |
| 13 | + * x = [1, 2, 3], z = [4, 5, 6], y = [7, 8, 9] |
| 14 | + * Total 9 variables |
| 15 | + * hStacked vectorized output is 12x1: [log(x), |
| 16 | + * log(z), |
| 17 | + * exp(x), |
| 18 | + * sin(y)] |
| 19 | + * w = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] |
| 20 | + */ |
| 21 | + |
| 22 | + double u_vals[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; |
| 23 | + double w[12] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}; |
| 24 | + |
| 25 | + expr *x = new_variable(3, 1, 0, 9); |
| 26 | + expr *z = new_variable(3, 1, 3, 9); |
| 27 | + expr *y = new_variable(3, 1, 6, 9); |
| 28 | + |
| 29 | + expr *log_x = new_log(x); |
| 30 | + expr *log_z = new_log(z); |
| 31 | + expr *exp_x = new_exp(x); |
| 32 | + expr *sin_y = new_sin(y); |
| 33 | + |
| 34 | + expr *args[4] = {log_x, log_z, exp_x, sin_y}; |
| 35 | + expr *hstack_node = new_hstack(args, 4, 9); |
| 36 | + |
| 37 | + hstack_node->forward(hstack_node, u_vals); |
| 38 | + hstack_node->wsum_hess_init(hstack_node); |
| 39 | + hstack_node->eval_wsum_hess(hstack_node, w); |
| 40 | + |
| 41 | + /* Expected Hessian: |
| 42 | + * log(x): d²/dx² = -1/x² |
| 43 | + * w[0] * (-1/1²) = 1 * (-1) = -1 at (0,0) |
| 44 | + * w[1] * (-1/2²) = 2 * (-0.25) = -0.5 at (1,1) |
| 45 | + * w[2] * (-1/3²) = 3 * (-1/9) = -1/3 at (2,2) |
| 46 | + * |
| 47 | + * log(z): d²/dz² = -1/z² |
| 48 | + * w[3] * (-1/4²) = 4 * (-1/16) = -0.25 at (3,3) |
| 49 | + * w[4] * (-1/5²) = 5 * (-1/25) = -0.2 at (4,4) |
| 50 | + * w[5] * (-1/6²) = 6 * (-1/36) = -1/6 at (5,5) |
| 51 | + * |
| 52 | + * exp(x): d²/dx² = exp(x) |
| 53 | + * w[6] * exp(1) = 7 * e at (0,0) |
| 54 | + * w[7] * exp(2) = 8 * e² at (1,1) |
| 55 | + * w[8] * exp(3) = 9 * e³ at (2,2) |
| 56 | + * |
| 57 | + * sin(y): d²/dy² = -sin(y) |
| 58 | + * w[9] * (-sin(7)) at (6,6) |
| 59 | + * w[10] * (-sin(8)) at (7,7) |
| 60 | + * w[11] * (-sin(9)) at (8,8) |
| 61 | + * |
| 62 | + * Accumulated: |
| 63 | + * (0,0): -1 + 7*e |
| 64 | + * (1,1): -0.5 + 8*e² |
| 65 | + * (2,2): -1/3 + 9*e³ |
| 66 | + * (3,3): -0.25 |
| 67 | + * (4,4): -0.2 |
| 68 | + * (5,5): -1/6 |
| 69 | + * (6,6): -10*sin(7) |
| 70 | + * (7,7): -11*sin(8) |
| 71 | + * (8,8): -12*sin(9) |
| 72 | + */ |
| 73 | + |
| 74 | + double e = exp(1.0); |
| 75 | + double e2 = exp(2.0); |
| 76 | + double e3 = exp(3.0); |
| 77 | + |
| 78 | + double expected_x[9] = {-1.0 + 7.0 * e, |
| 79 | + -0.5 + 8.0 * e2, |
| 80 | + -1.0 / 3.0 + 9.0 * e3, |
| 81 | + -0.25, |
| 82 | + -0.2, |
| 83 | + -1.0 / 6.0, |
| 84 | + -10.0 * sin(7.0), |
| 85 | + -11.0 * sin(8.0), |
| 86 | + -12.0 * sin(9.0)}; |
| 87 | + |
| 88 | + int expected_p[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; |
| 89 | + int expected_i[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; |
| 90 | + |
| 91 | + mu_assert("vals incorrect", |
| 92 | + cmp_double_array(hstack_node->wsum_hess->x, expected_x, 9)); |
| 93 | + mu_assert("rows incorrect", |
| 94 | + cmp_int_array(hstack_node->wsum_hess->p, expected_p, 10)); |
| 95 | + mu_assert("cols incorrect", |
| 96 | + cmp_int_array(hstack_node->wsum_hess->i, expected_i, 9)); |
| 97 | + |
| 98 | + free_expr(hstack_node); |
| 99 | + free_expr(sin_y); |
| 100 | + free_expr(exp_x); |
| 101 | + free_expr(log_z); |
| 102 | + free_expr(log_x); |
| 103 | + free_expr(y); |
| 104 | + free_expr(z); |
| 105 | + free_expr(x); |
| 106 | + |
| 107 | + return 0; |
| 108 | +} |
| 109 | + |
| 110 | +const char *test_wsum_hess_hstack_matrix() |
| 111 | +{ |
| 112 | + /* Test: hstack([log(x), log(z), exp(x), sin(y)]) with matrix variables |
| 113 | + * Variables: x at idx 0, z at idx 6, y at idx 12 |
| 114 | + * Each is 3x2, so 6 elements per variable |
| 115 | + * x = [1 4] z = [7 10] y = [13 16] |
| 116 | + * [2 5] [8 11] [14 17] |
| 117 | + * [3 6] [9 12] [15 18] |
| 118 | + * Vectorized column-wise: x = [1,2,3,4,5,6], z = [7,8,9,10,11,12], y = |
| 119 | + * [13,14,15,16,17,18] Total 18 variables Stacked output is 24x1: [log(x), |
| 120 | + * log(z), exp(x), sin(y)] each 6x1 w = [1, 2, 3, ..., 24] |
| 121 | + */ |
| 122 | + |
| 123 | + double u_vals[18] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, |
| 124 | + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0}; |
| 125 | + double w[24]; |
| 126 | + for (int i = 0; i < 24; i++) |
| 127 | + { |
| 128 | + w[i] = i + 1.0; |
| 129 | + } |
| 130 | + |
| 131 | + expr *x = new_variable(3, 2, 0, 18); |
| 132 | + expr *z = new_variable(3, 2, 6, 18); |
| 133 | + expr *y = new_variable(3, 2, 12, 18); |
| 134 | + |
| 135 | + expr *log_x = new_log(x); |
| 136 | + expr *log_z = new_log(z); |
| 137 | + expr *exp_x = new_exp(x); |
| 138 | + expr *sin_y = new_sin(y); |
| 139 | + |
| 140 | + expr *args[4] = {log_x, log_z, exp_x, sin_y}; |
| 141 | + expr *hstack_node = new_hstack(args, 4, 18); |
| 142 | + |
| 143 | + hstack_node->forward(hstack_node, u_vals); |
| 144 | + hstack_node->wsum_hess_init(hstack_node); |
| 145 | + hstack_node->eval_wsum_hess(hstack_node, w); |
| 146 | + |
| 147 | + /* Expected Hessian (diagonal): |
| 148 | + * log(x): w[0:5] * (-1/x[0:5]²) at indices 0-5 |
| 149 | + * log(z): w[6:11] * (-1/z[0:5]²) at indices 6-11 |
| 150 | + * exp(x): w[12:17] * exp(x[0:5]) at indices 0-5 (accumulates with log(x)) |
| 151 | + * sin(y): w[18:23] * (-sin(y[0:5])) at indices 12-17 |
| 152 | + * |
| 153 | + * For x indices (0-5): |
| 154 | + * i=0: -1/1² + 13*e¹ = -1 + 13*e |
| 155 | + * i=1: -2/2² + 14*e² = -0.5 + 14*e² |
| 156 | + * i=2: -3/3² + 15*e³ = -1/3 + 15*e³ |
| 157 | + * i=3: -4/4² + 16*e⁴ = -0.25 + 16*e⁴ |
| 158 | + * i=4: -5/5² + 17*e⁵ = -0.2 + 17*e⁵ |
| 159 | + * i=5: -6/6² + 18*e⁶ = -1/6 + 18*e⁶ |
| 160 | + * |
| 161 | + * For z indices (6-11): |
| 162 | + * i=0: -7/7² = -1/7 |
| 163 | + * i=1: -8/8² = -1/8 |
| 164 | + * i=2: -9/9² = -1/9 |
| 165 | + * i=3: -10/10² = -0.1 |
| 166 | + * i=4: -11/11² = -11/121 |
| 167 | + * i=5: -12/12² = -1/12 |
| 168 | + * |
| 169 | + * For y indices (12-17): |
| 170 | + * i=0: -19*sin(13) |
| 171 | + * i=1: -20*sin(14) |
| 172 | + * i=2: -21*sin(15) |
| 173 | + * i=3: -22*sin(16) |
| 174 | + * i=4: -23*sin(17) |
| 175 | + * i=5: -24*sin(18) |
| 176 | + */ |
| 177 | + |
| 178 | + double expected_x[18]; |
| 179 | + // x indices (0-5) - accumulation of log and exp |
| 180 | + expected_x[0] = -1.0 + 13.0 * exp(1.0); |
| 181 | + expected_x[1] = -0.5 + 14.0 * exp(2.0); |
| 182 | + expected_x[2] = -1.0 / 3.0 + 15.0 * exp(3.0); |
| 183 | + expected_x[3] = -0.25 + 16.0 * exp(4.0); |
| 184 | + expected_x[4] = -0.2 + 17.0 * exp(5.0); |
| 185 | + expected_x[5] = -1.0 / 6.0 + 18.0 * exp(6.0); |
| 186 | + |
| 187 | + // z indices (6-11) - only log |
| 188 | + expected_x[6] = -1.0 / 7.0; |
| 189 | + expected_x[7] = -1.0 / 8.0; |
| 190 | + expected_x[8] = -1.0 / 9.0; |
| 191 | + expected_x[9] = -0.1; |
| 192 | + expected_x[10] = -11.0 / 121.0; |
| 193 | + expected_x[11] = -1.0 / 12.0; |
| 194 | + |
| 195 | + // y indices (12-17) - only sin |
| 196 | + expected_x[12] = -19.0 * sin(13.0); |
| 197 | + expected_x[13] = -20.0 * sin(14.0); |
| 198 | + expected_x[14] = -21.0 * sin(15.0); |
| 199 | + expected_x[15] = -22.0 * sin(16.0); |
| 200 | + expected_x[16] = -23.0 * sin(17.0); |
| 201 | + expected_x[17] = -24.0 * sin(18.0); |
| 202 | + |
| 203 | + int expected_p[19] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, |
| 204 | + 10, 11, 12, 13, 14, 15, 16, 17, 18}; |
| 205 | + int expected_i[18] = {0, 1, 2, 3, 4, 5, 6, 7, 8, |
| 206 | + 9, 10, 11, 12, 13, 14, 15, 16, 17}; |
| 207 | + |
| 208 | + mu_assert("vals incorrect", |
| 209 | + cmp_double_array(hstack_node->wsum_hess->x, expected_x, 18)); |
| 210 | + mu_assert("rows incorrect", |
| 211 | + cmp_int_array(hstack_node->wsum_hess->p, expected_p, 19)); |
| 212 | + mu_assert("cols incorrect", |
| 213 | + cmp_int_array(hstack_node->wsum_hess->i, expected_i, 18)); |
| 214 | + |
| 215 | + free_expr(hstack_node); |
| 216 | + free_expr(sin_y); |
| 217 | + free_expr(exp_x); |
| 218 | + free_expr(log_z); |
| 219 | + free_expr(log_x); |
| 220 | + free_expr(y); |
| 221 | + free_expr(z); |
| 222 | + free_expr(x); |
| 223 | + |
| 224 | + return 0; |
| 225 | +} |
0 commit comments