|
8 | 8 | ) |
9 | 9 | from spikeinterface.comparison.comparisontools import make_matching_events |
10 | 10 | from spikeinterface.core import get_noise_levels |
11 | | - |
| 11 | +from spikeinterface.benchmark.benchmark_plot_tools import despine |
12 | 12 |
|
13 | 13 | import numpy as np |
14 | 14 | from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs |
@@ -152,54 +152,29 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)): |
152 | 152 | ax.set_title(self.cases[key]["label"]) |
153 | 153 | plot_agreement_matrix(self.get_result(key)["sliced_gt_comparison"], ax=ax) |
154 | 154 |
|
155 | | - def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15), detect_threshold=None): |
156 | | - if case_keys is None: |
157 | | - case_keys = list(self.cases.keys()) |
158 | | - |
159 | | - import matplotlib.pyplot as plt |
160 | | - |
161 | | - fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) |
162 | | - |
163 | | - for count, k in enumerate(("accuracy", "recall", "precision")): |
164 | | - |
165 | | - ax = axs[count] |
166 | | - for key in case_keys: |
167 | | - color = self.get_colors()[key] |
168 | | - label = self.cases[key]["label"] |
169 | | - |
170 | | - analyzer = self.get_sorting_analyzer(key) |
171 | | - metrics = analyzer.get_extension("quality_metrics").get_data() |
172 | | - x = metrics["snr"].values |
173 | | - y = self.get_result(key)["sliced_gt_comparison"].get_performance()[k].values |
174 | | - ax.scatter(x, y, marker=".", label=label, color=color) |
175 | | - ax.set_title(k) |
176 | | - if detect_threshold is not None: |
177 | | - ymin, ymax = ax.get_ylim() |
178 | | - ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--") |
179 | | - |
180 | | - popt = fit_sigmoid(x, y, p0=None) |
181 | | - xfit = np.linspace(0, max(metrics["snr"].values), 100) |
182 | | - ax.plot(xfit, sigmoid(xfit, *popt), color=color) |
183 | | - |
184 | | - if count == 2: |
185 | | - ax.legend() |
186 | | - |
187 | | - def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None): |
| 155 | + def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None, axs=None): |
188 | 156 |
|
189 | 157 | if case_keys is None: |
190 | 158 | case_keys = list(self.cases.keys()) |
191 | 159 | import matplotlib.pyplot as plt |
192 | 160 |
|
193 | | - fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) |
| 161 | + if axs is None: |
| 162 | + fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False) |
| 163 | + else: |
| 164 | + fig = axs[0].get_figure() |
| 165 | + assert len(axs) == len(case_keys), "axs should be the same length as case_keys" |
194 | 166 |
|
195 | 167 | for count, key in enumerate(case_keys): |
196 | | - ax = axs[0, count] |
| 168 | + ax = axs[count] |
| 169 | + despine(ax) |
197 | 170 | data1 = self.get_result(key)["peaks"]["amplitude"] |
198 | 171 | data2 = self.get_result(key)["gt_amplitudes"] |
| 172 | + color = self.get_colors()[key] |
199 | 173 | bins = np.linspace(data2.min(), data2.max(), 100) |
200 | | - ax.hist(data1, bins=bins, alpha=0.5, label="detected") |
201 | | - ax.hist(data2, bins=bins, alpha=0.5, label="gt") |
202 | | - ax.set_title(self.cases[key]["label"]) |
| 174 | + ax.hist(data1, bins=bins, label="detected", histtype="step", color=color, linewidth=2) |
| 175 | + ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k") |
| 176 | + ax.set_yscale("log") |
| 177 | + # ax.set_title(self.cases[key]["label"]) |
203 | 178 | ax.legend() |
204 | 179 | if detect_threshold is not None: |
205 | 180 | noise_levels = get_noise_levels(self.benchmarks[key].recording, return_scaled=False).mean() |
@@ -230,13 +205,62 @@ def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)): |
230 | 205 | ax.set_ylabel("# frames") |
231 | 206 | ax.set_xlabel("unit id") |
232 | 207 |
|
233 | | - def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5), detect_threshold=None): |
| 208 | + def plot_mean_deltas(self, case_keys=None, figsize=(15, 5), ax=None): |
234 | 209 |
|
235 | 210 | if case_keys is None: |
236 | 211 | case_keys = list(self.cases.keys()) |
237 | 212 | import matplotlib.pyplot as plt |
238 | 213 |
|
239 | | - fig, ax = plt.subplots(ncols=1, nrows=1, figsize=figsize, squeeze=True) |
| 214 | + if ax is None: |
| 215 | + fig, ax = plt.subplots(1, figsize=figsize) |
| 216 | + else: |
| 217 | + fig = ax.get_figure() |
| 218 | + |
| 219 | + results = {} |
| 220 | + labels = [] |
| 221 | + colors = [] |
| 222 | + for count, key in enumerate(case_keys): |
| 223 | + gt_sorting = self.benchmarks[key].gt_sorting |
| 224 | + results[key] = [] |
| 225 | + labels += [self.cases[key]["label"]] |
| 226 | + data = self.get_result(key)["matches"] |
| 227 | + for unit_ind, unit_id in enumerate(gt_sorting.unit_ids): |
| 228 | + mask = data["labels"] == unit_id |
| 229 | + results[key] += [np.mean(data["deltas"][mask])] |
| 230 | + |
| 231 | + colors += [self.get_colors()[key]] |
| 232 | + despine(ax) |
| 233 | + plots = ax.violinplot( |
| 234 | + results.values(), |
| 235 | + range(len(case_keys)), |
| 236 | + showmeans=True, |
| 237 | + showmedians=False, |
| 238 | + showextrema=False, |
| 239 | + ) |
| 240 | + |
| 241 | + # Set the color of the violin patches |
| 242 | + for pc, color in zip(plots["bodies"], colors): |
| 243 | + pc.set_facecolor(color) |
| 244 | + pc.set_edgecolor(color) |
| 245 | + |
| 246 | + plots["cmeans"].set_colors(colors) |
| 247 | + |
| 248 | + # ax.set_title(self.cases[key]["label"]) |
| 249 | + ax.set_xticks(range(len(case_keys)), labels, rotation=45) |
| 250 | + ax.set_ylabel("# frames") |
| 251 | + # ax.set_xlabel("unit id") |
| 252 | + |
| 253 | + def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5), detect_threshold=None, ax=None): |
| 254 | + |
| 255 | + if case_keys is None: |
| 256 | + case_keys = list(self.cases.keys()) |
| 257 | + import matplotlib.pyplot as plt |
| 258 | + |
| 259 | + if ax is None: |
| 260 | + fig, ax = plt.subplots(ncols=1, nrows=1, figsize=figsize, squeeze=True) |
| 261 | + else: |
| 262 | + fig = ax.get_figure() |
| 263 | + |
240 | 264 | for key in case_keys: |
241 | 265 |
|
242 | 266 | import sklearn.metrics |
@@ -266,7 +290,7 @@ def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5 |
266 | 290 | x = metrics["snr"].values |
267 | 291 | y = distances |
268 | 292 | ax.scatter(x, y, marker=".", label=label, color=color) |
269 | | - |
| 293 | + despine(ax) |
270 | 294 | popt = fit_sigmoid(x, y, p0=None) |
271 | 295 | xfit = np.linspace(0, max(metrics["snr"].values), 100) |
272 | 296 | ax.plot(xfit, sigmoid(xfit, *popt), color=color) |
|
0 commit comments