Skip to content

Commit 01b4c47

Browse files
authored
Add events handling: trace/map + new view for aliggned/rasters + PSTHs (#218)
* Add events handling: trace/map + new view for aliggned/rasters + PSTHs * EventView in panel * centralize toolbars in traces * Add events on Trace/TraceMap for panel * wip: move event functions to event_tools * fix multiple event lines in panel and qt * Fix conflicts2 * Fix conflicts3 * feat: sample events from all segments and subsample based on seg duration * feat: sample events from all segments and subsample based on seg duration * fix: filter events based on start/stop and sort them * cleanup test * Fix duplicated event source
1 parent 58d4eb7 commit 01b4c47

12 files changed

Lines changed: 763 additions & 88 deletions

spikeinterface_gui/basescatterview.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _qt_make_layout(self):
211211
tb = self.qt_widget.view_toolbar
212212
self.combo_seg = QT.QComboBox()
213213
tb.addWidget(self.combo_seg)
214-
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
214+
self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)])
215215
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
216216
add_stretch_to_qtoolbar(tb)
217217
self.lasso_but = QT.QPushButton("select", checkable = True)
@@ -313,7 +313,7 @@ def _qt_refresh(self, set_scatter_range=False):
313313
# make a copy of the color
314314
color = QT.QColor(self.get_unit_color(unit_id))
315315
color.setAlpha(int(self.settings['alpha']*255))
316-
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)
316+
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)
317317

318318
color = self.get_unit_color(unit_id)
319319
curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color)

spikeinterface_gui/controller.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
from spikeinterface.widgets.utils import get_unit_colors
1010
from spikeinterface import compute_sparsity
11-
from spikeinterface.core import get_template_extremum_channel
11+
from spikeinterface.core import get_template_extremum_channel, BaseEvent
1212
from spikeinterface.core.sorting_tools import spike_vector_to_indices
1313
from spikeinterface.curation import validate_curation_dict
1414
from spikeinterface.curation.curation_model import Curation
1515
from spikeinterface.widgets.utils import make_units_table_from_analyzer
1616

1717
from .curation_tools import add_merge, default_label_definitions, empty_curation_data
18+
from .event_tools import parse_events
1819

1920
spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
2021
('channel_index', 'int64'), ('segment_index', 'int64'),
@@ -46,6 +47,7 @@ def __init__(
4647
extra_unit_properties=None,
4748
skip_extensions=None,
4849
disable_save_settings_button=False,
50+
events=None,
4951
external_data=None,
5052
curation_callback=None,
5153
curation_callback_kwargs=None,
@@ -256,6 +258,16 @@ def __init__(
256258
self.valid_periods = None
257259

258260
self._potential_merges = None
261+
# some direct attribute
262+
self.num_segments = self.analyzer.get_num_segments()
263+
self.sampling_frequency = self.analyzer.sampling_frequency
264+
265+
# parse events
266+
self.events = None
267+
if events is not None:
268+
self.events = parse_events(events, self, verbose=verbose)
269+
if len(self.events) == 0:
270+
self.events = None
259271

260272
t1 = time.perf_counter()
261273
if verbose:
@@ -266,10 +278,6 @@ def __init__(
266278
self._extremum_channel = get_template_extremum_channel(self.analyzer,
267279
mode="extremum", peak_sign='both', outputs='index')
268280

269-
# some direct attribute
270-
self.num_segments = self.analyzer.get_num_segments()
271-
self.sampling_frequency = self.analyzer.sampling_frequency
272-
273281
# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
274282
self.refresh_colors()
275283

@@ -468,9 +476,12 @@ def update_time_info(self):
468476
else:
469477
self.time_info['time_by_seg'] = time_by_seg
470478

471-
def get_t_start_t_stop(self):
472-
segment_index = self.time_info["segment_index"]
473-
if self.main_settings["use_times"] and self.has_extension("recording"):
479+
def get_t_start_t_stop(self, use_times=None, segment_index=None):
480+
if segment_index is None:
481+
segment_index = self.time_info["segment_index"]
482+
if use_times is None:
483+
use_times = self.main_settings["use_times"]
484+
if use_times and self.has_extension("recording"):
474485
t_start = self.analyzer.recording.get_start_time(segment_index=segment_index)
475486
t_stop = self.analyzer.recording.get_end_time(segment_index=segment_index)
476487
return t_start, t_stop
@@ -508,14 +519,26 @@ def sample_index_to_time(self, sample_index):
508519
else:
509520
return sample_index / self.sampling_frequency
510521

511-
def time_to_sample_index(self, time):
512-
segment_index = self.time_info["segment_index"]
513-
if self.main_settings["use_times"] and self.has_extension("recording"):
522+
def time_to_sample_index(self, time, segment_index=None, use_times=None):
523+
if segment_index is None:
524+
segment_index = self.time_info["segment_index"]
525+
if use_times is None:
526+
use_times = self.main_settings["use_times"]
527+
if use_times and self.has_extension("recording"):
514528
time = self.analyzer.recording.time_to_sample_index(time, segment_index=segment_index)
515529
return time
516530
else:
517531
return int(time * self.sampling_frequency)
518532

533+
def get_events(self, event_name, segment_index=None):
534+
if self.events is None:
535+
return None
536+
if event_name not in self.events:
537+
return None
538+
if segment_index is None:
539+
segment_index = self.time_info['segment_index']
540+
return self.events[event_name][segment_index]
541+
519542
def get_information_txt(self):
520543
nseg = self.analyzer.get_num_segments()
521544
nchan = self.analyzer.get_num_channels()
@@ -768,6 +791,8 @@ def set_channel_visibility(self, visible_channel_inds):
768791
def has_extension(self, extension_name):
769792
if extension_name == 'recording':
770793
return self.analyzer.has_recording() or self.analyzer.has_temporary_recording()
794+
elif extension_name == 'events':
795+
return self.events is not None
771796
else:
772797
# extension needs to be loaded
773798
if extension_name in self.skip_extensions:

spikeinterface_gui/event_tools.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
3+
from spikeinterface.core import BaseEvent
4+
5+
def parse_events(events, controller, verbose=False):
6+
"""Parse events input into a standard format.
7+
8+
Parameters
9+
----------
10+
events : dict | BaseEvent
11+
BaseEvent object or a dictionary where keys are event names and values are dictionaries with at least a
12+
'samples' key or 'times' key.
13+
controller : Controller
14+
Controller object managing the event parsing.
15+
verbose : bool, default: False
16+
Whether to print verbose messages.
17+
18+
Returns
19+
-------
20+
parsed_events : dict
21+
Parsed events dictionary. The keys are event names, and the values are lists of numpy arrays of event sample indices.
22+
Each element corresponds to a segment in the recording.
23+
"""
24+
parsed_events = {}
25+
if isinstance(events, dict):
26+
for key, val in events.items():
27+
if not isinstance(val, dict):
28+
if verbose:
29+
print(f'\tSkipping event {key}: not a dict')
30+
continue
31+
if 'samples' not in val and 'times' not in val:
32+
if verbose:
33+
print(f'\tSkipping event {key}: missing samples or times')
34+
continue
35+
if 'times' in val:
36+
samples_data = val['times']
37+
convert_to_samples = True
38+
else:
39+
samples_data = val['samples']
40+
convert_to_samples = False
41+
if controller.num_segments > 1:
42+
if not len(samples_data) == controller.num_segments:
43+
if verbose:
44+
print(f'\tSkipping event {key}: inconsistent number of samples')
45+
continue
46+
else:
47+
# here we make sure samples is a list of list
48+
if np.array(samples_data).ndim == 1:
49+
samples_data = [samples_data]
50+
if convert_to_samples:
51+
# filter events based on recording start/stop times
52+
filtered_samples_data = []
53+
parsed_events[key] = []
54+
for segment_index in range(controller.num_segments):
55+
t_start, t_end = controller.get_t_start_t_stop(use_times=True, segment_index=segment_index)
56+
s = samples_data[segment_index]
57+
filtered_samples_data = s[(s >= t_start) & (s <= t_end)]
58+
parsed_events[key].append(
59+
np.sort(
60+
controller.time_to_sample_index(
61+
filtered_samples_data,
62+
segment_index=segment_index,
63+
use_times=True
64+
)
65+
)
66+
)
67+
68+
else:
69+
parsed_events[key] = [np.sort(s) for s in samples_data]
70+
elif isinstance(events, BaseEvent):
71+
event_names = events.channel_ids
72+
parsed_events = {
73+
event_name: [] for event_name in event_names
74+
}
75+
for event_name in event_names:
76+
for segment_index in range(controller.num_segments):
77+
event_times_segment = events.get_event_times(
78+
channel_id=event_name,
79+
segment_index=segment_index
80+
)
81+
event_samples_segment = controller.time_to_sample_index(
82+
event_times_segment, segment_index=segment_index, use_times=True
83+
)
84+
parsed_events[event_name].append(np.sort(event_samples_segment))
85+
else:
86+
if verbose:
87+
print('\tSkipping events: not a dict or BaseEvent')
88+
89+
return parsed_events

0 commit comments

Comments
 (0)