1919from .clustering_tools import remove_duplicates_via_matching
2020from spikeinterface .core .recording_tools import get_noise_levels , get_channel_distances
2121from spikeinterface .sortingcomponents .peak_selection import select_peaks
22- from spikeinterface .sortingcomponents .waveforms .temporal_pca import TemporalPCAProjection
23- from spikeinterface .sortingcomponents .waveforms .hanning_filter import HanningFilter
2422from spikeinterface .core .template import Templates
2523from spikeinterface .core .sparsity import compute_sparsity
2624from spikeinterface .sortingcomponents .tools import remove_empty_templates
27- import pickle , json
28- from spikeinterface .core .node_pipeline import (
29- run_node_pipeline ,
30- ExtractSparseWaveforms ,
31- PeakRetriever ,
32- )
25+ from spikeinterface .sortingcomponents .clustering .peak_svd import extract_peaks_svd
3326
3427
3528from spikeinterface .sortingcomponents .tools import extract_waveform_at_max_channel
@@ -48,20 +41,24 @@ class CircusClustering:
4841 "allow_single_cluster" : True ,
4942 },
5043 "cleaning_kwargs" : {},
44+ "remove_mixtures" : False ,
5145 "waveforms" : {"ms_before" : 2 , "ms_after" : 2 },
5246 "sparsity" : {"method" : "snr" , "amplitude_mode" : "peak_to_peak" , "threshold" : 0.25 },
5347 "recursive_kwargs" : {
5448 "recursive" : True ,
5549 "recursive_depth" : 3 ,
5650 "returns_split_count" : True ,
5751 },
52+ "split_kwargs" : {"projection_mode" : "tsvd" , "n_pca_features" : 0.9 },
5853 "radius_um" : 100 ,
54+ "neighbors_radius_um" : 50 ,
5955 "n_svd" : 5 ,
6056 "few_waveforms" : None ,
6157 "ms_before" : 0.5 ,
6258 "ms_after" : 0.5 ,
6359 "noise_threshold" : 4 ,
6460 "rank" : 5 ,
61+ "templates_from_svd" : False ,
6562 "noise_levels" : None ,
6663 "tmp_folder" : None ,
6764 "verbose" : True ,
@@ -78,6 +75,8 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
7875 fs = recording .get_sampling_frequency ()
7976 ms_before = params ["ms_before" ]
8077 ms_after = params ["ms_after" ]
78+ radius_um = params ["radius_um" ]
79+ neighbors_radius_um = params ["neighbors_radius_um" ]
8180 nbefore = int (ms_before * fs / 1000.0 )
8281 nafter = int (ms_after * fs / 1000.0 )
8382 if params ["tmp_folder" ] is None :
@@ -108,210 +107,139 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
108107 valid = np .argmax (np .abs (wfs ), axis = 1 ) == nbefore
109108 wfs = wfs [valid ]
110109
111- # Perform Hanning filtering
112- hanning_before = np .hanning (2 * nbefore )
113- hanning_after = np .hanning (2 * nafter )
114- hanning = np .concatenate ((hanning_before [:nbefore ], hanning_after [nafter :]))
115- wfs *= hanning
116-
117110 from sklearn .decomposition import TruncatedSVD
118111
119- tsvd = TruncatedSVD (params ["n_svd" ])
120- tsvd .fit (wfs )
121-
122- model_folder = tmp_folder / "tsvd_model"
123-
124- model_folder .mkdir (exist_ok = True )
125- with open (model_folder / "pca_model.pkl" , "wb" ) as f :
126- pickle .dump (tsvd , f )
127-
128- model_params = {
129- "ms_before" : ms_before ,
130- "ms_after" : ms_after ,
131- "sampling_frequency" : float (fs ),
132- }
133-
134- with open (model_folder / "params.json" , "w" ) as f :
135- json .dump (model_params , f )
112+ svd_model = TruncatedSVD (params ["n_svd" ])
113+ svd_model .fit (wfs )
114+ features_folder = tmp_folder / "tsvd_features"
115+ features_folder .mkdir (exist_ok = True )
136116
137- # features
138- node0 = PeakRetriever (recording , peaks )
139-
140- radius_um = params ["radius_um" ]
141- node1 = ExtractSparseWaveforms (
117+ peaks_svd , sparse_mask , svd_model = extract_peaks_svd (
142118 recording ,
143- parents = [node0 ],
144- return_output = False ,
119+ peaks ,
145120 ms_before = ms_before ,
146121 ms_after = ms_after ,
122+ svd_model = svd_model ,
147123 radius_um = radius_um ,
124+ folder = features_folder ,
125+ ** job_kwargs ,
148126 )
149127
150- node2 = HanningFilter (recording , parents = [ node0 , node1 ], return_output = False )
128+ neighbours_mask = get_channel_distances (recording ) <= neighbors_radius_um
151129
152- node3 = TemporalPCAProjection (
153- recording , parents = [ node0 , node2 ], return_output = True , model_folder_path = model_folder
154- )
130+ if params [ "debug" ]:
131+ np . save ( features_folder / "sparse_mask.npy" , sparse_mask )
132+ np . save ( features_folder / "peaks.npy" , peaks )
155133
156- pipeline_nodes = [node0 , node1 , node2 , node3 ]
134+ original_labels = peaks ["channel_index" ]
135+ from spikeinterface .sortingcomponents .clustering .split import split_clusters
157136
158- if len (params ["recursive_kwargs" ]) == 0 :
159- from sklearn .decomposition import PCA
137+ split_kwargs = params ["split_kwargs" ].copy ()
138+ split_kwargs ["neighbours_mask" ] = neighbours_mask
139+ split_kwargs ["waveforms_sparse_mask" ] = sparse_mask
140+ split_kwargs ["min_size_split" ] = 2 * params ["hdbscan_kwargs" ].get ("min_cluster_size" , 50 )
141+ split_kwargs ["clusterer_kwargs" ] = params ["hdbscan_kwargs" ]
160142
161- all_pc_data = run_node_pipeline (
162- recording ,
163- pipeline_nodes ,
164- job_kwargs ,
165- job_name = "extracting features" ,
166- )
167-
168- peak_labels = - 1 * np .ones (len (peaks ), dtype = int )
169- nb_clusters = 0
170- for c in np .unique (peaks ["channel_index" ]):
171- mask = peaks ["channel_index" ] == c
172- sub_data = all_pc_data [mask ]
173- sub_data = sub_data .reshape (len (sub_data ), - 1 )
174-
175- if all_pc_data .shape [1 ] > params ["n_svd" ]:
176- tsvd = PCA (params ["n_svd" ], whiten = True )
177- else :
178- tsvd = PCA (all_pc_data .shape [1 ], whiten = True )
179-
180- hdbscan_data = tsvd .fit_transform (sub_data )
181- try :
182- clustering = hdbscan .hdbscan (hdbscan_data , ** d ["hdbscan_kwargs" ])
183- local_labels = clustering [0 ]
184- except Exception :
185- local_labels = np .zeros (len (hdbscan_data ))
186- valid_clusters = local_labels > - 1
187- if np .sum (valid_clusters ) > 0 :
188- local_labels [valid_clusters ] += nb_clusters
189- peak_labels [mask ] = local_labels
190- nb_clusters += len (np .unique (local_labels [valid_clusters ]))
143+ if params ["debug" ]:
144+ debug_folder = tmp_folder / "split"
191145 else :
146+ debug_folder = None
192147
193- features_folder = tmp_folder / "tsvd_features"
194- features_folder .mkdir (exist_ok = True )
195-
196- _ = run_node_pipeline (
197- recording ,
198- pipeline_nodes ,
199- job_kwargs ,
200- job_name = "extracting features" ,
201- gather_mode = "npy" ,
202- gather_kwargs = dict (exist_ok = True ),
203- folder = features_folder ,
204- names = ["sparse_tsvd" ],
205- )
206-
207- sparse_mask = node1 .neighbours_mask
208- neighbours_mask = get_channel_distances (recording ) <= radius_um
209-
210- # np.save(features_folder / "sparse_mask.npy", sparse_mask)
211- np .save (features_folder / "peaks.npy" , peaks )
212-
213- original_labels = peaks ["channel_index" ]
214- from spikeinterface .sortingcomponents .clustering .split import split_clusters
148+ peak_labels , _ = split_clusters (
149+ original_labels ,
150+ recording ,
151+ {"peaks" : peaks , "sparse_tsvd" : peaks_svd },
152+ method = "local_feature_clustering" ,
153+ method_kwargs = split_kwargs ,
154+ debug_folder = debug_folder ,
155+ ** params ["recursive_kwargs" ],
156+ ** job_kwargs ,
157+ )
215158
216- min_size = 2 * params ["hdbscan_kwargs" ].get ("min_cluster_size" , 20 )
159+ if params ["noise_levels" ] is None :
160+ params ["noise_levels" ] = get_noise_levels (recording , return_scaled = False , ** job_kwargs )
217161
218- if params ["debug" ]:
219- debug_folder = tmp_folder / "split"
220- else :
221- debug_folder = None
162+ if not params ["templates_from_svd" ]:
163+ from spikeinterface .sortingcomponents .clustering .tools import get_templates_from_peaks_and_recording
222164
223- peak_labels , _ = split_clusters (
224- original_labels ,
165+ templates = get_templates_from_peaks_and_recording (
225166 recording ,
226- features_folder ,
227- method = "local_feature_clustering" ,
228- method_kwargs = dict (
229- clusterer = "hdbscan" ,
230- feature_name = "sparse_tsvd" ,
231- neighbours_mask = neighbours_mask ,
232- waveforms_sparse_mask = sparse_mask ,
233- min_size_split = min_size ,
234- clusterer_kwargs = d ["hdbscan_kwargs" ],
235- n_pca_features = 5 ,
236- ),
237- debug_folder = debug_folder ,
238- ** params ["recursive_kwargs" ],
167+ peaks ,
168+ peak_labels ,
169+ ms_before ,
170+ ms_after ,
239171 ** job_kwargs ,
240172 )
173+ else :
174+ from spikeinterface .sortingcomponents .clustering .tools import get_templates_from_peaks_and_svd
241175
242- non_noise = peak_labels > - 1
243- labels , inverse = np .unique (peak_labels [non_noise ], return_inverse = True )
244- peak_labels [non_noise ] = inverse
245- labels = np .unique (inverse )
246-
247- spikes = np .zeros (non_noise .sum (), dtype = minimum_spike_dtype )
248- spikes ["sample_index" ] = peaks [non_noise ]["sample_index" ]
249- spikes ["segment_index" ] = peaks [non_noise ]["segment_index" ]
250- spikes ["unit_index" ] = peak_labels [non_noise ]
251-
252- unit_ids = labels
253-
254- nbefore = int (params ["waveforms" ]["ms_before" ] * fs / 1000.0 )
255- nafter = int (params ["waveforms" ]["ms_after" ] * fs / 1000.0 )
256-
257- if params ["noise_levels" ] is None :
258- params ["noise_levels" ] = get_noise_levels (recording , return_scaled = False , ** job_kwargs )
259-
260- templates_array = estimate_templates (
261- recording ,
262- spikes ,
263- unit_ids ,
264- nbefore ,
265- nafter ,
266- return_scaled = False ,
267- job_name = None ,
268- ** job_kwargs ,
269- )
176+ templates = get_templates_from_peaks_and_svd (
177+ recording ,
178+ peaks ,
179+ peak_labels ,
180+ ms_before ,
181+ ms_after ,
182+ svd_model ,
183+ peaks_svd ,
184+ sparse_mask ,
185+ operator = "median" ,
186+ )
270187
188+ templates_array = templates .templates_array
271189 best_channels = np .argmax (np .abs (templates_array [:, nbefore , :]), axis = 1 )
272190 peak_snrs = np .abs (templates_array [:, nbefore , :])
273191 best_snrs_ratio = (peak_snrs / params ["noise_levels" ])[np .arange (len (peak_snrs )), best_channels ]
192+ old_unit_ids = templates .unit_ids .copy ()
274193 valid_templates = best_snrs_ratio > params ["noise_threshold" ]
275194
276- if d [ "rank" ] is not None :
277- from spikeinterface . sortingcomponents . matching . circus import compress_templates
195+ mask = np . isin ( peak_labels , old_unit_ids [ ~ valid_templates ])
196+ peak_labels [ mask ] = - 1
278197
279- _ , _ , _ , templates_array = compress_templates ( templates_array , d [ "rank" ])
198+ from spikeinterface . core . template import Templates
280199
281200 templates = Templates (
282201 templates_array = templates_array [valid_templates ],
283202 sampling_frequency = fs ,
284- nbefore = nbefore ,
203+ nbefore = templates . nbefore ,
285204 sparsity_mask = None ,
286205 channel_ids = recording .channel_ids ,
287- unit_ids = unit_ids [valid_templates ],
206+ unit_ids = templates . unit_ids [valid_templates ],
288207 probe = recording .get_probe (),
289208 is_scaled = False ,
290209 )
291210
211+ if params ["debug" ]:
212+ templates_folder = tmp_folder / "dense_templates"
213+ templates .to_zarr (folder_path = templates_folder )
214+
292215 sparsity = compute_sparsity (templates , noise_levels = params ["noise_levels" ], ** params ["sparsity" ])
293216 templates = templates .to_sparse (sparsity )
294217 empty_templates = templates .sparsity_mask .sum (axis = 1 ) == 0
218+ old_unit_ids = templates .unit_ids .copy ()
295219 templates = remove_empty_templates (templates )
296220
297- mask = np .isin (peak_labels , np . where ( empty_templates )[ 0 ])
221+ mask = np .isin (peak_labels , old_unit_ids [ empty_templates ])
298222 peak_labels [mask ] = - 1
299223
300- mask = np .isin (peak_labels , np . where ( ~ valid_templates )[ 0 ] )
301- peak_labels [ mask ] = - 1
224+ labels = np .unique (peak_labels )
225+ labels = labels [ labels >= 0 ]
302226
303- if verbose :
304- print ("Found %d raw clusters, starting to clean with matching" % (len (templates .unit_ids )))
227+ if params ["remove_mixtures" ]:
228+ if verbose :
229+ print ("Found %d raw clusters, starting to clean with matching" % (len (templates .unit_ids )))
305230
306- cleaning_job_kwargs = job_kwargs .copy ()
307- cleaning_job_kwargs ["progress_bar" ] = False
308- cleaning_params = params ["cleaning_kwargs" ].copy ()
231+ cleaning_job_kwargs = job_kwargs .copy ()
232+ cleaning_job_kwargs ["progress_bar" ] = False
233+ cleaning_params = params ["cleaning_kwargs" ].copy ()
309234
310- labels , peak_labels = remove_duplicates_via_matching (
311- templates , peak_labels , job_kwargs = cleaning_job_kwargs , ** cleaning_params
312- )
235+ labels , peak_labels = remove_duplicates_via_matching (
236+ templates , peak_labels , job_kwargs = cleaning_job_kwargs , ** cleaning_params
237+ )
313238
314- if verbose :
315- print ("Kept %d non-duplicated clusters" % len (labels ))
239+ if verbose :
240+ print ("Kept %d non-duplicated clusters" % len (labels ))
241+ else :
242+ if verbose :
243+ print ("Kept %d raw clusters" % len (labels ))
316244
317- return labels , peak_labels
245+ return labels , peak_labels , svd_model , peaks_svd , sparse_mask
0 commit comments