Skip to content

Commit 7882efe

Browse files
authored
Merge pull request #18 from TuragaLab/copilot/fix-flash-response-index-alignment
Align `flash_response_index` labels with filtered FRI coordinate order
2 parents bdf6e44 + 15dd21c commit 7882efe

2 files changed

Lines changed: 57 additions & 0 deletions

File tree

flyvis/network/ensemble_view.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ def flash_response_index(
237237
responses = self.flash_responses()
238238
fris = flash_response_index(responses, radius=6)
239239
if cell_types is not None:
240+
requested_cell_types = cell_types
240241
fris = fris.custom.where(cell_type=cell_types)
242+
cell_types = fris.cell_type.values
243+
kwargs.setdefault("sorted_type_list", requested_cell_types)
241244
else:
242245
cell_types = fris.cell_type.values
243246
task_error = self.task_error()

tests/test_ensemble_view.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from types import SimpleNamespace
2+
3+
import numpy as np
4+
5+
from flyvis.network import ensemble_view as ensemble_view_module
6+
from flyvis.network.ensemble_view import EnsembleView
7+
8+
9+
class _FakeFRI:
10+
def __init__(self, values, cell_types, filtered_cell_types=None):
11+
self.values = values
12+
self.cell_type = SimpleNamespace(values=np.array(cell_types))
13+
self.filtered_cell_types = filtered_cell_types
14+
self.custom = SimpleNamespace(where=self._where)
15+
16+
def _where(self, cell_type):
17+
if self.filtered_cell_types is not None:
18+
return _FakeFRI(self.values, self.filtered_cell_types)
19+
return _FakeFRI(self.values, list(cell_type))
20+
21+
22+
class _DummyView:
23+
def flash_responses(self):
24+
return "responses"
25+
26+
def task_error(self):
27+
return SimpleNamespace(values=np.array([0.5, 0.1, 0.2]))
28+
29+
30+
def test_flash_response_index_aligns_labels_with_filtered_data(monkeypatch):
31+
requested = ["Mi1", "Tm3", "CT1(M10)"]
32+
filtered = ["CT1(M10)", "Mi1", "Tm3"]
33+
fake_fris = _FakeFRI(np.ones((3, 2, 1)), requested, filtered_cell_types=filtered)
34+
35+
monkeypatch.setattr(
36+
ensemble_view_module, "flash_response_index", lambda *_args, **_kwargs: fake_fris
37+
)
38+
39+
captured = SimpleNamespace(fris=None, cell_types=None, sorted_type_list=None)
40+
41+
def _fake_plot_fris(fris, cell_types, **kwargs):
42+
captured.fris = fris
43+
captured.cell_types = list(cell_types)
44+
captured.sorted_type_list = kwargs.get("sorted_type_list")
45+
return "fig", "ax"
46+
47+
monkeypatch.setattr(ensemble_view_module, "plot_fris", _fake_plot_fris)
48+
49+
fig, ax = EnsembleView.flash_response_index(_DummyView(), cell_types=requested)
50+
51+
assert (fig, ax) == ("fig", "ax")
52+
assert captured.fris.shape == (3, 2, 1)
53+
assert captured.cell_types == filtered
54+
assert captured.sorted_type_list == requested

0 commit comments

Comments
 (0)