Skip to content

Commit 770244d

Browse files
committed
add Halko sketcher
1 parent fcacd7d commit 770244d

5 files changed

Lines changed: 162 additions & 26 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"numpy==2.0.0",
2929
"scipy==1.13.1",
3030
"seaborn==0.13.2",
31+
"scikit-learn==1.5.0",
3132
]
3233
description = "Quantum-inspired algorithms"
3334
keywords = []

src/quantum_inspired_algorithms/estimator.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from quantum_inspired_algorithms.quantum_inspired import sample_from_b
1212
from quantum_inspired_algorithms.quantum_inspired import sample_from_x
1313
from quantum_inspired_algorithms.sketching import FKV
14+
from quantum_inspired_algorithms.sketching import Halko
1415

1516

1617
class EstimatorError(Exception):
@@ -32,6 +33,7 @@ def __init__(
3233
n_samples: int,
3334
random_state: np.random.RandomState,
3435
sigma_threshold: float = 1e-15,
36+
sketcher_name: str = "fkv",
3537
func: Optional[Callable[[float], float]] = None,
3638
) -> None:
3739
"""Init QILinearEstimator.
@@ -46,6 +48,7 @@ def __init__(
4648
random_state: random state.
4749
sigma_threshold: the argument `rank` is recomputed in case it is higher
4850
the number of singular values below this threhold.
51+
sketcher_name: name of sketching method: "fkv" or "halko".
4952
func: function to transform singular values when estimating lambda coefficients.
5053
This can be used for Tikhonov regularization purposes.
5154
"""
@@ -55,6 +58,7 @@ def __init__(
5558
self.n_samples = n_samples
5659
self.random_state = random_state
5760
self.sigma_threshold = sigma_threshold
61+
self.sketcher_name = sketcher_name
5862
self.func = func
5963

6064
def fit(
@@ -84,24 +88,37 @@ def fit(
8488
# 1. Generate length-square probability distributions to sample from matrix `A`
8589
logging.info("1. Generate length-square probability distributions to sample from matrix `A`")
8690
(
87-
self.A_ls_prob_rows_,
88-
self.A_ls_prob_columns_2d_,
91+
A_ls_prob_rows,
92+
A_ls_prob_columns_2d,
93+
A_ls_prob_columns,
8994
_,
90-
_,
91-
self.A_frobenius_,
95+
A_frobenius,
9296
) = compute_ls_probs(A)
9397

9498
# 2. Build matrix `C`
9599
logging.info("2. Build matrix `C`")
96-
self.sketcher_ = FKV(
97-
A,
98-
self.r,
99-
self.c,
100-
self.A_ls_prob_rows_,
101-
self.A_ls_prob_columns_2d_,
102-
self.A_frobenius_,
103-
self.random_state,
104-
)
100+
if self.sketcher_name == "fkv":
101+
self.sketcher_ = FKV(
102+
A,
103+
self.r,
104+
self.c,
105+
A_ls_prob_rows,
106+
A_ls_prob_columns_2d,
107+
A_frobenius,
108+
self.random_state,
109+
)
110+
elif self.sketcher_name == "halko":
111+
self.sketcher_ = Halko(
112+
A,
113+
self.r,
114+
self.c,
115+
A_ls_prob_rows,
116+
A_ls_prob_columns,
117+
self.random_state,
118+
)
119+
else:
120+
raise ValueError('`sketcher_name` should be either "fkv" or "halko"')
121+
105122
C = self.sketcher_.right_project(self.sketcher_.left_project(A))
106123

107124
# 3. Compute the SVD of `C`
@@ -135,9 +152,9 @@ def func_(arg: float) -> float:
135152
self.w_left_,
136153
self.sigma_,
137154
self.sketcher_,
138-
self.A_ls_prob_rows_,
139-
self.A_ls_prob_columns_2d_,
140-
self.A_frobenius_,
155+
A_ls_prob_rows,
156+
A_ls_prob_columns_2d,
157+
A_frobenius,
141158
self.random_state,
142159
func,
143160
)
@@ -147,9 +164,6 @@ def func_(arg: float) -> float:
147164
def _check_is_fitted(self):
148165
"""Check if the `fit` method has been called."""
149166
for attribute_name in [
150-
"A_ls_prob_rows_",
151-
"A_ls_prob_columns_2d_",
152-
"A_frobenius_",
153167
"sketcher_",
154168
"w_left_",
155169
"sigma_",

src/quantum_inspired_algorithms/sketching.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from numpy import linalg as la
55
from numpy.typing import NDArray
6+
from sklearn.utils.extmath import randomized_range_finder
67

78

89
class Sketcher(ABC):
@@ -46,7 +47,7 @@ def __init__(
4647
frobenius: NDArray[np.float64],
4748
rng: np.random.RandomState,
4849
) -> None:
49-
"""Init QILinearEstimator.
50+
"""Init FKV.
5051
5152
Note: LS stands for length-square.
5253
@@ -130,3 +131,89 @@ def sample_row_idx(self, rng: np.random.RandomState) -> int:
130131
sample_i = rng.choice(self._n_rows, 1, p=self._C_ls_prob_rows[:, sample_j])[0]
131132

132133
return sample_i
134+
135+
136+
class Halko(Sketcher):
137+
"""Halko sketching."""
138+
139+
def __init__(
140+
self,
141+
A: NDArray[np.float64],
142+
r: int,
143+
c: int,
144+
ls_prob_rows: NDArray[np.float64],
145+
ls_prob_columns: NDArray[np.float64],
146+
rng: np.random.RandomState,
147+
) -> None:
148+
"""Init Halko.
149+
150+
Note: LS stands for length-square.
151+
152+
Args:
153+
A: coefficient matrix.
154+
r: number of rows for left projection matrix.
155+
c: number of columns for right projection matrix.
156+
ls_prob_rows: row LS probability distribution of `A`.
157+
ls_prob_columns: column LS probability distribution of `A`.
158+
rng: random state.
159+
"""
160+
self._Q_left = Halko._get_low_dimensional_projector(A, axis=0, n_components=r, random_state=rng)
161+
self._Q_right = Halko._get_low_dimensional_projector(A, axis=1, n_components=c, random_state=rng)
162+
self._ls_prob_rows = ls_prob_rows
163+
self._ls_prob_columns = ls_prob_columns
164+
165+
@classmethod
166+
def _get_low_dimensional_projector(
167+
cls,
168+
M: NDArray[np.float64],
169+
axis: int,
170+
n_components: int,
171+
random_state: np.random.RandomState,
172+
) -> NDArray[np.float128]:
173+
"""Find random matrix to reduce dimensionality of axis of `M`."""
174+
n_oversamples = 10
175+
n_random = n_components + n_oversamples
176+
n_iter = 7 if n_components < 0.1 * min(M.shape) else 4
177+
if axis == 1:
178+
M = M.T
179+
Q = np.asarray(
180+
randomized_range_finder(
181+
M,
182+
size=n_random,
183+
n_iter=n_iter,
184+
power_iteration_normalizer="auto",
185+
random_state=random_state,
186+
)
187+
)
188+
if axis == 0:
189+
Q = Q.T
190+
191+
return np.asarray(Q)
192+
193+
def left_project(self, M: NDArray[np.float64]) -> NDArray[np.float64]:
194+
"""Define left projector."""
195+
return self._Q_left @ M
196+
197+
def right_project(self, M: NDArray[np.float64]) -> NDArray[np.float64]:
198+
"""Define right projector."""
199+
return M @ self._Q_right
200+
201+
def set_up_column_sampler(self, A: NDArray[np.float64]) -> None:
202+
"""No setup required."""
203+
204+
def set_up_row_sampler(self, A: NDArray[np.float64]) -> None:
205+
"""No setup required."""
206+
207+
def sample_column_idx(self, rng: np.random.RandomState) -> int:
208+
"""Sample a column index."""
209+
n_cols = self._ls_prob_columns.size
210+
sample_j = rng.choice(n_cols, 1, p=self._ls_prob_columns)[0]
211+
212+
return sample_j
213+
214+
def sample_row_idx(self, rng: np.random.RandomState) -> int:
215+
"""Sample a row index."""
216+
n_rows = self._ls_prob_rows.size
217+
sample_i = rng.choice(n_rows, 1, p=self._ls_prob_rows)[0]
218+
219+
return sample_i

tests/test_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def func(arg: float) -> float:
141141
)
142142

143143

144-
def test_finding_largest_entries_b_underdetermined():
144+
@pytest.mark.parametrize("sketcher_name,n_matches_expected", [("fkv", 48), ("halko", 49)])
145+
def test_finding_largest_entries_b_underdetermined(sketcher_name: str, n_matches_expected: int):
145146
"""Test quantum-inspired least squares."""
146147
# Load data
147148
A, b, _ = _load_data(underdetermined=True)
@@ -154,7 +155,7 @@ def test_finding_largest_entries_b_underdetermined():
154155
n_samples = 100
155156
n_entries_b = 1000
156157
rng = np.random.RandomState(111)
157-
qi = QILinearEstimator(r, c, rank, n_samples, rng)
158+
qi = QILinearEstimator(r, c, rank, n_samples, rng, sketcher_name=sketcher_name)
158159
qi = qi.fit(A, b)
159160
sampled_indices, sampled_b = qi.predict_b(A, n_entries_b)
160161

@@ -169,14 +170,14 @@ def test_finding_largest_entries_b_underdetermined():
169170
n_matches = plot_solution(
170171
b,
171172
b_idx,
172-
"test_finding_largest_entries_b_underdetermined",
173+
f"test_finding_largest_entries_b_underdetermined_{sketcher_name}",
173174
expected_solution=b[unique_sampled_indices],
174175
solution=unique_sampled_b,
175176
expected_counts=n_entries_b * np.abs(b / norm(b))[unique_sampled_indices] ** 2,
176177
counts=np.squeeze(np.round(df_counts.values)),
177178
)
178179

179-
assert n_matches == 48
180+
assert n_matches == n_matches_expected
180181

181182

182183
def test_finding_largest_entries_b():

tests/test_sketching.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from numpy.typing import NDArray
44
from quantum_inspired_algorithms.quantum_inspired import compute_ls_probs
55
from quantum_inspired_algorithms.sketching import FKV
6+
from quantum_inspired_algorithms.sketching import Halko
67

78

89
def _get_FKV_sketcher(A: NDArray[np.float64], r: int, c: int) -> FKV:
9-
"""Load dummy FKV sketcher."""
10+
"""Load dummy sketcher."""
1011
A_ls_prob_rows, A_ls_prob_columns_2d, _, _, A_frobenius = compute_ls_probs(A)
1112
r = 30
1213
c = 40
@@ -22,8 +23,24 @@ def _get_FKV_sketcher(A: NDArray[np.float64], r: int, c: int) -> FKV:
2223
)
2324

2425

26+
def _get_Halko_sketcher(A: NDArray[np.float64], r: int, c: int) -> Halko:
27+
"""Load dummy sketcher."""
28+
A_ls_prob_rows, _, A_ls_prob_columns, _, _ = compute_ls_probs(A)
29+
r = 30
30+
c = 40
31+
random_state = np.random.RandomState(7)
32+
return Halko(
33+
A,
34+
r,
35+
c,
36+
A_ls_prob_rows,
37+
A_ls_prob_columns,
38+
random_state,
39+
)
40+
41+
2542
def test_FKV_dimensions():
26-
"""Test dimensions of FKV-based sketches."""
43+
"""Test dimensions of sketches."""
2744
A = np.arange(100 * 100, dtype=np.float64).reshape((100, 100))
2845
r = 30
2946
c = 40
@@ -38,6 +55,22 @@ def test_FKV_dimensions():
3855
assert left_right_sketch_matrix.shape == (30, 40)
3956

4057

58+
def test_Halko_dimensions():
59+
"""Test dimensions of sketches."""
60+
A = np.arange(100 * 100, dtype=np.float64).reshape((100, 100))
61+
r = 30
62+
c = 40
63+
sketcher = _get_Halko_sketcher(A, r, c)
64+
65+
left_sketch_matrix = sketcher.left_project(A)
66+
right_sketch_matrix = sketcher.right_project(A)
67+
left_right_sketch_matrix = sketcher.right_project(sketcher.left_project(A))
68+
69+
assert left_sketch_matrix.shape == (30 + 10, 100)
70+
assert right_sketch_matrix.shape == (100, 40 + 10)
71+
assert left_right_sketch_matrix.shape == (30 + 10, 40 + 10)
72+
73+
4174
def test_FKV_samplers():
4275
"""Test FKV-based samplers."""
4376
A = np.arange(100 * 100, dtype=np.float64).reshape((100, 100))

0 commit comments

Comments
 (0)