Skip to content

Commit b611912

Browse files
committed
more infrastructure prep for hessian
1 parent 72b5698 commit b611912

8 files changed

Lines changed: 244 additions & 0 deletions

File tree

include/utils/CSC_Matrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,9 @@ void free_csc_matrix(CSC_Matrix *matrix);
3333
*/
3434
CSR_Matrix *ATA_alloc(const CSC_Matrix *A);
3535

36+
/* Compute values for C = A^T D A
37+
* C must have precomputed sparsity pattern
38+
*/
39+
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C);
40+
3641
#endif /* CSC_MATRIX_H */

include/utils/CSR_Matrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ int count_nonzero_cols(const CSR_Matrix *A, bool *col_nz);
8080
/* inserts 'idx' into array 'arr' in sorted order, and moves the other elements */
8181
void insert_idx(int idx, int *arr, int len);
8282

83+
double csr_get_value(const CSR_Matrix *A, int row, int col);
84+
8385
CSR_Matrix *transpose(const CSR_Matrix *A, int *iwork);
8486

8587
/* Expand symmetric CSR matrix A to full matrix C. A is assumed to store

src/utils/CSC_Matrix.c

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,60 @@ CSR_Matrix *ATA_alloc(const CSC_Matrix *A)
9595

9696
return C;
9797
}
98+
99+
static inline double sparse_wdot(const double *a_x, const int *a_i, int a_nnz,
100+
const double *b_x, const int *b_i, int b_nnz,
101+
const double *d)
102+
{
103+
int ii = 0;
104+
int jj = 0;
105+
double sum = 0.0;
106+
107+
while (ii < a_nnz && jj < b_nnz)
108+
{
109+
if (a_i[ii] == b_i[jj])
110+
{
111+
sum += a_x[ii] * b_x[jj] * d[a_i[ii]];
112+
ii++;
113+
jj++;
114+
}
115+
else if (a_i[ii] < b_i[jj])
116+
{
117+
ii++;
118+
}
119+
else
120+
{
121+
jj++;
122+
}
123+
}
124+
125+
return sum;
126+
}
127+
128+
void ATDA_values(const CSC_Matrix *A, const double *d, CSR_Matrix *C)
129+
{
130+
int i, j, ii, jj;
131+
for (i = 0; i < C->m; i++)
132+
{
133+
for (jj = C->p[i]; jj < C->p[i + 1]; jj++)
134+
{
135+
j = C->i[jj];
136+
137+
if (j < i)
138+
{
139+
C->x[jj] = csr_get_value(C, j, i);
140+
}
141+
else
142+
{
143+
int nnz_ai = A->p[i + 1] - A->p[i];
144+
int nnz_aj = A->p[j + 1] - A->p[j];
145+
146+
/* compute Cij = weighted inner product of column i and column j */
147+
double sum = sparse_wdot(A->x + A->p[i], A->i + A->p[i], nnz_ai,
148+
A->x + A->p[j], A->i + A->p[j], nnz_aj, d);
149+
150+
C->x[jj] = sum;
151+
}
152+
}
153+
}
154+
}

src/utils/CSR_Matrix.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,18 @@ void csr_matvec_fill_values(const CSR_Matrix *AT, const double *z, CSR_Matrix *C
475475
// }
476476
// }
477477

478+
double csr_get_value(const CSR_Matrix *A, int row, int col)
479+
{
480+
for (int j = A->p[row]; j < A->p[row + 1]; j++)
481+
{
482+
if (A->i[j] == col)
483+
{
484+
return A->x[j];
485+
}
486+
}
487+
return 0.0;
488+
}
489+
478490
void symmetrize_csr(const int *Ap, const int *Ai, int m, CSR_Matrix *C)
479491
{
480492
int i, j, col;

tests/all_tests.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ int main(void)
7979
mu_run_test(test_ATA_alloc_simple, tests_run);
8080
mu_run_test(test_ATA_alloc_diagonal_like, tests_run);
8181
mu_run_test(test_ATA_alloc_random, tests_run);
82+
mu_run_test(test_ATA_alloc_random2, tests_run);
8283

8384
printf("\n=== All %d tests passed ===\n", tests_run);
8485

tests/utils/python_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import scipy.sparse as sp
2+
import numpy as np
3+
4+
np.random.seed(42)
5+
m = 10
6+
n = 15
7+
# density = 0.1
8+
# A = sp.random(m, n, density=density, format="csc", dtype=float)
9+
10+
Ap = np.array([0, 1, 1, 1, 1, 4, 5, 6, 7, 8, 9, 11, 11, 11, 13, 15])
11+
Ai = np.array([5, 0, 6, 9, 0, 5, 1, 3, 6, 0, 6, 3, 6, 6, 8])
12+
Ax = np.random.randint(1, 10, size=len(Ai))
13+
d = np.random.randint(1, 10, size=m)
14+
A = sp.csc_matrix((Ax, Ai, Ap), shape=(m, n))
15+
16+
17+
C = A.T @ sp.diags(d) @ A
18+
19+
C.sort_indices()
20+
21+
import pdb
22+
23+
pdb.set_trace()

tests/utils/python_test2.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import scipy.sparse as sp
2+
import numpy as np
3+
4+
np.random.seed(42)
5+
m = 15
6+
n = 10
7+
density = 0.1
8+
# A = sp.random(m, n, density=density, format="csc", dtype=float)
9+
# A.data = np.round(A.data, 2) + 0.1
10+
# d = np.round(np.random.randn(m), 2) + 0.1
11+
12+
13+
# Ap = [0 2 4 6 6 9 12 12 14 14 15]
14+
# Ai = [9 12 3 4 1 6 4 8 13 1 3 7 5 13 6]
15+
# Ax =
16+
# [0.99 0.9 0.51 0.64 0.39 0.29 0.26 0.91 0.35 0.18 0.33 0.73 0.97 0.86 1.03]
17+
# d =
18+
# [-0.6 - 0.23 - 0.29 - 1.36 0.4 0.36 0.11 - 0.13 - 1.32 - 0.32 - 0.24 - 0.7 -
19+
# 0.06 0.5 1.99] Cp = [0 1 4 7 7 10 13 13 15 15 17] Ci =
20+
# [0 1 4 5 2 5 9 1 4 7 1 2 5 4 7 2 9] Cx =
21+
# [-0.362232 - 0.189896 0.06656 - 0.228888 - 0.025732 -
22+
# 0.016146 0.032857 0.06656 - 1.004802 0.1505 - 0.228888 -
23+
# 0.016146 - 0.224833 0.1505 0.708524 0.032857 0.116699] *
24+
# */
25+
26+
27+
Ap = np.array([0, 2, 4, 6, 6, 9, 12, 12, 14, 14, 15])
28+
Ai = np.array([9, 12, 3, 4, 1, 6, 4, 8, 13, 1, 3, 7, 5, 13, 6])
29+
Ax = np.array(
30+
[
31+
0.99,
32+
0.9,
33+
0.51,
34+
0.64,
35+
0.39,
36+
0.29,
37+
0.26,
38+
0.91,
39+
0.35,
40+
0.18,
41+
0.33,
42+
0.73,
43+
0.97,
44+
0.86,
45+
1.03,
46+
]
47+
)
48+
49+
d = np.array(
50+
[
51+
-0.6,
52+
-0.23,
53+
-0.29,
54+
-1.36,
55+
0.4,
56+
0.36,
57+
0.11,
58+
-0.13,
59+
-1.32,
60+
-0.32,
61+
-0.24,
62+
-0.7,
63+
-0.06,
64+
0.5,
65+
1.99,
66+
]
67+
)
68+
A = sp.csc_matrix((Ax, Ai, Ap), shape=(m, n))
69+
70+
C = A.T @ sp.diags(d) @ A
71+
C.sort_indices()
72+
73+
Ap = A.indptr
74+
Ai = A.indices
75+
Ax = A.data
76+
77+
Cp = C.indptr
78+
Ci = C.indices
79+
Cx = C.data
80+
81+
# set precision for printing
82+
np.set_printoptions(precision=10, suppress=True)
83+
84+
print("Ap =", Ap)
85+
print("Ai =", Ai)
86+
print("Ax =", Ax)
87+
print("d =", d)
88+
print("Cp =", Cp)
89+
print("Ci =", Ci)
90+
print("Cx =", Cx)
91+
92+
import pdb
93+
94+
pdb.set_trace()

tests/utils/test_csc_matrix.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ const char *test_ATA_alloc_random()
8484
CSC_Matrix *A = new_csc_matrix(10, 15, 15);
8585
int Ap[16] = {0, 1, 1, 1, 1, 4, 5, 6, 7, 8, 9, 11, 11, 11, 13, 15};
8686
int Ai[15] = {5, 0, 6, 9, 0, 5, 1, 3, 6, 0, 6, 3, 6, 6, 8};
87+
double Ax[15] = {7, 4, 8, 5, 7, 3, 7, 8, 5, 4, 8, 8, 3, 6, 5};
8788
memcpy(A->p, Ap, 16 * sizeof(int));
8889
memcpy(A->i, Ai, 15 * sizeof(int));
90+
memcpy(A->x, Ax, 15 * sizeof(double));
8991
CSR_Matrix *C = ATA_alloc(A);
9092

9193
int expected_p[16] = {0, 2, 2, 2, 2, 8, 11, 13, 14, 16, 21, 27, 27, 27, 33, 38};
@@ -97,6 +99,54 @@ const char *test_ATA_alloc_random()
9799
mu_assert("i incorrect", cmp_int_array(C->i, expected_i, C->nnz));
98100
mu_assert("nnz incorrect", C->nnz == 38);
99101

102+
double d[10] = {2, 8, 6, 2, 5, 1, 6, 9, 1, 3};
103+
104+
ATDA_values(A, d, C);
105+
106+
double Cx_correct[38] = {
107+
49., 21., 491., 56., 240., 416., 144., 288., 56., 98., 56., 21., 9.,
108+
392., 128., 128., 240., 150., 240., 90., 180., 416., 56., 240., 416., 144.,
109+
288., 144., 128., 90., 144., 182., 108., 288., 180., 288., 108., 241.};
110+
mu_assert("x incorrect", cmp_double_array(C->x, Cx_correct, C->nnz));
111+
112+
free_csr_matrix(C);
113+
free_csc_matrix(A);
114+
115+
return 0;
116+
}
117+
118+
const char *test_ATA_alloc_random2()
119+
{
120+
/* Create A in CSC format */
121+
int m = 15;
122+
int n = 10;
123+
CSC_Matrix *A = new_csc_matrix(m, n, 15);
124+
int Ap[11] = {0, 2, 4, 6, 6, 9, 12, 12, 14, 14, 15};
125+
int Ai[15] = {9, 12, 3, 4, 1, 6, 4, 8, 13, 1, 3, 7, 5, 13, 6};
126+
double Ax[15] = {0.99, 0.9, 0.51, 0.64, 0.39, 0.29, 0.26, 0.91,
127+
0.35, 0.18, 0.33, 0.73, 0.97, 0.86, 1.03};
128+
memcpy(A->p, Ap, 11 * sizeof(int));
129+
memcpy(A->i, Ai, 15 * sizeof(int));
130+
memcpy(A->x, Ax, 15 * sizeof(double));
131+
CSR_Matrix *C = ATA_alloc(A);
132+
133+
int expected_p[11] = {0, 1, 4, 7, 7, 10, 13, 13, 15, 15, 17};
134+
int expected_i[17] = {0, 1, 4, 5, 2, 5, 9, 1, 4, 7, 1, 2, 5, 4, 7, 2, 9};
135+
136+
mu_assert("p incorrect", cmp_int_array(C->p, expected_p, 11));
137+
mu_assert("i incorrect", cmp_int_array(C->i, expected_i, C->nnz));
138+
mu_assert("nnz incorrect", C->nnz == 17);
139+
double d[15] = {-0.6, -0.23, -0.29, -1.36, 0.4, 0.36, 0.11, -0.13,
140+
-1.32, -0.32, -0.24, -0.7, -0.06, 0.5, 1.99};
141+
142+
ATDA_values(A, d, C);
143+
144+
double Cx_correct[17] = {-0.362232, -0.189896, 0.06656, -0.228888, -0.025732,
145+
-0.016146, 0.032857, 0.06656, -1.004802, 0.1505,
146+
-0.228888, -0.016146, -0.224833, 0.1505, 0.708524,
147+
0.032857, 0.116699};
148+
mu_assert("x incorrect", cmp_double_array(C->x, Cx_correct, C->nnz));
149+
100150
free_csr_matrix(C);
101151
free_csc_matrix(A);
102152

0 commit comments

Comments
 (0)