Skip to content

Commit 6ef50f0

Browse files
authored
Merge pull request #3778 from chrishalcrow/new-string-check-on-aggregate-sorting
Update sorting property dtype check on unit aggregation to allow for different string types
2 parents 526e308 + b08b81b commit 6ef50f0

6 files changed

Lines changed: 103 additions & 83 deletions

File tree

src/spikeinterface/core/tests/test_unitsaggregationsorting.py

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,68 @@
11
import pytest
22
import numpy as np
33

4-
from spikeinterface.core import aggregate_units
4+
from spikeinterface.core import aggregate_units, generate_sorting
55

6-
from spikeinterface.core import NpzSortingExtractor
7-
from spikeinterface.core import create_sorting_npz
8-
from spikeinterface.core import generate_sorting
96

7+
def create_three_sortings(num_units):
8+
sorting1 = generate_sorting(seed=1205, num_units=num_units)
9+
sorting2 = generate_sorting(seed=1206, num_units=num_units)
10+
sorting3 = generate_sorting(seed=1207, num_units=num_units)
1011

11-
def test_unitsaggregationsorting(create_cache_folder):
12-
cache_folder = create_cache_folder
12+
return (sorting1, sorting2, sorting3)
1313

14-
num_seg = 2
15-
file_path = cache_folder / "test_BaseSorting.npz"
1614

17-
create_sorting_npz(num_seg, file_path)
15+
def test_unitsaggregationsorting_spiketrains():
16+
"""Aggregates three sortings, then checks that the number of units and spike trains are equal
17+
for pre-aggregated sorting and the aggregated sorting."""
1818

19-
sorting1 = NpzSortingExtractor(file_path)
20-
sorting2 = sorting1.clone()
21-
sorting3 = sorting1.clone()
22-
print(sorting1)
23-
num_units = len(sorting1.get_unit_ids())
19+
num_units = 5
20+
sorting1, sorting2, sorting3 = create_three_sortings(num_units=num_units)
2421

2522
# test num units
2623
sorting_agg = aggregate_units([sorting1, sorting2, sorting3])
27-
print(sorting_agg)
28-
assert len(sorting_agg.get_unit_ids()) == 3 * num_units
24+
unit_ids = sorting_agg.get_unit_ids()
25+
assert len(unit_ids) == 3 * num_units
2926

3027
# test spike trains
31-
unit_ids = sorting1.get_unit_ids()
32-
33-
for seg in range(num_seg):
34-
spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg)
35-
spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg)
36-
spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg)
37-
assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg))
38-
assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0], segment_index=seg))
39-
assert np.allclose(
40-
spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2], segment_index=seg)
28+
for segment_index in range(sorting1.get_num_segments()):
29+
30+
spiketrain1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=segment_index)
31+
assert np.all(spiketrain1 == sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=segment_index))
32+
33+
spiketrain2 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=segment_index)
34+
assert np.all(
35+
spiketrain2 == sorting_agg.get_unit_spike_train(unit_ids[0 + num_units], segment_index=segment_index)
36+
)
37+
38+
spiketrain3 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=segment_index)
39+
assert np.all(
40+
spiketrain3 == sorting_agg.get_unit_spike_train(unit_ids[2 + num_units * 2], segment_index=segment_index)
4141
)
4242

4343
# test rename units
4444
renamed_unit_ids = [f"#Unit {i}" for i in range(3 * num_units)]
4545
sorting_agg_renamed = aggregate_units([sorting1, sorting2, sorting3], renamed_unit_ids=renamed_unit_ids)
4646
assert all(unit in renamed_unit_ids for unit in sorting_agg_renamed.get_unit_ids())
4747

48-
# test annotations
4948

50-
# matching annotation
49+
def test_unitsaggregationsorting_annotations():
50+
"""Aggregates a sorting and check if annotations were correctly propagated."""
51+
52+
num_units = 5
53+
sorting1, sorting2, sorting3 = create_three_sortings(num_units=num_units)
54+
55+
# Annotations the same, so can be propagated to aggregated sorting
5156
sorting1.annotate(organ="brain")
5257
sorting2.annotate(organ="brain")
5358
sorting3.annotate(organ="brain")
5459

55-
# not matching annotation
60+
# Annotations are not equal, so cannot be propagated to aggregated sorting
5661
sorting1.annotate(area="CA1")
5762
sorting2.annotate(area="CA2")
5863
sorting3.annotate(area="CA3")
5964

60-
# incomplete annotation
65+
# Annotations are not known for all sortings, so cannot be propagated to aggregated sorting
6166
sorting1.annotate(date="2022-10-13")
6267
sorting2.annotate(date="2022-10-13")
6368

@@ -66,31 +71,45 @@ def test_unitsaggregationsorting(create_cache_folder):
6671
assert "area" not in sorting_agg_prop.get_annotation_keys()
6772
assert "date" not in sorting_agg_prop.get_annotation_keys()
6873

69-
# test properties
7074

71-
# complete property
75+
def test_unitsaggregationsorting_properties():
76+
"""Aggregates a sorting and check if properties were correctly propagated."""
77+
78+
num_units = 5
79+
sorting1, sorting2, sorting3 = create_three_sortings(num_units=num_units)
80+
81+
# Can propagated property
7282
sorting1.set_property("brain_area", ["CA1"] * num_units)
7383
sorting2.set_property("brain_area", ["CA2"] * num_units)
7484
sorting3.set_property("brain_area", ["CA3"] * num_units)
7585

76-
# skip for inconsistency
77-
sorting1.set_property("template", np.zeros((num_units, 4, 30)))
78-
sorting1.set_property("template", np.zeros((num_units, 20, 50)))
79-
sorting1.set_property("template", np.zeros((num_units, 2, 10)))
80-
81-
# incomplete property (str can't be propagated)
82-
sorting1.set_property("quality", ["good"] * num_units)
83-
sorting2.set_property("quality", ["bad"] * num_units)
86+
# Can propagated, even though the dtype is different, since dtype.kind is the same
87+
sorting1.set_property("quality_string", ["good"] * num_units)
88+
sorting2.set_property("quality_string", ["bad"] * num_units)
89+
sorting3.set_property("quality_string", ["bad"] * num_units)
8490

85-
# incomplete property (object can be propagated)
91+
# Can propagated. Although we don't know the "rand" property for sorting3, we can
92+
# use the Extractor's `default_missing_property_values`
8693
sorting1.set_property("rand", np.random.rand(num_units))
8794
sorting2.set_property("rand", np.random.rand(num_units))
8895

96+
# Cannot propagate as arrays are different shapes for each sorting
97+
sorting1.set_property("template", np.zeros((num_units, 4, 30)))
98+
sorting2.set_property("template", np.zeros((num_units, 20, 50)))
99+
sorting3.set_property("template", np.zeros((num_units, 2, 10)))
100+
101+
# Cannot propagate as dtypes are different
102+
sorting1.set_property("quality_mixed", ["good"] * num_units)
103+
sorting2.set_property("quality_mixed", [1] * num_units)
104+
sorting3.set_property("quality_mixed", [2] * num_units)
105+
89106
sorting_agg_prop = aggregate_units([sorting1, sorting2, sorting3])
107+
90108
assert "brain_area" in sorting_agg_prop.get_property_keys()
91-
assert "quality" not in sorting_agg_prop.get_property_keys()
109+
assert "quality_string" in sorting_agg_prop.get_property_keys()
92110
assert "rand" in sorting_agg_prop.get_property_keys()
93-
print(sorting_agg_prop.get_property("brain_area"))
111+
assert "template" not in sorting_agg_prop.get_property_keys()
112+
assert "quality_mixed" not in sorting_agg_prop.get_property_keys()
94113

95114

96115
def test_unit_aggregation_preserve_ids():

src/spikeinterface/core/unitsaggregationsorting.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -77,43 +77,44 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
7777
if np.all(annotations == annotations[0]):
7878
self.set_annotation(annotation_name, sorting_list[0].get_annotation(annotation_name))
7979

80-
property_keys = {}
81-
property_dict = {}
82-
deleted_keys = []
83-
for sort in sorting_list:
84-
for prop_name in sort.get_property_keys():
85-
if prop_name in deleted_keys:
86-
continue
87-
if prop_name in property_keys:
88-
if property_keys[prop_name] != sort.get_property(prop_name).dtype:
89-
print(f"Skipping property '{prop_name}: difference in dtype between sortings'")
90-
del property_keys[prop_name]
91-
deleted_keys.append(prop_name)
92-
else:
93-
property_keys[prop_name] = sort.get_property(prop_name).dtype
94-
for prop_name in property_keys:
95-
dtype = property_keys[prop_name]
96-
property_dict[prop_name] = np.array([], dtype=dtype)
80+
# Check if all the sortings have the same properties
81+
properties_set = set(np.concatenate([sorting.get_property_keys() for sorting in sorting_list]))
82+
for prop_name in properties_set:
9783

84+
dtypes_per_sorting = []
9885
for sort in sorting_list:
9986
if prop_name in sort.get_property_keys():
100-
values = sort.get_property(prop_name)
101-
else:
102-
if dtype.kind not in BaseExtractor.default_missing_property_values:
103-
del property_dict[prop_name]
87+
dtypes_per_sorting.append(sort.get_property(prop_name).dtype.kind)
88+
89+
if len(set(dtypes_per_sorting)) != 1:
90+
warnings.warn(
91+
f"Skipping property '{prop_name}'. Difference in dtype.kind between sortings: {dtypes_per_sorting}"
92+
)
93+
continue
94+
95+
all_property_values = []
96+
for sort in sorting_list:
97+
98+
# If one of the sortings doesn't have the property, use the default missing property value
99+
if prop_name not in sort.get_property_keys():
100+
try:
101+
values = np.full(
102+
sort.get_num_units(),
103+
BaseExtractor.default_missing_property_values[dtypes_per_sorting[0]],
104+
)
105+
except:
106+
warnings.warn(f"Skipping property '{prop_name}: cannot inpute missing property values.'")
104107
break
105-
values = np.full(
106-
sort.get_num_units(), BaseExtractor.default_missing_property_values[dtype.kind], dtype=dtype
107-
)
108-
109-
try:
110-
property_dict[prop_name] = np.concatenate((property_dict[prop_name], values))
111-
except Exception as e:
112-
print(f"Skipping property '{prop_name}' due to shape inconsistency")
113-
del property_dict[prop_name]
114-
break
115-
for prop_name, prop_values in property_dict.items():
116-
self.set_property(key=prop_name, values=prop_values)
108+
else:
109+
values = sort.get_property(prop_name)
110+
111+
all_property_values.append(values)
112+
113+
try:
114+
prop_values = np.concatenate(all_property_values)
115+
self.set_property(key=prop_name, values=prop_values)
116+
except Exception as ext:
117+
warnings.warn(f"Skipping property '{prop_name}' as numpy cannot concatente. Numpy error: {ext}")
117118

118119
# add segments
119120
for i_seg in range(num_segments):

src/spikeinterface/postprocessing/template_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ def _set_params(
153153
if delete_existing_metrics is False and tm_extension is not None:
154154

155155
existing_metric_names = tm_extension.params["metric_names"]
156-
existing_metric_names_propogated = [
156+
existing_metric_names_propagated = [
157157
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
158158
]
159-
metric_names = metrics_to_compute + existing_metric_names_propogated
159+
metric_names = metrics_to_compute + existing_metric_names_propagated
160160

161161
params = dict(
162162
metric_names=metric_names,
@@ -328,7 +328,7 @@ def _run(self, verbose=False):
328328

329329
existing_metrics = []
330330

331-
# Check if we need to propogate any old metrics. If so, we'll do that.
331+
# Check if we need to propagate any old metrics. If so, we'll do that.
332332
# Otherwise, we'll avoid attempting to load an empty template_metrics.
333333
if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]):
334334

src/spikeinterface/postprocessing/tests/test_template_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer):
9898

9999
def test_metric_names_in_same_order(small_sorting_analyzer):
100100
"""
101-
Computes sepecified template metrics and checks order is propogated.
101+
Computes sepecified template metrics and checks order is propagated.
102102
"""
103103
specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"]
104104
small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names)

src/spikeinterface/qualitymetrics/quality_metric_calculator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def _set_params(
108108
if delete_existing_metrics is False and qm_extension is not None:
109109

110110
existing_metric_names = qm_extension.params["metric_names"]
111-
existing_metric_names_propogated = [
111+
existing_metric_names_propagated = [
112112
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
113113
]
114-
metric_names = metrics_to_compute + existing_metric_names_propogated
114+
metric_names = metrics_to_compute + existing_metric_names_propagated
115115

116116
params = dict(
117117
metric_names=metric_names,

src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer):
120120

121121
def test_metric_names_in_same_order(small_sorting_analyzer):
122122
"""
123-
Computes sepecified quality metrics and checks order is propogated.
123+
Computes sepecified quality metrics and checks order is propagated.
124124
"""
125125
specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"]
126126
small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names)

0 commit comments

Comments
 (0)