11import pytest
22import numpy as np
33
4- from spikeinterface .core import aggregate_units
4+ from spikeinterface .core import aggregate_units , generate_sorting
55
6- from spikeinterface .core import NpzSortingExtractor
7- from spikeinterface .core import create_sorting_npz
8- from spikeinterface .core import generate_sorting
96
7+ def create_three_sortings (num_units ):
8+ sorting1 = generate_sorting (seed = 1205 , num_units = num_units )
9+ sorting2 = generate_sorting (seed = 1206 , num_units = num_units )
10+ sorting3 = generate_sorting (seed = 1207 , num_units = num_units )
1011
11- def test_unitsaggregationsorting (create_cache_folder ):
12- cache_folder = create_cache_folder
12+ return (sorting1 , sorting2 , sorting3 )
1313
14- num_seg = 2
15- file_path = cache_folder / "test_BaseSorting.npz"
1614
17- create_sorting_npz (num_seg , file_path )
15+ def test_unitsaggregationsorting_spiketrains ():
16+ """Aggregates three sortings, then checks that the number of units and spike trains are equal
17+ for pre-aggregated sorting and the aggregated sorting."""
1818
19- sorting1 = NpzSortingExtractor (file_path )
20- sorting2 = sorting1 .clone ()
21- sorting3 = sorting1 .clone ()
22- print (sorting1 )
23- num_units = len (sorting1 .get_unit_ids ())
19+ num_units = 5
20+ sorting1 , sorting2 , sorting3 = create_three_sortings (num_units = num_units )
2421
2522 # test num units
2623 sorting_agg = aggregate_units ([sorting1 , sorting2 , sorting3 ])
27- print ( sorting_agg )
28- assert len (sorting_agg . get_unit_ids () ) == 3 * num_units
24+ unit_ids = sorting_agg . get_unit_ids ( )
25+ assert len (unit_ids ) == 3 * num_units
2926
3027 # test spike trains
31- unit_ids = sorting1 .get_unit_ids ()
32-
33- for seg in range (num_seg ):
34- spiketrain1_1 = sorting1 .get_unit_spike_train (unit_ids [1 ], segment_index = seg )
35- spiketrains2_0 = sorting2 .get_unit_spike_train (unit_ids [0 ], segment_index = seg )
36- spiketrains3_2 = sorting3 .get_unit_spike_train (unit_ids [2 ], segment_index = seg )
37- assert np .allclose (spiketrain1_1 , sorting_agg .get_unit_spike_train (unit_ids [1 ], segment_index = seg ))
38- assert np .allclose (spiketrains2_0 , sorting_agg .get_unit_spike_train (num_units + unit_ids [0 ], segment_index = seg ))
39- assert np .allclose (
40- spiketrains3_2 , sorting_agg .get_unit_spike_train (2 * num_units + unit_ids [2 ], segment_index = seg )
28+ for segment_index in range (sorting1 .get_num_segments ()):
29+
30+ spiketrain1 = sorting1 .get_unit_spike_train (unit_ids [1 ], segment_index = segment_index )
31+ assert np .all (spiketrain1 == sorting_agg .get_unit_spike_train (unit_ids [1 ], segment_index = segment_index ))
32+
33+ spiketrain2 = sorting2 .get_unit_spike_train (unit_ids [0 ], segment_index = segment_index )
34+ assert np .all (
35+ spiketrain2 == sorting_agg .get_unit_spike_train (unit_ids [0 + num_units ], segment_index = segment_index )
36+ )
37+
38+ spiketrain3 = sorting3 .get_unit_spike_train (unit_ids [2 ], segment_index = segment_index )
39+ assert np .all (
40+ spiketrain3 == sorting_agg .get_unit_spike_train (unit_ids [2 + num_units * 2 ], segment_index = segment_index )
4141 )
4242
4343 # test rename units
4444 renamed_unit_ids = [f"#Unit { i } " for i in range (3 * num_units )]
4545 sorting_agg_renamed = aggregate_units ([sorting1 , sorting2 , sorting3 ], renamed_unit_ids = renamed_unit_ids )
4646 assert all (unit in renamed_unit_ids for unit in sorting_agg_renamed .get_unit_ids ())
4747
48- # test annotations
4948
50- # matching annotation
49+ def test_unitsaggregationsorting_annotations ():
50+ """Aggregates a sorting and check if annotations were correctly propagated."""
51+
52+ num_units = 5
53+ sorting1 , sorting2 , sorting3 = create_three_sortings (num_units = num_units )
54+
55+ # Annotations the same, so can be propagated to aggregated sorting
5156 sorting1 .annotate (organ = "brain" )
5257 sorting2 .annotate (organ = "brain" )
5358 sorting3 .annotate (organ = "brain" )
5459
55- # not matching annotation
60+ # Annotations are not equal, so cannot be propagated to aggregated sorting
5661 sorting1 .annotate (area = "CA1" )
5762 sorting2 .annotate (area = "CA2" )
5863 sorting3 .annotate (area = "CA3" )
5964
60- # incomplete annotation
65+ # Annotations are not known for all sortings, so cannot be propagated to aggregated sorting
6166 sorting1 .annotate (date = "2022-10-13" )
6267 sorting2 .annotate (date = "2022-10-13" )
6368
@@ -66,31 +71,45 @@ def test_unitsaggregationsorting(create_cache_folder):
6671 assert "area" not in sorting_agg_prop .get_annotation_keys ()
6772 assert "date" not in sorting_agg_prop .get_annotation_keys ()
6873
69- # test properties
7074
71- # complete property
75+ def test_unitsaggregationsorting_properties ():
76+ """Aggregates a sorting and check if properties were correctly propagated."""
77+
78+ num_units = 5
79+ sorting1 , sorting2 , sorting3 = create_three_sortings (num_units = num_units )
80+
81+ # Can propagated property
7282 sorting1 .set_property ("brain_area" , ["CA1" ] * num_units )
7383 sorting2 .set_property ("brain_area" , ["CA2" ] * num_units )
7484 sorting3 .set_property ("brain_area" , ["CA3" ] * num_units )
7585
76- # skip for inconsistency
77- sorting1 .set_property ("template" , np .zeros ((num_units , 4 , 30 )))
78- sorting1 .set_property ("template" , np .zeros ((num_units , 20 , 50 )))
79- sorting1 .set_property ("template" , np .zeros ((num_units , 2 , 10 )))
80-
81- # incomplete property (str can't be propagated)
82- sorting1 .set_property ("quality" , ["good" ] * num_units )
83- sorting2 .set_property ("quality" , ["bad" ] * num_units )
86+ # Can propagated, even though the dtype is different, since dtype.kind is the same
87+ sorting1 .set_property ("quality_string" , ["good" ] * num_units )
88+ sorting2 .set_property ("quality_string" , ["bad" ] * num_units )
89+ sorting3 .set_property ("quality_string" , ["bad" ] * num_units )
8490
85- # incomplete property (object can be propagated)
91+ # Can propagated. Although we don't know the "rand" property for sorting3, we can
92+ # use the Extractor's `default_missing_property_values`
8693 sorting1 .set_property ("rand" , np .random .rand (num_units ))
8794 sorting2 .set_property ("rand" , np .random .rand (num_units ))
8895
96+ # Cannot propagate as arrays are different shapes for each sorting
97+ sorting1 .set_property ("template" , np .zeros ((num_units , 4 , 30 )))
98+ sorting2 .set_property ("template" , np .zeros ((num_units , 20 , 50 )))
99+ sorting3 .set_property ("template" , np .zeros ((num_units , 2 , 10 )))
100+
101+ # Cannot propagate as dtypes are different
102+ sorting1 .set_property ("quality_mixed" , ["good" ] * num_units )
103+ sorting2 .set_property ("quality_mixed" , [1 ] * num_units )
104+ sorting3 .set_property ("quality_mixed" , [2 ] * num_units )
105+
89106 sorting_agg_prop = aggregate_units ([sorting1 , sorting2 , sorting3 ])
107+
90108 assert "brain_area" in sorting_agg_prop .get_property_keys ()
91- assert "quality" not in sorting_agg_prop .get_property_keys ()
109+ assert "quality_string" in sorting_agg_prop .get_property_keys ()
92110 assert "rand" in sorting_agg_prop .get_property_keys ()
93- print (sorting_agg_prop .get_property ("brain_area" ))
111+ assert "template" not in sorting_agg_prop .get_property_keys ()
112+ assert "quality_mixed" not in sorting_agg_prop .get_property_keys ()
94113
95114
96115def test_unit_aggregation_preserve_ids ():
0 commit comments