Skip to content

Commit af8da7d

Browse files
committed
removed CSR_work from expr
1 parent 43f51b8 commit af8da7d

14 files changed

Lines changed: 110 additions & 82 deletions

File tree

include/affine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define AFFINE_H
33

44
#include "expr.h"
5+
#include "subexpr.h"
56
#include "utils/CSR_Matrix.h"
67

78
/* Helper function to initialize a linear operator expr (can be used with derived

include/elementwise_univariate.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define ELEMENTWISE_H
33

44
#include "expr.h"
5+
#include "subexpr.h"
56

67
/* Helper function to initialize an elementwise expr (can be used with derived types)
78
*/

include/expr.h

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88

99
#define JAC_IDXS_NOT_SET -1
1010

11-
/* Forward declarations */
12-
struct expr;
13-
struct int_double_pair;
14-
1511
/* Function pointer types */
12+
struct expr;
1613
typedef void (*forward_fn)(struct expr *node, const double *u);
1714
typedef void (*jacobian_init_fn)(struct expr *node);
1815
typedef void (*wsum_hess_init_fn)(struct expr *node);
@@ -50,54 +47,13 @@ typedef struct expr
5047
// ------------------------------------------------------------------------
5148
// other things
5249
// ------------------------------------------------------------------------
53-
CSR_Matrix *CSR_work;
5450
is_affine_fn is_affine;
5551
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
5652
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
5753
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
5854

5955
} expr;
6056

61-
/* Type-specific expression structures that "inherit" from expr */
62-
63-
/* Linear operator: y = A * x */
64-
typedef struct linear_op_expr
65-
{
66-
expr base; /* MUST be first member for casting to work */
67-
CSC_Matrix *A_csc;
68-
CSR_Matrix *A_csr;
69-
} linear_op_expr;
70-
71-
/* Power: y = x^p */
72-
typedef struct power_expr
73-
{
74-
expr base; /* MUST be first member for casting to work */
75-
int p;
76-
} power_expr;
77-
78-
/* Quadratic form: y = x'*Q*x */
79-
typedef struct quad_form_expr
80-
{
81-
expr base; /* MUST be first member for casting to work */
82-
CSR_Matrix *Q;
83-
} quad_form_expr;
84-
85-
/* Sum reduction along an axis */
86-
typedef struct sum_expr
87-
{
88-
expr base; /* MUST be first member for casting to work */
89-
int axis;
90-
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
91-
} sum_expr;
92-
93-
/* Horizontal stack (concatenate) */
94-
typedef struct hstack_expr
95-
{
96-
expr base; /* MUST be first member for casting to work */
97-
expr **args;
98-
int n_args;
99-
} hstack_expr;
100-
10157
expr *new_expr(int d1, int d2, int n_vars);
10258
void free_expr(expr *node);
10359

include/other.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define OTHER_H
33

44
#include "expr.h"
5+
#include "subexpr.h"
56
#include "utils/CSR_Matrix.h"
67

78
/* Helper function to initialize a quad_form expr (can be used with derived types) */

include/subexpr.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef SUBEXPR_H
2+
#define SUBEXPR_H
3+
4+
#include "expr.h"
5+
#include "utils/CSC_Matrix.h"
6+
#include "utils/CSR_Matrix.h"
7+
8+
/* Forward declaration */
9+
struct int_double_pair;
10+
11+
/* Type-specific expression structures that "inherit" from expr */
12+
13+
/* Linear operator: y = A * x */
14+
typedef struct linear_op_expr
15+
{
16+
expr base;
17+
CSC_Matrix *A_csc;
18+
CSR_Matrix *A_csr;
19+
} linear_op_expr;
20+
21+
/* Power: y = x^p */
22+
typedef struct power_expr
23+
{
24+
expr base;
25+
int p;
26+
} power_expr;
27+
28+
/* Quadratic form: y = x'*Q*x */
29+
typedef struct quad_form_expr
30+
{
31+
expr base;
32+
CSR_Matrix *Q;
33+
} quad_form_expr;
34+
35+
/* Sum reduction along an axis */
36+
typedef struct sum_expr
37+
{
38+
expr base;
39+
int axis;
40+
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
41+
} sum_expr;
42+
43+
/* Horizontal stack (concatenate) */
44+
typedef struct hstack_expr
45+
{
46+
expr base;
47+
expr **args;
48+
int n_args;
49+
} hstack_expr;
50+
51+
#endif /* SUBEXPR_H */

include/utils/CSC_Matrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,9 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3939
*/
4040
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
4141

42+
/* C = z^T A where A is in CSC format and C is assumed to have one row.
43+
* C must have column indices pre-computed. Fills in values of C only.
44+
*/
45+
void csc_matvec_fill_values(const CSC_Matrix *A, const double *z, CSR_Matrix *C);
46+
4247
#endif /* CSC_MATRIX_H */

src/affine/hstack.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ void init_hstack(expr *node, int d1, int d2, int n_vars)
108108
node->value = (double *) calloc(node->size, sizeof(double));
109109
node->jacobian = NULL;
110110
node->wsum_hess = NULL;
111-
node->CSR_work = NULL;
112111
node->jacobian_init = jacobian_init;
113112
node->wsum_hess_init = NULL;
114113
node->eval_jacobian = eval_jacobian;

src/affine/linear_op.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ void init_linear_op(expr *node, expr *child, int d1, int d2)
3939
node->value = (double *) calloc(node->size, sizeof(double));
4040
node->jacobian = NULL;
4141
node->wsum_hess = NULL;
42-
node->CSR_work = NULL;
4342
node->jacobian_init = NULL;
4443
node->wsum_hess_init = NULL;
4544
node->eval_jacobian = NULL;

src/affine/sum.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ void init_sum(expr *node, expr *child, int d1)
161161
node->value = (double *) calloc(node->size, sizeof(double));
162162
node->jacobian = NULL;
163163
node->wsum_hess = NULL;
164-
node->CSR_work = NULL;
165164
node->jacobian_init = jacobian_init;
166165
node->wsum_hess_init = wsum_hess_init;
167166
node->eval_jacobian = eval_jacobian;

src/bivariate/quad_over_lin.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "bivariate.h"
2+
#include "subexpr.h"
3+
#include "utils/CSC_Matrix.h"
24
#include <assert.h>
35
#include <math.h>
46
#include <stdlib.h>
@@ -58,14 +60,15 @@ static void jacobian_init(expr *node)
5860
}
5961
}
6062
}
61-
else /* left node is not a variable */
63+
else /* left node is not a variable (guaranteed to be a linear operator) */
6264
{
65+
linear_op_expr *lin_x = (linear_op_expr *) x;
6366
node->dwork = (double *) malloc(x->d1 * sizeof(double));
6467

6568
/* compute required allocation and allocate jacobian */
6669
bool *col_nz = (bool *) calloc(
6770
node->n_vars, sizeof(bool)); /* TODO: could use iwork here instead*/
68-
int nonzero_cols = count_nonzero_cols(x->jacobian, col_nz);
71+
int nonzero_cols = count_nonzero_cols(lin_x->base.jacobian, col_nz);
6972
node->jacobian = new_csr_matrix(1, node->n_vars, nonzero_cols + 1);
7073

7174
/* precompute column indices */
@@ -88,11 +91,8 @@ static void jacobian_init(expr *node)
8891
node->jacobian->p[0] = 0;
8992
node->jacobian->p[1] = node->jacobian->nnz;
9093

91-
/* store A^T of child's A to simplify chain rule computation */
92-
node->iwork = (int *) malloc(x->jacobian->n * sizeof(int));
93-
node->CSR_work = transpose(x->jacobian, node->iwork);
94-
9594
/* find position where y should be inserted */
95+
node->iwork = (int *) malloc(sizeof(int));
9696
for (int j = 0; j < node->jacobian->nnz; j++)
9797
{
9898
if (node->jacobian->i[j] == y->var_id)
@@ -132,14 +132,16 @@ static void eval_jacobian(expr *node)
132132
}
133133
else /* x is not a variable */
134134
{
135+
CSC_Matrix *A_csc = ((linear_op_expr *) x)->A_csc;
136+
135137
/* local jacobian */
136138
for (int j = 0; j < x->d1; j++)
137139
{
138140
node->dwork[j] = (2.0 * x->value[j]) / y->value[0];
139141
}
140142

141-
/* chain rule (no derivative wrt y) */
142-
csr_matvec_fill_values(node->CSR_work, node->dwork, node->jacobian);
143+
/* chain rule (no derivative wrt y) using CSC format */
144+
csc_matvec_fill_values(A_csc, node->dwork, node->jacobian);
143145

144146
/* insert derivative wrt y at right place (for correctness this assumes
145147
that y does not appear in the denominator, but this will always be

0 commit comments

Comments
 (0)