Skip to content

Commit e49d034

Browse files
Transurgeonclaude
andcommitted
Fix upper_tri to use row-major ordering to match CVXPY
CVXPY's upper_tri returns elements row-by-row (i outer, j inner), which differs from the engine's typical column-major convention. Swap the loop nesting to iterate rows-then-columns for compatibility. Add a 4x4 forward test that distinguishes the two orderings. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9d87ea1 commit e49d034

4 files changed

Lines changed: 52 additions & 7 deletions

File tree

src/atoms/affine/upper_tri.c

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@
2222
#include <stdlib.h>
2323

2424
/* Extract strict upper triangular elements (excluding diagonal)
25-
* from a square matrix, in column-major order.
25+
* from a square matrix, in ROW-MAJOR order to match CVXPY.
2626
*
27-
* For an (n, n) matrix, element (i, j) with i < j is at flat
28-
* index j * n + i. Output has n * (n - 1) / 2 elements. */
27+
* NOTE: This is an exception to the engine's column-major
28+
* convention. CVXPY's upper_tri iterates row-by-row across
29+
* columns (i outer, j inner), so we do the same here for
30+
* compatibility.
31+
*
32+
* For an (n, n) column-major matrix, element (i, j) with
33+
* i < j is at flat index j * n + i.
34+
* Output has n * (n - 1) / 2 elements. */
2935

3036
expr *new_upper_tri(expr *child)
3137
{
@@ -38,9 +44,9 @@ expr *new_upper_tri(expr *child)
3844
{
3945
indices = (int *) malloc((size_t) n_elems * sizeof(int));
4046
int k = 0;
41-
for (int j = 0; j < n; j++)
47+
for (int i = 0; i < n; i++)
4248
{
43-
for (int i = 0; i < j; i++)
49+
for (int j = i + 1; j < n; j++)
4450
{
4551
indices[k++] = j * n + i;
4652
}

tests/all_tests.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ int main(void)
136136
mu_run_test(test_left_matmul_dense, tests_run);
137137
mu_run_test(test_diag_mat_forward, tests_run);
138138
mu_run_test(test_upper_tri_forward, tests_run);
139+
mu_run_test(test_upper_tri_forward_4x4, tests_run);
139140

140141
printf("\n--- Jacobian Tests ---\n");
141142
mu_run_test(test_neg_jacobian, tests_run);

tests/forward_pass/affine/test_upper_tri.h

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ const char *test_upper_tri_forward(void)
1313
* Matrix: 1 4 7
1414
* 2 5 8
1515
* 3 6 9
16-
* Upper tri (i < j): (0,1)=4, (0,2)=7, (1,2)=8
16+
* Upper tri in row-major order (matching CVXPY):
17+
* Row 0: (0,1)=4, (0,2)=7
18+
* Row 1: (1,2)=8
1719
* Flat indices: 3, 6, 7 */
1820
double u[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
1921
expr *var = new_variable(3, 3, 0, 9);
@@ -30,3 +32,39 @@ const char *test_upper_tri_forward(void)
3032
free_expr(ut);
3133
return 0;
3234
}
35+
36+
const char *test_upper_tri_forward_4x4(void)
37+
{
38+
/* 4x4 matrix variable (column-major): [1..16]
39+
* Matrix: 1 5 9 13
40+
* 2 6 10 14
41+
* 3 7 11 15
42+
* 4 8 12 16
43+
* Upper tri in row-major order (matching CVXPY):
44+
* Row 0: (0,1)=5, (0,2)=9, (0,3)=13
45+
* Row 1: (1,2)=10, (1,3)=14
46+
* Row 2: (2,3)=15
47+
* Flat indices: 4, 8, 12, 9, 13, 14
48+
*
49+
* NOTE: column-major order would give [4, 8, 9, 12, 13, 14]
50+
* instead. This test verifies the row-major ordering. */
51+
double u[16];
52+
for (int k = 0; k < 16; k++)
53+
{
54+
u[k] = (double) (k + 1);
55+
}
56+
expr *var = new_variable(4, 4, 0, 16);
57+
expr *ut = new_upper_tri(var);
58+
59+
mu_assert("upper_tri 4x4 d1", ut->d1 == 6);
60+
mu_assert("upper_tri 4x4 d2", ut->d2 == 1);
61+
62+
ut->forward(ut, u);
63+
64+
double expected[6] = {5.0, 9.0, 13.0, 10.0, 14.0, 15.0};
65+
mu_assert("upper_tri forward 4x4",
66+
cmp_double_array(ut->value, expected, 6));
67+
68+
free_expr(ut);
69+
return 0;
70+
}

tests/wsum_hess/affine/test_upper_tri.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const char *test_wsum_hess_upper_tri_log(void)
1515
{
1616
/* upper_tri(log(X)) where X is 3x3, w = [1, 1, 1]
1717
* X (column-major): [1, 2, 3, 4, 5, 6, 7, 8, 9]
18-
* Upper tri flat indices: [3, 6, 7]
18+
* Upper tri flat indices (row-major): [3, 6, 7]
1919
* Hessian of log is diag(-1/x^2)
2020
* Weights scatter: parent_w[3]=1, parent_w[6]=1, parent_w[7]=1
2121
* All other parent_w entries = 0

0 commit comments

Comments
 (0)