@@ -254,4 +254,98 @@ const char *test_param_left_matmul_problem(void)
254254 return 0 ;
255255}
256256
257+ /*
258+ * Test 4: right_param_matmul in constraint
259+ *
260+ * Problem: minimize sum(x), subject to x @ A, x size 1x2, A is 2x2
261+ * A is a 2x2 matrix parameter (param_id=0, size=4, CSR data order)
262+ * A = [[1,2],[3,4]] → CSR data order theta = [1,2,3,4]
263+ *
264+ * At x=[1,2]:
265+ * constraint_values = [1*1+2*3, 1*2+2*4] = [7, 10]
266+ * jacobian = [[1,3],[2,4]] = A^T
267+ *
268+ * After update A = [[5,6],[7,8]] → theta = [5,6,7,8]:
269+ * constraint_values = [1*5+2*7, 1*6+2*8] = [19, 22]
270+ * jacobian = [[5,7],[6,8]] = A^T
271+ */
272+ const char * test_param_right_matmul_problem (void )
273+ {
274+ int n_vars = 2 ;
275+
276+ /* Objective: sum(x) */
277+ expr * x_obj = new_variable (1 , 2 , 0 , n_vars );
278+ expr * objective = new_sum (x_obj , -1 );
279+
280+ /* Constraint: x @ A */
281+ expr * x_con = new_variable (1 , 2 , 0 , n_vars );
282+ expr * A_param = new_parameter (2 , 2 , 0 , n_vars , NULL );
283+
284+ /* Dense 2x2 CSR with placeholder zeros (values refreshed from A_param) */
285+ CSR_Matrix * A = new_csr_matrix (2 , 2 , 4 );
286+ int Ap [3 ] = {0 , 2 , 4 };
287+ int Ai [4 ] = {0 , 1 , 0 , 1 };
288+ double Ax [4 ] = {0.0 , 0.0 , 0.0 , 0.0 };
289+ memcpy (A -> p , Ap , 3 * sizeof (int ));
290+ memcpy (A -> i , Ai , 4 * sizeof (int ));
291+ memcpy (A -> x , Ax , 4 * sizeof (double ));
292+
293+ expr * constraint = new_right_matmul (A_param , x_con , A );
294+ free_csr_matrix (A );
295+
296+ expr * constraints [1 ] = {constraint };
297+
298+ /* Create problem */
299+ problem * prob = new_problem (objective , constraints , 1 , true);
300+
301+ expr * param_nodes [1 ] = {A_param };
302+ problem_register_params (prob , param_nodes , 1 );
303+ problem_init_derivatives (prob );
304+
305+ /* Set A = [[1,2],[3,4]], CSR data order: [1,2,3,4] */
306+ double theta [4 ] = {1.0 , 2.0 , 3.0 , 4.0 };
307+ problem_update_params (prob , theta );
308+
309+ double u [2 ] = {1.0 , 2.0 };
310+ problem_constraint_forward (prob , u );
311+ problem_jacobian (prob );
312+
313+ double expected_cv [2 ] = {7.0 , 10.0 };
314+ mu_assert ("constraint values wrong (A1)" ,
315+ cmp_double_array (prob -> constraint_values , expected_cv , 2 ));
316+
317+ CSR_Matrix * jac = prob -> jacobian ;
318+ mu_assert ("jac rows wrong" , jac -> m == 2 );
319+ mu_assert ("jac cols wrong" , jac -> n == 2 );
320+
321+ /* Dense jacobian = [[1,3],[2,4]] = A^T, CSR: row 0 → cols 0,1 vals 1,3;
322+ * row 1 → cols 0,1 vals 2,4 */
323+ int expected_p [3 ] = {0 , 2 , 4 };
324+ mu_assert ("jac->p wrong (A1)" , cmp_int_array (jac -> p , expected_p , 3 ));
325+
326+ int expected_i [4 ] = {0 , 1 , 0 , 1 };
327+ mu_assert ("jac->i wrong (A1)" , cmp_int_array (jac -> i , expected_i , 4 ));
328+
329+ double expected_x [4 ] = {1.0 , 3.0 , 2.0 , 4.0 };
330+ mu_assert ("jac->x wrong (A1)" , cmp_double_array (jac -> x , expected_x , 4 ));
331+
332+ /* Update A = [[5,6],[7,8]], CSR data order: [5,6,7,8] */
333+ double theta2 [4 ] = {5.0 , 6.0 , 7.0 , 8.0 };
334+ problem_update_params (prob , theta2 );
335+
336+ problem_constraint_forward (prob , u );
337+ problem_jacobian (prob );
338+
339+ double expected_cv2 [2 ] = {19.0 , 22.0 };
340+ mu_assert ("constraint values wrong (A2)" ,
341+ cmp_double_array (prob -> constraint_values , expected_cv2 , 2 ));
342+
343+ double expected_x2 [4 ] = {5.0 , 7.0 , 6.0 , 8.0 };
344+ mu_assert ("jac->x wrong (A2)" , cmp_double_array (jac -> x , expected_x2 , 4 ));
345+
346+ free_problem (prob );
347+
348+ return 0 ;
349+ }
350+
257351#endif /* TEST_PARAM_PROB_H */
0 commit comments