Skip to content

Commit 28f61f9

Browse files
authored
add work struct for expression (#60)
* add work struct for expression * run formatter
1 parent e613c15 commit 28f61f9

15 files changed

Lines changed: 135 additions & 119 deletions

File tree

include/expr.h

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,24 @@ typedef void (*local_wsum_hess_fn)(struct expr *node, double *out, const double
3939
typedef bool (*is_affine_fn)(const struct expr *node);
4040
typedef void (*free_type_data_fn)(struct expr *node);
4141

42-
/* Base expression node structure - contains only common fields */
42+
/* Workspace for derivative computation */
43+
typedef struct
44+
{
45+
double *dwork;
46+
int *iwork;
47+
CSC_Matrix *jacobian_csc;
48+
int *csc_work; /* for CSR-CSC conversion */
49+
50+
/* jacobian_csc_filled is only used for affine functions to avoid redundant
51+
conversions. Could become relevant for non-affine functions if we start
52+
supporting common subexpressions on the Python side. */
53+
bool jacobian_csc_filled;
54+
double *local_jac_diag; /* cached f'(g(x)) diagonal */
55+
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
56+
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
57+
} Expr_Work;
58+
59+
/* Base expression node structure */
4360
typedef struct expr
4461
{
4562
// ------------------------------------------------------------------------
@@ -48,24 +65,13 @@ typedef struct expr
4865
int d1, d2, size, n_vars, refcount, var_id;
4966
struct expr *left;
5067
struct expr *right;
51-
double *dwork;
52-
int *iwork;
5368

5469
// ------------------------------------------------------------------------
5570
// oracle related quantities
5671
// ------------------------------------------------------------------------
5772
double *value;
5873
CSR_Matrix *jacobian;
59-
CSC_Matrix *jacobian_csc;
60-
int *csc_work; /* workspace for CSR-CSC conversion */
61-
62-
/* jacobian_csc_filled is only used for affine functions to avoid redundant
63-
conversions. Could become relevant for non-affine functions if we start
64-
supporting common subexpressions on the Python side. */
65-
bool jacobian_csc_filled;
6674
CSR_Matrix *wsum_hess;
67-
CSR_Matrix *hess_term1; /* Jg^T D Jg workspace */
68-
CSR_Matrix *hess_term2; /* child wsum_hess workspace */
6975
forward_fn forward;
7076
jacobian_init_fn jacobian_init;
7177
wsum_hess_init_fn wsum_hess_init;
@@ -76,10 +82,10 @@ typedef struct expr
7682
// other things
7783
// ------------------------------------------------------------------------
7884
is_affine_fn is_affine;
79-
double *local_jac_diag; /* cached f'(g(x)) diagonal */
8085
local_jacobian_fn local_jacobian; /* used by elementwise univariate atoms*/
8186
local_wsum_hess_fn local_wsum_hess; /* used by elementwise univariate atoms*/
8287
free_type_data_fn free_type_data; /* Cleanup for type-specific fields */
88+
Expr_Work *work; /* derivative workspace */
8389

8490
// name of node just for debugging - should be removed later
8591
char name[32];

src/affine/broadcast.c

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ static void wsum_hess_init(expr *node)
196196
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
197197

198198
/* allocate space for weight vector */
199-
node->dwork = malloc(node->size * sizeof(double));
199+
node->work->dwork = malloc(node->size * sizeof(double));
200200
}
201201

202202
static void eval_wsum_hess(expr *node, const double *w)
@@ -205,7 +205,7 @@ static void eval_wsum_hess(expr *node, const double *w)
205205
expr *x = node->left;
206206

207207
/* Zero out the work array first */
208-
memset(node->dwork, 0, x->size * sizeof(double));
208+
memset(node->work->dwork, 0, x->size * sizeof(double));
209209

210210
if (bcast->type == BROADCAST_ROW)
211211
{
@@ -214,7 +214,7 @@ static void eval_wsum_hess(expr *node, const double *w)
214214
{
215215
for (int i = 0; i < node->d1; i++)
216216
{
217-
node->dwork[j] += w[i + j * node->d1];
217+
node->work->dwork[j] += w[i + j * node->d1];
218218
}
219219
}
220220
}
@@ -225,21 +225,21 @@ static void eval_wsum_hess(expr *node, const double *w)
225225
{
226226
for (int i = 0; i < node->d1; i++)
227227
{
228-
node->dwork[i] += w[i + j * node->d1];
228+
node->work->dwork[i] += w[i + j * node->d1];
229229
}
230230
}
231231
}
232232
else
233233
{
234234
/* (1, 1) -> (m, n): scalar has m*n weights to sum */
235-
node->dwork[0] = 0.0;
235+
node->work->dwork[0] = 0.0;
236236
for (int k = 0; k < node->size; k++)
237237
{
238-
node->dwork[0] += w[k];
238+
node->work->dwork[0] += w[k];
239239
}
240240
}
241241

242-
x->eval_wsum_hess(x, node->dwork);
242+
x->eval_wsum_hess(x, node->work->dwork);
243243
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
244244
}
245245

src/affine/diag_vec.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ static void wsum_hess_init(expr *node)
100100
x->wsum_hess_init(x);
101101

102102
/* workspace for extracting diagonal weights */
103-
node->dwork = (double *) calloc(x->size, sizeof(double));
103+
node->work->dwork = (double *) calloc(x->size, sizeof(double));
104104

105105
/* Copy child's Hessian structure (diag_vec is linear, so its own Hessian is
106106
* zero) */
@@ -118,11 +118,11 @@ static void eval_wsum_hess(expr *node, const double *w)
118118
/* Extract weights from diagonal positions of w (which has n^2 elements) */
119119
for (int i = 0; i < n; i++)
120120
{
121-
node->dwork[i] = w[i * (n + 1)];
121+
node->work->dwork[i] = w[i * (n + 1)];
122122
}
123123

124124
/* Evaluate child's Hessian with extracted weights */
125-
x->eval_wsum_hess(x, node->dwork);
125+
x->eval_wsum_hess(x, node->work->dwork);
126126
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
127127
}
128128

src/affine/index.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ static void wsum_hess_init(expr *node)
104104
x->wsum_hess_init(x);
105105

106106
/* for setting weight vector to evaluate hessian of child */
107-
node->dwork = (double *) calloc(x->size, sizeof(double));
107+
node->work->dwork = (double *) calloc(x->size, sizeof(double));
108108

109109
/* in the implementation of eval_wsum_hess we evaluate the
110110
child's hessian with a weight vector that has w[i] = 0
@@ -126,23 +126,23 @@ static void eval_wsum_hess(expr *node, const double *w)
126126
if (idx->has_duplicates)
127127
{
128128
/* zero and accumulate for repeated indices */
129-
memset(node->dwork, 0, x->size * sizeof(double));
129+
memset(node->work->dwork, 0, x->size * sizeof(double));
130130
for (int i = 0; i < idx->n_idxs; i++)
131131
{
132-
node->dwork[idx->indices[i]] += w[i];
132+
node->work->dwork[idx->indices[i]] += w[i];
133133
}
134134
}
135135
else
136136
{
137137
/* direct write (no memset needed, no accumulation) */
138138
for (int i = 0; i < idx->n_idxs; i++)
139139
{
140-
node->dwork[idx->indices[i]] = w[i];
140+
node->work->dwork[idx->indices[i]] = w[i];
141141
}
142142
}
143143

144144
/* evalute hessian of child */
145-
x->eval_wsum_hess(x, node->dwork);
145+
x->eval_wsum_hess(x, node->work->dwork);
146146
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
147147
}
148148

src/affine/sum.c

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ static void jacobian_init(expr *node)
8888

8989
/* we never have to store more than the child's nnz */
9090
node->jacobian = new_csr_matrix(node->size, node->n_vars, x->jacobian->nnz);
91-
node->iwork = malloc(MAX(node->jacobian->n, x->jacobian->nnz) * sizeof(int));
91+
node->work->iwork =
92+
malloc(MAX(node->jacobian->n, x->jacobian->nnz) * sizeof(int));
9293
snode->idx_map = malloc(x->jacobian->nnz * sizeof(int));
9394

9495
/* the idx_map array maps each nonzero entry j in x->jacobian
@@ -98,18 +99,19 @@ static void jacobian_init(expr *node)
9899

99100
if (axis == -1)
100101
{
101-
sum_all_rows_csr_fill_sparsity_and_idx_map(x->jacobian, node->jacobian,
102-
node->iwork, snode->idx_map);
102+
sum_all_rows_csr_fill_sparsity_and_idx_map(
103+
x->jacobian, node->jacobian, node->work->iwork, snode->idx_map);
103104
}
104105
else if (axis == 0)
105106
{
106107
sum_block_of_rows_csr_fill_sparsity_and_idx_map(
107-
x->jacobian, node->jacobian, x->d1, node->iwork, snode->idx_map);
108+
x->jacobian, node->jacobian, x->d1, node->work->iwork, snode->idx_map);
108109
}
109110
else if (axis == 1)
110111
{
111112
sum_evenly_spaced_rows_csr_fill_sparsity_and_idx_map(
112-
x->jacobian, node->jacobian, node->size, node->iwork, snode->idx_map);
113+
x->jacobian, node->jacobian, node->size, node->work->iwork,
114+
snode->idx_map);
113115
}
114116
}
115117

@@ -135,7 +137,7 @@ static void wsum_hess_init(expr *node)
135137

136138
/* we never have to store more than the child's nnz */
137139
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
138-
node->dwork = malloc(x->size * sizeof(double));
140+
node->work->dwork = malloc(x->size * sizeof(double));
139141

140142
/* copy sparsity pattern */
141143
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->n_vars + 1) * sizeof(int));
@@ -150,18 +152,18 @@ static void eval_wsum_hess(expr *node, const double *w)
150152

151153
if (axis == -1)
152154
{
153-
scaled_ones(node->dwork, x->size, *w);
155+
scaled_ones(node->work->dwork, x->size, *w);
154156
}
155157
else if (axis == 0)
156158
{
157-
repeat(node->dwork, w, x->d2, x->d1);
159+
repeat(node->work->dwork, w, x->d2, x->d1);
158160
}
159161
else if (axis == 1)
160162
{
161-
tile_double(node->dwork, w, x->d1, x->d2);
163+
tile_double(node->work->dwork, w, x->d1, x->d2);
162164
}
163165

164-
x->eval_wsum_hess(x, node->dwork);
166+
x->eval_wsum_hess(x, node->work->dwork);
165167

166168
/* copy values */
167169
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));

src/affine/trace.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ static void jacobian_init(expr *node)
6969
// fill sparsity pattern and idx_map
7070
// ---------------------------------------------------------------
7171
trace_expr *tnode = (trace_expr *) node;
72-
node->iwork = malloc(MAX(node->jacobian->n, total_nnz) * sizeof(int));
72+
node->work->iwork = malloc(MAX(node->jacobian->n, total_nnz) * sizeof(int));
7373

7474
/* the idx_map array maps each nonzero entry j in the original matrix A (from the
7575
selected, evenly spaced rows) to the corresponding index in the output row
@@ -78,7 +78,7 @@ static void jacobian_init(expr *node)
7878
should be accumulated. */
7979
tnode->idx_map = malloc(x->jacobian->nnz * sizeof(int));
8080
sum_spaced_rows_into_row_csr_fill_sparsity_and_idx_map(
81-
A, node->jacobian, row_spacing, node->iwork, tnode->idx_map);
81+
A, node->jacobian, row_spacing, node->work->iwork, tnode->idx_map);
8282
}
8383

8484
static void eval_jacobian(expr *node)
@@ -104,7 +104,7 @@ static void wsum_hess_init(expr *node)
104104
x->wsum_hess_init(x);
105105

106106
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
107-
node->dwork = (double *) calloc(x->size, sizeof(double));
107+
node->work->dwork = (double *) calloc(x->size, sizeof(double));
108108

109109
/* We copy over the sparsity pattern from the child. This also includes the
110110
contribution to wsum_hess of entries of the child that will always have
@@ -121,10 +121,10 @@ static void eval_wsum_hess(expr *node, const double *w)
121121
int row_spacing = x->d1 + 1;
122122
for (int i = 0; i < x->size; i += row_spacing)
123123
{
124-
node->dwork[i] = w[0];
124+
node->work->dwork[i] = w[0];
125125
}
126126

127-
x->eval_wsum_hess(x, node->dwork);
127+
x->eval_wsum_hess(x, node->work->dwork);
128128

129129
memcpy(node->wsum_hess->x, x->wsum_hess->x, sizeof(double) * x->wsum_hess->nnz);
130130
}

src/affine/transpose.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ static void wsum_hess_init(expr *node)
9999
node->wsum_hess = H;
100100

101101
/* for computing Kw where K is the commutation matrix */
102-
node->dwork = (double *) malloc(node->size * sizeof(double));
102+
node->work->dwork = (double *) malloc(node->size * sizeof(double));
103103
}
104104
static void eval_wsum_hess(expr *node, const double *w)
105105
{
@@ -112,11 +112,11 @@ static void eval_wsum_hess(expr *node, const double *w)
112112
{
113113
for (int j = 0; j < d1; ++j)
114114
{
115-
node->dwork[j * d2 + i] = w[i * d1 + j];
115+
node->work->dwork[j * d2 + i] = w[i * d1 + j];
116116
}
117117
}
118118

119-
node->left->eval_wsum_hess(node->left, node->dwork);
119+
node->left->eval_wsum_hess(node->left, node->work->dwork);
120120

121121
/* copy to this node's hessian */
122122
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,

src/bivariate/const_vector_mult.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ static void wsum_hess_init(expr *node)
8181
memcpy(node->wsum_hess->p, x->wsum_hess->p, (node->n_vars + 1) * sizeof(int));
8282
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
8383

84-
node->dwork = (double *) malloc(node->size * sizeof(double));
84+
node->work->dwork = (double *) malloc(node->size * sizeof(double));
8585
}
8686

8787
static void eval_wsum_hess(expr *node, const double *w)
@@ -92,10 +92,10 @@ static void eval_wsum_hess(expr *node, const double *w)
9292
/* scale weights w by a */
9393
for (int i = 0; i < node->size; i++)
9494
{
95-
node->dwork[i] = a[i] * w[i];
95+
node->work->dwork[i] = a[i] * w[i];
9696
}
9797

98-
x->eval_wsum_hess(x, node->dwork);
98+
x->eval_wsum_hess(x, node->work->dwork);
9999

100100
/* copy values from child to this node */
101101
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));

src/bivariate/left_matmul.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ static void jacobian_init(expr *node)
8888

8989
/* initialize child's jacobian and precompute sparsity of its CSC */
9090
x->jacobian_init(x);
91-
lnode->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->iwork);
91+
lnode->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->work->iwork);
9292

9393
/* precompute sparsity of this node's jacobian in CSC and CSR */
9494
lnode->J_CSC = lnode->A->block_left_mult_sparsity(lnode->A, lnode->Jchild_CSC,
@@ -106,7 +106,7 @@ static void eval_jacobian(expr *node)
106106

107107
/* evaluate child's jacobian and convert to CSC */
108108
x->eval_jacobian(x);
109-
csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->iwork);
109+
csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->work->iwork);
110110

111111
/* compute this node's jacobian: */
112112
lnode->A->block_left_mult_values(lnode->A, Jchild_CSC, J_CSC);
@@ -127,17 +127,17 @@ static void wsum_hess_init(expr *node)
127127
/* work for computing A^T w*/
128128
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
129129
int dim = ((left_matmul_expr *) node)->AT->m * n_blocks;
130-
node->dwork = (double *) malloc(dim * sizeof(double));
130+
node->work->dwork = (double *) malloc(dim * sizeof(double));
131131
}
132132

133133
static void eval_wsum_hess(expr *node, const double *w)
134134
{
135135
/* compute A^T w*/
136136
Matrix *AT = ((left_matmul_expr *) node)->AT;
137137
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
138-
AT->block_left_mult_vec(AT, w, node->dwork, n_blocks);
138+
AT->block_left_mult_vec(AT, w, node->work->dwork, n_blocks);
139139

140-
node->left->eval_wsum_hess(node->left, node->dwork);
140+
node->left->eval_wsum_hess(node->left, node->work->dwork);
141141
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
142142
node->wsum_hess->nnz * sizeof(double));
143143
}
@@ -180,13 +180,14 @@ expr *new_left_matmul(expr *u, const CSR_Matrix *A)
180180
(requiring size node->n_vars) and for transposing A (requiring size A->n).
181181
csc_to_csr_work is used for converting J_CSC to CSR (requiring
182182
node->size) */
183-
node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int));
183+
node->work->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int));
184184
lnode->csc_to_csr_work = (int *) malloc(node->size * sizeof(int));
185185
lnode->n_blocks = n_blocks;
186186

187187
/* store A and AT */
188188
lnode->A = new_sparse_matrix(A);
189-
lnode->AT = sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->iwork);
189+
lnode->AT =
190+
sparse_matrix_trans((const Sparse_Matrix *) lnode->A, node->work->iwork);
190191

191192
return node;
192193
}
@@ -220,7 +221,7 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data)
220221
node->left = u;
221222
expr_retain(u);
222223

223-
node->iwork = (int *) malloc(MAX(n, node->n_vars) * sizeof(int));
224+
node->work->iwork = (int *) malloc(MAX(n, node->n_vars) * sizeof(int));
224225
lnode->csc_to_csr_work = (int *) malloc(node->size * sizeof(int));
225226
lnode->n_blocks = n_blocks;
226227

0 commit comments

Comments
 (0)