Skip to content

Commit a5dfe1b

Browse files
Transurgeonclaude
andcommitted
Add upper_tri and diag_mat affine atoms
Both are thin wrappers that compute flat column-major index arrays and delegate to new_index. diag_mat extracts the diagonal of a square matrix into a vector. upper_tri extracts strict upper triangular elements (excluding diagonal). Also removes a duplicate new_diag_vec declaration in affine.h. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d3a5747 commit a5dfe1b

10 files changed

Lines changed: 407 additions & 1 deletion

File tree

include/affine.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
3939
expr *new_reshape(expr *child, int d1, int d2);
4040
expr *new_broadcast(expr *child, int target_d1, int target_d2);
4141
expr *new_diag_vec(expr *child);
42+
expr *new_diag_mat(expr *child);
43+
expr *new_upper_tri(expr *child);
4244
expr *new_transpose(expr *child);
43-
expr *new_diag_vec(expr *child);
4445

4546
#endif /* AFFINE_H */

src/affine/diag_mat.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright 2026 Daniel Cederberg and William Zhang
3+
*
4+
* This file is part of the DNLP-differentiation-engine project.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
// SPDX-License-Identifier: Apache-2.0
19+
20+
#include "affine.h"
21+
#include <assert.h>
22+
#include <stdlib.h>
23+
24+
/* Extract diagonal from a square matrix into a column vector.
25+
* For an (n, n) matrix in column-major order, diagonal element i
26+
* is at flat index i * (n + 1). */
27+
28+
expr *new_diag_mat(expr *child)
29+
{
30+
assert(child->d1 == child->d2);
31+
int n = child->d1;
32+
33+
int *indices = (int *) malloc((size_t) n * sizeof(int));
34+
for (int i = 0; i < n; i++)
35+
{
36+
indices[i] = i * (n + 1);
37+
}
38+
39+
expr *node = new_index(child, n, 1, indices, n);
40+
free(indices);
41+
return node;
42+
}

src/affine/upper_tri.c

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2026 Daniel Cederberg and William Zhang
3+
*
4+
* This file is part of the DNLP-differentiation-engine project.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
// SPDX-License-Identifier: Apache-2.0
19+
20+
#include "affine.h"
21+
#include <assert.h>
22+
#include <stdlib.h>
23+
24+
/* Extract strict upper triangular elements (excluding diagonal)
25+
* from a square matrix, in column-major order.
26+
*
27+
* For an (n, n) matrix, element (i, j) with i < j is at flat
28+
* index j * n + i. Output has n * (n - 1) / 2 elements. */
29+
30+
expr *new_upper_tri(expr *child)
31+
{
32+
assert(child->d1 == child->d2);
33+
int n = child->d1;
34+
int n_elems = n * (n - 1) / 2;
35+
36+
int *indices = NULL;
37+
if (n_elems > 0)
38+
{
39+
indices = (int *) malloc((size_t) n_elems * sizeof(int));
40+
int k = 0;
41+
for (int j = 0; j < n; j++)
42+
{
43+
for (int i = 0; i < j; i++)
44+
{
45+
indices[k++] = j * n + i;
46+
}
47+
}
48+
assert(k == n_elems);
49+
}
50+
51+
expr *node = new_index(child, n_elems, 1, indices, n_elems);
52+
free(indices);
53+
return node;
54+
}

tests/all_tests.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
#ifndef PROFILE_ONLY
77
#include "forward_pass/affine/test_add.h"
88
#include "forward_pass/affine/test_broadcast.h"
9+
#include "forward_pass/affine/test_diag_mat.h"
910
#include "forward_pass/affine/test_hstack.h"
1011
#include "forward_pass/affine/test_linear_op.h"
1112
#include "forward_pass/affine/test_neg.h"
1213
#include "forward_pass/affine/test_promote.h"
1314
#include "forward_pass/affine/test_sum.h"
15+
#include "forward_pass/affine/test_upper_tri.h"
1416
#include "forward_pass/affine/test_variable_constant.h"
1517
#include "forward_pass/composite/test_composite.h"
1618
#include "forward_pass/elementwise/test_exp.h"
@@ -24,6 +26,7 @@
2426
#include "jacobian_tests/test_composite.h"
2527
#include "jacobian_tests/test_const_scalar_mult.h"
2628
#include "jacobian_tests/test_const_vector_mult.h"
29+
#include "jacobian_tests/test_diag_mat.h"
2730
#include "jacobian_tests/test_elementwise_mult.h"
2831
#include "jacobian_tests/test_hstack.h"
2932
#include "jacobian_tests/test_index.h"
@@ -44,6 +47,7 @@
4447
#include "jacobian_tests/test_sum.h"
4548
#include "jacobian_tests/test_trace.h"
4649
#include "jacobian_tests/test_transpose.h"
50+
#include "jacobian_tests/test_upper_tri.h"
4751
#include "problem/test_problem.h"
4852
#include "utils/test_cblas.h"
4953
#include "utils/test_coo_matrix.h"
@@ -63,6 +67,7 @@
6367
#include "wsum_hess/test_broadcast.h"
6468
#include "wsum_hess/test_const_scalar_mult.h"
6569
#include "wsum_hess/test_const_vector_mult.h"
70+
#include "wsum_hess/test_diag_mat.h"
6671
#include "wsum_hess/test_hstack.h"
6772
#include "wsum_hess/test_index.h"
6873
#include "wsum_hess/test_left_matmul.h"
@@ -80,6 +85,7 @@
8085
#include "wsum_hess/test_sum.h"
8186
#include "wsum_hess/test_trace.h"
8287
#include "wsum_hess/test_transpose.h"
88+
#include "wsum_hess/test_upper_tri.h"
8389
#endif /* PROFILE_ONLY */
8490

8591
#ifdef PROFILE_ONLY
@@ -116,6 +122,8 @@ int main(void)
116122
mu_run_test(test_forward_prod_axis_one, tests_run);
117123
mu_run_test(test_matmul, tests_run);
118124
mu_run_test(test_left_matmul_dense, tests_run);
125+
mu_run_test(test_diag_mat_forward, tests_run);
126+
mu_run_test(test_upper_tri_forward, tests_run);
119127

120128
printf("\n--- Jacobian Tests ---\n");
121129
mu_run_test(test_neg_jacobian, tests_run);
@@ -181,6 +189,10 @@ int main(void)
181189
mu_run_test(test_jacobian_right_matmul_log_vector, tests_run);
182190
mu_run_test(test_jacobian_matmul, tests_run);
183191
mu_run_test(test_jacobian_transpose, tests_run);
192+
mu_run_test(test_diag_mat_jacobian_variable, tests_run);
193+
mu_run_test(test_diag_mat_jacobian_of_log, tests_run);
194+
mu_run_test(test_upper_tri_jacobian_variable, tests_run);
195+
mu_run_test(test_upper_tri_jacobian_of_log, tests_run);
184196

185197
printf("\n--- Weighted Sum of Hessian Tests ---\n");
186198
mu_run_test(test_wsum_hess_log, tests_run);
@@ -246,6 +258,8 @@ int main(void)
246258
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
247259
mu_run_test(test_wsum_hess_trace_composite, tests_run);
248260
mu_run_test(test_wsum_hess_transpose, tests_run);
261+
mu_run_test(test_wsum_hess_diag_mat_log, tests_run);
262+
mu_run_test(test_wsum_hess_upper_tri_log, tests_run);
249263

250264
printf("\n--- Utility Tests ---\n");
251265
mu_run_test(test_cblas_ddot, tests_run);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <stdio.h>
4+
5+
#include "affine.h"
6+
#include "expr.h"
7+
#include "minunit.h"
8+
#include "test_helpers.h"
9+
10+
const char *test_diag_mat_forward(void)
11+
{
12+
/* 3x3 matrix variable (column-major): [1,2,3,4,5,6,7,8,9]
13+
* Matrix: 1 4 7
14+
* 2 5 8
15+
* 3 6 9
16+
* Diagonal: (0,0)=1, (1,1)=5, (2,2)=9 */
17+
double u[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
18+
expr *var = new_variable(3, 3, 0, 9);
19+
expr *dm = new_diag_mat(var);
20+
21+
mu_assert("diag_mat d1", dm->d1 == 3);
22+
mu_assert("diag_mat d2", dm->d2 == 1);
23+
24+
dm->forward(dm, u);
25+
26+
double expected[3] = {1.0, 5.0, 9.0};
27+
mu_assert("diag_mat forward", cmp_double_array(dm->value, expected, 3));
28+
29+
free_expr(dm);
30+
return 0;
31+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <stdio.h>
4+
5+
#include "affine.h"
6+
#include "expr.h"
7+
#include "minunit.h"
8+
#include "test_helpers.h"
9+
10+
const char *test_upper_tri_forward(void)
11+
{
12+
/* 3x3 matrix variable (column-major): [1,2,3,4,5,6,7,8,9]
13+
* Matrix: 1 4 7
14+
* 2 5 8
15+
* 3 6 9
16+
* Upper tri (i < j): (0,1)=4, (0,2)=7, (1,2)=8
17+
* Flat indices: 3, 6, 7 */
18+
double u[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
19+
expr *var = new_variable(3, 3, 0, 9);
20+
expr *ut = new_upper_tri(var);
21+
22+
mu_assert("upper_tri d1", ut->d1 == 3);
23+
mu_assert("upper_tri d2", ut->d2 == 1);
24+
25+
ut->forward(ut, u);
26+
27+
double expected[3] = {4.0, 7.0, 8.0};
28+
mu_assert("upper_tri forward", cmp_double_array(ut->value, expected, 3));
29+
30+
free_expr(ut);
31+
return 0;
32+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <stdio.h>
4+
5+
#include "affine.h"
6+
#include "elementwise_univariate.h"
7+
#include "expr.h"
8+
#include "minunit.h"
9+
#include "test_helpers.h"
10+
11+
const char *test_diag_mat_jacobian_variable(void)
12+
{
13+
/* diag_mat of a 2x2 variable (4 vars total)
14+
* Diagonal indices in column-major: [0, 3]
15+
* Jacobian is 2x4 CSR: row 0 has col 0, row 1 has col 3 */
16+
double u[4] = {1.0, 2.0, 3.0, 4.0};
17+
expr *var = new_variable(2, 2, 0, 4);
18+
expr *dm = new_diag_mat(var);
19+
20+
dm->forward(dm, u);
21+
dm->jacobian_init(dm);
22+
dm->eval_jacobian(dm);
23+
24+
double expected_x[2] = {1.0, 1.0};
25+
int expected_p[3] = {0, 1, 2};
26+
int expected_i[2] = {0, 3};
27+
28+
mu_assert("diag_mat jac vals", cmp_double_array(dm->jacobian->x, expected_x, 2));
29+
mu_assert("diag_mat jac p", cmp_int_array(dm->jacobian->p, expected_p, 3));
30+
mu_assert("diag_mat jac i", cmp_int_array(dm->jacobian->i, expected_i, 2));
31+
32+
free_expr(dm);
33+
return 0;
34+
}
35+
36+
const char *test_diag_mat_jacobian_of_log(void)
37+
{
38+
/* diag_mat(log(X)) where X is 2x2 variable
39+
* X = [[1, 3], [2, 4]] (column-major: [1, 2, 3, 4])
40+
* Diagonal: x[0]=1, x[3]=4
41+
* d/dx log at diagonal positions:
42+
* Row 0: 1/1 = 1.0 at col 0
43+
* Row 1: 1/4 = 0.25 at col 3 */
44+
double u[4] = {1.0, 2.0, 3.0, 4.0};
45+
expr *var = new_variable(2, 2, 0, 4);
46+
expr *log_node = new_log(var);
47+
expr *dm = new_diag_mat(log_node);
48+
49+
dm->forward(dm, u);
50+
dm->jacobian_init(dm);
51+
dm->eval_jacobian(dm);
52+
53+
double expected_x[2] = {1.0, 0.25};
54+
int expected_i[2] = {0, 3};
55+
56+
mu_assert("diag_mat log jac vals",
57+
cmp_double_array(dm->jacobian->x, expected_x, 2));
58+
mu_assert("diag_mat log jac cols",
59+
cmp_int_array(dm->jacobian->i, expected_i, 2));
60+
61+
free_expr(dm);
62+
return 0;
63+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <stdio.h>
4+
5+
#include "affine.h"
6+
#include "elementwise_univariate.h"
7+
#include "expr.h"
8+
#include "minunit.h"
9+
#include "test_helpers.h"
10+
11+
const char *test_upper_tri_jacobian_variable(void)
12+
{
13+
/* upper_tri of a 3x3 variable (9 vars total)
14+
* Upper tri flat indices: [3, 6, 7]
15+
* Jacobian is 3x9 CSR: row k has col indices[k] */
16+
double u[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
17+
expr *var = new_variable(3, 3, 0, 9);
18+
expr *ut = new_upper_tri(var);
19+
20+
ut->forward(ut, u);
21+
ut->jacobian_init(ut);
22+
ut->eval_jacobian(ut);
23+
24+
double expected_x[3] = {1.0, 1.0, 1.0};
25+
int expected_p[4] = {0, 1, 2, 3};
26+
int expected_i[3] = {3, 6, 7};
27+
28+
mu_assert("upper_tri jac vals",
29+
cmp_double_array(ut->jacobian->x, expected_x, 3));
30+
mu_assert("upper_tri jac p", cmp_int_array(ut->jacobian->p, expected_p, 4));
31+
mu_assert("upper_tri jac i", cmp_int_array(ut->jacobian->i, expected_i, 3));
32+
33+
free_expr(ut);
34+
return 0;
35+
}
36+
37+
const char *test_upper_tri_jacobian_of_log(void)
38+
{
39+
/* upper_tri(log(X)) where X is 3x3 variable
40+
* Upper tri flat indices: [3, 6, 7]
41+
* X values at those positions: x[3]=4, x[6]=7, x[7]=8
42+
* d/dx log at those positions:
43+
* Row 0: 1/4 = 0.25 at col 3
44+
* Row 1: 1/7 at col 6
45+
* Row 2: 1/8 = 0.125 at col 7 */
46+
double u[9] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
47+
expr *var = new_variable(3, 3, 0, 9);
48+
expr *log_node = new_log(var);
49+
expr *ut = new_upper_tri(log_node);
50+
51+
ut->forward(ut, u);
52+
ut->jacobian_init(ut);
53+
ut->eval_jacobian(ut);
54+
55+
double expected_x[3] = {0.25, 1.0 / 7.0, 0.125};
56+
int expected_i[3] = {3, 6, 7};
57+
58+
mu_assert("upper_tri log jac vals",
59+
cmp_double_array(ut->jacobian->x, expected_x, 3));
60+
mu_assert("upper_tri log jac cols",
61+
cmp_int_array(ut->jacobian->i, expected_i, 3));
62+
63+
free_expr(ut);
64+
return 0;
65+
}

0 commit comments

Comments
 (0)