Skip to content

Commit 95c379c

Browse files
committed
new approach
1 parent 4fa5cdb commit 95c379c

71 files changed

Lines changed: 329 additions & 355 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/expr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "utils/CSC_Matrix.h"
2222
#include "utils/CSR_Matrix.h"
2323
#include <stdbool.h>
24-
#include <stddef.h> /* size_t */
24+
#include <stddef.h> /* size_t */
2525
#include <string.h>
2626

2727
#define JAC_IDXS_NOT_SET -1

include/utils/CSC_Matrix.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,19 @@ typedef struct CSC_Matrix
2424
int nnz;
2525
} CSC_Matrix;
2626

27-
/* constructor and destructor */
28-
CSC_Matrix *new_csc_matrix(int m, int n, int nnz);
27+
/* constructor and destructor.
28+
If mem is non-NULL, *mem is incremented by the bytes allocated. */
29+
CSC_Matrix *new_csc_matrix(int m, int n, int nnz, size_t *mem);
2930
void free_csc_matrix(CSC_Matrix *matrix);
3031

3132
/* Fill sparsity of C = A^T D A for diagonal D */
32-
CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
33+
CSR_Matrix *ATA_alloc(const CSC_Matrix *A, size_t *mem);
3334

3435
/* Fill sparsity of C = B^T D A for diagonal D */
35-
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B);
36+
CSR_Matrix *BTA_alloc(const CSC_Matrix *A, const CSC_Matrix *B, size_t *mem);
3637

3738
/* Fill sparsity of C = BA, where B is symmetric. */
38-
CSC_Matrix *symBA_alloc(const CSR_Matrix *B, const CSC_Matrix *A);
39+
CSC_Matrix *symBA_alloc(const CSR_Matrix *B, const CSC_Matrix *A, size_t *mem);
3940

4041
/* Compute values for C = A^T D A (null d corresponds to D as identity) */
4142
void ATDA_fill_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
@@ -54,11 +55,11 @@ void yTA_fill_values(const CSC_Matrix *A, const double *x, CSR_Matrix *C);
5455
int count_nonzero_cols_csc(const CSC_Matrix *A);
5556

5657
/* convert from CSR to CSC format */
57-
CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork);
58+
CSC_Matrix *csr_to_csc_alloc(const CSR_Matrix *A, int *iwork, size_t *mem);
5859
void csr_to_csc_fill_values(const CSR_Matrix *A, CSC_Matrix *C, int *iwork);
5960

6061
/* convert from CSC to CSR format */
61-
CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork);
62+
CSR_Matrix *csc_to_csr_alloc(const CSC_Matrix *A, int *iwork, size_t *mem);
6263
void csc_to_csr_fill_values(const CSC_Matrix *A, CSR_Matrix *C, int *iwork);
6364

6465
/* Returns total bytes used by p, i, x arrays (0 if A is NULL) */

include/utils/CSR_Matrix.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ typedef struct CSR_Matrix
2323
int nnz;
2424
} CSR_Matrix;
2525

26-
/* constructors and destructors */
27-
CSR_Matrix *new_csr_matrix(int m, int n, int nnz);
28-
CSR_Matrix *new_csr(const CSR_Matrix *A);
29-
CSR_Matrix *new_csr_copy_sparsity(const CSR_Matrix *A);
26+
/* constructors and destructors.
27+
If mem is non-NULL, *mem is incremented by the bytes allocated. */
28+
CSR_Matrix *new_csr_matrix(int m, int n, int nnz, size_t *mem);
29+
CSR_Matrix *new_csr(const CSR_Matrix *A, size_t *mem);
30+
CSR_Matrix *new_csr_copy_sparsity(const CSR_Matrix *A, size_t *mem);
3031
void free_csr_matrix(CSR_Matrix *matrix);
3132
void copy_csr_matrix(const CSR_Matrix *A, CSR_Matrix *C);
3233

3334
/* transpose functionality (iwork must be of size A->n) */
3435
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
35-
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork);
36+
CSR_Matrix *AT_alloc(const CSR_Matrix *A, int *iwork, size_t *mem);
3637
void AT_fill_values(const CSR_Matrix *A, CSR_Matrix *AT, int *iwork);
3738

3839
/* computes dense y = Ax */

include/utils/CSR_sum.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void sum_spaced_rows_into_row_csr_alloc(const CSR_Matrix *A, CSR_Matrix *C,
4646
/* Compute sparsity pattern of out = A + B + C + D */
4747
CSR_Matrix *sum_4_csr_alloc(const CSR_Matrix *A, const CSR_Matrix *B,
4848
const CSR_Matrix *C, const CSR_Matrix *D,
49-
int *idx_maps[4]);
49+
int *idx_maps[4], size_t *mem);
5050
// ------------------------------------------------------------------------------------
5151

5252
/* Accumulates values from A according to map. Must memset to zero before calling. */

include/utils/linalg_dense_sparse_matmuls.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ void I_kron_A_fill_values(const Matrix *A, const CSC_Matrix *J, CSC_Matrix *C);
1313

1414
/* Sparsity and values of C = (Y^T kron I_m) @ J where Y is k x n, J is (m*k) x p,
1515
and C is (m*n) x p. Y is given in column-major dense format. */
16-
CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J);
16+
CSR_Matrix *YT_kron_I_alloc(int m, int k, int n, const CSC_Matrix *J, size_t *mem);
1717
void YT_kron_I_fill_values(int m, int k, int n, const double *Y, const CSC_Matrix *J,
1818
CSR_Matrix *C);
1919

2020
/* Sparsity and values of C = (I_n kron X) @ J where X is m x k (col-major dense),
2121
J is (k*n) x p, and C is (m*n) x p. */
22-
CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J);
22+
CSR_Matrix *I_kron_X_alloc(int m, int k, int n, const CSC_Matrix *J, size_t *mem);
2323
void I_kron_X_fill_values(int m, int k, int n, const double *X, const CSC_Matrix *J,
2424
CSR_Matrix *C);
2525

include/utils/linalg_sparse_matmuls.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void csr_csc_matmul_fill_values(const CSR_Matrix *A, const CSC_Matrix *B,
3636
/* C = A @ B where A is CSR, B is CSC. Result C is CSR.
3737
* Allocates and precomputes sparsity pattern. No workspace required.
3838
*/
39-
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B);
39+
CSR_Matrix *csr_csc_matmul_alloc(const CSR_Matrix *A, const CSC_Matrix *B,
40+
size_t *mem);
4041

4142
#endif /* LINALG_H */

src/atoms/affine/add.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ static void jacobian_init_impl(expr *node)
4242

4343
/* we never have to store more than the sum of children's nnz */
4444
int nnz_max = node->left->jacobian->nnz + node->right->jacobian->nnz;
45-
node->jacobian = new_csr_matrix(node->size, node->n_vars, nnz_max);
46-
node->memory_bytes += csr_memory_bytes(node->jacobian);
45+
node->jacobian =
46+
new_csr_matrix(node->size, node->n_vars, nnz_max, &node->memory_bytes);
4747

4848
/* fill sparsity pattern */
4949
sum_csr_alloc(node->left->jacobian, node->right->jacobian, node->jacobian);
@@ -67,8 +67,8 @@ static void wsum_hess_init_impl(expr *node)
6767

6868
/* we never have to store more than the sum of children's nnz */
6969
int nnz_max = node->left->wsum_hess->nnz + node->right->wsum_hess->nnz;
70-
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, nnz_max);
71-
node->memory_bytes += csr_memory_bytes(node->wsum_hess);
70+
node->wsum_hess =
71+
new_csr_matrix(node->n_vars, node->n_vars, nnz_max, &node->memory_bytes);
7272

7373
/* fill sparsity pattern of hessian */
7474
sum_csr_alloc(node->left->wsum_hess, node->right->wsum_hess, node->wsum_hess);

src/atoms/affine/broadcast.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ static void jacobian_init_impl(expr *node)
9292
total_nnz = x->jacobian->nnz * node->size;
9393
}
9494

95-
node->jacobian = new_csr_matrix(node->size, node->n_vars, total_nnz);
96-
node->memory_bytes += csr_memory_bytes(node->jacobian);
95+
node->jacobian =
96+
new_csr_matrix(node->size, node->n_vars, total_nnz, &node->memory_bytes);
9797

9898
// ---------------------------------------------------------------------
9999
// fill sparsity pattern
@@ -192,8 +192,7 @@ static void wsum_hess_init_impl(expr *node)
192192
wsum_hess_init(x);
193193

194194
/* Same sparsity as child - weights get summed */
195-
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
196-
node->memory_bytes += csr_memory_bytes(node->wsum_hess);
195+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess, &node->memory_bytes);
197196

198197
/* allocate space for weight vector */
199198
node->work->dwork = malloc(node->size * sizeof(double));

src/atoms/affine/const_scalar_mult.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ static void jacobian_init_impl(expr *node)
4747
jacobian_init(x);
4848

4949
/* same sparsity as child */
50-
node->jacobian = new_csr_copy_sparsity(x->jacobian);
51-
node->memory_bytes += csr_memory_bytes(node->jacobian);
50+
node->jacobian = new_csr_copy_sparsity(x->jacobian, &node->memory_bytes);
5251
}
5352

5453
static void eval_jacobian(expr *node)
@@ -74,8 +73,7 @@ static void wsum_hess_init_impl(expr *node)
7473
wsum_hess_init(x);
7574

7675
/* same sparsity as child */
77-
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
78-
node->memory_bytes += csr_memory_bytes(node->wsum_hess);
76+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess, &node->memory_bytes);
7977
}
8078

8179
static void eval_wsum_hess(expr *node, const double *w)

src/atoms/affine/const_vector_mult.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ static void jacobian_init_impl(expr *node)
4646
jacobian_init(x);
4747

4848
/* same sparsity as child */
49-
node->jacobian = new_csr_copy_sparsity(x->jacobian);
50-
node->memory_bytes += csr_memory_bytes(node->jacobian);
49+
node->jacobian = new_csr_copy_sparsity(x->jacobian, &node->memory_bytes);
5150
}
5251

5352
static void eval_jacobian(expr *node)
@@ -76,8 +75,7 @@ static void wsum_hess_init_impl(expr *node)
7675
wsum_hess_init(x);
7776

7877
/* same sparsity as child */
79-
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess);
80-
node->memory_bytes += csr_memory_bytes(node->wsum_hess);
78+
node->wsum_hess = new_csr_copy_sparsity(x->wsum_hess, &node->memory_bytes);
8179

8280
node->work->dwork = (double *) malloc(node->size * sizeof(double));
8381
node->memory_bytes += node->size * sizeof(double);

0 commit comments

Comments
 (0)