Skip to content

Commit c198d6d

Browse files
committed
Solve conflicts
2 parents b4578ca + f552464 commit c198d6d

19 files changed

Lines changed: 504 additions & 701 deletions

spikeinterface_gui/basescatterview.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class BaseScatterView(ViewBase):
1717
{'name': 'display_high_percentiles', 'type': 'float', 'value' : 98.0, 'limits':(50, 100), 'step':0.5},
1818
]
1919
_need_compute = False
20-
20+
2121
def __init__(self, spike_data, y_label, controller=None, parent=None, backend="qt"):
2222

2323
# compute data bounds
@@ -46,6 +46,12 @@ def get_unit_data(self, unit_id, segment_index=0):
4646
spike_times = self.controller.sample_index_to_time(spike_indices)
4747
spike_data = self.spike_data[inds]
4848

49+
if len(spike_data) == 1:
50+
spike_data_value = spike_data[0]
51+
ymin = spike_data_value - 0.1 * spike_data_value
52+
ymax = spike_data_value + 0.1 * spike_data_value
53+
return spike_times, spike_data, np.array([1]), np.array([ymin, ymax]), ymin, ymax, inds
54+
4955
# avoid clear outliers in the plot and histogram by using percentiles
5056
ymin, ymax = np.percentile(spike_data, [self.settings['display_low_percentiles'], self.settings['display_high_percentiles']])
5157
min_bin_size = np.min(np.diff(np.unique(spike_data)))
@@ -185,6 +191,9 @@ def on_unit_visibility_changed(self):
185191
self._current_selected = self.controller.get_indices_spike_selected().size
186192
self.refresh(set_scatter_range=True)
187193

194+
def on_spike_selection_changed(self):
195+
self.refresh()
196+
188197
def on_use_times_updated(self):
189198
self.refresh(set_scatter_range=True)
190199

@@ -295,6 +304,8 @@ def _qt_refresh(self, set_scatter_range=False):
295304
unit_id,
296305
segment_index=segment_index
297306
)
307+
if len(spike_times) == 0:
308+
continue
298309

299310
# make a copy of the color
300311
color = QT.QColor(self.get_unit_color(unit_id))
@@ -314,7 +325,7 @@ def _qt_refresh(self, set_scatter_range=False):
314325

315326
# set x range to time range of the current segment for scatter, and max count for histogram
316327
# set y range to min and max of visible spike amplitudes
317-
if set_scatter_range or not self._first_refresh_done:
328+
if len(ymins) > 0 and (set_scatter_range or not self._first_refresh_done):
318329
ymin = np.min(ymins)
319330
ymax = np.max(ymaxs)
320331
t_start, t_stop = self.controller.get_t_start_t_stop()
@@ -490,6 +501,7 @@ def _panel_make_layout(self):
490501
self.plotted_inds = []
491502

492503
def _panel_refresh(self, set_scatter_range=False):
504+
import panel as pn
493505
from bokeh.models import FixedTicker
494506

495507
self.plotted_inds = []
@@ -516,6 +528,8 @@ def _panel_refresh(self, set_scatter_range=False):
516528
unit_id,
517529
segment_index=segment_index
518530
)
531+
if len(spike_times) == 0:
532+
continue
519533
color = self.get_unit_color(unit_id)
520534
xs.extend(spike_times)
521535
ys.extend(spike_data)
@@ -555,28 +569,45 @@ def _panel_refresh(self, set_scatter_range=False):
555569
# handle selected spikes
556570
self._panel_update_selected_spikes()
557571

558-
# set y range to min and max of visible spike amplitudes
572+
# Defer Range updates to avoid nested document lock issues
573+
# def update_ranges():
559574
if set_scatter_range or not self._first_refresh_done:
560575
self.y_range.start = np.min(ymins)
561576
self.y_range.end = np.max(ymaxs)
562577
self._first_refresh_done = True
563578
self.hist_fig.x_range.end = max_count
564579
self.hist_fig.xaxis.ticker = FixedTicker(ticks=[0, max_count // 2, max_count])
565580

581+
# Schedule the update to run after the current event loop iteration
582+
# pn.state.execute(update_ranges, schedule=True)
583+
566584
def _panel_on_select_button(self, event):
567-
if self.select_toggle_button.value:
568-
self.scatter_fig.toolbar.active_drag = self.lasso_tool
569-
else:
570-
self.scatter_fig.toolbar.active_drag = None
571-
self.scatter_source.selected.indices = []
585+
import panel as pn
586+
587+
value = self.select_toggle_button.value
588+
589+
def _do_update():
590+
if value:
591+
self.scatter_fig.toolbar.active_drag = self.lasso_tool
592+
else:
593+
self.scatter_fig.toolbar.active_drag = None
594+
self.scatter_source.selected.indices = []
595+
596+
pn.state.execute(_do_update, schedule=True)
572597

573598
def _panel_change_segment(self, event):
599+
import panel as pn
600+
574601
self._current_selected = 0
575602
segment_index = int(self.segment_selector.value.split()[-1])
576603
self.controller.set_time(segment_index=segment_index)
577604
t_start, t_end = self.controller.get_t_start_t_stop()
578-
self.scatter_fig.x_range.start = t_start
579-
self.scatter_fig.x_range.end = t_end
605+
606+
def _do_update():
607+
self.scatter_fig.x_range.start = t_start
608+
self.scatter_fig.x_range.end = t_end
609+
610+
pn.state.execute(_do_update, schedule=True)
580611
self.refresh(set_scatter_range=True)
581612
self.notify_time_info_updated()
582613

@@ -618,9 +649,17 @@ def _panel_split(self, event):
618649
self.split()
619650

620651
def _panel_update_selected_spikes(self):
652+
import panel as pn
653+
621654
# handle selected spikes
622655
selected_spike_indices = self.controller.get_indices_spike_selected()
623656
selected_spike_indices = np.intersect1d(selected_spike_indices, self.plotted_inds)
657+
if len(selected_spike_indices) == 1:
658+
selected_segment = self.controller.spikes[selected_spike_indices[0]]['segment_index']
659+
segment_index = self.controller.get_time()[1]
660+
if selected_segment != segment_index:
661+
self.segment_selector.value = f"Segment {selected_segment}"
662+
self._panel_change_segment(None)
624663
if len(selected_spike_indices) > 0:
625664
# map absolute indices to visible spikes
626665
segment_index = self.controller.get_time()[1]
@@ -634,23 +673,16 @@ def _panel_update_selected_spikes(self):
634673
# set selected spikes in scatter plot
635674
if self.settings["auto_decimate"] and len(selected_indices) > 0:
636675
selected_indices, = np.nonzero(np.isin(self.plotted_inds, selected_spike_indices))
637-
self.scatter_source.selected.indices = list(selected_indices)
638676
else:
639-
self.scatter_source.selected.indices = []
677+
selected_indices = []
678+
679+
def _do_update():
680+
self.scatter_source.selected.indices = list(selected_indices)
681+
682+
pn.state.execute(_do_update, schedule=True)
640683

641684
def _panel_on_spike_selection_changed(self):
642-
# set selection in scatter plot
643-
selected_indices = self.controller.get_indices_spike_selected()
644-
if len(selected_indices) == 0:
645-
self.scatter_source.selected.indices = []
646-
return
647-
elif len(selected_indices) == 1:
648-
selected_segment = self.controller.spikes[selected_indices[0]]['segment_index']
649-
segment_index = self.controller.get_time()[1]
650-
if selected_segment != segment_index:
651-
self.segment_selector.value = f"Segment {selected_segment}"
652-
self._panel_change_segment(None)
653-
# update selected spikes
685+
# update selected spikes (scheduled via pn.state.execute inside)
654686
self._panel_update_selected_spikes()
655687

656688
def _panel_handle_shortcut(self, event):

spikeinterface_gui/controller.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from spikeinterface.widgets.utils import get_unit_colors
1010
from spikeinterface import compute_sparsity
1111
from spikeinterface.core import get_template_extremum_channel
12-
import spikeinterface.postprocessing
13-
import spikeinterface.qualitymetrics
1412
from spikeinterface.core.sorting_tools import spike_vector_to_indices
15-
from spikeinterface.core.core_tools import check_json
1613
from spikeinterface.curation import validate_curation_dict
1714
from spikeinterface.curation.curation_model import CurationModel
1815
from spikeinterface.widgets.utils import make_units_table_from_analyzer
@@ -345,7 +342,6 @@ def __init__(
345342
self.curation_callback = curation_callback
346343
self.curation_callback_kwargs = curation_callback_kwargs
347344

348-
# TODO: Reload the dictionary if it already exists
349345
if self.curation:
350346
# rules:
351347
# * if user sends curation_data, then it is used
@@ -354,6 +350,7 @@ def __init__(
354350

355351
if curation_data is not None:
356352
# validate the curation data
353+
curation_data = deepcopy(curation_data)
357354
format_version = curation_data.get("format_version", None)
358355
# assume version 2 if not present
359356
if format_version is None:
@@ -363,24 +360,6 @@ def __init__(
363360
except Exception as e:
364361
raise ValueError(f"Invalid curation data.\nError: {e}")
365362

366-
if curation_data.get("merges") is None:
367-
curation_data["merges"] = []
368-
else:
369-
# here we reset the merges for better formatting (str)
370-
existing_merges = curation_data["merges"]
371-
new_merges = []
372-
for m in existing_merges:
373-
if "unit_ids" not in m:
374-
continue
375-
if len(m["unit_ids"]) < 2:
376-
continue
377-
new_merges = add_merge(new_merges, m["unit_ids"])
378-
curation_data["merges"] = new_merges
379-
if curation_data.get("splits") is None:
380-
curation_data["splits"] = []
381-
if curation_data.get("removed") is None:
382-
curation_data["removed"] = []
383-
384363
elif self.analyzer.format == "binary_folder":
385364
json_file = self.analyzer.folder / "spikeinterface_gui" / "curation_data.json"
386365
if json_file.exists():
@@ -395,24 +374,27 @@ def __init__(
395374

396375
if curation_data is None:
397376
curation_data = deepcopy(empty_curation_data)
377+
curation_data["unit_ids"] = self.unit_ids.tolist()
398378

399-
self.curation_data = curation_data
400-
401-
self.has_default_quality_labels = False
402-
if "label_definitions" not in self.curation_data:
379+
if "label_definitions" not in curation_data:
403380
if label_definitions is not None:
404-
self.curation_data["label_definitions"] = label_definitions
381+
curation_data["label_definitions"] = label_definitions
405382
else:
406-
self.curation_data["label_definitions"] = default_label_definitions.copy()
383+
curation_data["label_definitions"] = default_label_definitions.copy()
407384

408-
if "quality" in self.curation_data["label_definitions"]:
409-
curation_dict_quality_labels = self.curation_data["label_definitions"]["quality"]["label_options"]
385+
# This will enable the default shortcuts if has default quality labels
386+
self.has_default_quality_labels = False
387+
if "quality" in curation_data["label_definitions"]:
388+
curation_dict_quality_labels = curation_data["label_definitions"]["quality"]["label_options"]
410389
default_quality_labels = default_label_definitions["quality"]["label_options"]
411390
if set(curation_dict_quality_labels) == set(default_quality_labels):
412391
if self.verbose:
413392
print('Curation quality labels are the default ones')
414393
self.has_default_quality_labels = True
415394

395+
curation_data = CurationModel(**curation_data).model_dump()
396+
self.curation_data = curation_data
397+
416398
def check_is_view_possible(self, view_name):
417399
from .viewlist import get_all_possible_views
418400
possible_class_views = get_all_possible_views()
@@ -854,7 +836,7 @@ def compute_auto_merge(self, **params):
854836
)
855837

856838
return merge_unit_groups, extra
857-
839+
858840
def curation_can_be_saved(self):
859841
return self.analyzer.format != "memory"
860842

0 commit comments

Comments
 (0)