|
| 1 | +from types import SimpleNamespace |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from flyvis.network import ensemble_view as ensemble_view_module |
| 6 | +from flyvis.network.ensemble_view import EnsembleView |
| 7 | + |
| 8 | + |
| 9 | +class _FakeFRI: |
| 10 | + def __init__(self, values, cell_types, filtered_cell_types=None): |
| 11 | + self.values = values |
| 12 | + self.cell_type = SimpleNamespace(values=np.array(cell_types)) |
| 13 | + self.filtered_cell_types = filtered_cell_types |
| 14 | + self.custom = SimpleNamespace(where=self._where) |
| 15 | + |
| 16 | + def _where(self, cell_type): |
| 17 | + if self.filtered_cell_types is not None: |
| 18 | + return _FakeFRI(self.values, self.filtered_cell_types) |
| 19 | + return _FakeFRI(self.values, list(cell_type)) |
| 20 | + |
| 21 | + |
| 22 | +class _DummyView: |
| 23 | + def flash_responses(self): |
| 24 | + return "responses" |
| 25 | + |
| 26 | + def task_error(self): |
| 27 | + return SimpleNamespace(values=np.array([0.5, 0.1, 0.2])) |
| 28 | + |
| 29 | + |
| 30 | +def test_flash_response_index_aligns_labels_with_filtered_data(monkeypatch): |
| 31 | + requested = ["Mi1", "Tm3", "CT1(M10)"] |
| 32 | + filtered = ["CT1(M10)", "Mi1", "Tm3"] |
| 33 | + fake_fris = _FakeFRI(np.ones((3, 2, 1)), requested, filtered_cell_types=filtered) |
| 34 | + |
| 35 | + monkeypatch.setattr( |
| 36 | + ensemble_view_module, "flash_response_index", lambda *_args, **_kwargs: fake_fris |
| 37 | + ) |
| 38 | + |
| 39 | + captured = SimpleNamespace(fris=None, cell_types=None, sorted_type_list=None) |
| 40 | + |
| 41 | + def _fake_plot_fris(fris, cell_types, **kwargs): |
| 42 | + captured.fris = fris |
| 43 | + captured.cell_types = list(cell_types) |
| 44 | + captured.sorted_type_list = kwargs.get("sorted_type_list") |
| 45 | + return "fig", "ax" |
| 46 | + |
| 47 | + monkeypatch.setattr(ensemble_view_module, "plot_fris", _fake_plot_fris) |
| 48 | + |
| 49 | + fig, ax = EnsembleView.flash_response_index(_DummyView(), cell_types=requested) |
| 50 | + |
| 51 | + assert (fig, ax) == ("fig", "ax") |
| 52 | + assert captured.fris.shape == (3, 2, 1) |
| 53 | + assert captured.cell_types == filtered |
| 54 | + assert captured.sorted_type_list == requested |
0 commit comments