Skip to content

Commit 92eda5c

Browse files
committed
fix sampling from b
1 parent bdaa283 commit 92eda5c

3 files changed

Lines changed: 12 additions & 15 deletions

File tree

src/quantum_inspired_algorithms/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def fit(
112112
A,
113113
self.r,
114114
self.c,
115-
A_ls_prob_rows,
116115
A_ls_prob_columns,
117116
self.random_state,
118117
)

src/quantum_inspired_algorithms/sketching.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def __init__(
141141
A: NDArray[np.float64],
142142
r: int,
143143
c: int,
144-
ls_prob_rows: NDArray[np.float64],
145144
ls_prob_columns: NDArray[np.float64],
146145
rng: np.random.RandomState,
147146
) -> None:
@@ -153,13 +152,11 @@ def __init__(
153152
A: coefficient matrix.
154153
r: number of rows for left projection matrix.
155154
c: number of columns for right projection matrix.
156-
ls_prob_rows: row LS probability distribution of `A`.
157155
ls_prob_columns: column LS probability distribution of `A`.
158156
rng: random state.
159157
"""
160158
self._Q_left = Halko._get_low_dimensional_projector(A, axis=0, n_components=r, random_state=rng)
161159
self._Q_right = Halko._get_low_dimensional_projector(self._Q_left @ A, axis=1, n_components=c, random_state=rng)
162-
self._ls_prob_rows = ls_prob_rows
163160
self._ls_prob_columns = ls_prob_columns
164161

165162
@classmethod
@@ -169,7 +166,7 @@ def _get_low_dimensional_projector(
169166
axis: int,
170167
n_components: int,
171168
random_state: np.random.RandomState,
172-
) -> NDArray[np.float128]:
169+
) -> NDArray[np.float64]:
173170
"""Find random matrix to reduce dimensionality of axis of `M`."""
174171
n_oversamples = 10
175172
n_random = n_components + n_oversamples
@@ -202,7 +199,12 @@ def set_up_column_sampler(self, A: NDArray[np.float64]) -> None:
202199
"""No setup required."""
203200

204201
def set_up_row_sampler(self, A: NDArray[np.float64]) -> None:
205-
"""No setup required."""
202+
"""Build LS distribution to sample rows from matrix `C`."""
203+
A_row_norms = la.norm(A, axis=1)
204+
A_row_norms_squared = A_row_norms**2
205+
A_frobenius = np.sqrt(np.sum(A_row_norms_squared))
206+
self._ls_prob_rows = A_row_norms_squared / A_frobenius**2
207+
self._n_rows = A.shape[0]
206208

207209
def sample_column_idx(self, rng: np.random.RandomState) -> int:
208210
"""Sample a column index."""
@@ -213,7 +215,6 @@ def sample_column_idx(self, rng: np.random.RandomState) -> int:
213215

214216
def sample_row_idx(self, rng: np.random.RandomState) -> int:
215217
"""Sample a row index."""
216-
n_rows = self._ls_prob_rows.size
217-
sample_i = rng.choice(n_rows, 1, p=self._ls_prob_rows)[0]
218+
sample_i = rng.choice(self._n_rows, 1, p=self._ls_prob_rows)[0]
218219

219220
return sample_i

tests/test_sketching.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,12 @@ def _get_FKV_sketcher(A: NDArray[np.float64], r: int, c: int) -> FKV:
2525

2626
def _get_Halko_sketcher(A: NDArray[np.float64], r: int, c: int) -> Halko:
2727
"""Load dummy sketcher."""
28-
A_ls_prob_rows, _, A_ls_prob_columns, _, _ = compute_ls_probs(A)
29-
r = 30
30-
c = 40
28+
_, _, A_ls_prob_columns, _, _ = compute_ls_probs(A)
3129
random_state = np.random.RandomState(7)
3230
return Halko(
3331
A,
3432
r,
3533
c,
36-
A_ls_prob_rows,
3734
A_ls_prob_columns,
3835
random_state,
3936
)
@@ -66,9 +63,9 @@ def test_Halko_dimensions():
6663
right_sketch_matrix = sketcher.right_project(A)
6764
left_right_sketch_matrix = sketcher.right_project(sketcher.left_project(A))
6865

69-
assert left_sketch_matrix.shape == (30 + 10, 100)
70-
assert right_sketch_matrix.shape == (100, 40 + 10)
71-
assert left_right_sketch_matrix.shape == (30 + 10, 40 + 10)
66+
assert left_sketch_matrix.shape == (r + 10, 100)
67+
assert right_sketch_matrix.shape == (100, r + 10)
68+
assert left_right_sketch_matrix.shape == (r + 10, r + 10)
7269

7370

7471
def test_FKV_samplers():

0 commit comments

Comments
 (0)