Skip to content

Commit 3cb6652

Browse files
committed
more infrastructure prepratation for hess_vec
1 parent b611912 commit 3cb6652

9 files changed

Lines changed: 149 additions & 9 deletions

File tree

include/expr.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef EXPR_H
22
#define EXPR_H
33

4+
#include "utils/CSC_Matrix.h"
45
#include "utils/CSR_Matrix.h"
56
#include <stdbool.h>
67
#include <stddef.h>
@@ -18,6 +19,8 @@ typedef void (*eval_jacobian_fn)(struct expr *node);
1819
typedef void (*eval_local_jacobian_fn)(struct expr *node, double *out);
1920
typedef bool (*is_affine_fn)(struct expr *node);
2021

22+
/* TODO: implement proper polymorphism */
23+
2124
/* Expression node structure */
2225
typedef struct expr
2326
{
@@ -36,7 +39,8 @@ typedef struct expr
3639
int *iwork;
3740
struct int_double_pair *int_double_pairs; /* for sorting jacobian entries */
3841
int p; /* power of power expression */
39-
int axis; /* axis for sum or similar operations */
42+
int axis; /* axis for sum or similar operations */
43+
CSR_Matrix *Q; /* Q for quad_form */
4044

4145
// ------------------------------------------------------------------------
4246
// forward pass related quantities
@@ -48,13 +52,16 @@ typedef struct expr
4852
// jacobian related quantities
4953
// ------------------------------------------------------------------------
5054
CSR_Matrix *jacobian;
51-
CSR_Matrix *Q;
5255
CSR_Matrix *CSR_work;
5356
jacobian_init_fn jacobian_init;
5457
eval_jacobian_fn eval_jacobian;
5558
eval_local_jacobian_fn eval_local_jacobian;
5659
is_affine_fn is_affine;
5760

61+
// for every linear operator we store A in CSR and CSC
62+
CSC_Matrix *A_csc;
63+
CSR_Matrix *A_csr;
64+
5865
} expr;
5966

6067
expr *new_expr(int d1, int d2, int n_vars);

include/utils/CSC_Matrix.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ CSC_Matrix *new_csc_matrix(int m, int n, int nnz);
2929
/* Free a CSC matrix */
3030
void free_csc_matrix(CSC_Matrix *matrix);
3131

32-
/* Allocate sparsity pattern for C = A^T D A or C = A^T A
33-
*/
32+
CSC_Matrix *csr_to_csc(const CSR_Matrix *A);
33+
34+
/* Allocate sparsity pattern for C = A^T D A for diagonal D */
3435
CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3536

3637
/* Compute values for C = A^T D A

src/affine/linear_op.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ expr *new_linear(expr *u, const CSR_Matrix *A)
2727
node->is_affine = is_affine;
2828

2929
/* allocate jacobian and copy A into it */
30+
// TODO: this should eventually be removed
3031
node->jacobian = new_csr_matrix(A->m, A->n, A->nnz);
3132
copy_csr_matrix(A, node->jacobian);
3233

34+
node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
35+
copy_csr_matrix(A, node->A_csr);
36+
node->A_csc = csr_to_csc(A);
37+
3338
return node;
3439
}

src/elementwise_univariate/common.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void eval_jacobian_elementwise(expr *node)
3636
else
3737
{
3838
node->eval_local_jacobian(node, node->dwork);
39-
diag_csr_mult(node->dwork, child->jacobian, node->jacobian);
39+
diag_csr_mult(node->dwork, child->A_csr, node->jacobian);
4040
}
4141
}
4242

src/expr.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void free_expr(expr *node)
5959
free(node->value);
6060
free_csr_matrix(node->jacobian);
6161
free_csr_matrix(node->CSR_work);
62+
free_csr_matrix(node->A_csr);
63+
free_csc_matrix(node->A_csc);
6264
free(node->dwork);
6365
free(node->iwork);
6466
free_int_double_pair_array(node->int_double_pairs);

src/utils/CSC_Matrix.c

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,6 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A)
8484

8585
/* Allocate C and symmetrize it */
8686
CSR_Matrix *C = new_csr_matrix(n, n, nnz);
87-
88-
/* TODO: do we need to symmetrize here? If we are a bit careful with symmetry
89-
throughout the implementation we can skip this step. */
9087
symmetrize_csr(Cp, Ci->data, n, C);
9188

9289
/* free workspace */
@@ -152,3 +149,49 @@ void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C)
152149
}
153150
}
154151
}
152+
153+
CSC_Matrix *csr_to_csc(const CSR_Matrix *A)
154+
{
155+
CSC_Matrix *C = new_csc_matrix(A->m, A->n, A->nnz);
156+
157+
int i, j, start;
158+
int *count = malloc(A->n * sizeof(int));
159+
160+
memset(count, 0, A->n * sizeof(int));
161+
162+
// -------------------------------------------------------------------
163+
// compute nnz in each column of A
164+
// -------------------------------------------------------------------
165+
for (i = 0; i < A->m; ++i)
166+
{
167+
for (j = A->p[i]; j < A->p[i + 1]; ++j)
168+
{
169+
count[A->i[j]]++;
170+
}
171+
}
172+
173+
// ------------------------------------------------------------------
174+
// compute column pointers
175+
// ------------------------------------------------------------------
176+
C->p[0] = 0;
177+
for (i = 0; i < A->n; ++i)
178+
{
179+
C->p[i + 1] = C->p[i] + count[i];
180+
count[i] = C->p[i];
181+
}
182+
183+
// ------------------------------------------------------------------
184+
// fill matrix
185+
// ------------------------------------------------------------------
186+
for (i = 0; i < A->m; ++i)
187+
{
188+
for (j = A->p[i]; j < A->p[i + 1]; ++j)
189+
{
190+
C->x[count[A->i[j]]] = A->x[j];
191+
C->i[count[A->i[j]]] = i;
192+
count[A->i[j]]++;
193+
}
194+
}
195+
196+
return C;
197+
}

src/utils/CSR_Matrix.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stdio.h>
66
#include <stdlib.h>
77
#include <string.h>
8+
89
CSR_Matrix *new_csr_matrix(int m, int n, int nnz)
910
{
1011
CSR_Matrix *matrix = (CSR_Matrix *) malloc(sizeof(CSR_Matrix));
@@ -423,7 +424,7 @@ CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork)
423424
}
424425

425426
// ------------------------------------------------------------------
426-
// fill transposed matrix (this is a bottleneck)
427+
// fill transposed matrix
427428
// ------------------------------------------------------------------
428429
for (i = 0; i < A->m; ++i)
429430
{

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ int main(void)
7272
mu_run_test(test_csr_sum, tests_run);
7373
mu_run_test(test_csr_sum2, tests_run);
7474
mu_run_test(test_transpose, tests_run);
75+
mu_run_test(test_csr_to_csc1, tests_run);
76+
mu_run_test(test_csr_to_csc2, tests_run);
7577
mu_run_test(test_csr_vecmat_values_sparse, tests_run);
7678
mu_run_test(test_sum_all_rows_csr, tests_run);
7779
mu_run_test(test_sum_block_of_rows_csr, tests_run);

tests/utils/test_csc_matrix.h

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,85 @@
77
#include "test_helpers.h"
88
#include "utils/CSC_Matrix.h"
99

10+
const char *test_csr_to_csc1()
11+
{
12+
CSR_Matrix *A = new_csr_matrix(4, 5, 5);
13+
double Ax[5] = {1.0, 1.0, 3.0, 2.0, 4.0};
14+
int Ai[5] = {0, 4, 1, 0, 1};
15+
int Ap[5] = {0, 2, 3, 4, 5};
16+
memcpy(A->x, Ax, 5 * sizeof(double));
17+
memcpy(A->i, Ai, 5 * sizeof(int));
18+
memcpy(A->p, Ap, 5 * sizeof(int));
19+
20+
CSC_Matrix *C = csr_to_csc(A);
21+
22+
double Cx_correct[5] = {1.0, 2.0, 3.0, 4.0, 1.0};
23+
int Ci_correct[5] = {0, 2, 1, 3, 0};
24+
int Cp_correct[6] = {0, 2, 4, 4, 4, 5};
25+
26+
mu_assert("C vals incorrect", cmp_double_array(C->x, Cx_correct, 5));
27+
mu_assert("C rows incorrect", cmp_int_array(C->i, Ci_correct, 5));
28+
mu_assert("C cols incorrect", cmp_int_array(C->p, Cp_correct, 6));
29+
30+
free_csr_matrix(A);
31+
free_csc_matrix(C);
32+
33+
return 0;
34+
}
35+
36+
const char *test_csr_to_csc2()
37+
{
38+
CSR_Matrix *A = new_csr_matrix(20, 30, 120);
39+
double Ax[120] = {9, 6, 5, 9, 7, 3, 8, 2, 6, 1, 3, 9, 2, 8, 9, 1, 4, 9, 2, 1,
40+
3, 4, 2, 8, 6, 2, 9, 7, 3, 8, 3, 7, 9, 2, 2, 2, 5, 5, 3, 5,
41+
1, 6, 7, 2, 7, 3, 3, 7, 3, 5, 4, 7, 7, 3, 6, 3, 6, 1, 8, 8,
42+
3, 2, 2, 3, 4, 5, 5, 5, 8, 3, 5, 3, 7, 5, 1, 4, 9, 6, 6, 7,
43+
4, 6, 8, 2, 7, 3, 5, 3, 3, 4, 7, 3, 6, 4, 2, 1, 1, 5, 5, 8,
44+
1, 9, 5, 2, 3, 8, 5, 8, 4, 5, 5, 6, 9, 6, 4, 4, 1, 8, 9, 8};
45+
int Ai[120] = {1, 2, 3, 19, 21, 22, 9, 10, 19, 20, 25, 0, 6, 8, 9,
46+
12, 15, 19, 20, 21, 26, 2, 5, 6, 8, 12, 14, 16, 19, 27,
47+
8, 11, 13, 15, 25, 26, 27, 10, 12, 19, 22, 23, 24, 25, 28,
48+
1, 11, 12, 15, 18, 24, 13, 22, 2, 5, 6, 9, 18, 24, 3,
49+
6, 8, 22, 20, 27, 7, 9, 17, 26, 29, 0, 1, 11, 13, 15,
50+
16, 18, 23, 24, 4, 5, 8, 9, 16, 20, 23, 4, 6, 14, 15,
51+
24, 8, 9, 11, 12, 20, 22, 29, 2, 5, 12, 14, 15, 19, 21,
52+
10, 19, 27, 1, 5, 6, 9, 11, 15, 21, 26, 3, 15, 26, 27};
53+
int Ap[21] = {0, 6, 11, 21, 30, 37, 45, 51, 53, 59, 63,
54+
65, 70, 79, 86, 91, 98, 105, 108, 116, 120};
55+
memcpy(A->x, Ax, 120 * sizeof(double));
56+
memcpy(A->i, Ai, 120 * sizeof(int));
57+
memcpy(A->p, Ap, 21 * sizeof(int));
58+
59+
CSC_Matrix *C = csr_to_csc(A);
60+
61+
double Cx_correct[120] = {
62+
9, 5, 9, 3, 3, 4, 6, 4, 3, 5, 5, 8, 1, 7, 5, 2, 6, 4, 8, 5, 2, 8, 3, 3,
63+
3, 5, 5, 8, 6, 3, 2, 6, 3, 8, 9, 6, 5, 8, 6, 6, 2, 5, 8, 7, 3, 7, 4, 9,
64+
1, 2, 3, 7, 2, 1, 9, 7, 5, 9, 3, 9, 4, 2, 3, 1, 4, 5, 6, 8, 7, 4, 2, 5,
65+
5, 1, 9, 9, 6, 9, 3, 5, 2, 5, 1, 2, 3, 7, 1, 7, 1, 3, 4, 3, 1, 7, 2, 1,
66+
6, 6, 3, 7, 4, 8, 6, 7, 3, 2, 2, 3, 2, 8, 4, 9, 8, 5, 4, 8, 8, 7, 3, 5};
67+
int Ci_correct[120] = {
68+
2, 12, 0, 6, 12, 18, 0, 3, 8, 16, 0, 9, 19, 13, 14, 3, 8, 13,
69+
16, 18, 2, 3, 8, 9, 14, 18, 11, 2, 3, 4, 9, 13, 15, 1, 2, 8,
70+
11, 13, 15, 18, 1, 5, 17, 4, 6, 12, 15, 18, 2, 3, 5, 6, 15, 16,
71+
4, 7, 12, 3, 14, 16, 2, 4, 6, 12, 14, 16, 18, 19, 3, 12, 13, 11,
72+
6, 8, 12, 0, 1, 2, 3, 5, 16, 17, 1, 2, 10, 13, 15, 0, 2, 16,
73+
18, 0, 5, 7, 9, 15, 5, 12, 13, 5, 6, 8, 12, 14, 1, 4, 5, 2,
74+
4, 11, 18, 19, 3, 4, 10, 17, 19, 5, 11, 15};
75+
int Cp_correct[31] = {0, 2, 6, 10, 13, 15, 20, 26, 27, 33, 40,
76+
43, 48, 54, 57, 60, 68, 71, 72, 75, 82, 87,
77+
91, 96, 99, 104, 107, 112, 117, 118, 120};
78+
79+
mu_assert("C vals incorrect", cmp_double_array(C->x, Cx_correct, 120));
80+
mu_assert("C rows incorrect", cmp_int_array(C->i, Ci_correct, 120));
81+
mu_assert("C cols incorrect", cmp_int_array(C->p, Cp_correct, 31));
82+
83+
free_csr_matrix(A);
84+
free_csc_matrix(C);
85+
86+
return 0;
87+
}
88+
1089
/* Test ATA_alloc with a simple 3x3 example
1190
* A is 4x3 (4 rows, 3 columns):
1291
* [x 0 x]

0 commit comments

Comments
 (0)