1111)
1212
1313import numpy as np
14- from .benchmark_base import Benchmark , BenchmarkStudy
14+ from .benchmark_base import Benchmark , BenchmarkStudy , MixinStudyUnitCount
1515from spikeinterface .core .basesorting import minimum_spike_dtype
1616
1717
@@ -53,7 +53,7 @@ def compute_result(self, with_collision=False, **result_params):
5353 _result_key_saved = [("gt_collision" , "pickle" ), ("gt_comparison" , "pickle" )]
5454
5555
56- class MatchingStudy (BenchmarkStudy ):
56+ class MatchingStudy (BenchmarkStudy , MixinStudyUnitCount ):
5757
5858 benchmark_class = MatchingBenchmark
5959
@@ -84,6 +84,11 @@ def plot_performances_vs_depth_and_snr(self, *args, **kwargs):
8484
8585 return plot_performances_vs_depth_and_snr (self , * args , ** kwargs )
8686
87+ def plot_performances_ordered (self , * args , ** kwargs ):
88+ from .benchmark_plot_tools import plot_performances_ordered
89+
90+ return plot_performances_ordered (self , * args , ** kwargs )
91+
8792 def plot_collisions (self , case_keys = None , figsize = None ):
8893 if case_keys is None :
8994 case_keys = list (self .cases .keys ())
@@ -104,42 +109,6 @@ def plot_collisions(self, case_keys=None, figsize=None):
104109
105110 return fig
106111
107- def get_count_units (self , case_keys = None , well_detected_score = None , redundant_score = None , overmerged_score = None ):
108- import pandas as pd
109-
110- if case_keys is None :
111- case_keys = list (self .cases .keys ())
112-
113- if isinstance (case_keys [0 ], str ):
114- index = pd .Index (case_keys , name = self .levels )
115- else :
116- index = pd .MultiIndex .from_tuples (case_keys , names = self .levels )
117-
118- columns = ["num_gt" , "num_sorter" , "num_well_detected" ]
119- comp = self .get_result (case_keys [0 ])["gt_comparison" ]
120- if comp .exhaustive_gt :
121- columns .extend (["num_false_positive" , "num_redundant" , "num_overmerged" , "num_bad" ])
122- count_units = pd .DataFrame (index = index , columns = columns , dtype = int )
123-
124- for key in case_keys :
125- comp = self .get_result (key )["gt_comparison" ]
126- assert comp is not None , "You need to do study.run_comparisons() first"
127-
128- gt_sorting = comp .sorting1
129- sorting = comp .sorting2
130-
131- count_units .loc [key , "num_gt" ] = len (gt_sorting .get_unit_ids ())
132- count_units .loc [key , "num_sorter" ] = len (sorting .get_unit_ids ())
133- count_units .loc [key , "num_well_detected" ] = comp .count_well_detected_units (well_detected_score )
134-
135- if comp .exhaustive_gt :
136- count_units .loc [key , "num_redundant" ] = comp .count_redundant_units (redundant_score )
137- count_units .loc [key , "num_overmerged" ] = comp .count_overmerged_units (overmerged_score )
138- count_units .loc [key , "num_false_positive" ] = comp .count_false_positive_units (redundant_score )
139- count_units .loc [key , "num_bad" ] = comp .count_bad_units ()
140-
141- return count_units
142-
143112 def plot_unit_counts (self , case_keys = None , ** kwargs ):
144113 from .benchmark_plot_tools import plot_unit_counts
145114
0 commit comments