Skip to content

Commit cb7b17a

Browse files
authored
transpose implementation (#36)
1 parent 195a4b0 commit cb7b17a

8 files changed

Lines changed: 250 additions & 3 deletions

File tree

include/affine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars);
2121
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
2222
expr *new_reshape(expr *child, int d1, int d2);
2323
expr *new_broadcast(expr *child, int target_d1, int target_d2);
24+
expr *new_transpose(expr *child);
2425

2526
#endif /* AFFINE_H */

python/atoms/trace.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#ifndef ATOM_TRACE_H
2-
32
#define ATOM_TRACE_H
43

54
#include "common.h"

python/atoms/transpose.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef PYTHON_ATOMS_TRANSPOSE_H
2+
#define PYTHON_ATOMS_TRANSPOSE_H
3+
4+
#include "common.h"
5+
6+
// Python binding for the transpose atom
7+
static PyObject *py_make_transpose(PyObject *self, PyObject *args)
8+
{
9+
PyObject *child_capsule;
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
15+
16+
if (!child)
17+
{
18+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
19+
return NULL;
20+
}
21+
22+
expr *node = new_transpose(child);
23+
if (!node)
24+
{
25+
PyErr_SetString(PyExc_RuntimeError, "failed to create trace node");
26+
return NULL;
27+
}
28+
expr_retain(node); /* Capsule owns a reference */
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif // PYTHON_ATOMS_TRANSPOSE_H

python/bindings.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "atoms/tan.h"
4242
#include "atoms/tanh.h"
4343
#include "atoms/trace.h"
44+
#include "atoms/transpose.h"
4445
#include "atoms/variable.h"
4546
#include "atoms/xexp.h"
4647

@@ -71,8 +72,8 @@ static PyMethodDef DNLPMethods[] = {
7172
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
7273
{"make_index", py_make_index, METH_VARARGS, "Create index node"},
7374
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
74-
{"make_trace", py_make_trace, METH_VARARGS,
75-
"Create trace node from an expr capsule (make_trace(child))"},
75+
{"make_trace", py_make_trace, METH_VARARGS, "Create trace node"},
76+
{"make_transpose", py_make_transpose, METH_VARARGS, "Create transpose node"},
7677
{"make_hstack", py_make_hstack, METH_VARARGS,
7778
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
7879
"e2, ...], n_vars))"},

src/affine/transpose.c

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include "affine.h"
2+
#include <assert.h>
3+
#include <stdlib.h>
4+
#include <string.h>
5+
6+
// Forward pass for transpose atom
7+
static void forward(expr *node, const double *u)
8+
{
9+
/* forward pass for child */
10+
node->left->forward(node->left, u);
11+
12+
/* local forward pass */
13+
int d1 = node->d1;
14+
int d2 = node->d2;
15+
double *X = node->left->value;
16+
double *XT = node->value;
17+
for (int i = 0; i < d1; ++i)
18+
{
19+
for (int j = 0; j < d2; ++j)
20+
{
21+
XT[j * d1 + i] = X[i * d2 + j];
22+
}
23+
}
24+
}
25+
26+
static void jacobian_init(expr *node)
27+
{
28+
expr *child = node->left;
29+
child->jacobian_init(child);
30+
CSR_Matrix *Jc = child->jacobian;
31+
node->jacobian = new_csr_matrix(node->size, node->n_vars, Jc->nnz);
32+
33+
/* fill sparsity */
34+
CSR_Matrix *J = node->jacobian;
35+
int d1 = node->d1;
36+
int d2 = node->d2;
37+
int nnz = 0;
38+
J->p[0] = 0;
39+
40+
/* 'k' is the old row that gets swapped to 'row'*/
41+
int k, len;
42+
for (int row = 0; row < J->m; ++row)
43+
{
44+
k = (row / d1) + (row % d1) * d2;
45+
len = Jc->p[k + 1] - Jc->p[k];
46+
memcpy(J->i + nnz, Jc->i + Jc->p[k], len * sizeof(int));
47+
nnz += len;
48+
J->p[row + 1] = nnz;
49+
}
50+
}
51+
52+
static void eval_jacobian(expr *node)
53+
{
54+
expr *child = node->left;
55+
child->eval_jacobian(child);
56+
CSR_Matrix *Jc = child->jacobian;
57+
CSR_Matrix *J = node->jacobian;
58+
59+
int d1 = node->d1;
60+
int d2 = node->d2;
61+
int nnz = 0;
62+
for (int row = 0; row < J->m; ++row)
63+
{
64+
int k = (row / d1) + (row % d1) * d2;
65+
int len = Jc->p[k + 1] - Jc->p[k];
66+
memcpy(J->x + nnz, Jc->x + Jc->p[k], len * sizeof(double));
67+
nnz += len;
68+
}
69+
}
70+
71+
static void wsum_hess_init(expr *node)
72+
{
73+
/* initialize child */
74+
expr *x = node->left;
75+
x->wsum_hess_init(x);
76+
77+
/* same sparsity pattern as child */
78+
CSR_Matrix *H = node->wsum_hess;
79+
H = new_csr_matrix(x->wsum_hess->m, node->n_vars, x->wsum_hess->nnz);
80+
memcpy(H->p, x->wsum_hess->p, (H->m + 1) * sizeof(int));
81+
memcpy(H->i, x->wsum_hess->i, H->nnz * sizeof(int));
82+
node->wsum_hess = H;
83+
84+
/* for computing Kw where K is the commutation matrix */
85+
node->dwork = (double *) malloc(node->size * sizeof(double));
86+
}
87+
static void eval_wsum_hess(expr *node, const double *w)
88+
{
89+
int d2 = node->d2;
90+
int d1 = node->d1;
91+
// TODO: meaybe more efficient to do this with memcpy first
92+
93+
/* evaluate hessian of child at Kw */
94+
for (int i = 0; i < d2; ++i)
95+
{
96+
for (int j = 0; j < d1; ++j)
97+
{
98+
node->dwork[j * d2 + i] = w[i * d1 + j];
99+
}
100+
}
101+
102+
node->left->eval_wsum_hess(node->left, node->dwork);
103+
104+
/* copy to this node's hessian */
105+
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
106+
node->wsum_hess->nnz * sizeof(double));
107+
}
108+
109+
static bool is_affine(const expr *node)
110+
{
111+
return node->left->is_affine(node->left);
112+
}
113+
114+
expr *new_transpose(expr *child)
115+
{
116+
expr *node = (expr *) calloc(1, sizeof(expr));
117+
init_expr(node, child->d2, child->d1, child->n_vars, forward, jacobian_init,
118+
eval_jacobian, is_affine, wsum_hess_init, eval_wsum_hess, NULL);
119+
node->left = child;
120+
expr_retain(child);
121+
122+
return node;
123+
}

tests/all_tests.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "jacobian_tests/test_right_matmul.h"
4141
#include "jacobian_tests/test_sum.h"
4242
#include "jacobian_tests/test_trace.h"
43+
#include "jacobian_tests/test_transpose.h"
4344
#include "problem/test_problem.h"
4445
#include "utils/test_csc_matrix.h"
4546
#include "utils/test_csr_matrix.h"
@@ -70,6 +71,7 @@
7071
#include "wsum_hess/test_right_matmul.h"
7172
#include "wsum_hess/test_sum.h"
7273
#include "wsum_hess/test_trace.h"
74+
#include "wsum_hess/test_transpose.h"
7375

7476
int main(void)
7577
{
@@ -161,6 +163,7 @@ int main(void)
161163
mu_run_test(test_jacobian_right_matmul_log, tests_run);
162164
mu_run_test(test_jacobian_right_matmul_log_vector, tests_run);
163165
mu_run_test(test_jacobian_matmul, tests_run);
166+
mu_run_test(test_jacobian_transpose, tests_run);
164167

165168
printf("\n--- Weighted Sum of Hessian Tests ---\n");
166169
mu_run_test(test_wsum_hess_log, tests_run);
@@ -225,6 +228,7 @@ int main(void)
225228
mu_run_test(test_wsum_hess_trace_variable, tests_run);
226229
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
227230
mu_run_test(test_wsum_hess_trace_composite, tests_run);
231+
mu_run_test(test_wsum_hess_transpose, tests_run);
228232

229233
printf("\n--- Utility Tests ---\n");
230234
mu_run_test(test_diag_csr_mult, tests_run);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
#ifndef TEST_TRANSPOSE_H
3+
#define TEST_TRANSPOSE_H
4+
5+
#include "affine.h"
6+
#include "minunit.h"
7+
#include "test_helpers.h"
8+
#include <math.h>
9+
#include <stdio.h>
10+
11+
const char *test_jacobian_transpose()
12+
{
13+
// A = [1 2; 3 4]
14+
CSR_Matrix *A = new_csr_matrix(2, 2, 4);
15+
int A_p[3] = {0, 2, 4};
16+
int A_i[4] = {0, 1, 0, 1};
17+
double A_x[4] = {1, 2, 3, 4};
18+
memcpy(A->p, A_p, 3 * sizeof(int));
19+
memcpy(A->i, A_i, 4 * sizeof(int));
20+
memcpy(A->x, A_x, 4 * sizeof(double));
21+
22+
// X = [1 2; 3 4] (columnwise: x = [1 3 2 4])
23+
expr *X = new_variable(2, 2, 0, 4);
24+
expr *AX = new_left_matmul(X, A);
25+
expr *transpose_AX = new_transpose(AX);
26+
double u[4] = {1, 3, 2, 4};
27+
transpose_AX->forward(transpose_AX, u);
28+
transpose_AX->jacobian_init(transpose_AX);
29+
transpose_AX->eval_jacobian(transpose_AX);
30+
31+
// Jacobian of transpose_AX
32+
double expected_x[8] = {1, 2, 1, 2, 3, 4, 3, 4};
33+
int expected_p[5] = {0, 2, 4, 6, 8};
34+
int expected_i[8] = {0, 1, 2, 3, 0, 1, 2, 3};
35+
36+
mu_assert("jacobian values fail",
37+
cmp_double_array(transpose_AX->jacobian->x, expected_x, 8));
38+
mu_assert("jacobian row ptr fail",
39+
cmp_int_array(transpose_AX->jacobian->p, expected_p, 5));
40+
mu_assert("jacobian col idx fail",
41+
cmp_int_array(transpose_AX->jacobian->i, expected_i, 8));
42+
free_expr(AX);
43+
free_csr_matrix(A);
44+
return 0;
45+
}
46+
47+
#endif // TEST_TRANSPOSE_H

tests/wsum_hess/test_transpose.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef TEST_WSUM_HESS_TRANSPOSE_H
2+
#define TEST_WSUM_HESS_TRANSPOSE_H
3+
4+
#include "affine.h"
5+
#include "minunit.h"
6+
#include "test_helpers.h"
7+
#include <math.h>
8+
#include <stdio.h>
9+
10+
const char *test_wsum_hess_transpose()
11+
{
12+
13+
expr *X = new_variable(2, 2, 0, 8);
14+
expr *Y = new_variable(2, 2, 4, 8);
15+
16+
expr *XY = new_matmul(X, Y);
17+
expr *XYT = new_transpose(XY);
18+
19+
double u[8] = {1, 3, 2, 4, 5, 7, 6, 8};
20+
XYT->forward(XYT, u);
21+
XYT->wsum_hess_init(XYT);
22+
double w[4] = {1, 2, 3, 4};
23+
XYT->eval_wsum_hess(XYT, w);
24+
25+
double expected_x[16] = {1, 2, 3, 4, 1, 2, 3, 4, 1, 3, 1, 3, 2, 4, 2, 4};
26+
int expected_p[9] = {0, 2, 4, 6, 8, 10, 12, 14, 16};
27+
int expected_i[16] = {4, 6, 4, 6, 5, 7, 5, 7, 0, 1, 2, 3, 0, 1, 2, 3};
28+
29+
mu_assert("hess values fail",
30+
cmp_double_array(XYT->wsum_hess->x, expected_x, 8));
31+
mu_assert("jacobian row ptr fail",
32+
cmp_int_array(XYT->wsum_hess->p, expected_p, 5));
33+
mu_assert("jacobian col idx fail",
34+
cmp_int_array(XYT->wsum_hess->i, expected_i, 8));
35+
free_expr(XYT);
36+
37+
return 0;
38+
}
39+
40+
#endif // TEST_WSUM_HESS_TRANSPOSE_H

0 commit comments

Comments
 (0)