Skip to content

Commit 413d451

Browse files
authored
Merge branch 'main' into add_del_method_to_binary
2 parents bc301ec + 5220cc7 commit 413d451

6 files changed

Lines changed: 319 additions & 213 deletions

File tree

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,68 @@ def run(self):
619619
def compute_result(self):
620620
# run becnhmark result
621621
raise NotImplementedError
622+
623+
624+
# Common feature accross some benchmark : sorter + matching
625+
class MixinStudyUnitCount:
626+
def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
627+
import pandas as pd
628+
629+
if case_keys is None:
630+
case_keys = list(self.cases.keys())
631+
632+
if isinstance(case_keys[0], str):
633+
index = pd.Index(case_keys, name=self.levels)
634+
else:
635+
index = pd.MultiIndex.from_tuples(case_keys, names=self.levels)
636+
637+
columns = ["num_gt", "num_sorter", "num_well_detected"]
638+
key0 = case_keys[0]
639+
comp = self.get_result(key0)["gt_comparison"]
640+
if comp.exhaustive_gt:
641+
columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"])
642+
count_units = pd.DataFrame(index=index, columns=columns, dtype=int)
643+
644+
for key in case_keys:
645+
comp = self.get_result(key)["gt_comparison"]
646+
647+
gt_sorting = comp.sorting1
648+
sorting = comp.sorting2
649+
650+
count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
651+
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
652+
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)
653+
654+
if comp.exhaustive_gt:
655+
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
656+
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
657+
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
658+
count_units.loc[key, "num_bad"] = comp.count_bad_units()
659+
660+
return count_units
661+
662+
def get_performance_by_unit(self, case_keys=None):
663+
import pandas as pd
664+
665+
if case_keys is None:
666+
case_keys = self.cases.keys()
667+
668+
perf_by_unit = []
669+
for key in case_keys:
670+
comp = self.get_result(key)["gt_comparison"]
671+
672+
perf = comp.get_performance(method="by_unit", output="pandas")
673+
674+
if isinstance(key, str):
675+
perf[self.levels] = key
676+
elif isinstance(key, tuple):
677+
for col, k in zip(self.levels, key):
678+
perf[col] = k
679+
680+
perf = perf.reset_index()
681+
perf_by_unit.append(perf)
682+
683+
perf_by_unit = pd.concat(perf_by_unit)
684+
perf_by_unit = perf_by_unit.set_index(self.levels)
685+
perf_by_unit = perf_by_unit.sort_index()
686+
return perf_by_unit

src/spikeinterface/benchmark/benchmark_clustering.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
14-
from .benchmark_base import Benchmark, BenchmarkStudy
14+
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
1515
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
1616
from spikeinterface.core.template_tools import get_template_extremum_channel
1717

@@ -93,7 +93,7 @@ def compute_result(self, **result_params):
9393
]
9494

9595

96-
class ClusteringStudy(BenchmarkStudy):
96+
class ClusteringStudy(BenchmarkStudy, MixinStudyUnitCount):
9797

9898
benchmark_class = ClusteringBenchmark
9999

@@ -196,6 +196,11 @@ def plot_performances_vs_depth_and_snr(self, *args, **kwargs):
196196

197197
return plot_performances_vs_depth_and_snr(self, *args, **kwargs)
198198

199+
def plot_performances_ordered(self, *args, **kwargs):
200+
from .benchmark_plot_tools import plot_performances_ordered
201+
202+
return plot_performances_ordered(self, *args, **kwargs)
203+
199204
def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
200205

201206
if case_keys is None:

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
import numpy as np
14-
from .benchmark_base import Benchmark, BenchmarkStudy
14+
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
1515
from spikeinterface.core.basesorting import minimum_spike_dtype
1616

1717

@@ -53,7 +53,7 @@ def compute_result(self, with_collision=False, **result_params):
5353
_result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")]
5454

5555

56-
class MatchingStudy(BenchmarkStudy):
56+
class MatchingStudy(BenchmarkStudy, MixinStudyUnitCount):
5757

5858
benchmark_class = MatchingBenchmark
5959

@@ -84,6 +84,11 @@ def plot_performances_vs_depth_and_snr(self, *args, **kwargs):
8484

8585
return plot_performances_vs_depth_and_snr(self, *args, **kwargs)
8686

87+
def plot_performances_ordered(self, *args, **kwargs):
88+
from .benchmark_plot_tools import plot_performances_ordered
89+
90+
return plot_performances_ordered(self, *args, **kwargs)
91+
8792
def plot_collisions(self, case_keys=None, figsize=None):
8893
if case_keys is None:
8994
case_keys = list(self.cases.keys())
@@ -104,42 +109,6 @@ def plot_collisions(self, case_keys=None, figsize=None):
104109

105110
return fig
106111

107-
def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
108-
import pandas as pd
109-
110-
if case_keys is None:
111-
case_keys = list(self.cases.keys())
112-
113-
if isinstance(case_keys[0], str):
114-
index = pd.Index(case_keys, name=self.levels)
115-
else:
116-
index = pd.MultiIndex.from_tuples(case_keys, names=self.levels)
117-
118-
columns = ["num_gt", "num_sorter", "num_well_detected"]
119-
comp = self.get_result(case_keys[0])["gt_comparison"]
120-
if comp.exhaustive_gt:
121-
columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"])
122-
count_units = pd.DataFrame(index=index, columns=columns, dtype=int)
123-
124-
for key in case_keys:
125-
comp = self.get_result(key)["gt_comparison"]
126-
assert comp is not None, "You need to do study.run_comparisons() first"
127-
128-
gt_sorting = comp.sorting1
129-
sorting = comp.sorting2
130-
131-
count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids())
132-
count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids())
133-
count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score)
134-
135-
if comp.exhaustive_gt:
136-
count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score)
137-
count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score)
138-
count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score)
139-
count_units.loc[key, "num_bad"] = comp.count_bad_units()
140-
141-
return count_units
142-
143112
def plot_unit_counts(self, case_keys=None, **kwargs):
144113
from .benchmark_plot_tools import plot_unit_counts
145114

0 commit comments

Comments
 (0)