@@ -141,7 +141,6 @@ def __init__(
141141 A : NDArray [np .float64 ],
142142 r : int ,
143143 c : int ,
144- ls_prob_rows : NDArray [np .float64 ],
145144 ls_prob_columns : NDArray [np .float64 ],
146145 rng : np .random .RandomState ,
147146 ) -> None :
@@ -153,13 +152,11 @@ def __init__(
153152 A: coefficient matrix.
154153 r: number of rows for left projection matrix.
155154 c: number of columns for right projection matrix.
156- ls_prob_rows: row LS probability distribution of `A`.
157155 ls_prob_columns: column LS probability distribution of `A`.
158156 rng: random state.
159157 """
160158 self ._Q_left = Halko ._get_low_dimensional_projector (A , axis = 0 , n_components = r , random_state = rng )
161159 self ._Q_right = Halko ._get_low_dimensional_projector (self ._Q_left @ A , axis = 1 , n_components = c , random_state = rng )
162- self ._ls_prob_rows = ls_prob_rows
163160 self ._ls_prob_columns = ls_prob_columns
164161
165162 @classmethod
@@ -169,7 +166,7 @@ def _get_low_dimensional_projector(
169166 axis : int ,
170167 n_components : int ,
171168 random_state : np .random .RandomState ,
172- ) -> NDArray [np .float128 ]:
169+ ) -> NDArray [np .float64 ]:
173170 """Find random matrix to reduce dimensionality of axis of `M`."""
174171 n_oversamples = 10
175172 n_random = n_components + n_oversamples
@@ -202,7 +199,12 @@ def set_up_column_sampler(self, A: NDArray[np.float64]) -> None:
202199 """No setup required."""
203200
204201 def set_up_row_sampler (self , A : NDArray [np .float64 ]) -> None :
205- """No setup required."""
202+ """Build LS distribution to sample rows from matrix `C`."""
203+ A_row_norms = la .norm (A , axis = 1 )
204+ A_row_norms_squared = A_row_norms ** 2
205+ A_frobenius = np .sqrt (np .sum (A_row_norms_squared ))
206+ self ._ls_prob_rows = A_row_norms_squared / A_frobenius ** 2
207+ self ._n_rows = A .shape [0 ]
206208
207209 def sample_column_idx (self , rng : np .random .RandomState ) -> int :
208210 """Sample a column index."""
@@ -213,7 +215,6 @@ def sample_column_idx(self, rng: np.random.RandomState) -> int:
213215
214216 def sample_row_idx (self , rng : np .random .RandomState ) -> int :
215217 """Sample a row index."""
216- n_rows = self ._ls_prob_rows .size
217- sample_i = rng .choice (n_rows , 1 , p = self ._ls_prob_rows )[0 ]
218+ sample_i = rng .choice (self ._n_rows , 1 , p = self ._ls_prob_rows )[0 ]
218219
219220 return sample_i
0 commit comments