Skip to content

Commit 50e02ee

Browse files
authored
Merge pull request #241 from alejoe91/curation-callback
Implement `curation_callback` and `set_external_curation`
2 parents 26882bf + f6ed515 commit 50e02ee

6 files changed

Lines changed: 170 additions & 30 deletions

File tree

spikeinterface_gui/backend_panel.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,29 @@ def _handle_shortcut(self, event):
418418
tabs.stylesheets = []
419419

420420

421+
def set_external_curation(self, curation_data):
422+
"""Set external curation to controlled and triggers curation and unitlist refresh
423+
424+
Parameters
425+
----------
426+
curation_data : dict
427+
The external curation data to be set.
428+
"""
429+
if "curation" not in self.views:
430+
return
431+
432+
curation_view = self.views["curation"]
433+
self.controller.set_curation_data(curation_data)
434+
self.controller.current_curation_saved = True
435+
curation_view.notify_manual_curation_updated()
436+
curation_view.refresh()
437+
438+
# we also need to refresh the unit list view to update the unit visibility according to the new curation
439+
if "unitlist" in self.views:
440+
unitlist_view = self.views["unitlist"]
441+
unitlist_view.update_manual_labels()
442+
443+
421444
def get_local_ip():
422445
"""
423446
Get the local IP address of the machine.

spikeinterface_gui/controller.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __init__(
4747
skip_extensions=None,
4848
disable_save_settings_button=False,
4949
external_data=None,
50+
curation_callback=None,
51+
curation_callback_kwargs=None,
5052
user_main_settings=None,
5153
):
5254
self.views = []
@@ -338,6 +340,9 @@ def __init__(
338340
self.update_time_info()
339341

340342
self.curation = curation
343+
self.curation_callback = curation_callback
344+
self.curation_callback_kwargs = curation_callback_kwargs
345+
341346
if self.curation:
342347
# rules:
343348
# * if user sends curation_data, then it is used
@@ -850,6 +855,23 @@ def construct_final_curation(self):
850855
model = CurationModel(**d)
851856
return model
852857

858+
def set_curation_data(self, curation_data):
859+
print("Setting curation data")
860+
new_curation_data = empty_curation_data.copy()
861+
new_curation_data.update(curation_data)
862+
863+
if "unit_ids" not in curation_data:
864+
print("Setting unit_ids from controller")
865+
new_curation_data["unit_ids"] = self.unit_ids.tolist()
866+
867+
if "label_definitions" not in curation_data:
868+
print("Setting default label definitions")
869+
new_curation_data["label_definitions"] = default_label_definitions.copy()
870+
871+
# validate the curation data
872+
model = CurationModel(**new_curation_data)
873+
self.curation_data = model.model_dump()
874+
853875
def save_curation_in_analyzer(self):
854876
if self.analyzer.format == "memory":
855877
print("Analyzer is an in-memory object. Cannot save curation file in it.")
@@ -872,6 +894,16 @@ def save_curation_in_analyzer(self):
872894
sigui_group.attrs["curation_data"] = curation_model.model_dump(mode="json")
873895
self.current_curation_saved = True
874896

897+
def save_curation_callback(self):
898+
curation = self.construct_final_curation()
899+
curation_data = curation.model_dump()
900+
if self.curation_callback_kwargs is None:
901+
curation_callback_kwargs = {}
902+
else:
903+
curation_callback_kwargs = self.curation_callback_kwargs
904+
self.curation_callback(curation_data, **curation_callback_kwargs)
905+
self.current_curation_saved = True
906+
875907
def get_split_unit_ids(self):
876908
if not self.curation:
877909
return []

spikeinterface_gui/curationview.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,19 @@ def select_and_notify_split(self, split_unit_id):
6161
## Qt
6262
def _qt_make_layout(self):
6363
from .myqt import QT
64-
import pyqtgraph as pg
6564

6665
self.merge_info = {}
6766
self.layout = QT.QVBoxLayout()
6867

6968
tb = self.qt_widget.view_toolbar
70-
if self.controller.curation_can_be_saved():
69+
if self.controller.curation_callback is not None:
70+
but = QT.QPushButton("Save curation")
71+
tb.addWidget(but)
72+
but.clicked.connect(self.controller.save_curation_callback)
73+
elif self.controller.curation_can_be_saved():
7174
but = QT.QPushButton("Save in analyzer")
7275
tb.addWidget(but)
73-
but.clicked.connect(self.save_in_analyzer)
74-
76+
but.clicked.connect(self.controller.save_curation_in_analyzer)
7577
but = QT.QPushButton("Export JSON")
7678
but.clicked.connect(self._qt_export_json)
7779
tb.addWidget(but)
@@ -275,9 +277,6 @@ def _qt_on_unit_visibility_changed(self):
275277
def on_manual_curation_updated(self):
276278
self.refresh()
277279

278-
def save_in_analyzer(self):
279-
self.controller.save_curation_in_analyzer()
280-
281280
def _qt_export_json(self):
282281
from .myqt import QT
283282

@@ -351,37 +350,47 @@ def _panel_make_layout(self):
351350
)
352351

353352
# Create buttons
354-
buttons_row = []
355-
self.save_button = None
356-
if self.controller.curation_can_be_saved():
357-
self.save_button = pn.widgets.Button(name="Save in analyzer", button_type="primary", height=30)
358-
self.save_button.on_click(self._panel_save_in_analyzer)
359-
buttons_row.append(self.save_button)
360-
361-
self.download_button = pn.widgets.FileDownload(
353+
if self.controller.curation_callback is not None:
354+
save_button_name = "Save curation"
355+
save_button_callback = self._panel_save_curation_callback
356+
else:
357+
save_button_name = "Save in analyzer"
358+
save_button_callback = self._panel_save_in_analyzer
359+
save_button = pn.widgets.Button(
360+
name=save_button_name,
361+
button_type="primary",
362+
height=30
363+
)
364+
save_button.on_click(save_button_callback)
365+
366+
download_button = pn.widgets.FileDownload(
362367
button_type="primary", filename="curation.json", callback=self._panel_generate_json, height=30
363368
)
364-
buttons_row.append(self.download_button)
365369

366370
restore_button = pn.widgets.Button(name="Restore", button_type="primary", height=30)
367371
restore_button.on_click(self._panel_restore_units)
368372

369373
remove_merge_button = pn.widgets.Button(name="Unmerge", button_type="primary", height=30)
370374
remove_merge_button.on_click(self._panel_unmerge)
371375

372-
remove_split = pn.widgets.Button(name="Unsplit", button_type="primary", height=30)
373-
remove_split.on_click(self._panel_unsplit)
376+
remove_split_button = pn.widgets.Button(name="Unsplit", button_type="primary", height=30)
377+
remove_split_button.on_click(self._panel_unsplit)
374378

375379
# Create layout
376-
self.buttons_save = pn.Row(
377-
*buttons_row,
380+
buttons_save = pn.Row(
381+
save_button,
382+
download_button,
383+
sizing_mode="stretch_width",
384+
)
385+
save_sections = pn.Column(
386+
buttons_save,
378387
sizing_mode="stretch_width",
379388
)
380389

381390
buttons_curate = pn.Row(
382391
restore_button,
383392
remove_merge_button,
384-
remove_split,
393+
remove_split_button,
385394
sizing_mode="stretch_width",
386395
)
387396

@@ -397,7 +406,7 @@ def _panel_make_layout(self):
397406
# Create main layout with proper sizing
398407
sections = pn.Row(self.table_delete, self.table_merge, self.table_split, sizing_mode="stretch_width")
399408
self.layout = pn.Column(
400-
self.buttons_save, buttons_curate, sections, shortcuts_component, scroll=True, sizing_mode="stretch_both"
409+
save_sections, buttons_curate, sections, shortcuts_component, scroll=True, sizing_mode="stretch_both"
401410
)
402411

403412
def _panel_refresh(self):
@@ -447,7 +456,7 @@ def _panel_ensure_save_warning_message(self):
447456
import panel as pn
448457

449458
alert_markdown = pn.pane.Markdown(
450-
f"""⚠️⚠️⚠️ Your curation is not saved""",
459+
f"""⚠️ Your curation is not saved!""",
451460
hard_line_break=True,
452461
styles={"color": "red", "font-size": "16px"},
453462
name="curation_save_warning",
@@ -490,7 +499,11 @@ def _panel_unsplit(self, event):
490499
self.unsplit()
491500

492501
def _panel_save_in_analyzer(self, event):
493-
self.save_in_analyzer()
502+
self.controller.save_curation_in_analyzer()
503+
self.refresh()
504+
505+
def _panel_save_curation_callback(self, event):
506+
self.controller.save_curation_callback()
494507
self.refresh()
495508

496509
def _panel_generate_json(self):
@@ -612,12 +625,12 @@ def _conditional_refresh_split(self):
612625
revert, and export the curation data.
613626
614627
### Controls
615-
- **save in analyzer**: Save the current curation state in the analyzer.
628+
- **save in analyzer**/**save data**: Save the current curation state in the analyzer.
629+
If a custom save callback is provided, it will be used instead.
616630
- **export/download JSON**: Export the current curation state to a JSON file.
617631
- **restore**: Restore the selected unit from the deleted units table.
618632
- **unmerge**: Unmerge the selected merges from the merged units table.
619633
- **unsplit**: Unsplit the selected split groups from the split units table.
620-
- **submit to parent**: Submit the current curation state to the parent window (for use in web applications).
621634
- **press 'ctrl+r'**: Restore the selected units from the deleted units table.
622635
- **press 'ctrl+u'**: Unmerge the selected merges from the merged units table.
623636
- **press 'ctrl+x'**: Unsplit the selected split groups from the split units table.

spikeinterface_gui/main.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ def run_mainwindow(
2929
start_app=True,
3030
layout_preset=None,
3131
layout=None,
32+
external_data=None,
33+
curation_callback=None,
34+
curation_callback_kwargs=None,
3235
address="localhost",
3336
port=0,
3437
panel_start_server_kwargs=None,
3538
panel_window_servable=True,
3639
verbose=False,
3740
user_settings=None,
3841
disable_save_settings_button=False,
39-
external_data=None,
4042
):
4143
"""
4244
Create the main window and start the QT app loop.
@@ -72,6 +74,15 @@ def run_mainwindow(
7274
The name of the layout preset. None is default.
7375
layout : dict | None
7476
The layout dictionary to use instead of the preset.
77+
external_data: object, default: None
78+
Whatever is passed to `external_data` is attached to the controller as the attribute
79+
`external_data`. Useful for custom views.
80+
curation_callback: function, default: None
81+
A function that is called when the curation is saved. It should take two arguments:
82+
- `curation_data`: a dictionary containing the curation data (merges, splits, removed units)
83+
- `curation_callback_kwargs`: a dictionary of additional keyword arguments specified in `curation_callback_kwargs`
84+
curation_callback_kwargs: dict, default: None
85+
A dictionary of additional keyword arguments to pass to the `curation_callback` when it is called.
7586
address: str, default : "localhost"
7687
For "web" mode only. By default it is "localhost".
7788
Use "auto-ip" to use the real IP address of the machine.
@@ -93,9 +104,6 @@ def run_mainwindow(
93104
A dictionary of user settings for each view, which overwrite the default settings.
94105
disable_save_settings_button: bool, default: False
95106
If True, disables the "save default settings" button, so that user cannot do this.
96-
external_data: object, default: None
97-
Whatever is passed to `external_data` is attached to the controller as the attribute
98-
`external_data`. Useful for custom views.
99107
"""
100108

101109
if mode == "desktop":
@@ -149,6 +157,8 @@ def run_mainwindow(
149157
skip_extensions=skip_extensions,
150158
disable_save_settings_button=disable_save_settings_button,
151159
external_data=external_data,
160+
curation_callback=curation_callback,
161+
curation_callback_kwargs=curation_callback_kwargs,
152162
user_main_settings=user_main_settings
153163
)
154164
if verbose:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import panel as pn
2+
from pathlib import Path
3+
from spikeinterface import load_sorting_analyzer
4+
from spikeinterface_gui import run_mainwindow
5+
6+
pn.extension()
7+
8+
9+
test_folder = Path(__file__).parent / 'my_dataset_small'
10+
11+
12+
analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer")
13+
14+
# State in the parent app
15+
status_md = pn.pane.Markdown("No curation submitted yet.")
16+
17+
18+
def on_curation_saved(curation_data, title):
19+
"""This runs in the parent app's context — pure Python, no JS."""
20+
status_md.object = f"{title}\n\nReceived curation data:\n```\n{curation_data}\n```"
21+
# You can do anything here: save to DB, trigger a pipeline, etc.
22+
23+
# Create the embedded GUI with the callback
24+
win = run_mainwindow(
25+
analyzer,
26+
mode="web",
27+
start_app=False,
28+
panel_window_servable=False,
29+
curation=True,
30+
curation_callback=on_curation_saved,
31+
curation_callback_kwargs={"title": "✅ Curation received!\n"},
32+
)
33+
34+
# Compose the parent layout
35+
parent_layout = pn.Column(
36+
"# Parent Application",
37+
status_md,
38+
pn.layout.Divider(),
39+
win.main_layout,
40+
sizing_mode="stretch_both",
41+
)
42+
43+
parent_layout.servable()
44+
45+
pn.serve(parent_layout, port=12345, show=True)

spikeinterface_gui/unitlistview.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def get_selected_unit_ids(self):
2626
elif self.backend == 'panel':
2727
return self._panel_get_selected_unit_ids()
2828

29+
def update_manual_labels(self):
30+
if self.backend == 'qt':
31+
self._qt_full_table_refresh()
32+
elif self.backend == 'panel':
33+
self._panel_update_labels()
34+
2935
## Qt ##
3036
def _qt_make_layout(self):
3137

@@ -708,6 +714,17 @@ def _panel_on_edit(self, event):
708714
self.notify_manual_curation_updated()
709715
self.notifier.notify_active_view_updated()
710716

717+
def _panel_update_labels(self):
718+
# this is called after a label change to update the table values
719+
for col in self.label_definitions:
720+
for row in range(len(self.table.value)):
721+
unit_id = self.table.value.index[row]
722+
label_value = self.controller.get_unit_label(unit_id, col)
723+
if label_value is None:
724+
label_value = ""
725+
self.table.value.at[unit_id, col] = label_value
726+
self.refresh()
727+
711728
def _panel_on_only_selection(self):
712729
selected_unit = self.table.selection[0]
713730
unit_id = self.table.value.index.values[selected_unit]

0 commit comments

Comments
 (0)