@@ -122,10 +122,10 @@ def split_clusters(
122122 if recursive :
123123 recursion_level = np .max (split_count [peak_indices ])
124124 if recursive_depth is not None :
125- # stop reccursivity when recursive_depth is reach
125+ # stop recursivity when recursive_depth is reach
126126 extra_ball = recursion_level < recursive_depth
127127 else :
128- # reccurssive always
128+ # recursive always
129129 extra_ball = True
130130
131131 if extra_ball :
@@ -211,6 +211,7 @@ def split(
211211 waveforms_sparse_mask = None ,
212212 min_size_split = 25 ,
213213 n_pca_features = 2 ,
214+ projection_mode = "tsvd" ,
214215 minimum_overlap_ratio = 0.25 ,
215216 ):
216217 local_labels = np .zeros (peak_indices .size , dtype = np .int64 )
@@ -247,16 +248,36 @@ def split(
247248
248249 is_split = False
249250
250- if flatten_features .shape [1 ] > n_pca_features :
251- from sklearn .decomposition import PCA
251+ if isinstance (n_pca_features , float ):
252+ assert 0 < n_pca_features < 1 , "n_components should be in ]0, 1["
253+ nb_dimensions = min (flatten_features .shape [0 ], flatten_features .shape [1 ])
254+ if projection_mode == "pca" :
255+ from sklearn .decomposition import PCA
252256
253- # from sklearn.decomposition import TruncatedSVD
254- # tsvd = TruncatedSVD(n_pca_features)
255- tsvd = PCA (n_pca_features , whiten = True )
257+ tsvd = PCA (nb_dimensions , whiten = True )
258+ elif projection_mode == "tsvd" :
259+ from sklearn .decomposition import TruncatedSVD
260+
261+ tsvd = TruncatedSVD (nb_dimensions )
256262 final_features = tsvd .fit_transform (flatten_features )
257- del tsvd
258- else :
259- final_features = flatten_features
263+ n_explain = np .sum (np .cumsum (tsvd .explained_variance_ratio_ ) <= n_pca_features ) + 1
264+ final_features = final_features [:, :n_explain ]
265+ n_pca_features = final_features .shape [1 ]
266+ elif isinstance (n_pca_features , int ):
267+ if flatten_features .shape [1 ] > n_pca_features :
268+ if projection_mode == "pca" :
269+ from sklearn .decomposition import PCA
270+
271+ tsvd = PCA (n_pca_features , whiten = True )
272+ elif projection_mode == "tsvd" :
273+ from sklearn .decomposition import TruncatedSVD
274+
275+ tsvd = TruncatedSVD (n_pca_features )
276+
277+ final_features = tsvd .fit_transform (flatten_features )
278+ else :
279+ final_features = flatten_features
280+ tsvd = None
260281
261282 if clusterer == "hdbscan" :
262283 from hdbscan import HDBSCAN
@@ -270,7 +291,8 @@ def split(
270291 min_cluster_size = clusterer_kwargs ["min_cluster_size" ]
271292 dipscore , cutpoint = isocut5 (final_features [:, 0 ])
272293 possible_labels = np .zeros (final_features .shape [0 ])
273- if dipscore > 1.5 :
294+ min_dip = clusterer_kwargs .get ("min_dip" , 1.5 )
295+ if dipscore > min_dip :
274296 mask = final_features [:, 0 ] > cutpoint
275297 if np .sum (mask ) > min_cluster_size and np .sum (~ mask ):
276298 possible_labels [mask ] = 1
@@ -289,10 +311,13 @@ def split(
289311 colors = plt .colormaps ["tab10" ].resampled (len (labels_set ))
290312 colors = {k : colors (i ) for i , k in enumerate (labels_set )}
291313 colors [- 1 ] = "k"
292- fig , axs = plt .subplots (nrows = 2 )
314+ fig , axs = plt .subplots (nrows = 4 )
293315
294316 flatten_wfs = aligned_wfs .swapaxes (1 , 2 ).reshape (aligned_wfs .shape [0 ], - 1 )
295317
318+ if final_features .shape [1 ] == 1 :
319+ final_features = np .hstack ((final_features , np .zeros_like (final_features )))
320+
296321 sl = slice (None , None , 100 )
297322 for k in np .unique (possible_labels ):
298323 mask = possible_labels == k
@@ -302,10 +327,27 @@ def split(
302327 centroid = final_features [:, :2 ][mask ].mean (axis = 0 )
303328 ax .text (centroid [0 ], centroid [1 ], f"Label { k } " , fontsize = 10 , color = "k" )
304329 ax = axs [1 ]
305- ax .plot (flatten_wfs [mask ][ sl ] .T , color = colors [k ], alpha = 0.5 )
330+ ax .plot (flatten_wfs [mask ].T , color = colors [k ], alpha = 0.1 )
306331 if k > - 1 :
307332 ax .plot (np .median (flatten_wfs [mask ].T , axis = 1 ), color = colors [k ], lw = 2 )
308333 ax .set_xlabel ("PCA features" )
334+
335+ ax = axs [3 ]
336+ if n_pca_features == 1 :
337+ bins = np .linspace (final_features [:, 0 ].min (), final_features [:, 0 ].max (), 100 )
338+ ax .hist (final_features [mask , 0 ], bins , color = colors [k ], alpha = 0.1 )
339+ else :
340+ ax .plot (final_features [mask ].T , color = colors [k ], alpha = 0.1 )
341+ if k > - 1 and n_pca_features > 1 :
342+ ax .plot (np .median (final_features [mask ].T , axis = 1 ), color = colors [k ], lw = 2 )
343+ ax .set_xlabel ("Projected PCA features" )
344+
345+ if tsvd is not None :
346+ ax = axs [2 ]
347+ sorted_components = np .argsort (tsvd .explained_variance_ratio_ )[::- 1 ]
348+ ax .plot (tsvd .explained_variance_ratio_ [sorted_components ], c = "k" )
349+ del tsvd
350+
309351 ymin , ymax = ax .get_ylim ()
310352 ax .plot ([n_pca_features , n_pca_features ], [ymin , ymax ], "k--" )
311353
0 commit comments