|
| 1 | +"""ezmsg-viewer — plot a specific ezmsg topic without the graph inspector.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +import sys |
| 5 | +from enum import Enum |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import typer |
| 9 | +from ezmsg.qt import EzSession, EzSubscriber |
| 10 | +from ezmsg.util.messages.axisarray import AxisArray |
| 11 | +from phosphor import ( |
| 12 | + ScatterConfig, |
| 13 | + ScatterWidget, |
| 14 | + SpectrumConfig, |
| 15 | + SpectrumWidget, |
| 16 | + SweepConfig, |
| 17 | + SweepEvent, |
| 18 | + SweepWidget, |
| 19 | +) |
| 20 | +from PySide6.QtWidgets import QApplication, QMainWindow, QWidget |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | +GRAPH_IP = "127.0.0.1" |
| 25 | +GRAPH_PORT = 25978 |
| 26 | + |
| 27 | +# Event color palette (deterministic by label hash) |
| 28 | +EVENT_COLORS = [ |
| 29 | + (1.0, 1.0, 0.4), # yellow |
| 30 | + (0.4, 1.0, 1.0), # cyan |
| 31 | + (1.0, 0.4, 1.0), # magenta |
| 32 | + (1.0, 0.7, 0.3), # orange |
| 33 | + (0.4, 1.0, 0.4), # green |
| 34 | + (1.0, 0.4, 0.4), # red |
| 35 | +] |
| 36 | + |
| 37 | + |
| 38 | +class PlotMode(str, Enum): |
| 39 | + timeseries = "timeseries" |
| 40 | + spectral = "spectral" |
| 41 | + scatter = "scatter" |
| 42 | + |
| 43 | + |
| 44 | +def _extract_channel_meta(msg) -> tuple[list[str] | None, np.ndarray | None]: |
| 45 | + """Extract channel labels and 2D positions from AxisArray channel metadata.""" |
| 46 | + if "ch" not in msg.dims: |
| 47 | + return None, None |
| 48 | + |
| 49 | + ch_axis = msg.get_axis("ch") |
| 50 | + ch_data = getattr(ch_axis, "data", None) |
| 51 | + if ch_data is None or ch_data.dtype.names is None: |
| 52 | + return None, None |
| 53 | + |
| 54 | + labels = None |
| 55 | + if "label" in ch_data.dtype.names: |
| 56 | + labels = [str(v) for v in ch_data["label"]] |
| 57 | + |
| 58 | + positions = None |
| 59 | + if "x" in ch_data.dtype.names and "y" in ch_data.dtype.names: |
| 60 | + x = ch_data["x"].astype(np.float32) |
| 61 | + y = ch_data["y"].astype(np.float32) |
| 62 | + if np.any(x != 0) or np.any(y != 0): |
| 63 | + positions = np.column_stack([x, y]) |
| 64 | + |
| 65 | + return labels, positions |
| 66 | + |
| 67 | + |
| 68 | +def _event_label(msg) -> str: |
| 69 | + """Try to extract a human-readable label from an event message.""" |
| 70 | + if hasattr(msg, "dims") and hasattr(msg, "get_axis"): |
| 71 | + if "ch" in msg.dims: |
| 72 | + ch_axis = msg.get_axis("ch") |
| 73 | + ch_data = getattr(ch_axis, "data", None) |
| 74 | + if ch_data is not None and ch_data.dtype.names and "label" in ch_data.dtype.names: |
| 75 | + labels = ch_data["label"] |
| 76 | + if len(labels) > 0: |
| 77 | + return str(labels[0]) |
| 78 | + return "" |
| 79 | + |
| 80 | + |
| 81 | +class ViewerWindow(QMainWindow): |
| 82 | + def __init__( |
| 83 | + self, |
| 84 | + session: EzSession, |
| 85 | + mode: PlotMode, |
| 86 | + data_topic: str, |
| 87 | + event_topic: str | None = None, |
| 88 | + event_filter: str | None = None, |
| 89 | + parent: QWidget | None = None, |
| 90 | + ) -> None: |
| 91 | + super().__init__(parent) |
| 92 | + self.setWindowTitle(f"ezmsg Viewer — {data_topic}") |
| 93 | + self._mode = mode |
| 94 | + self._event_filter = event_filter |
| 95 | + |
| 96 | + self._data_sub = EzSubscriber(topic=data_topic, parent=self, session=session) |
| 97 | + self._data_sub.connect(self._on_data) |
| 98 | + |
| 99 | + self._event_sub: EzSubscriber | None = None |
| 100 | + if event_topic: |
| 101 | + self._event_sub = EzSubscriber(topic=event_topic, parent=self, session=session) |
| 102 | + self._event_sub.connect(self._on_event) |
| 103 | + |
| 104 | + self._plot_widget: QWidget | None = None |
| 105 | + self._first_message = True |
| 106 | + self._channel_labels: list[str] | None = None |
| 107 | + self._channel_positions: np.ndarray | None = None |
| 108 | + |
| 109 | + # ------------------------------------------------------------------ |
| 110 | + # Data handling |
| 111 | + # ------------------------------------------------------------------ |
| 112 | + |
| 113 | + def _on_data(self, msg) -> None: |
| 114 | + if self._first_message: |
| 115 | + self._channel_labels, self._channel_positions = _extract_channel_meta(msg) |
| 116 | + self._create_plot_widget(msg) |
| 117 | + self._first_message = False |
| 118 | + self._push_message(msg) |
| 119 | + |
| 120 | + def _create_plot_widget(self, msg) -> None: |
| 121 | + labels = self._channel_labels |
| 122 | + |
| 123 | + if self._mode == PlotMode.timeseries: |
| 124 | + if "time" in msg.dims: |
| 125 | + time_axis = msg.get_axis("time") |
| 126 | + srate = 1.0 / time_axis.gain |
| 127 | + time_idx = msg.get_axis_idx("time") |
| 128 | + n_samples = msg.shape[time_idx] |
| 129 | + n_channels = msg.data.size // n_samples |
| 130 | + else: |
| 131 | + logger.warning("No 'time' dimension — using shape[0] as time") |
| 132 | + n_samples = msg.shape[0] |
| 133 | + n_channels = msg.data.size // n_samples if n_samples > 0 else 1 |
| 134 | + srate = 1000.0 |
| 135 | + |
| 136 | + config = SweepConfig(n_channels=n_channels, srate=srate, channel_labels=labels) |
| 137 | + widget = SweepWidget(config) |
| 138 | + |
| 139 | + elif self._mode == PlotMode.spectral: |
| 140 | + if "freq" not in msg.dims: |
| 141 | + logger.error("Spectral mode requires 'freq' dimension in data") |
| 142 | + sys.exit(1) |
| 143 | + |
| 144 | + freq_axis = msg.get_axis("freq") |
| 145 | + freq_idx = msg.get_axis_idx("freq") |
| 146 | + n_bins = msg.shape[freq_idx] |
| 147 | + srate = 2.0 * freq_axis.gain * n_bins |
| 148 | + n_channels = msg.data.size // n_bins |
| 149 | + |
| 150 | + config = SpectrumConfig(n_channels=n_channels, srate=srate, n_bins=n_bins, channel_labels=labels) |
| 151 | + widget = SpectrumWidget(config) |
| 152 | + |
| 153 | + elif self._mode == PlotMode.scatter: |
| 154 | + if self._channel_positions is None: |
| 155 | + logger.error("Scatter mode requires channel position metadata (x, y fields in ch axis)") |
| 156 | + sys.exit(1) |
| 157 | + |
| 158 | + config = ScatterConfig(positions=self._channel_positions, channel_labels=labels) |
| 159 | + widget = ScatterWidget(config) |
| 160 | + |
| 161 | + self._plot_widget = widget |
| 162 | + self.setCentralWidget(widget) |
| 163 | + |
| 164 | + def _push_message(self, msg) -> None: |
| 165 | + widget = self._plot_widget |
| 166 | + |
| 167 | + if isinstance(widget, SweepWidget): |
| 168 | + time_idx = msg.get_axis_idx("time") if "time" in msg.dims else 0 |
| 169 | + n_samples = msg.shape[time_idx] |
| 170 | + n_channels = msg.data.size // n_samples if n_samples > 0 else 1 |
| 171 | + data_2d = np.moveaxis(msg.data, time_idx, 0).reshape(n_samples, n_channels) |
| 172 | + # Pass the AxisArray time-axis offset so the sweep buffer |
| 173 | + # tracks the same clock as the event timestamps. |
| 174 | + ts = msg.get_axis("time").offset if "time" in msg.dims else None |
| 175 | + widget.push_data(data_2d.astype(np.float32), timestamps=ts) |
| 176 | + |
| 177 | + elif isinstance(widget, SpectrumWidget): |
| 178 | + freq_idx = msg.get_axis_idx("freq") if "freq" in msg.dims else 0 |
| 179 | + n_bins = msg.shape[freq_idx] |
| 180 | + n_channels = msg.data.size // n_bins if n_bins > 0 else 1 |
| 181 | + data_2d = np.moveaxis(msg.data, freq_idx, 0).reshape(n_bins, n_channels) |
| 182 | + widget.push_data(data_2d.astype(np.float32)) |
| 183 | + |
| 184 | + elif isinstance(widget, ScatterWidget): |
| 185 | + if len(msg.shape) > 1: |
| 186 | + targ_idx = 0 |
| 187 | + if "time" in msg.dims or "freq" in msg.dims: |
| 188 | + targ_idx = msg.get_axis_idx("time") if "time" in msg.dims else msg.get_axis_idx("freq") |
| 189 | + n_items = msg.shape[targ_idx] |
| 190 | + n_channels = msg.data.size // n_items if n_items > 0 else 1 |
| 191 | + data_2d = np.moveaxis(msg.data, targ_idx, 0).reshape(n_items, n_channels) |
| 192 | + else: |
| 193 | + data_2d = msg.data.reshape(1, msg.data.size) |
| 194 | + widget.push_data(data_2d.astype(np.float32)) |
| 195 | + |
| 196 | + # ------------------------------------------------------------------ |
| 197 | + # Event handling |
| 198 | + # ------------------------------------------------------------------ |
| 199 | + |
| 200 | + def _on_event(self, msg: AxisArray) -> None: |
| 201 | + widget = self._plot_widget |
| 202 | + if not isinstance(widget, SweepWidget): |
| 203 | + return |
| 204 | + |
| 205 | + if "time" not in msg.dims: |
| 206 | + logger.warning("Event message must have 'time' dimension") |
| 207 | + return |
| 208 | + |
| 209 | + time_axis = msg.get_axis("time") |
| 210 | + time_idx = msg.get_axis_idx("time") |
| 211 | + timestamps = time_axis.value(list(range(msg.shape[time_idx]))) |
| 212 | + events: list[SweepEvent] = [] |
| 213 | + for ev_ix, ts in enumerate(timestamps): |
| 214 | + label = msg.data[ev_ix, 0] |
| 215 | + if self._event_filter and self._event_filter not in label: |
| 216 | + continue |
| 217 | + color = EVENT_COLORS[hash(label) % len(EVENT_COLORS)] |
| 218 | + events.append(SweepEvent(t_elapsed=ts, label=label, color=color)) |
| 219 | + widget.push_events(events) |
| 220 | + |
| 221 | + |
| 222 | +def _run( |
| 223 | + data_topic: str = typer.Argument(..., help="ezmsg topic for continuous data"), |
| 224 | + mode: PlotMode = typer.Option(PlotMode.timeseries, help="Plot mode"), |
| 225 | + event_topic: str | None = typer.Option(None, "--events", help="ezmsg topic for event markers"), |
| 226 | + event_filter: str | None = typer.Option( |
| 227 | + None, "--event-filter", help="Only show events whose label contains this string" |
| 228 | + ), |
| 229 | + graph_addr: str = typer.Option( |
| 230 | + ":".join((GRAPH_IP, str(GRAPH_PORT))), |
| 231 | + help="ezmsg graph address (ip:port)", |
| 232 | + ), |
| 233 | +) -> None: |
| 234 | + graph_ip, graph_port_str = graph_addr.split(":") |
| 235 | + graph_address = (graph_ip, int(graph_port_str)) |
| 236 | + |
| 237 | + app = QApplication.instance() or QApplication(sys.argv) |
| 238 | + session = EzSession(graph_address=graph_address) |
| 239 | + window = ViewerWindow( |
| 240 | + session=session, |
| 241 | + mode=mode, |
| 242 | + data_topic=data_topic, |
| 243 | + event_topic=event_topic, |
| 244 | + event_filter=event_filter, |
| 245 | + ) |
| 246 | + window.showMaximized() |
| 247 | + with session: |
| 248 | + app.exec() |
| 249 | + |
| 250 | + |
| 251 | +def main() -> None: |
| 252 | + typer.run(_run) |
| 253 | + |
| 254 | + |
| 255 | +if __name__ == "__main__": |
| 256 | + main() |
0 commit comments