Skip to content

Commit b8f7253

Browse files
authored
Improve plot components with custum axis and fix a bug
1 parent 8ae93c5 commit b8f7253

2 files changed

Lines changed: 148 additions & 72 deletions

File tree

src/spikeinterface/benchmark/benchmark_peak_detection.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from spikeinterface.comparison.comparisontools import make_matching_events
1010
from spikeinterface.core import get_noise_levels
11-
11+
from spikeinterface.benchmark.benchmark_plot_tools import despine
1212

1313
import numpy as np
1414
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
@@ -152,54 +152,29 @@ def plot_agreements_by_units(self, case_keys=None, figsize=(15, 15)):
152152
ax.set_title(self.cases[key]["label"])
153153
plot_agreement_matrix(self.get_result(key)["sliced_gt_comparison"], ax=ax)
154154

155-
def plot_performances_vs_snr(self, case_keys=None, figsize=(15, 15), detect_threshold=None):
156-
if case_keys is None:
157-
case_keys = list(self.cases.keys())
158-
159-
import matplotlib.pyplot as plt
160-
161-
fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize)
162-
163-
for count, k in enumerate(("accuracy", "recall", "precision")):
164-
165-
ax = axs[count]
166-
for key in case_keys:
167-
color = self.get_colors()[key]
168-
label = self.cases[key]["label"]
169-
170-
analyzer = self.get_sorting_analyzer(key)
171-
metrics = analyzer.get_extension("quality_metrics").get_data()
172-
x = metrics["snr"].values
173-
y = self.get_result(key)["sliced_gt_comparison"].get_performance()[k].values
174-
ax.scatter(x, y, marker=".", label=label, color=color)
175-
ax.set_title(k)
176-
if detect_threshold is not None:
177-
ymin, ymax = ax.get_ylim()
178-
ax.plot([detect_threshold, detect_threshold], [ymin, ymax], "k--")
179-
180-
popt = fit_sigmoid(x, y, p0=None)
181-
xfit = np.linspace(0, max(metrics["snr"].values), 100)
182-
ax.plot(xfit, sigmoid(xfit, *popt), color=color)
183-
184-
if count == 2:
185-
ax.legend()
186-
187-
def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None):
155+
def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_threshold=None, axs=None):
188156

189157
if case_keys is None:
190158
case_keys = list(self.cases.keys())
191159
import matplotlib.pyplot as plt
192160

193-
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
161+
if axs is None:
162+
fig, axs = plt.subplots(ncols=len(case_keys), figsize=figsize, squeeze=False)
163+
else:
164+
fig = axs[0].get_figure()
165+
assert len(axs) == len(case_keys), "axs should be the same length as case_keys"
194166

195167
for count, key in enumerate(case_keys):
196-
ax = axs[0, count]
168+
ax = axs[count]
169+
despine(ax)
197170
data1 = self.get_result(key)["peaks"]["amplitude"]
198171
data2 = self.get_result(key)["gt_amplitudes"]
172+
color = self.get_colors()[key]
199173
bins = np.linspace(data2.min(), data2.max(), 100)
200-
ax.hist(data1, bins=bins, alpha=0.5, label="detected")
201-
ax.hist(data2, bins=bins, alpha=0.5, label="gt")
202-
ax.set_title(self.cases[key]["label"])
174+
ax.hist(data1, bins=bins, label="detected", histtype="step", color=color, linewidth=2)
175+
ax.hist(data2, bins=bins, alpha=0.1, label="gt", color="k")
176+
ax.set_yscale("log")
177+
# ax.set_title(self.cases[key]["label"])
203178
ax.legend()
204179
if detect_threshold is not None:
205180
noise_levels = get_noise_levels(self.benchmarks[key].recording, return_scaled=False).mean()
@@ -230,13 +205,62 @@ def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)):
230205
ax.set_ylabel("# frames")
231206
ax.set_xlabel("unit id")
232207

233-
def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5), detect_threshold=None):
208+
def plot_mean_deltas(self, case_keys=None, figsize=(15, 5), ax=None):
234209

235210
if case_keys is None:
236211
case_keys = list(self.cases.keys())
237212
import matplotlib.pyplot as plt
238213

239-
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=figsize, squeeze=True)
214+
if ax is None:
215+
fig, ax = plt.subplots(1, figsize=figsize)
216+
else:
217+
fig = ax.get_figure()
218+
219+
results = {}
220+
labels = []
221+
colors = []
222+
for count, key in enumerate(case_keys):
223+
gt_sorting = self.benchmarks[key].gt_sorting
224+
results[key] = []
225+
labels += [self.cases[key]["label"]]
226+
data = self.get_result(key)["matches"]
227+
for unit_ind, unit_id in enumerate(gt_sorting.unit_ids):
228+
mask = data["labels"] == unit_id
229+
results[key] += [np.mean(data["deltas"][mask])]
230+
231+
colors += [self.get_colors()[key]]
232+
despine(ax)
233+
plots = ax.violinplot(
234+
results.values(),
235+
range(len(case_keys)),
236+
showmeans=True,
237+
showmedians=False,
238+
showextrema=False,
239+
)
240+
241+
# Set the color of the violin patches
242+
for pc, color in zip(plots["bodies"], colors):
243+
pc.set_facecolor(color)
244+
pc.set_edgecolor(color)
245+
246+
plots["cmeans"].set_colors(colors)
247+
248+
# ax.set_title(self.cases[key]["label"])
249+
ax.set_xticks(range(len(case_keys)), labels, rotation=45)
250+
ax.set_ylabel("# frames")
251+
# ax.set_xlabel("unit id")
252+
253+
def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5), detect_threshold=None, ax=None):
254+
255+
if case_keys is None:
256+
case_keys = list(self.cases.keys())
257+
import matplotlib.pyplot as plt
258+
259+
if ax is None:
260+
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=figsize, squeeze=True)
261+
else:
262+
fig = ax.get_figure()
263+
240264
for key in case_keys:
241265

242266
import sklearn.metrics
@@ -266,7 +290,7 @@ def plot_template_similarities(self, case_keys=None, metric="l2", figsize=(15, 5
266290
x = metrics["snr"].values
267291
y = distances
268292
ax.scatter(x, y, marker=".", label=label, color=color)
269-
293+
despine(ax)
270294
popt = fit_sigmoid(x, y, p0=None)
271295
xfit = np.linspace(0, max(metrics["snr"].values), 100)
272296
ax.plot(xfit, sigmoid(xfit, *popt), color=color)

0 commit comments

Comments
 (0)