Skip to content

Commit 19b1b0a

Browse files
authored
Numerical derivative checker (#56)
* wsum checker * formatter
1 parent 5139dce commit 19b1b0a

4 files changed

Lines changed: 131 additions & 0 deletions

File tree

tests/all_tests.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ int main(void)
300300

301301
printf("\n--- Numerical Diff Tests ---\n");
302302
mu_run_test(test_check_jacobian_composite_log, tests_run);
303+
mu_run_test(test_check_wsum_hess_log_composite, tests_run);
303304

304305
printf("\n--- Problem Struct Tests ---\n");
305306
mu_run_test(test_problem_new_free, tests_run);

tests/numerical_diff.c

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,98 @@ int check_jacobian(expr *node, const double *u, double h)
9797
free(J_num);
9898
return result;
9999
}
100+
101+
/* Compute g = J^T w where J is CSR (m x n) and w has m entries.
102+
* Result written into g (size n), which must be zero-initialized. */
103+
static void csr_transpose_mult_vec(const CSR_Matrix *J, const double *w, double *g)
104+
{
105+
for (int row = 0; row < J->m; row++)
106+
{
107+
for (int idx = J->p[row]; idx < J->p[row + 1]; idx++)
108+
{
109+
g[J->i[idx]] += J->x[idx] * w[row];
110+
}
111+
}
112+
}
113+
114+
double *numerical_wsum_hess(expr *node, const double *u, const double *w, double h)
115+
{
116+
int n = node->n_vars;
117+
double inv_2h = 1.0 / (2.0 * h);
118+
119+
/* Initialize jacobian sparsity once, then forward */
120+
node->jacobian_init(node);
121+
node->forward(node, u);
122+
123+
double *H = calloc((size_t) n * n, sizeof(double));
124+
double *u_work = malloc(n * sizeof(double));
125+
double *g_plus = malloc(n * sizeof(double));
126+
double *g_minus = malloc(n * sizeof(double));
127+
128+
memcpy(u_work, u, n * sizeof(double));
129+
130+
for (int j = 0; j < n; j++)
131+
{
132+
/* g(u + h*e_j) */
133+
u_work[j] = u[j] + h;
134+
node->forward(node, u_work);
135+
node->eval_jacobian(node);
136+
memset(g_plus, 0, n * sizeof(double));
137+
csr_transpose_mult_vec(node->jacobian, w, g_plus);
138+
139+
/* g(u - h*e_j) */
140+
u_work[j] = u[j] - h;
141+
node->forward(node, u_work);
142+
node->eval_jacobian(node);
143+
memset(g_minus, 0, n * sizeof(double));
144+
csr_transpose_mult_vec(node->jacobian, w, g_minus);
145+
146+
u_work[j] = u[j];
147+
148+
for (int i = 0; i < n; i++)
149+
{
150+
H[i * n + j] = (g_plus[i] - g_minus[i]) * inv_2h;
151+
}
152+
}
153+
154+
free(g_minus);
155+
free(g_plus);
156+
free(u_work);
157+
return H;
158+
}
159+
160+
int check_wsum_hess(expr *node, const double *u, const double *w, double h)
161+
{
162+
int n = node->n_vars;
163+
164+
/* Compute numerical first (does its own jacobian_init) */
165+
double *H_num = numerical_wsum_hess(node, u, w, h);
166+
167+
/* Now compute analytical (reuses jacobian from numerical) */
168+
node->wsum_hess_init(node);
169+
node->forward(node, u);
170+
node->eval_jacobian(node);
171+
node->eval_wsum_hess(node, w);
172+
173+
double *H_ana = calloc((size_t) n * n, sizeof(double));
174+
csr_to_dense(node->wsum_hess, H_ana);
175+
176+
int result = 1;
177+
for (int i = 0; i < n * n; i++)
178+
{
179+
if (!is_close(H_ana[i], H_num[i]))
180+
{
181+
int row = i / n;
182+
int col = i % n;
183+
printf(" check_wsum_hess FAILED at (%d, %d):"
184+
" analytical=%e, numerical=%e\n",
185+
row, col, H_ana[i], H_num[i]);
186+
result = 0;
187+
break;
188+
}
189+
}
190+
191+
free(H_ana);
192+
free(H_num);
193+
return result;
194+
}

tests/numerical_diff.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,15 @@ double *numerical_jacobian(expr *node, const double *u, double h);
1515
* Prints diagnostic on first failing entry. */
1616
int check_jacobian(expr *node, const double *u, double h);
1717

18+
/* Compute dense numerical weighted-sum Hessian via central
19+
* differences on the gradient g(u) = J(u)^T w.
20+
* Returns malloc'd row-major array (n_vars x n_vars).
21+
* Caller must free(). */
22+
double *numerical_wsum_hess(expr *node, const double *u, const double *w, double h);
23+
24+
/* Evaluate analytical wsum_hess, compute numerical wsum_hess,
25+
* and compare. Returns 1 on match, 0 on mismatch.
26+
* Prints diagnostic on first failing entry. */
27+
int check_wsum_hess(expr *node, const double *u, const double *w, double h);
28+
1829
#endif /* NUMERICAL_DIFF_H */

tests/numerical_diff/test_numerical_diff.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,27 @@ const char *test_check_jacobian_composite_log(void)
2828
free_csr_matrix(A);
2929
return 0;
3030
}
31+
32+
const char *test_check_wsum_hess_log_composite(void)
33+
{
34+
double u_vals[5] = {1, 2, 3, 4, 5};
35+
double w[3] = {-1, -2, -3};
36+
double Ax[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
37+
int Ai[] = {0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4};
38+
int Ap[] = {0, 5, 10, 15};
39+
CSR_Matrix *A_csr = new_csr_matrix(3, 5, 15);
40+
memcpy(A_csr->x, Ax, 15 * sizeof(double));
41+
memcpy(A_csr->i, Ai, 15 * sizeof(int));
42+
memcpy(A_csr->p, Ap, 4 * sizeof(int));
43+
44+
expr *x = new_variable(5, 1, 0, 5);
45+
expr *Ax_node = new_linear(x, A_csr, NULL);
46+
expr *log_node = new_log(Ax_node);
47+
48+
mu_assert("check_wsum_hess failed",
49+
check_wsum_hess(log_node, u_vals, w, NUMERICAL_DIFF_DEFAULT_H));
50+
51+
free_expr(log_node);
52+
free_csr_matrix(A_csr);
53+
return 0;
54+
}

0 commit comments

Comments
 (0)