Skip to content

Commit 9571bc5

Browse files
committed
trace
1 parent 2b63224 commit 9571bc5

10 files changed

Lines changed: 205 additions & 55 deletions

File tree

TODO.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,5 @@
22
10. For performance reasons, is it useful to have a dense matmul with A and B as dense matrices?
33
11. right matmul, add broadcasting logic as in left matmul. Is this necessary?
44

5-
Going through all atoms to see that sparsity pattern is computed in initialization of jacobian:
6-
2. trace - not ok
7-
85
Going through all atoms to see that sparsity pattern is computed in initialization of hessian:
9-
2. hstack - not ok
106
3. trace - not ok

include/subexpr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ typedef struct sum_expr
4040
int *idx_map; /* maps child nnz to summed-row positions */
4141
} sum_expr;
4242

43-
/* Trace-like reduction: sums entries spaced by child->d1 */
43+
/* trace */
4444
typedef struct trace_expr
4545
{
4646
expr base;
47-
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
47+
int *idx_map; /* maps child nnz to summed-row positions */
4848
} trace_expr;
4949

5050
/* Product of all entries */

include/utils/CSR_Matrix.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix
106106
Must memset accumulator to zero before calling. */
107107
void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map,
108108
double *accumulator);
109+
void idx_map_accumulator_with_spacing(const CSR_Matrix *A, const int *idx_map,
110+
double *accumulator, int spacing);
109111

110112
/* Sum blocks of rows of A into a matrix C */
111113
void sum_block_of_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
@@ -141,6 +143,11 @@ void sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
141143
void sum_spaced_rows_into_row_csr(const CSR_Matrix *A, CSR_Matrix *C,
142144
struct int_double_pair *pairs, int offset,
143145
int spacing);
146+
/* Fills the sparsity and index map for summing spaced rows into a row matrix */
147+
void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
148+
CSR_Matrix *C,
149+
int spacing, int *iwork,
150+
int *idx_map);
144151

145152
/* Count number of columns with nonzero entries */
146153
int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz);

python/atoms/trace.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef ATOM_TRACE_H
2+
3+
#define ATOM_TRACE_H
4+
5+
#include "common.h"
6+
7+
static PyObject *py_make_trace(PyObject *self, PyObject *args)
8+
{
9+
PyObject *child_capsule;
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
15+
if (!child)
16+
{
17+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
18+
return NULL;
19+
}
20+
expr *node = new_trace(child);
21+
if (!node)
22+
{
23+
PyErr_SetString(PyExc_RuntimeError, "failed to create trace node");
24+
return NULL;
25+
}
26+
expr_retain(node); /* Capsule owns a reference */
27+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
28+
}
29+
30+
#endif // ATOM_TRACE_H

python/bindings.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "atoms/sum.h"
4141
#include "atoms/tan.h"
4242
#include "atoms/tanh.h"
43+
#include "atoms/trace.h"
4344
#include "atoms/variable.h"
4445
#include "atoms/xexp.h"
4546

@@ -70,6 +71,8 @@ static PyMethodDef DNLPMethods[] = {
7071
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
7172
{"make_index", py_make_index, METH_VARARGS, "Create index node"},
7273
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
74+
{"make_trace", py_make_trace, METH_VARARGS,
75+
"Create trace node from an expr capsule (make_trace(child))"},
7376
{"make_hstack", py_make_hstack, METH_VARARGS,
7477
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
7578
"e2, ...], n_vars))"},

src/affine/sum.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ static void jacobian_init(expr *node)
7373
node->iwork = malloc(MAX(node->jacobian->n, x->jacobian->nnz) * sizeof(int));
7474
snode->idx_map = malloc(x->jacobian->nnz * sizeof(int));
7575

76+
/* the idx_map array maps each nonzero entry j in x->jacobian
77+
to the corresponding index in the output row matrix C. Specifically, for
78+
each nonzero entry j in A, idx_map[j] gives the position in C->x where
79+
the value from x->jacobian->x[j] should be accumulated. */
80+
7681
if (axis == -1)
7782
{
7883
sum_all_rows_csr_fill_sparsity_and_idx_map(x->jacobian, node->jacobian,

src/affine/trace.c

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include "affine.h"
22
#include "utils/int_double_pair.h"
3+
#include "utils/utils.h"
34
#include <assert.h>
5+
#include <stdio.h>
46
#include <stdlib.h>
57
#include <string.h>
8+
#include <utils/iVec.h>
69

710
static void forward(expr *node, const double *u)
811
{
@@ -25,21 +28,39 @@ static void forward(expr *node, const double *u)
2528
static void jacobian_init(expr *node)
2629
{
2730
expr *x = node->left;
31+
assert(x->d1 == x->d2);
2832

2933
/* initialize child's jacobian */
3034
x->jacobian_init(x);
3135

32-
/* count total nnz in the rows of the jacobian that should be summed */
36+
// ---------------------------------------------------------------
37+
// count total nnz and allocate matrix with sufficient space
38+
// ---------------------------------------------------------------
3339
const CSR_Matrix *A = x->jacobian;
3440
int total_nnz = 0;
3541
int row_spacing = x->d1 + 1;
42+
3643
for (int row = 0; row < A->m; row += row_spacing)
3744
{
38-
total_nnz += (A->p[row + 1] - A->p[row]);
45+
total_nnz += A->p[row + 1] - A->p[row];
3946
}
4047

4148
node->jacobian = new_csr_matrix(1, node->n_vars, total_nnz);
42-
((trace_expr *) node)->int_double_pairs = new_int_double_pair_array(total_nnz);
49+
50+
// ---------------------------------------------------------------
51+
// fill sparsity pattern and idx_map
52+
// ---------------------------------------------------------------
53+
trace_expr *tnode = (trace_expr *) node;
54+
node->iwork = malloc(MAX(node->jacobian->n, total_nnz) * sizeof(int));
55+
56+
/* the idx_map array maps each nonzero entry j in the original matrix A (from the
57+
selected, evenly spaced rows) to the corresponding index in the output row
58+
matrix C. Specifically, for each nonzero entry j in A (from the selected
59+
rows), idx_map[j] gives the position in C->x where the value from A->x[j]
60+
should be accumulated. */
61+
tnode->idx_map = malloc(x->jacobian->nnz * sizeof(int));
62+
sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(
63+
A, node->jacobian, row_spacing, node->iwork, tnode->idx_map);
4364
}
4465

4566
static void eval_jacobian(expr *node)
@@ -51,8 +72,9 @@ static void eval_jacobian(expr *node)
5172
x->eval_jacobian(x);
5273

5374
/* local jacobian */
54-
sum_spaced_rows_into_row_csr(x->jacobian, node->jacobian,
55-
tnode->int_double_pairs, 0, x->d1 + 1);
75+
memset(node->jacobian->x, 0, node->jacobian->nnz * sizeof(double));
76+
idx_map_accumulator_with_spacing(x->jacobian, tnode->idx_map, node->jacobian->x,
77+
x->d1 + 1);
5678
}
5779

5880
/* Placeholders for Hessian-related functions */
@@ -66,16 +88,12 @@ static void wsum_hess_init(expr *node)
6688
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
6789
node->dwork = (double *) calloc(x->size, sizeof(double));
6890

69-
/* TODO: here we could copy over sparsity pattern once we have checked
70-
that all atoms fill their sparsity pattern in the init functions. Perhaps
71-
we should only take sparsity pattern of rows that are summed? Not the rows
72-
which will get zero weight in the hessian. That would be very cool.
73-
But must eval_wsum_hess then also ignore contributions with zero weight? that
74-
would be bad. */
75-
76-
/* lets conclude that the hessian can be made more sophisticated */
77-
78-
/* but perhaps let's keep it as simple as possible! */
91+
/* We copy over the sparsity pattern from the child. This also includes the
92+
contribution to wsum_hess of entries of the child that will always have
93+
zero weight in eval_wsum_hess. We do this for simplicity. But the Hessian
94+
can for sure be made more sophisticated. */
95+
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->n_vars + 1) * sizeof(int));
96+
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
7997
}
8098

8199
static void eval_wsum_hess(expr *node, const double *w)
@@ -90,14 +108,7 @@ static void eval_wsum_hess(expr *node, const double *w)
90108

91109
x->eval_wsum_hess(x, node->dwork);
92110

93-
/* TODO: here we only need to copy over values once we have filled the sparsity
94-
* pattern in wsum_hess_init*/
95-
node->wsum_hess->nnz = x->wsum_hess->nnz;
96-
memcpy(node->wsum_hess->p, x->wsum_hess->p, sizeof(int) * (node->n_vars + 1));
97-
memcpy(node->wsum_hess->i, x->wsum_hess->i, sizeof(int) * x->wsum_hess->nnz);
98111
memcpy(node->wsum_hess->x, x->wsum_hess->x, sizeof(double) * x->wsum_hess->nnz);
99-
100-
/* This might contain many many zeros actually! Hmm...*/
101112
}
102113

103114
static bool is_affine(const expr *node)
@@ -107,27 +118,21 @@ static bool is_affine(const expr *node)
107118

108119
static void free_type_data(expr *node)
109120
{
110-
trace_expr *tnode = (trace_expr *) node;
111-
free_int_double_pair_array(tnode->int_double_pairs);
121+
if (node)
122+
{
123+
trace_expr *tnode = (trace_expr *) node;
124+
free(tnode->idx_map);
125+
}
112126
}
113127

114128
expr *new_trace(expr *child)
115129
{
116-
/* Output is a single scalar */
117-
int d1 = 1;
118-
119130
trace_expr *tnode = (trace_expr *) calloc(1, sizeof(trace_expr));
120131
expr *node = &tnode->base;
121-
init_expr(node, d1, 1, child->n_vars, forward, jacobian_init, eval_jacobian,
132+
init_expr(node, 1, 1, child->n_vars, forward, jacobian_init, eval_jacobian,
122133
is_affine, wsum_hess_init, eval_wsum_hess, free_type_data);
123134
node->left = child;
124135
expr_retain(child);
125136

126-
/* Initialize type-specific fields */
127-
tnode->int_double_pairs = NULL;
128-
129-
// just for debugging, should be removed
130-
strcpy(node->name, "trace");
131-
132137
return node;
133138
}

src/utils/CSR_Matrix.c

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,19 @@ void idx_map_accumulator(const CSR_Matrix *A, const int *idx_map,
754754
}
755755
}
756756

757+
void idx_map_accumulator_with_spacing(const CSR_Matrix *A, const int *idx_map,
758+
double *accumulator, int spacing)
759+
{
760+
/* don't forget to initialze accumulator to 0 before calling this */
761+
for (int row = 0; row < A->m; row += spacing)
762+
{
763+
for (int j = A->p[row]; j < A->p[row + 1]; j++)
764+
{
765+
accumulator[idx_map[j]] += A->x[j];
766+
}
767+
}
768+
}
769+
757770
void sum_spaced_rows_into_row_csr(const CSR_Matrix *A, CSR_Matrix *C,
758771
struct int_double_pair *pairs, int offset,
759772
int spacing)
@@ -1058,7 +1071,8 @@ void sum_all_rows_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A, CSR_Matrix
10581071
C->nnz = unique_nnz;
10591072

10601073
// -------------------------------------------------------------------
1061-
// Map child values to summed-row positions
1074+
// Map child values to summed-row positions. col_to_pos maps
1075+
// column indices to positions in C's row.
10621076
// -------------------------------------------------------------------
10631077
int *col_to_pos = iwork;
10641078
for (int idx = 0; idx < unique_nnz; idx++)
@@ -1090,3 +1104,61 @@ void sum_all_rows_csr_fill_values(const CSR_Matrix *A, CSR_Matrix *C,
10901104
}
10911105
}
10921106
*/
1107+
1108+
/*
1109+
* Sums evenly spaced rows from A into a single row in C and fills an index map.
1110+
* A: input CSR matrix
1111+
* C: output CSR matrix (must have m=1)
1112+
* spacing: row spacing
1113+
* iwork: workspace of size at least max(A->n, A->nnz)
1114+
* idx_map: output index map, size at least A->nnz
1115+
*/
1116+
void sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(const CSR_Matrix *A,
1117+
CSR_Matrix *C,
1118+
int spacing, int *iwork,
1119+
int *idx_map)
1120+
{
1121+
assert(C->m == 1);
1122+
C->n = A->n;
1123+
1124+
/* gather all column indices from the spaced rows */
1125+
int count = 0;
1126+
for (int row = 0; row < A->m; row += spacing)
1127+
{
1128+
int len = A->p[row + 1] - A->p[row];
1129+
memcpy(iwork + count, A->i + A->p[row], len * sizeof(int));
1130+
count += len;
1131+
}
1132+
1133+
/* fill sparsity pattern */
1134+
sort_int_array(iwork, count);
1135+
int unique_nnz = 0;
1136+
int prev_col = -1;
1137+
for (int i = 0; i < count; i++)
1138+
{
1139+
int col = iwork[i];
1140+
if (col != prev_col)
1141+
{
1142+
C->i[unique_nnz++] = col;
1143+
prev_col = col;
1144+
}
1145+
}
1146+
C->nnz = unique_nnz;
1147+
C->p[0] = 0;
1148+
C->p[1] = C->nnz;
1149+
1150+
/* fill idx_map */
1151+
int *col_to_pos = iwork;
1152+
for (int idx = 0; idx < unique_nnz; idx++)
1153+
{
1154+
col_to_pos[C->i[idx]] = idx;
1155+
}
1156+
1157+
for (int row = 0; row < A->m; row += spacing)
1158+
{
1159+
for (int j = A->p[row]; j < A->p[row + 1]; j++)
1160+
{
1161+
idx_map[j] = col_to_pos[A->i[j]];
1162+
}
1163+
}
1164+
}

tests/all_tests.c

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,9 @@ int main(void)
222222
mu_run_test(test_wsum_hess_broadcast_row, tests_run);
223223
mu_run_test(test_wsum_hess_broadcast_col, tests_run);
224224
mu_run_test(test_wsum_hess_broadcast_scalar_to_matrix, tests_run);
225-
// This test leads to seg fault
226-
// mu_run_test(test_wsum_hess_trace_variable, tests_run);
227-
228-
// This test fails - not sure how sophisticated we should make
229-
// wsum_hess for trace
230-
// mu_run_test(test_wsum_hess_trace_composite, tests_run);
225+
mu_run_test(test_wsum_hess_trace_variable, tests_run);
226+
mu_run_test(test_wsum_hess_trace_log_variable, tests_run);
227+
mu_run_test(test_wsum_hess_trace_composite, tests_run);
231228

232229
printf("\n--- Utility Tests ---\n");
233230
mu_run_test(test_diag_csr_mult, tests_run);
@@ -257,7 +254,7 @@ int main(void)
257254
mu_run_test(test_problem_constraint_forward, tests_run);
258255
mu_run_test(test_problem_hessian, tests_run);
259256

260-
printf("\n=== All %d tests passed ===\n", tests_run);
257+
printf("\n=== All %d tests passed ===\n", tests_run);
261258

262259
return 0;
263260
}

0 commit comments

Comments
 (0)