Skip to content

Commit 71e6021

Browse files
committed
jacobian of sum
1 parent 5c6fcc2 commit 71e6021

45 files changed

Lines changed: 1033 additions & 193 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/affine.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef AFFINE_H
2+
#define AFFINE_H
3+
4+
#include "expr.h"
5+
#include "utils/CSR_Matrix.h"
6+
7+
expr *new_linear(expr *u, const CSR_Matrix *A);
8+
9+
expr *new_add(expr *left, expr *right);
10+
expr *new_sum(expr *child, int axis);
11+
12+
expr *new_constant(int d1, int d2, const double *values);
13+
expr *new_variable(int d1, int d2, int var_id, int n_vars);
14+
15+
#endif /* AFFINE_H */

include/affine/add.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/affine/constant.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/affine/linear_op.h

Lines changed: 0 additions & 10 deletions
This file was deleted.

include/affine/variable.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/expr.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
/* Forward declarations */
1111
struct expr;
12+
struct int_double_pair;
1213

1314
/* Function pointer types */
1415
typedef void (*forward_fn)(struct expr *node, const double *u);
@@ -23,15 +24,17 @@ typedef struct expr
2324
// ------------------------------------------------------------------------
2425
// general quantities
2526
// ------------------------------------------------------------------------
26-
int m;
27+
int d1, d2, size;
2728
int n_vars;
2829
int var_id;
2930
int refcount;
3031
struct expr *left;
3132
struct expr *right;
3233
double *dwork;
3334
int *iwork;
34-
int p; /* power of power expression */
35+
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
36+
int p; /* power of power expression */
37+
int axis; /* axis for sum or similar operations */
3538

3639
// ------------------------------------------------------------------------
3740
// forward pass related quantities
@@ -52,7 +55,7 @@ typedef struct expr
5255

5356
} expr;
5457

55-
expr *new_expr(int m, int n_vars);
58+
expr *new_expr(int d1, int d2, int n_vars);
5659
void free_expr(expr *node);
5760

5861
/* Reference counting helpers */

include/utils/CSR_Matrix.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#define CSR_MATRIX_H
33
#include <stdbool.h>
44

5+
/* forward declaration */
6+
struct int_double_pair;
7+
58
/* CSR (Compressed Sparse Row) Matrix Format
69
*
710
* For an m x n matrix with nnz nonzeros:
@@ -59,11 +62,24 @@ void sum_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C);
5962
void sum_scaled_csr_matrices(const CSR_Matrix *A, const CSR_Matrix *B, CSR_Matrix *C,
6063
const double *d1, const double *d2);
6164

65+
/* Sum all rows of A into a single row matrix C */
66+
void sum_all_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
67+
struct int_double_pair *pairs);
68+
69+
/* Sum blocks of rows of A into a matrix C */
70+
void sum_block_of_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
71+
struct int_double_pair *pairs, int row_block_size);
72+
73+
/* Sum evenly spaced rows of A into a matrix C */
74+
void sum_evenly_spaced_rows_csr(const CSR_Matrix *A, CSR_Matrix *C,
75+
struct int_double_pair *pairs, int row_spacing);
76+
6277
/* Count number of columns with nonzero entries */
6378
int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz);
6479

6580
/* inserts 'idx' into array 'arr' in sorted order, and moves the other elements */
6681
void insert_idx(int idx, int *arr, int len);
6782

6883
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
84+
6985
#endif /* CSR_MATRIX_H */

include/utils/int_double_pair.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
#ifndef INT_DOUBLE_PAIR_H
3+
#define INT_DOUBLE_PAIR_H
4+
5+
typedef struct int_double_pair
6+
{
7+
int col;
8+
double val;
9+
} int_double_pair;
10+
11+
int_double_pair *new_int_double_pair_array(int size);
12+
void set_int_double_pair_array(int_double_pair *pair, int *ints, double *doubles,
13+
int size);
14+
void free_int_double_pair_array(int_double_pair *array);
15+
void sort_int_double_pair_array(int_double_pair *array, int size);
16+
17+
#endif /* INT_DOUBLE_PAIR_H */

src/affine/add.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "affine/add.h"
1+
#include "affine.h"
22

33
static void add_forward(expr *node, const double *u)
44
{
@@ -7,7 +7,7 @@ static void add_forward(expr *node, const double *u)
77
node->right->forward(node->right, u);
88

99
/* add left and right values */
10-
for (int i = 0; i < node->m; i++)
10+
for (int i = 0; i < node->size; i++)
1111
{
1212
node->value[i] = node->left->value[i] + node->right->value[i];
1313
}
@@ -21,7 +21,7 @@ static void jacobian_init(expr *node)
2121

2222
/* we never have to store more than the sum of children's nnz */
2323
int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz;
24-
node->jacobian = new_csr_matrix(node->m, node->n_vars, nnz_max);
24+
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);
2525
}
2626

2727
static void eval_jacobian(expr *node)
@@ -42,9 +42,10 @@ static bool is_affine(expr *node)
4242
expr *new_add(expr *left, expr *right)
4343
{
4444
if (!left || !right) return NULL;
45-
if (left->m != right->m) return NULL;
45+
if (left->d1 != right->d1) return NULL;
46+
if (left->d2 != right->d2) return NULL;
4647

47-
expr *node = new_expr(left->m, left->n_vars);
48+
expr *node = new_expr(left->d1, left->d2, left->n_vars);
4849
if (!node) return NULL;
4950

5051
node->left = left;

src/affine/constant.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "affine/constant.h"
1+
#include "affine.h"
22
#include <string.h>
33

44
static void forward(expr *node, const double *u)
@@ -14,13 +14,13 @@ static bool is_affine(expr *node)
1414
return true; /* constant is affine */
1515
}
1616

17-
expr *new_constant(int m, const double *values)
17+
expr *new_constant(int d1, int d2, const double *values)
1818
{
19-
expr *node = new_expr(m, 0);
19+
expr *node = new_expr(d1, d2, node->n_vars);
2020
if (!node) return NULL;
2121

2222
/* Copy constant values */
23-
memcpy(node->value, values, m * sizeof(double));
23+
memcpy(node->value, values, d1 * d2 * sizeof(double));
2424

2525
node->forward = forward;
2626
node->is_affine = is_affine;

0 commit comments

Comments
 (0)