Skip to content

Commit 5527b17

Browse files
committed
refactor hstack to precompute sparsity pattern + new init_expr to reuse functioanlity
1 parent af8da7d commit 5527b17

19 files changed

Lines changed: 70 additions & 101 deletions

File tree

include/affine.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ expr *new_add(expr *left, expr *right);
1616
/* Helper function to initialize a sum expr (can be used with derived types) */
1717
void init_sum(expr *node, expr *child, int d1);
1818

19-
/* Helper function to initialize an hstack expr (can be used with derived types) */
20-
void init_hstack(expr *node, int d1, int d2, int n_vars);
21-
2219
expr *new_sum(expr *child, int axis);
2320
expr *new_hstack(expr **args, int n_args, int n_vars);
2421

include/elementwise_univariate.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define ELEMENTWISE_H
33

44
#include "expr.h"
5-
#include "subexpr.h"
65

76
/* Helper function to initialize an elementwise expr (can be used with derived types)
87
*/
@@ -27,11 +26,11 @@ expr *new_xexp(expr *child);
2726
void jacobian_init_elementwise(expr *node);
2827
void eval_jacobian_elementwise(expr *node);
2928
void wsum_hess_init_elementwise(expr *node);
30-
void eval_wsum_hess_elementwise(expr *node, double *w);
29+
void eval_wsum_hess_elementwise(expr *node, const double *w);
3130
expr *new_elementwise(expr *child);
3231

3332
/* no elementwise atoms are affine according to our convention,
3433
so we can have a common implementation */
35-
bool is_affine_elementwise(expr *node);
34+
bool is_affine_elementwise(const expr *node);
3635

3736
#endif /* ELEMENTWISE_UNIVARIATE_H */

include/expr.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ typedef void (*forward_fn)(struct expr *node, const double *u);
1414
typedef void (*jacobian_init_fn)(struct expr *node);
1515
typedef void (*wsum_hess_init_fn)(struct expr *node);
1616
typedef void (*eval_jacobian_fn)(struct expr *node);
17-
typedef void (*wsum_hess_fn)(struct expr *node, double *w);
17+
typedef void (*wsum_hess_fn)(struct expr *node, const double *w);
1818
typedef void (*local_jacobian_fn)(struct expr *node, double *out);
19-
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, double *w);
20-
typedef bool (*is_affine_fn)(struct expr *node);
19+
typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double *w);
20+
typedef bool (*is_affine_fn)(const struct expr *node);
2121
typedef void (*free_type_data_fn)(struct expr *node);
2222

2323
/* Base expression node structure - contains only common fields */
@@ -54,6 +54,8 @@ typedef struct expr
5454

5555
} expr;
5656

57+
void init_expr(expr *node, int d1, int d2, int n_vars);
58+
5759
expr *new_expr(int d1, int d2, int n_vars);
5860
void free_expr(expr *node);
5961

src/affine/add.c

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static void wsum_hess_init(expr *node)
4545
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
4646
}
4747

48-
static void eval_wsum_hess(expr *node, double *w)
48+
static void eval_wsum_hess(expr *node, const double *w)
4949
{
5050
/* evaluate children's wsum_hess */
5151
node->left->eval_wsum_hess(node->left, w);
@@ -55,20 +55,14 @@ static void eval_wsum_hess(expr *node, double *w)
5555
sum_csr_matrices(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);
5656
}
5757

58-
static bool is_affine(expr *node)
58+
static bool is_affine(const expr *node)
5959
{
6060
return node->left->is_affine(node->left) && node->right->is_affine(node->right);
6161
}
6262

6363
expr *new_add(expr *left, expr *right)
6464
{
65-
if (!left || !right) return NULL;
66-
if (left->d1 != right->d1) return NULL;
67-
if (left->d2 != right->d2) return NULL;
68-
6965
expr *node = new_expr(left->d1, left->d2, left->n_vars);
70-
if (!node) return NULL;
71-
7266
node->left = left;
7367
node->right = right;
7468
expr_retain(left);

src/affine/constant.c

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,16 @@ static void forward(expr *node, const double *u)
88
(void) u;
99
}
1010

11-
static bool is_affine(expr *node)
11+
static bool is_affine(const expr *node)
1212
{
1313
(void) node;
14-
return true; /* constant is affine */
14+
return true;
1515
}
1616

1717
expr *new_constant(int d1, int d2, int n_vars, const double *values)
1818
{
1919
expr *node = new_expr(d1, d2, n_vars);
20-
if (!node) return NULL;
21-
22-
/* Copy constant values */
2320
memcpy(node->value, values, d1 * d2 * sizeof(double));
24-
2521
node->forward = forward;
2622
node->is_affine = is_affine;
2723

src/affine/hstack.c

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,18 @@ static void jacobian_init(expr *node)
3636
}
3737

3838
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz);
39-
}
40-
41-
static void eval_jacobian(expr *node)
42-
{
43-
hstack_expr *hnode = (hstack_expr *) node;
4439

45-
/* evaluate children's jacobians */
40+
/* precompute sparsity pattern of this jacobian's node */
4641
int row_offset = 0;
4742
CSR_Matrix *A = node->jacobian;
4843
A->nnz = 0;
4944

5045
for (int i = 0; i < hnode->n_args; i++)
5146
{
5247
expr *child = hnode->args[i];
53-
child->eval_jacobian(child);
5448
CSR_Matrix *B = child->jacobian;
5549

56-
/* copy columns and values */
57-
memcpy(A->x + A->nnz, B->x, B->nnz * sizeof(double));
50+
/* copy columns */
5851
memcpy(A->i + A->nnz, B->i, B->nnz * sizeof(int));
5952

6053
/* set row pointers */
@@ -69,9 +62,27 @@ static void eval_jacobian(expr *node)
6962
A->p[node->size] = A->nnz;
7063
}
7164

72-
static bool is_affine(expr *node)
65+
static void eval_jacobian(expr *node)
7366
{
7467
hstack_expr *hnode = (hstack_expr *) node;
68+
CSR_Matrix *A = node->jacobian;
69+
A->nnz = 0;
70+
71+
for (int i = 0; i < hnode->n_args; i++)
72+
{
73+
expr *child = hnode->args[i];
74+
child->eval_jacobian(child);
75+
76+
/* copy values */
77+
memcpy(A->x + A->nnz, child->jacobian->x,
78+
child->jacobian->nnz * sizeof(double));
79+
A->nnz += child->jacobian->nnz;
80+
}
81+
}
82+
83+
static bool is_affine(const expr *node)
84+
{
85+
const hstack_expr *hnode = (const hstack_expr *) node;
7586

7687
for (int i = 0; i < hnode->n_args; i++)
7788
{
@@ -92,33 +103,6 @@ static void free_type_data(expr *node)
92103
}
93104
}
94105

95-
/* Helper function to initialize an hstack expr */
96-
void init_hstack(expr *node, int d1, int d2, int n_vars)
97-
{
98-
node->d1 = d1;
99-
node->d2 = d2;
100-
node->size = d1 * d2;
101-
node->n_vars = n_vars;
102-
node->var_id = -1;
103-
node->refcount = 1;
104-
node->left = NULL;
105-
node->right = NULL;
106-
node->dwork = NULL;
107-
node->iwork = NULL;
108-
node->value = (double *) calloc(node->size, sizeof(double));
109-
node->jacobian = NULL;
110-
node->wsum_hess = NULL;
111-
node->jacobian_init = jacobian_init;
112-
node->wsum_hess_init = NULL;
113-
node->eval_jacobian = eval_jacobian;
114-
node->eval_wsum_hess = NULL;
115-
node->local_jacobian = NULL;
116-
node->local_wsum_hess = NULL;
117-
node->forward = forward;
118-
node->is_affine = is_affine;
119-
node->free_type_data = free_type_data;
120-
}
121-
122106
expr *new_hstack(expr **args, int n_args, int n_vars)
123107
{
124108
/* compute second dimension */
@@ -129,20 +113,18 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
129113
}
130114

131115
/* Allocate the type-specific struct */
132-
hstack_expr *hnode = (hstack_expr *) malloc(sizeof(hstack_expr));
133-
if (!hnode) return NULL;
134-
116+
hstack_expr *hnode = (hstack_expr *) calloc(1, sizeof(hstack_expr));
135117
expr *node = &hnode->base;
136118

137-
/* Initialize base hstack fields */
138-
init_hstack(node, args[0]->d1, d2, n_vars);
119+
/* Initialize basic fields */
120+
init_expr(node, args[0]->d1, d2, n_vars);
139121

140-
/* Check if allocation succeeded */
141-
if (!node->value)
142-
{
143-
free(hnode);
144-
return NULL;
145-
}
122+
/* Set function pointers */
123+
node->forward = forward;
124+
node->jacobian_init = jacobian_init;
125+
node->eval_jacobian = eval_jacobian;
126+
node->is_affine = is_affine;
127+
node->free_type_data = free_type_data;
146128

147129
/* Set type-specific fields */
148130
hnode->args = args;

src/affine/linear_op.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ static void forward(expr *node, const double *u)
1111
csr_matvec(node->jacobian, x->value, node->value, x->var_id);
1212
}
1313

14-
static bool is_affine(expr *node)
14+
static bool is_affine(const expr *node)
1515
{
1616
return node->left->is_affine(node->left);
1717
}

src/affine/sum.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ static void wsum_hess_init(expr *node)
109109
node->dwork = malloc(x->size * sizeof(double));
110110
}
111111

112-
static void eval_wsum_hess(expr *node, double *w)
112+
static void eval_wsum_hess(expr *node, const double *w)
113113
{
114114
expr *x = node->left;
115115
sum_expr *snode = (sum_expr *) node;
@@ -134,7 +134,7 @@ static void eval_wsum_hess(expr *node, double *w)
134134
copy_csr_matrix(x->wsum_hess, node->wsum_hess);
135135
}
136136

137-
static bool is_affine(expr *node)
137+
static bool is_affine(const expr *node)
138138
{
139139
return node->left->is_affine(node->left);
140140
}

src/affine/variable.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static void jacobian_init(expr *node)
2020
node->jacobian->p[size] = size;
2121
}
2222

23-
static bool is_affine(expr *node)
23+
static bool is_affine(const expr *node)
2424
{
2525
(void) node;
2626
return true;

src/elementwise_univariate/common.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "elementwise_univariate.h"
2+
#include "expr.h"
3+
#include "subexpr.h"
24
#include <stdlib.h>
35

46
void jacobian_init_elementwise(expr *node)
@@ -72,7 +74,7 @@ void wsum_hess_init_elementwise(expr *node)
7274
}
7375
}
7476

75-
void eval_wsum_hess_elementwise(expr *node, double *w)
77+
void eval_wsum_hess_elementwise(expr *node, const double *w)
7678
{
7779
expr *child = node->left;
7880

@@ -89,7 +91,7 @@ void eval_wsum_hess_elementwise(expr *node, double *w)
8991
}
9092
}
9193

92-
bool is_affine_elementwise(expr *node)
94+
bool is_affine_elementwise(const expr *node)
9395
{
9496
(void) node;
9597
return false;

0 commit comments

Comments
 (0)