88
99from spikeinterface .widgets .utils import get_unit_colors
1010from spikeinterface import compute_sparsity
11- from spikeinterface .core import get_template_extremum_channel
11+ from spikeinterface .core import get_template_extremum_channel , BaseEvent
1212from spikeinterface .core .sorting_tools import spike_vector_to_indices
1313from spikeinterface .curation import validate_curation_dict
1414from spikeinterface .curation .curation_model import Curation
1515from spikeinterface .widgets .utils import make_units_table_from_analyzer
1616
1717from .curation_tools import add_merge , default_label_definitions , empty_curation_data
18+ from .event_tools import parse_events
1819
1920spike_dtype = [('sample_index' , 'int64' ), ('unit_index' , 'int64' ),
2021 ('channel_index' , 'int64' ), ('segment_index' , 'int64' ),
@@ -46,6 +47,7 @@ def __init__(
4647 extra_unit_properties = None ,
4748 skip_extensions = None ,
4849 disable_save_settings_button = False ,
50+ events = None ,
4951 external_data = None ,
5052 curation_callback = None ,
5153 curation_callback_kwargs = None ,
@@ -256,6 +258,16 @@ def __init__(
256258 self .valid_periods = None
257259
258260 self ._potential_merges = None
261+ # some direct attribute
262+ self .num_segments = self .analyzer .get_num_segments ()
263+ self .sampling_frequency = self .analyzer .sampling_frequency
264+
265+ # parse events
266+ self .events = None
267+ if events is not None :
268+ self .events = parse_events (events , self , verbose = verbose )
269+ if len (self .events ) == 0 :
270+ self .events = None
259271
260272 t1 = time .perf_counter ()
261273 if verbose :
@@ -266,10 +278,6 @@ def __init__(
266278 self ._extremum_channel = get_template_extremum_channel (self .analyzer ,
267279 mode = "extremum" , peak_sign = 'both' , outputs = 'index' )
268280
269- # some direct attribute
270- self .num_segments = self .analyzer .get_num_segments ()
271- self .sampling_frequency = self .analyzer .sampling_frequency
272-
273281 # spikeinterface handle colors in matplotlib style tuple values in range (0,1)
274282 self .refresh_colors ()
275283
@@ -468,9 +476,12 @@ def update_time_info(self):
468476 else :
469477 self .time_info ['time_by_seg' ] = time_by_seg
470478
471- def get_t_start_t_stop (self ):
472- segment_index = self .time_info ["segment_index" ]
473- if self .main_settings ["use_times" ] and self .has_extension ("recording" ):
479+ def get_t_start_t_stop (self , use_times = None , segment_index = None ):
480+ if segment_index is None :
481+ segment_index = self .time_info ["segment_index" ]
482+ if use_times is None :
483+ use_times = self .main_settings ["use_times" ]
484+ if use_times and self .has_extension ("recording" ):
474485 t_start = self .analyzer .recording .get_start_time (segment_index = segment_index )
475486 t_stop = self .analyzer .recording .get_end_time (segment_index = segment_index )
476487 return t_start , t_stop
@@ -508,14 +519,26 @@ def sample_index_to_time(self, sample_index):
508519 else :
509520 return sample_index / self .sampling_frequency
510521
511- def time_to_sample_index (self , time ):
512- segment_index = self .time_info ["segment_index" ]
513- if self .main_settings ["use_times" ] and self .has_extension ("recording" ):
522+ def time_to_sample_index (self , time , segment_index = None , use_times = None ):
523+ if segment_index is None :
524+ segment_index = self .time_info ["segment_index" ]
525+ if use_times is None :
526+ use_times = self .main_settings ["use_times" ]
527+ if use_times and self .has_extension ("recording" ):
514528 time = self .analyzer .recording .time_to_sample_index (time , segment_index = segment_index )
515529 return time
516530 else :
517531 return int (time * self .sampling_frequency )
518532
533+ def get_events (self , event_name , segment_index = None ):
534+ if self .events is None :
535+ return None
536+ if event_name not in self .events :
537+ return None
538+ if segment_index is None :
539+ segment_index = self .time_info ['segment_index' ]
540+ return self .events [event_name ][segment_index ]
541+
519542 def get_information_txt (self ):
520543 nseg = self .analyzer .get_num_segments ()
521544 nchan = self .analyzer .get_num_channels ()
@@ -768,6 +791,8 @@ def set_channel_visibility(self, visible_channel_inds):
768791 def has_extension (self , extension_name ):
769792 if extension_name == 'recording' :
770793 return self .analyzer .has_recording () or self .analyzer .has_temporary_recording ()
794+ elif extension_name == 'events' :
795+ return self .events is not None
771796 else :
772797 # extension needs to be loaded
773798 if extension_name in self .skip_extensions :
0 commit comments