Skip to content

Commit 7f671cb

Browse files
authored
copy sparsity function (#61)
* add work struct for expression * run formatter * copy sparsity function
1 parent 28f61f9 commit 7f671cb

15 files changed

Lines changed: 28 additions & 70 deletions

File tree

include/utils/CSR_Matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ typedef struct CSR_Matrix
2828
/* constructors and destructors */
2929
CSR_Matrix *new_csr_matrix(int m, int n, int nnz);
3030
CSR_Matrix *new_csr(const CSR_Matrix *A);
31+
CSR_Matrix *new_csr_copy_sparsity(const CSR_Matrix *A);
3132
void free_csr_matrix(CSR_Matrix *matrix);
3233
void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C);
3334

src/affine/broadcast.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ static void wsum_hess_init(expr *node)
191191
x->wsum_hess_init(x);
192192

193193
/* Same sparsity as child - weights get summed */
194-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
195-
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->wsum_hess->m + 1) * sizeof(int));
196-
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
194+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
197195

198196
/* allocate space for weight vector */
199197
node->work->dwork = malloc(node->size * sizeof(double));

src/affine/diag_vec.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ static void wsum_hess_init(expr *node)
105105
/* Copy child's Hessian structure (diag_vec is linear, so its own Hessian is
106106
* zero) */
107107
CSR_Matrix *Hx = x->wsum_hess;
108-
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
109-
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
110-
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
108+
node->wsum_hess = new_csr_copy_sparsity(Hx);
111109
}
112110

113111
static void eval_wsum_hess(expr *node, const double *w)

src/affine/index.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ static void wsum_hess_init(expr *node)
113113
structural zeros, but we do not try to exploit that sparsity
114114
right now. */
115115
CSR_Matrix *Hx = x->wsum_hess;
116-
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
117-
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
118-
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
116+
node->wsum_hess = new_csr_copy_sparsity(Hx);
119117
}
120118

121119
static void eval_wsum_hess(expr *node, const double *w)

src/affine/neg.c

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@ static void jacobian_init(expr *node)
3939
x->jacobian_init(x);
4040

4141
/* same sparsity pattern as child */
42-
node->jacobian = new_csr_matrix(node->size, node->n_vars, x->jacobian->nnz);
43-
44-
/* copy row pointers and column indices (sparsity pattern is constant) */
45-
memcpy(node->jacobian->p, x->jacobian->p, (x->jacobian->m + 1) * sizeof(int));
46-
memcpy(node->jacobian->i, x->jacobian->i, x->jacobian->nnz * sizeof(int));
42+
node->jacobian = new_csr_copy_sparsity(x->jacobian);
4743
}
4844

4945
static void eval_jacobian(expr *node)
@@ -68,11 +64,7 @@ static void wsum_hess_init(expr *node)
6864

6965
/* same sparsity pattern as child */
7066
CSR_Matrix *child_hess = x->wsum_hess;
71-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, child_hess->nnz);
72-
73-
/* copy row pointers and column indices (sparsity pattern is constant) */
74-
memcpy(node->wsum_hess->p, child_hess->p, (child_hess->m + 1) * sizeof(int));
75-
memcpy(node->wsum_hess->i, child_hess->i, child_hess->nnz * sizeof(int));
67+
node->wsum_hess = new_csr_copy_sparsity(child_hess);
7668
}
7769

7870
static void eval_wsum_hess(expr *node, const double *w)

src/affine/promote.c

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,7 @@ static void wsum_hess_init(expr *node)
7878

7979
/* same sparsity as child since we're summing weights */
8080
CSR_Matrix *child_hess = node->left->wsum_hess;
81-
node->wsum_hess = new_csr_matrix(child_hess->m, child_hess->n, child_hess->nnz);
82-
83-
/* copy sparsity pattern */
84-
memcpy(node->wsum_hess->p, child_hess->p, (child_hess->m + 1) * sizeof(int));
85-
memcpy(node->wsum_hess->i, child_hess->i, child_hess->nnz * sizeof(int));
86-
node->wsum_hess->nnz = child_hess->nnz;
81+
node->wsum_hess = new_csr_copy_sparsity(child_hess);
8782
}
8883

8984
static void eval_wsum_hess(expr *node, const double *w)

src/affine/reshape.c

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ static void jacobian_init(expr *node)
3535
{
3636
expr *x = node->left;
3737
x->jacobian_init(x);
38-
node->jacobian = new_csr_matrix(node->size, node->n_vars, x->jacobian->nnz);
39-
CSR_Matrix *jac = node->jacobian;
40-
memcpy(jac->p, x->jacobian->p, (x->size + 1) * sizeof(int));
41-
memcpy(jac->i, x->jacobian->i, x->jacobian->nnz * sizeof(int));
38+
node->jacobian = new_csr_copy_sparsity(x->jacobian);
4239
}
4340

4441
static void eval_jacobian(expr *node)
@@ -52,10 +49,7 @@ static void wsum_hess_init(expr *node)
5249
{
5350
expr *x = node->left;
5451
x->wsum_hess_init(x);
55-
node->wsum_hess =
56-
new_csr_matrix(x->wsum_hess->m, x->wsum_hess->n, x->wsum_hess->nnz);
57-
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->wsum_hess->m + 1) * sizeof(int));
58-
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
52+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
5953
}
6054

6155
static void eval_wsum_hess(expr *node, const double *w)

src/affine/sum.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,8 @@ static void wsum_hess_init(expr *node)
136136
x->wsum_hess_init(x);
137137

138138
/* we never have to store more than the child's nnz */
139-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
139+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
140140
node->work->dwork = malloc(x->size * sizeof(double));
141-
142-
/* copy sparsity pattern */
143-
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->n_vars + 1) * sizeof(int));
144-
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
145141
}
146142

147143
static void eval_wsum_hess(expr *node, const double *w)

src/affine/trace.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,14 @@ static void wsum_hess_init(expr *node)
102102

103103
/* initialize child's hessian */
104104
x->wsum_hess_init(x);
105-
106-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
105+
107106
node->work->dwork = (double *) calloc(x->size, sizeof(double));
108107

109108
/* We copy over the sparsity pattern from the child. This also includes the
110109
contribution to wsum_hess of entries of the child that will always have
111110
zero weight in eval_wsum_hess. We do this for simplicity. But the Hessian
112111
can for sure be made more sophisticated. */
113-
memcpy(node->wsum_hess->p, x->wsum_hess->p, (x->n_vars + 1) * sizeof(int));
114-
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
112+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
115113
}
116114

117115
static void eval_wsum_hess(expr *node, const double *w)

src/affine/transpose.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ static void wsum_hess_init(expr *node)
9292
x->wsum_hess_init(x);
9393

9494
/* same sparsity pattern as child */
95-
CSR_Matrix *H = node->wsum_hess;
96-
H = new_csr_matrix(x->wsum_hess->m, node->n_vars, x->wsum_hess->nnz);
97-
memcpy(H->p, x->wsum_hess->p, (H->m + 1) * sizeof(int));
98-
memcpy(H->i, x->wsum_hess->i, H->nnz * sizeof(int));
99-
node->wsum_hess = H;
95+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
10096

10197
/* for computing Kw where K is the commutation matrix */
10298
node->work->dwork = (double *) malloc(node->size * sizeof(double));

0 commit comments

Comments
 (0)