@@ -400,4 +400,69 @@ const char *test_param_right_matmul_rectangular(void)
400400 return 0 ;
401401}
402402
403+ const char * test_param_shared_left_matmul_problem (void )
404+ {
405+ int n = 4 ;
406+
407+ /* minimize sum(x) subject to A@x and A@y, shared A parameter */
408+ expr * x = new_variable (2 , 1 , 0 , n );
409+ expr * y = new_variable (2 , 1 , 2 , n );
410+ expr * objective = new_sum (x , -1 );
411+ expr * A_param = new_parameter (2 , 2 , 0 , n , NULL );
412+
413+ /* dense 2x2 identity */
414+ double Ax [4 ] = {1.0 , 0.0 , 0.0 , 1.0 };
415+ expr * constraints [2 ];
416+ constraints [0 ] = new_left_matmul_dense (A_param , x , 2 , 2 , Ax );
417+ constraints [1 ] = new_left_matmul_dense (A_param , y , 2 , 2 , Ax );
418+ problem * prob = new_problem (objective , constraints , 2 , false);
419+
420+ /* register parameters and fill sparsity patterns */
421+ expr * param_nodes [1 ] = {A_param };
422+ problem_register_params (prob , param_nodes , 1 );
423+ problem_init_derivatives (prob );
424+
425+ /* point for evaluating and utilities for test */
426+ double x_vals [4 ] = {1.0 , 2.0 , 3.0 , 4.0 };
427+ int Ap [5 ] = {0 , 2 , 4 , 6 , 8 };
428+ int Ai [8 ] = {0 , 1 , 0 , 1 , 2 , 3 , 2 , 3 };
429+
430+ /* test 1: initial identity jacobian */
431+ problem_constraint_forward (prob , x_vals );
432+ double constrs [4 ] = {1.0 , 2.0 , 3.0 , 4.0 };
433+ problem_jacobian (prob );
434+ double jac_x [8 ] = {1.0 , 0.0 , 0.0 , 1.0 , 1.0 , 0.0 , 0.0 , 1.0 };
435+ mu_assert ("vals fail" , cmp_double_array (prob -> constraint_values , constrs , 4 ));
436+ mu_assert ("vals fail" , cmp_double_array (prob -> jacobian -> x , jac_x , 8 ));
437+ mu_assert ("rows fail" , cmp_int_array (prob -> jacobian -> p , Ap , 5 ));
438+ mu_assert ("cols fail" , cmp_int_array (prob -> jacobian -> i , Ai , 8 ));
439+
440+ /* test 2: A = [[1,2],[3,4]] (column-major [1,3,2,4]) */
441+ double theta [4 ] = {1.0 , 3.0 , 2.0 , 4.0 };
442+ problem_update_params (prob , theta );
443+ problem_constraint_forward (prob , x_vals );
444+ problem_jacobian (prob );
445+ constrs [0 ] = 5.0 ;
446+ constrs [1 ] = 11.0 ;
447+ constrs [2 ] = 11.0 ;
448+ constrs [3 ] = 25.0 ;
449+ jac_x [0 ] = 1.0 ;
450+ jac_x [1 ] = 2.0 ;
451+ jac_x [2 ] = 3.0 ;
452+ jac_x [3 ] = 4.0 ;
453+ jac_x [4 ] = 1.0 ;
454+ jac_x [5 ] = 2.0 ;
455+ jac_x [6 ] = 3.0 ;
456+ jac_x [7 ] = 4.0 ;
457+ mu_assert ("vals fail" , cmp_double_array (prob -> constraint_values , constrs , 4 ));
458+ mu_assert ("vals fail" , cmp_double_array (prob -> jacobian -> x , jac_x , 8 ));
459+ mu_assert ("rows fail" , cmp_int_array (prob -> jacobian -> p , Ap , 5 ));
460+ mu_assert ("cols fail" , cmp_int_array (prob -> jacobian -> i , Ai , 8 ));
461+ mu_assert ("vals fail" , cmp_double_array (prob -> constraint_values , constrs , 4 ));
462+
463+ free_problem (prob );
464+
465+ return 0 ;
466+ }
467+
403468#endif /* TEST_PARAM_PROB_H */
0 commit comments