Skip to content

Commit 978f319

Browse files
Transurgeonclaude
andcommitted
Add problem-level tests for parameter support
Exercises param_scalar_mult, param_vector_mult, and left_param_matmul with problem_register_params/problem_update_params to verify objective, gradient, constraint, and Jacobian values update correctly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dc56c90 commit 978f319

2 files changed

Lines changed: 249 additions & 0 deletions

File tree

tests/all_tests.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "jacobian_tests/test_sum.h"
4242
#include "jacobian_tests/test_trace.h"
4343
#include "jacobian_tests/test_transpose.h"
44+
#include "problem/test_param_prob.h"
4445
#include "problem/test_problem.h"
4546
#include "utils/test_csc_matrix.h"
4647
#include "utils/test_csr_matrix.h"
@@ -257,6 +258,9 @@ int main(void)
257258
mu_run_test(test_problem_jacobian_multi, tests_run);
258259
mu_run_test(test_problem_constraint_forward, tests_run);
259260
mu_run_test(test_problem_hessian, tests_run);
261+
mu_run_test(test_param_scalar_mult_problem, tests_run);
262+
mu_run_test(test_param_vector_mult_problem, tests_run);
263+
mu_run_test(test_param_left_matmul_problem, tests_run);
260264

261265
printf("\n=== All %d tests passed ===\n", tests_run);
262266

tests/problem/test_param_prob.h

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#ifndef TEST_PARAM_PROB_H
2+
#define TEST_PARAM_PROB_H
3+
4+
#include <math.h>
5+
#include <stdio.h>
6+
7+
#include "affine.h"
8+
#include "bivariate.h"
9+
#include "elementwise_univariate.h"
10+
#include "expr.h"
11+
#include "minunit.h"
12+
#include "problem.h"
13+
#include "test_helpers.h"
14+
15+
/*
16+
* Test 1: param_scalar_mult in objective
17+
*
18+
* Problem: minimize a * sum(log(x)), no constraints, x size 2
19+
* a is a scalar parameter (param_id=0)
20+
*
21+
* At x=[1,2], a=3:
22+
* obj = 3*(log(1)+log(2)) = 3*log(2)
23+
* gradient = [3/1, 3/2] = [3.0, 1.5]
24+
*
25+
* After update a=5:
26+
* obj = 5*log(2)
27+
* gradient = [5.0, 2.5]
28+
*/
29+
const char *test_param_scalar_mult_problem(void)
30+
{
31+
int n_vars = 2;
32+
33+
/* Build tree: sum(a * log(x)) */
34+
expr *x = new_variable(2, 1, 0, n_vars);
35+
expr *log_x = new_log(x);
36+
expr *a_param = new_parameter(1, 1, 0, n_vars);
37+
expr *scaled = new_param_scalar_mult(a_param, log_x);
38+
expr *objective = new_sum(scaled, -1);
39+
40+
/* Create problem (no constraints) */
41+
problem *prob = new_problem(objective, NULL, 0, true);
42+
43+
/* Register parameter */
44+
expr *param_nodes[1] = {a_param};
45+
problem_register_params(prob, param_nodes, 1);
46+
problem_init_derivatives(prob);
47+
48+
/* Set a=3 and evaluate at x=[1,2] */
49+
double theta[1] = {3.0};
50+
problem_update_params(prob, theta);
51+
52+
double u[2] = {1.0, 2.0};
53+
double obj_val = problem_objective_forward(prob, u);
54+
problem_gradient(prob);
55+
56+
double expected_obj = 3.0 * log(2.0);
57+
mu_assert("obj wrong (a=3)", fabs(obj_val - expected_obj) < 1e-10);
58+
59+
double expected_grad[2] = {3.0, 1.5};
60+
mu_assert("gradient wrong (a=3)",
61+
cmp_double_array(prob->gradient_values, expected_grad, 2));
62+
63+
/* Update a=5 and re-evaluate */
64+
theta[0] = 5.0;
65+
problem_update_params(prob, theta);
66+
67+
obj_val = problem_objective_forward(prob, u);
68+
problem_gradient(prob);
69+
70+
expected_obj = 5.0 * log(2.0);
71+
mu_assert("obj wrong (a=5)", fabs(obj_val - expected_obj) < 1e-10);
72+
73+
double expected_grad2[2] = {5.0, 2.5};
74+
mu_assert("gradient wrong (a=5)",
75+
cmp_double_array(prob->gradient_values, expected_grad2, 2));
76+
77+
free_problem(prob);
78+
79+
return 0;
80+
}
81+
82+
/*
83+
* Test 2: param_vector_mult in constraint
84+
*
85+
* Problem: minimize sum(x), subject to p ∘ x, x size 2
86+
* p is a vector parameter of size 2 (param_id=0)
87+
*
88+
* At x=[1,2], p=[3,4]:
89+
* constraint_values = [3, 8]
90+
* jacobian = diag([3, 4])
91+
*
92+
* After update p=[5,6]:
93+
* constraint_values = [5, 12]
94+
* jacobian = diag([5, 6])
95+
*/
96+
const char *test_param_vector_mult_problem(void)
97+
{
98+
int n_vars = 2;
99+
100+
/* Objective: sum(x) */
101+
expr *x_obj = new_variable(2, 1, 0, n_vars);
102+
expr *objective = new_sum(x_obj, -1);
103+
104+
/* Constraint: p ∘ x */
105+
expr *x_con = new_variable(2, 1, 0, n_vars);
106+
expr *p_param = new_parameter(2, 1, 0, n_vars);
107+
expr *constraint = new_param_vector_mult(p_param, x_con);
108+
109+
expr *constraints[1] = {constraint};
110+
111+
/* Create problem */
112+
problem *prob = new_problem(objective, constraints, 1, true);
113+
114+
expr *param_nodes[1] = {p_param};
115+
problem_register_params(prob, param_nodes, 1);
116+
problem_init_derivatives(prob);
117+
118+
/* Set p=[3,4] and evaluate at x=[1,2] */
119+
double theta[2] = {3.0, 4.0};
120+
problem_update_params(prob, theta);
121+
122+
double u[2] = {1.0, 2.0};
123+
problem_constraint_forward(prob, u);
124+
problem_jacobian(prob);
125+
126+
double expected_cv[2] = {3.0, 8.0};
127+
mu_assert("constraint values wrong (p=[3,4])",
128+
cmp_double_array(prob->constraint_values, expected_cv, 2));
129+
130+
CSR_Matrix *jac = prob->jacobian;
131+
mu_assert("jac rows wrong", jac->m == 2);
132+
mu_assert("jac cols wrong", jac->n == 2);
133+
134+
int expected_p[3] = {0, 1, 2};
135+
mu_assert("jac->p wrong (p=[3,4])", cmp_int_array(jac->p, expected_p, 3));
136+
137+
int expected_i[2] = {0, 1};
138+
mu_assert("jac->i wrong (p=[3,4])", cmp_int_array(jac->i, expected_i, 2));
139+
140+
double expected_x[2] = {3.0, 4.0};
141+
mu_assert("jac->x wrong (p=[3,4])", cmp_double_array(jac->x, expected_x, 2));
142+
143+
/* Update p=[5,6] and re-evaluate */
144+
double theta2[2] = {5.0, 6.0};
145+
problem_update_params(prob, theta2);
146+
147+
problem_constraint_forward(prob, u);
148+
problem_jacobian(prob);
149+
150+
double expected_cv2[2] = {5.0, 12.0};
151+
mu_assert("constraint values wrong (p=[5,6])",
152+
cmp_double_array(prob->constraint_values, expected_cv2, 2));
153+
154+
double expected_x2[2] = {5.0, 6.0};
155+
mu_assert("jac->x wrong (p=[5,6])", cmp_double_array(jac->x, expected_x2, 2));
156+
157+
free_problem(prob);
158+
159+
return 0;
160+
}
161+
162+
/*
163+
* Test 3: left_param_matmul in constraint
164+
*
165+
* Problem: minimize sum(x), subject to A @ x, x size 2, A is 2x2
166+
* A is a 2x2 matrix parameter (param_id=0, size=4, column-major)
167+
* A = [[1,2],[3,4]] → column-major theta = [1,3,2,4]
168+
*
169+
* At x=[1,2]:
170+
* constraint_values = [1*1+2*2, 3*1+4*2] = [5, 11]
171+
* jacobian = [[1,2],[3,4]]
172+
*
173+
* After update A = [[5,6],[7,8]] → theta = [5,7,6,8]:
174+
* constraint_values = [5*1+6*2, 7*1+8*2] = [17, 23]
175+
* jacobian = [[5,6],[7,8]]
176+
*/
177+
const char *test_param_left_matmul_problem(void)
178+
{
179+
int n_vars = 2;
180+
181+
/* Objective: sum(x) */
182+
expr *x_obj = new_variable(2, 1, 0, n_vars);
183+
expr *objective = new_sum(x_obj, -1);
184+
185+
/* Constraint: A @ x */
186+
expr *x_con = new_variable(2, 1, 0, n_vars);
187+
expr *A_param = new_parameter(2, 2, 0, n_vars);
188+
expr *constraint = new_left_param_matmul(A_param, x_con);
189+
190+
expr *constraints[1] = {constraint};
191+
192+
/* Create problem */
193+
problem *prob = new_problem(objective, constraints, 1, true);
194+
195+
expr *param_nodes[1] = {A_param};
196+
problem_register_params(prob, param_nodes, 1);
197+
problem_init_derivatives(prob);
198+
199+
/* Set A = [[1,2],[3,4]], column-major: [1,3,2,4] */
200+
double theta[4] = {1.0, 3.0, 2.0, 4.0};
201+
problem_update_params(prob, theta);
202+
203+
double u[2] = {1.0, 2.0};
204+
problem_constraint_forward(prob, u);
205+
problem_jacobian(prob);
206+
207+
double expected_cv[2] = {5.0, 11.0};
208+
mu_assert("constraint values wrong (A1)",
209+
cmp_double_array(prob->constraint_values, expected_cv, 2));
210+
211+
CSR_Matrix *jac = prob->jacobian;
212+
mu_assert("jac rows wrong", jac->m == 2);
213+
mu_assert("jac cols wrong", jac->n == 2);
214+
215+
/* Dense jacobian = [[1,2],[3,4]], CSR: row 0 → cols 0,1 vals 1,2;
216+
* row 1 → cols 0,1 vals 3,4 */
217+
int expected_p[3] = {0, 2, 4};
218+
mu_assert("jac->p wrong (A1)", cmp_int_array(jac->p, expected_p, 3));
219+
220+
int expected_i[4] = {0, 1, 0, 1};
221+
mu_assert("jac->i wrong (A1)", cmp_int_array(jac->i, expected_i, 4));
222+
223+
double expected_x[4] = {1.0, 2.0, 3.0, 4.0};
224+
mu_assert("jac->x wrong (A1)", cmp_double_array(jac->x, expected_x, 4));
225+
226+
/* Update A = [[5,6],[7,8]], column-major: [5,7,6,8] */
227+
double theta2[4] = {5.0, 7.0, 6.0, 8.0};
228+
problem_update_params(prob, theta2);
229+
230+
problem_constraint_forward(prob, u);
231+
problem_jacobian(prob);
232+
233+
double expected_cv2[2] = {17.0, 23.0};
234+
mu_assert("constraint values wrong (A2)",
235+
cmp_double_array(prob->constraint_values, expected_cv2, 2));
236+
237+
double expected_x2[4] = {5.0, 6.0, 7.0, 8.0};
238+
mu_assert("jac->x wrong (A2)", cmp_double_array(jac->x, expected_x2, 4));
239+
240+
free_problem(prob);
241+
242+
return 0;
243+
}
244+
245+
#endif /* TEST_PARAM_PROB_H */

0 commit comments

Comments
 (0)