Skip to content

Commit 97a2881

Browse files
committed
Merge branch 'main' of github.com:SpikeInterface/spikeinterface-gui into fix-order-mergeview
2 parents 06ace68 + 50e02ee commit 97a2881

22 files changed

Lines changed: 729 additions & 798 deletions

spikeinterface_gui/backend_panel.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,29 @@ def _handle_shortcut(self, event):
418418
tabs.stylesheets = []
419419

420420

421+
def set_external_curation(self, curation_data):
422+
"""Set external curation to controlled and triggers curation and unitlist refresh
423+
424+
Parameters
425+
----------
426+
curation_data : dict
427+
The external curation data to be set.
428+
"""
429+
if "curation" not in self.views:
430+
return
431+
432+
curation_view = self.views["curation"]
433+
self.controller.set_curation_data(curation_data)
434+
self.controller.current_curation_saved = True
435+
curation_view.notify_manual_curation_updated()
436+
curation_view.refresh()
437+
438+
# we also need to refresh the unit list view to update the unit visibility according to the new curation
439+
if "unitlist" in self.views:
440+
unitlist_view = self.views["unitlist"]
441+
unitlist_view.update_manual_labels()
442+
443+
421444
def get_local_ip():
422445
"""
423446
Get the local IP address of the machine.

spikeinterface_gui/basescatterview.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class BaseScatterView(ViewBase):
1717
{'name': 'display_high_percentiles', 'type': 'float', 'value' : 98.0, 'limits':(50, 100), 'step':0.5},
1818
]
1919
_need_compute = False
20-
20+
2121
def __init__(self, spike_data, y_label, controller=None, parent=None, backend="qt"):
2222

2323
# compute data bounds
@@ -46,6 +46,12 @@ def get_unit_data(self, unit_id, segment_index=0):
4646
spike_times = self.controller.sample_index_to_time(spike_indices)
4747
spike_data = self.spike_data[inds]
4848

49+
if len(spike_data) == 1:
50+
spike_data_value = spike_data[0]
51+
ymin = spike_data_value - 0.1 * spike_data_value
52+
ymax = spike_data_value + 0.1 * spike_data_value
53+
return spike_times, spike_data, np.array([1]), np.array([ymin, ymax]), ymin, ymax, inds
54+
4955
# avoid clear outliers in the plot and histogram by using percentiles
5056
ymin, ymax = np.percentile(spike_data, [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']])
5157
min_bin_size = np.min(np.diff(np.unique(spike_data)))
@@ -185,6 +191,9 @@ def on_unit_visibility_changed(self):
185191
self._current_selected = self.controller.get_indices_spike_selected().size
186192
self.refresh(set_scatter_range=True)
187193

194+
def on_spike_selection_changed(self):
195+
self.refresh()
196+
188197
def on_use_times_updated(self):
189198
self.refresh(set_scatter_range=True)
190199

@@ -295,6 +304,8 @@ def _qt_refresh(self, set_scatter_range=False):
295304
unit_id,
296305
segment_index=segment_index
297306
)
307+
if len(spike_times) == 0:
308+
continue
298309

299310
# make a copy of the color
300311
color = QT.QColor(self.get_unit_color(unit_id))
@@ -314,7 +325,7 @@ def _qt_refresh(self, set_scatter_range=False):
314325

315326
# set x range to time range of the current segment for scatter, and max count for histogram
316327
# set y range to min and max of visible spike amplitudes
317-
if set_scatter_range or not self._first_refresh_done:
328+
if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
318329
ymin = np.min(ymins)
319330
ymax = np.max(ymaxs)
320331
t_start, t_stop = self.controller.get_t_start_t_stop()
@@ -490,6 +501,7 @@ def _panel_make_layout(self):
490501
self.plotted_inds = []
491502

492503
def _panel_refresh(self, set_scatter_range=False):
504+
import panel as pn
493505
from bokeh.models import FixedTicker
494506

495507
self.plotted_inds = []
@@ -516,6 +528,8 @@ def _panel_refresh(self, set_scatter_range=False):
516528
unit_id,
517529
segment_index=segment_index
518530
)
531+
if len(spike_times) == 0:
532+
continue
519533
color = self.get_unit_color(unit_id)
520534
xs.extend(spike_times)
521535
ys.extend(spike_data)
@@ -555,28 +569,45 @@ def _panel_refresh(self, set_scatter_range=False):
555569
# handle selected spikes
556570
self._panel_update_selected_spikes()
557571

558-
# set y range to min and max of visible spike amplitudes
572+
# Defer Range updates to avoid nested document lock issues
573+
# def update_ranges():
559574
if set_scatter_range or not self._first_refresh_done:
560575
self.y_range.start = np.min(ymins)
561576
self.y_range.end = np.max(ymaxs)
562577
self._first_refresh_done = True
563578
self.hist_fig.x_range.end = max_count
564579
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])
565580

581+
# Schedule the update to run after the current event loop iteration
582+
# pn.state.execute(update_ranges, schedule=True)
583+
566584
def _panel_on_select_button(self, event):
567-
if self.select_toggle_button.value:
568-
self.scatter_fig.toolbar.active_drag = self.lasso_tool
569-
else:
570-
self.scatter_fig.toolbar.active_drag = None
571-
self.scatter_source.selected.indices = []
585+
import panel as pn
586+
587+
value = self.select_toggle_button.value
588+
589+
def _do_update():
590+
if value:
591+
self.scatter_fig.toolbar.active_drag = self.lasso_tool
592+
else:
593+
self.scatter_fig.toolbar.active_drag = None
594+
self.scatter_source.selected.indices = []
595+
596+
pn.state.execute(_do_update, schedule=True)
572597

573598
def _panel_change_segment(self, event):
599+
import panel as pn
600+
574601
self._current_selected = 0
575602
segment_index = int(self.segment_selector.value.split()[-1])
576603
self.controller.set_time(segment_index=segment_index)
577604
t_start, t_end = self.controller.get_t_start_t_stop()
578-
self.scatter_fig.x_range.start = t_start
579-
self.scatter_fig.x_range.end = t_end
605+
606+
def _do_update():
607+
self.scatter_fig.x_range.start = t_start
608+
self.scatter_fig.x_range.end = t_end
609+
610+
pn.state.execute(_do_update, schedule=True)
580611
self.refresh(set_scatter_range=True)
581612
self.notify_time_info_updated()
582613

@@ -618,9 +649,17 @@ def _panel_split(self, event):
618649
self.split()
619650

620651
def _panel_update_selected_spikes(self):
652+
import panel as pn
653+
621654
# handle selected spikes
622655
selected_spike_indices = self.controller.get_indices_spike_selected()
623656
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
657+
if len(selected_spike_indices) == 1:
658+
selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index']
659+
segment_index = self.controller.get_time()[1]
660+
if selected_segment != segment_index:
661+
self.segment_selector.value = f"Segment {selected_segment}"
662+
self._panel_change_segment(None)
624663
if len(selected_spike_indices) > 0:
625664
# map absolute indices to visible spikes
626665
segment_index = self.controller.get_time()[1]
@@ -634,23 +673,16 @@ def _panel_update_selected_spikes(self):
634673
# set selected spikes in scatter plot
635674
if self.settings["auto_decimate"] and len(selected_indices) > 0:
636675
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
637-
self.scatter_source.selected.indices = list(selected_indices)
638676
else:
639-
self.scatter_source.selected.indices = []
677+
selected_indices = []
678+
679+
def _do_update():
680+
self.scatter_source.selected.indices = list(selected_indices)
681+
682+
pn.state.execute(_do_update, schedule=True)
640683

641684
def _panel_on_spike_selection_changed(self):
642-
# set selection in scatter plot
643-
selected_indices = self.controller.get_indices_spike_selected()
644-
if len(selected_indices) == 0:
645-
self.scatter_source.selected.indices = []
646-
return
647-
elif len(selected_indices) == 1:
648-
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
649-
segment_index = self.controller.get_time()[1]
650-
if selected_segment != segment_index:
651-
self.segment_selector.value = f"Segment {selected_segment}"
652-
self._panel_change_segment(None)
653-
# update selected spikes
685+
# update selected spikes (scheduled via pn.state.execute inside)
654686
self._panel_update_selected_spikes()
655687

656688
def _panel_handle_shortcut(self, event):

spikeinterface_gui/controller.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from spikeinterface.widgets.utils import get_unit_colors
1010
from spikeinterface import compute_sparsity
1111
from spikeinterface.core import get_template_extremum_channel
12-
import spikeinterface.postprocessing
13-
import spikeinterface.qualitymetrics
1412
from spikeinterface.core.sorting_tools import spike_vector_to_indices
15-
from spikeinterface.core.core_tools import check_json
1613
from spikeinterface.curation import validate_curation_dict
1714
from spikeinterface.curation.curation_model import CurationModel
1815
from spikeinterface.widgets.utils import make_units_table_from_analyzer
@@ -50,6 +47,8 @@ def __init__(
5047
skip_extensions=None,
5148
disable_save_settings_button=False,
5249
external_data=None,
50+
curation_callback=None,
51+
curation_callback_kwargs=None,
5352
user_main_settings=None,
5453
):
5554
self.views = []
@@ -329,6 +328,7 @@ def __init__(
329328
self._traces_cached = {}
330329

331330
self.units_table = make_units_table_from_analyzer(analyzer, extra_properties=extra_unit_properties)
331+
332332
if displayed_unit_properties is None:
333333
displayed_unit_properties = list(_default_displayed_unit_properties)
334334
if extra_unit_properties is not None:
@@ -340,7 +340,9 @@ def __init__(
340340
self.update_time_info()
341341

342342
self.curation = curation
343-
# TODO: Reload the dictionary if it already exists
343+
self.curation_callback = curation_callback
344+
self.curation_callback_kwargs = curation_callback_kwargs
345+
344346
if self.curation:
345347
# rules:
346348
# * if user sends curation_data, then it is used
@@ -349,6 +351,7 @@ def __init__(
349351

350352
if curation_data is not None:
351353
# validate the curation data
354+
curation_data = deepcopy(curation_data)
352355
format_version = curation_data.get("format_version", None)
353356
# assume version 2 if not present
354357
if format_version is None:
@@ -358,24 +361,6 @@ def __init__(
358361
except Exception as e:
359362
raise ValueError(f"Invalid curation data.\nError: {e}")
360363

361-
if curation_data.get("merges") is None:
362-
curation_data["merges"] = []
363-
else:
364-
# here we reset the merges for better formatting (str)
365-
existing_merges = curation_data["merges"]
366-
new_merges = []
367-
for m in existing_merges:
368-
if "unit_ids" not in m:
369-
continue
370-
if len(m["unit_ids"]) < 2:
371-
continue
372-
new_merges = add_merge(new_merges, m["unit_ids"])
373-
curation_data["merges"] = new_merges
374-
if curation_data.get("splits") is None:
375-
curation_data["splits"] = []
376-
if curation_data.get("removed") is None:
377-
curation_data["removed"] = []
378-
379364
elif self.analyzer.format == "binary_folder":
380365
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
381366
if json_file.exists():
@@ -390,24 +375,27 @@ def __init__(
390375

391376
if curation_data is None:
392377
curation_data = deepcopy(empty_curation_data)
378+
curation_data["unit_ids"] = self.unit_ids.tolist()
393379

394-
self.curation_data = curation_data
395-
396-
self.has_default_quality_labels = False
397-
if "label_definitions" not in self.curation_data:
380+
if "label_definitions" not in curation_data:
398381
if label_definitions is not None:
399-
self.curation_data["label_definitions"] = label_definitions
382+
curation_data["label_definitions"] = label_definitions
400383
else:
401-
self.curation_data["label_definitions"] = default_label_definitions.copy()
384+
curation_data["label_definitions"] = default_label_definitions.copy()
402385

403-
if "quality" in self.curation_data["label_definitions"]:
404-
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
386+
# This will enable the default shortcuts if has default quality labels
387+
self.has_default_quality_labels = False
388+
if "quality" in curation_data["label_definitions"]:
389+
curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
405390
default_quality_labels = default_label_definitions["quality"]["label_options"]
406391
if set(curation_dict_quality_labels) == set(default_quality_labels):
407392
if self.verbose:
408393
print('Curation quality labels are the default ones')
409394
self.has_default_quality_labels = True
410395

396+
curation_data = CurationModel(**curation_data).model_dump()
397+
self.curation_data = curation_data
398+
411399
def check_is_view_possible(self, view_name):
412400
from .viewlist import get_all_possible_views
413401
possible_class_views = get_all_possible_views()
@@ -710,6 +698,12 @@ def get_traces(self, trace_source='preprocessed', **kargs):
710698
def get_contact_location(self):
711699
location = self.analyzer.get_channel_locations()
712700
return location
701+
702+
def get_channel_groups(self):
703+
if self.has_extension("recording"):
704+
return self.analyzer.recording.get_channel_groups()
705+
else:
706+
return np.zeros(self.analyzer.get_num_channels(), dtype=int)
713707

714708
def get_waveform_sweep(self):
715709
return self.nbefore, self.nafter
@@ -721,7 +715,7 @@ def get_waveforms(self, unit_id):
721715
wfs = self.waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False)
722716
if self.analyzer.sparsity is None:
723717
# dense waveforms
724-
chan_inds = np.arange(self.analyzer.recording.get_num_channels(), dtype='int64')
718+
chan_inds = np.arange(self.analyzer.get_num_channels(), dtype='int64')
725719
else:
726720
# sparse waveforms
727721
chan_inds = self.analyzer.sparsity.unit_id_to_channel_indices[unit_id]
@@ -849,7 +843,7 @@ def compute_auto_merge(self, **params):
849843
)
850844

851845
return merge_unit_groups, extra
852-
846+
853847
def curation_can_be_saved(self):
854848
return self.analyzer.format != "memory"
855849

@@ -861,6 +855,23 @@ def construct_final_curation(self):
861855
model = CurationModel(**d)
862856
return model
863857

858+
def set_curation_data(self, curation_data):
859+
print("Setting curation data")
860+
new_curation_data = empty_curation_data.copy()
861+
new_curation_data.update(curation_data)
862+
863+
if "unit_ids" not in curation_data:
864+
print("Setting unit_ids from controller")
865+
new_curation_data["unit_ids"] = self.unit_ids.tolist()
866+
867+
if "label_definitions" not in curation_data:
868+
print("Setting default label definitions")
869+
new_curation_data["label_definitions"] = default_label_definitions.copy()
870+
871+
# validate the curation data
872+
model = CurationModel(**new_curation_data)
873+
self.curation_data = model.model_dump()
874+
864875
def save_curation_in_analyzer(self):
865876
if self.analyzer.format == "memory":
866877
print("Analyzer is an in-memory object. Cannot save curation file in it.")
@@ -883,6 +894,16 @@ def save_curation_in_analyzer(self):
883894
sigui_group.attrs["curation_data"] = curation_model.model_dump(mode="json")
884895
self.current_curation_saved = True
885896

897+
def save_curation_callback(self):
898+
curation = self.construct_final_curation()
899+
curation_data = curation.model_dump()
900+
if self.curation_callback_kwargs is None:
901+
curation_callback_kwargs = {}
902+
else:
903+
curation_callback_kwargs = self.curation_callback_kwargs
904+
self.curation_callback(curation_data, **curation_callback_kwargs)
905+
self.current_curation_saved = True
906+
886907
def get_split_unit_ids(self):
887908
if not self.curation:
888909
return []

0 commit comments

Comments
 (0)