99from spikeinterface .widgets .utils import get_unit_colors
1010from spikeinterface import compute_sparsity
1111from spikeinterface .core import get_template_extremum_channel
12- import spikeinterface .postprocessing
13- import spikeinterface .qualitymetrics
1412from spikeinterface .core .sorting_tools import spike_vector_to_indices
15- from spikeinterface .core .core_tools import check_json
1613from spikeinterface .curation import validate_curation_dict
1714from spikeinterface .curation .curation_model import CurationModel
1815from spikeinterface .widgets .utils import make_units_table_from_analyzer
@@ -50,6 +47,8 @@ def __init__(
5047 skip_extensions = None ,
5148 disable_save_settings_button = False ,
5249 external_data = None ,
50+ curation_callback = None ,
51+ curation_callback_kwargs = None ,
5352 user_main_settings = None ,
5453 ):
5554 self .views = []
@@ -329,6 +328,7 @@ def __init__(
329328 self ._traces_cached = {}
330329
331330 self .units_table = make_units_table_from_analyzer (analyzer , extra_properties = extra_unit_properties )
331+
332332 if displayed_unit_properties is None :
333333 displayed_unit_properties = list (_default_displayed_unit_properties )
334334 if extra_unit_properties is not None :
@@ -340,7 +340,9 @@ def __init__(
340340 self .update_time_info ()
341341
342342 self .curation = curation
343- # TODO: Reload the dictionary if it already exists
343+ self .curation_callback = curation_callback
344+ self .curation_callback_kwargs = curation_callback_kwargs
345+
344346 if self .curation :
345347 # rules:
346348 # * if user sends curation_data, then it is used
@@ -349,6 +351,7 @@ def __init__(
349351
350352 if curation_data is not None :
351353 # validate the curation data
354+ curation_data = deepcopy (curation_data )
352355 format_version = curation_data .get ("format_version" , None )
353356 # assume version 2 if not present
354357 if format_version is None :
@@ -358,24 +361,6 @@ def __init__(
358361 except Exception as e :
359362 raise ValueError (f"Invalid curation data.\n Error: { e } " )
360363
361- if curation_data .get ("merges" ) is None :
362- curation_data ["merges" ] = []
363- else :
364- # here we reset the merges for better formatting (str)
365- existing_merges = curation_data ["merges" ]
366- new_merges = []
367- for m in existing_merges :
368- if "unit_ids" not in m :
369- continue
370- if len (m ["unit_ids" ]) < 2 :
371- continue
372- new_merges = add_merge (new_merges , m ["unit_ids" ])
373- curation_data ["merges" ] = new_merges
374- if curation_data .get ("splits" ) is None :
375- curation_data ["splits" ] = []
376- if curation_data .get ("removed" ) is None :
377- curation_data ["removed" ] = []
378-
379364 elif self .analyzer .format == "binary_folder" :
380365 json_file = self .analyzer .folder / "spikeinterface_gui" / "curation_data.json"
381366 if json_file .exists ():
@@ -390,24 +375,27 @@ def __init__(
390375
391376 if curation_data is None :
392377 curation_data = deepcopy (empty_curation_data )
378+ curation_data ["unit_ids" ] = self .unit_ids .tolist ()
393379
394- self .curation_data = curation_data
395-
396- self .has_default_quality_labels = False
397- if "label_definitions" not in self .curation_data :
380+ if "label_definitions" not in curation_data :
398381 if label_definitions is not None :
399- self . curation_data ["label_definitions" ] = label_definitions
382+ curation_data ["label_definitions" ] = label_definitions
400383 else :
401- self . curation_data ["label_definitions" ] = default_label_definitions .copy ()
384+ curation_data ["label_definitions" ] = default_label_definitions .copy ()
402385
403- if "quality" in self .curation_data ["label_definitions" ]:
404- curation_dict_quality_labels = self .curation_data ["label_definitions" ]["quality" ]["label_options" ]
386+ # This will enable the default shortcuts if has default quality labels
387+ self .has_default_quality_labels = False
388+ if "quality" in curation_data ["label_definitions" ]:
389+ curation_dict_quality_labels = curation_data ["label_definitions" ]["quality" ]["label_options" ]
405390 default_quality_labels = default_label_definitions ["quality" ]["label_options" ]
406391 if set (curation_dict_quality_labels ) == set (default_quality_labels ):
407392 if self .verbose :
408393 print ('Curation quality labels are the default ones' )
409394 self .has_default_quality_labels = True
410395
396+ curation_data = CurationModel (** curation_data ).model_dump ()
397+ self .curation_data = curation_data
398+
411399 def check_is_view_possible (self , view_name ):
412400 from .viewlist import get_all_possible_views
413401 possible_class_views = get_all_possible_views ()
@@ -710,6 +698,12 @@ def get_traces(self, trace_source='preprocessed', **kargs):
710698 def get_contact_location (self ):
711699 location = self .analyzer .get_channel_locations ()
712700 return location
701+
702+ def get_channel_groups (self ):
703+ if self .has_extension ("recording" ):
704+ return self .analyzer .recording .get_channel_groups ()
705+ else :
706+ return np .zeros (self .analyzer .get_num_channels (), dtype = int )
713707
714708 def get_waveform_sweep (self ):
715709 return self .nbefore , self .nafter
@@ -721,7 +715,7 @@ def get_waveforms(self, unit_id):
721715 wfs = self .waveforms_ext .get_waveforms_one_unit (unit_id , force_dense = False )
722716 if self .analyzer .sparsity is None :
723717 # dense waveforms
724- chan_inds = np .arange (self .analyzer .recording . get_num_channels (), dtype = 'int64' )
718+ chan_inds = np .arange (self .analyzer .get_num_channels (), dtype = 'int64' )
725719 else :
726720 # sparse waveforms
727721 chan_inds = self .analyzer .sparsity .unit_id_to_channel_indices [unit_id ]
@@ -849,7 +843,7 @@ def compute_auto_merge(self, **params):
849843 )
850844
851845 return merge_unit_groups , extra
852-
846+
853847 def curation_can_be_saved (self ):
854848 return self .analyzer .format != "memory"
855849
@@ -861,6 +855,23 @@ def construct_final_curation(self):
861855 model = CurationModel (** d )
862856 return model
863857
858+ def set_curation_data (self , curation_data ):
859+ print ("Setting curation data" )
860+ new_curation_data = empty_curation_data .copy ()
861+ new_curation_data .update (curation_data )
862+
863+ if "unit_ids" not in curation_data :
864+ print ("Setting unit_ids from controller" )
865+ new_curation_data ["unit_ids" ] = self .unit_ids .tolist ()
866+
867+ if "label_definitions" not in curation_data :
868+ print ("Setting default label definitions" )
869+ new_curation_data ["label_definitions" ] = default_label_definitions .copy ()
870+
871+ # validate the curation data
872+ model = CurationModel (** new_curation_data )
873+ self .curation_data = model .model_dump ()
874+
864875 def save_curation_in_analyzer (self ):
865876 if self .analyzer .format == "memory" :
866877 print ("Analyzer is an in-memory object. Cannot save curation file in it." )
@@ -883,6 +894,16 @@ def save_curation_in_analyzer(self):
883894 sigui_group .attrs ["curation_data" ] = curation_model .model_dump (mode = "json" )
884895 self .current_curation_saved = True
885896
897+ def save_curation_callback (self ):
898+ curation = self .construct_final_curation ()
899+ curation_data = curation .model_dump ()
900+ if self .curation_callback_kwargs is None :
901+ curation_callback_kwargs = {}
902+ else :
903+ curation_callback_kwargs = self .curation_callback_kwargs
904+ self .curation_callback (curation_data , ** curation_callback_kwargs )
905+ self .current_curation_saved = True
906+
886907 def get_split_unit_ids (self ):
887908 if not self .curation :
888909 return []
0 commit comments