Skip to content

Commit 5220cc7

Browse files
Enhance split with various options (#3835)
* Enhance split * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8c14897 commit 5220cc7

1 file changed

Lines changed: 55 additions & 13 deletions

File tree

  • src/spikeinterface/sortingcomponents/clustering

src/spikeinterface/sortingcomponents/clustering/split.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def split_clusters(
122122
if recursive:
123123
recursion_level = np.max(split_count[peak_indices])
124124
if recursive_depth is not None:
125-
# stop reccursivity when recursive_depth is reach
125+
# stop recursivity when recursive_depth is reach
126126
extra_ball = recursion_level < recursive_depth
127127
else:
128-
# reccurssive always
128+
# recursive always
129129
extra_ball = True
130130

131131
if extra_ball:
@@ -211,6 +211,7 @@ def split(
211211
waveforms_sparse_mask=None,
212212
min_size_split=25,
213213
n_pca_features=2,
214+
projection_mode="tsvd",
214215
minimum_overlap_ratio=0.25,
215216
):
216217
local_labels = np.zeros(peak_indices.size, dtype=np.int64)
@@ -247,16 +248,36 @@ def split(
247248

248249
is_split = False
249250

250-
if flatten_features.shape[1] > n_pca_features:
251-
from sklearn.decomposition import PCA
251+
if isinstance(n_pca_features, float):
252+
assert 0 < n_pca_features < 1, "n_components should be in ]0, 1["
253+
nb_dimensions = min(flatten_features.shape[0], flatten_features.shape[1])
254+
if projection_mode == "pca":
255+
from sklearn.decomposition import PCA
252256

253-
# from sklearn.decomposition import TruncatedSVD
254-
# tsvd = TruncatedSVD(n_pca_features)
255-
tsvd = PCA(n_pca_features, whiten=True)
257+
tsvd = PCA(nb_dimensions, whiten=True)
258+
elif projection_mode == "tsvd":
259+
from sklearn.decomposition import TruncatedSVD
260+
261+
tsvd = TruncatedSVD(nb_dimensions)
256262
final_features = tsvd.fit_transform(flatten_features)
257-
del tsvd
258-
else:
259-
final_features = flatten_features
263+
n_explain = np.sum(np.cumsum(tsvd.explained_variance_ratio_) <= n_pca_features) + 1
264+
final_features = final_features[:, :n_explain]
265+
n_pca_features = final_features.shape[1]
266+
elif isinstance(n_pca_features, int):
267+
if flatten_features.shape[1] > n_pca_features:
268+
if projection_mode == "pca":
269+
from sklearn.decomposition import PCA
270+
271+
tsvd = PCA(n_pca_features, whiten=True)
272+
elif projection_mode == "tsvd":
273+
from sklearn.decomposition import TruncatedSVD
274+
275+
tsvd = TruncatedSVD(n_pca_features)
276+
277+
final_features = tsvd.fit_transform(flatten_features)
278+
else:
279+
final_features = flatten_features
280+
tsvd = None
260281

261282
if clusterer == "hdbscan":
262283
from hdbscan import HDBSCAN
@@ -270,7 +291,8 @@ def split(
270291
min_cluster_size = clusterer_kwargs["min_cluster_size"]
271292
dipscore, cutpoint = isocut5(final_features[:, 0])
272293
possible_labels = np.zeros(final_features.shape[0])
273-
if dipscore > 1.5:
294+
min_dip = clusterer_kwargs.get("min_dip", 1.5)
295+
if dipscore > min_dip:
274296
mask = final_features[:, 0] > cutpoint
275297
if np.sum(mask) > min_cluster_size and np.sum(~mask):
276298
possible_labels[mask] = 1
@@ -289,10 +311,13 @@ def split(
289311
colors = plt.colormaps["tab10"].resampled(len(labels_set))
290312
colors = {k: colors(i) for i, k in enumerate(labels_set)}
291313
colors[-1] = "k"
292-
fig, axs = plt.subplots(nrows=2)
314+
fig, axs = plt.subplots(nrows=4)
293315

294316
flatten_wfs = aligned_wfs.swapaxes(1, 2).reshape(aligned_wfs.shape[0], -1)
295317

318+
if final_features.shape[1] == 1:
319+
final_features = np.hstack((final_features, np.zeros_like(final_features)))
320+
296321
sl = slice(None, None, 100)
297322
for k in np.unique(possible_labels):
298323
mask = possible_labels == k
@@ -302,10 +327,27 @@ def split(
302327
centroid = final_features[:, :2][mask].mean(axis=0)
303328
ax.text(centroid[0], centroid[1], f"Label {k}", fontsize=10, color="k")
304329
ax = axs[1]
305-
ax.plot(flatten_wfs[mask][sl].T, color=colors[k], alpha=0.5)
330+
ax.plot(flatten_wfs[mask].T, color=colors[k], alpha=0.1)
306331
if k > -1:
307332
ax.plot(np.median(flatten_wfs[mask].T, axis=1), color=colors[k], lw=2)
308333
ax.set_xlabel("PCA features")
334+
335+
ax = axs[3]
336+
if n_pca_features == 1:
337+
bins = np.linspace(final_features[:, 0].min(), final_features[:, 0].max(), 100)
338+
ax.hist(final_features[mask, 0], bins, color=colors[k], alpha=0.1)
339+
else:
340+
ax.plot(final_features[mask].T, color=colors[k], alpha=0.1)
341+
if k > -1 and n_pca_features > 1:
342+
ax.plot(np.median(final_features[mask].T, axis=1), color=colors[k], lw=2)
343+
ax.set_xlabel("Projected PCA features")
344+
345+
if tsvd is not None:
346+
ax = axs[2]
347+
sorted_components = np.argsort(tsvd.explained_variance_ratio_)[::-1]
348+
ax.plot(tsvd.explained_variance_ratio_[sorted_components], c="k")
349+
del tsvd
350+
309351
ymin, ymax = ax.get_ylim()
310352
ax.plot([n_pca_features, n_pca_features], [ymin, ymax], "k--")
311353

0 commit comments

Comments
 (0)