diff --git a/DEVELOPER.md b/DEVELOPER.md index 402bfc98..6f8d7ad7 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -168,14 +168,20 @@ graph LR - `_stopped` flag makes `send()` a no-op after pipeline stop **`Receiver[T]`** — reads from one channel as an iterator: +- Registers with the channel on construction (`__init__` takes `channel` and `stop_event`) - **Blocking mode** (default): blocks on `__next__()` until data arrives or `stop_event` fires - **Non-blocking mode**: returns `None` immediately if no data - **Newest mode**: fast-forwards cursor to latest item (essential for video to prevent lag) - Tracks `_msg_count`, `_byte_count`, `lag` for metrics +- Handles are **ephemeral** — created fresh on each `run()`, stored on `Node.senders`/`Node.receivers`, discarded on `stop()` -#### Channel Reconciliation +#### Channel Topology -When edges are added/removed, `GraphManager._reconcile()` recomputes the optimal channel layout: +The system separates **graph topology** (nodes + edges), **channel topology** (the wiring plan), and **wiring** (live handle creation) into three layers: + +1. **Graph topology** — trivial CRUD on nodes and edges +2. **Channel topology** — `_reconcile()` computes a `(sender_plan, receiver_plan)` from the current edges. It only produces a plan, never creates Sender/Receiver handles. +3. **Wiring** — `run()` reads the plan and creates fresh handles, storing them on `Node.senders` and `Node.receivers`. `stop()` discards them. ```mermaid graph TD @@ -183,15 +189,15 @@ graph TD B --> C[Diff against existing channels] C --> D[Reuse unchanged channels] C --> E[Create new channels for new groups] - D & E --> F[Rebuild Sender/Receiver handles] - F --> G[Ensure every output slot has a Sender] + D & E --> F["Return (sender_plan, receiver_plan)"] + F --> G["run() creates fresh handles from plan"] ``` Receivers sharing the same set of upstream senders share a single `Channel` instance, minimizing memory and synchronization overhead. #### UI Channels -UI channels are type-system markers that route data to/from the WebSocket instead of inter-component edges: +UI channels are type-system markers that route data to/from the WebSocket instead of inter-component edges. They are managed by `UIChannelBridge` (in `src/api/ui/bridge.py`), **not** by GraphManager: | Marker Class | Direction | Use Case | |---|---|---| @@ -200,6 +206,8 @@ UI channels are type-system markers that route data to/from the WebSocket instea | `UITextReceiver` | frontend → component | Text input from node UI | | `UIKeystrokeReceiver` | frontend → component | Individual keystrokes from node UI | +The bridge creates UI channels via `wire(manager)`, which returns overrides passed to `GraphManager.run()`. It owns the WebSocket lifecycle via `run(ws)` — spawning outbound tasks per UI output receiver and handling inbound messages in a receive loop. + --- ### Frame Types @@ -273,7 +281,7 @@ Key design decisions: ### GraphManager -The `GraphManager` is the runtime orchestrator. It owns the graph definition, component instances, and all channel/handle state. +The `GraphManager` is the runtime orchestrator. It owns the graph definition, component instances, and channel topology. Sender/Receiver handles are stored on `Node` objects, not on the manager. ```mermaid graph TD @@ -281,14 +289,21 @@ graph TD Graph["Graph (nodes + edges)"] CompMap["Component instances"] ChanMap["Channel map"] - Senders["Sender handles"] - Receivers["Receiver handles"] - UIChan["UI channels"] + end + + subgraph Node + Senders["senders: dict"] + Receivers["receivers: dict"] + end + + subgraph UIChannelBridge + UISend["UI senders (server-side)"] + UIRecv["UI receivers (server-side)"] end Graph -->|"_reconcile()"| ChanMap - ChanMap --> Senders - ChanMap --> Receivers + ChanMap -->|"run(overrides)"| Node + UIChannelBridge -->|"wire() → overrides"| ChanMap CompMap -->|"run()"| Threads["Daemon threads"] ``` @@ -297,24 +312,24 @@ graph TD ```mermaid sequenceDiagram participant Client + participant Bridge as UIChannelBridge participant GM as GraphManager participant Comp as Components - Client->>GM: run() + Client->>Bridge: wire(manager) + Bridge-->>Client: (recv_overrides, send_overrides) + Client->>GM: run(recv_overrides, send_overrides) GM->>Comp: stop() all (if running) - GM->>GM: Clear UI channels + GM->>GM: _reconcile() → (sender_plan, receiver_plan) loop For each node - GM->>GM: Build input/output handles - GM->>GM: Create UI channels - GM->>GM: Wire receivers (register cursors) + GM->>GM: Create fresh Sender/Receiver from plan + overrides + GM->>GM: Store on node.senders / node.receivers end GM->>Comp: setup(outputs) — all components, sequential GM->>Comp: start(inputs, outputs) — spawns daemon threads - GM->>GM: Register threads with log store - GM->>GM: Notify WebSocket watchers Client->>GM: stop() - GM->>GM: Set _stopped on all senders + GM->>GM: Set _stopped on all node senders GM->>Comp: stop() — sets stop_event on each thread ``` @@ -322,9 +337,9 @@ sequenceDiagram | Method | Effect | |---|---| -| `add_node(type, init_args)` | Instantiate component, add to graph, reconcile channels | -| `delete_node(id)` | Stop component + connected neighbors, remove edges, reconcile | -| `update_node_init_args(id, args)` | Recreate component; if graph was running, auto-restart (hot-reload) | +| `add_primitive_node(type_, init_args)` | Instantiate component, add to graph, reconcile channels | +| `delete_node(id)` | Stop component + downstream nodes, remove edges, reconcile | +| `update_primitive_node_init_args(id, args)` | Recreate component; returns `(node, was_running)` — caller handles restart | | `add_edge(edge)` / `delete_edge(edge)` | Modify graph topology, reconcile channels | | `reset(graph)` | Replace entire graph — stop everything, re-instantiate all components | @@ -416,6 +431,8 @@ graph TD **State management** is local React hooks — no Redux or Zustand. ReactFlow manages node/edge state via `useNodesState` / `useEdgesState`. A single `UIChannelContext` provides the WebSocket manager. +The frontend calls the backend directly at `http://localhost:8000` via `API_BASE` / `WS_BASE` constants in `src/lib/api.ts` — no Vite proxy. + **Key hooks:** | Hook | Purpose | diff --git a/README.md b/README.md index 1327b647..4f915844 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,9 @@ cd frontend && bun install && cd .. # Start both backend + frontend bun run dev + +# For AMD GPUs (ROCm): +bun run dev -- --amd ``` This runs the backend (FastAPI on `:8000`) and frontend (Vite on `:5173`) concurrently. @@ -139,7 +142,18 @@ cd backend uv sync --group cuda12 ``` -For **AMD GPUs** (ROCm), PyTorch's ROCm builds are mapped to `cuda` internally — install the appropriate ROCm wheels for your platform. +For **AMD GPUs** (ROCm): + +```bash +cd backend +uv sync --group rocm --no-default-groups +``` + +Or use the dev script flag: + +```bash +bun run dev -- --amd +``` ### Environment Variables diff --git a/backend/profiling/pipeline_hop_test.py b/backend/profiling/pipeline_hop_test.py index bfcca22f..fc49d386 100644 --- a/backend/profiling/pipeline_hop_test.py +++ b/backend/profiling/pipeline_hop_test.py @@ -208,20 +208,20 @@ def run_test(model: str): # Wire: FileSource -> VAD -> ASR -> Adapter -> LLM -> NullSink ch1: Channel[AudioFrame] = Channel() s1 = Sender(ch1) - r1 = Receiver(ch1) + r1 = Receiver(ch1, threading.Event()) ch2: Channel[AudioFrame] = Channel() s2 = Sender(ch2) - r2 = Receiver(ch2) + r2 = Receiver(ch2, threading.Event()) ch3: Channel[TextFrame] = Channel() ch4: Channel[list[MessageFrame]] = Channel() s4 = Sender(ch4) - r4 = Receiver(ch4) + r4 = Receiver(ch4, threading.Event()) # Adapter thread def adapter(): comp = _Adapter() comp._stop_event = threading.Event() - for text_frame in Receiver(ch3)(comp): + for text_frame in Receiver(ch3, threading.Event()): if text_frame is None: break if comp.stop_event.is_set(): diff --git a/backend/profiling/ttfa_profile.py b/backend/profiling/ttfa_profile.py index 2977e3f3..b2797a58 100644 --- a/backend/profiling/ttfa_profile.py +++ b/backend/profiling/ttfa_profile.py @@ -585,7 +585,7 @@ def wrap_audio(item): def adapter3(): comp = _Stub() comp._stop_event = threading.Event() - for tf in Receiver(ch3)(comp): + for tf in Receiver(ch3, threading.Event()): if tf is None: break msgs = [ @@ -607,19 +607,22 @@ def adapter3(): vad_2 = VAD(config=VADConfig()) file_source_2 = FileSource(wav_path=wav_path) - null_sink_2.start(NullSinkInputs(audio=Receiver(ch6)), ()) + null_sink_2.start(NullSinkInputs(audio=Receiver(ch6, threading.Event())), ()) tts_2.start( - tts_mod.TTSInputs(text=Receiver(ch5)), + tts_mod.TTSInputs(text=Receiver(ch5, threading.Event())), tts_mod.TTSOutputs(audio=Sender(ch6), text=Sender()), ) llm_2.start( - llm_mod.LLMInputs(messages=Receiver(ch4)), llm_mod.LLMOutputs(token=Sender(ch5)) + llm_mod.LLMInputs(messages=Receiver(ch4, threading.Event())), + llm_mod.LLMOutputs(token=Sender(ch5)), ) asr_2.start( - asr_mod.ASRInputs(audio=Receiver(ch2)), asr_mod.ASROutputs(text=Sender(ch3)) + asr_mod.ASRInputs(audio=Receiver(ch2, threading.Event())), + asr_mod.ASROutputs(text=Sender(ch3)), ) vad_2.start( - vad_mod.VADInputs(audio=Receiver(ch1)), vad_mod.VADOutputs(audio=Sender(ch2)) + vad_mod.VADInputs(audio=Receiver(ch1, threading.Event())), + vad_mod.VADOutputs(audio=Sender(ch2)), ) file_source_2.start((), FileSourceOutputs(audio=Sender(ch1))) diff --git a/backend/src/api/component/controller.py b/backend/src/api/component/controller.py index 4224d8ad..060c6e47 100644 --- a/backend/src/api/component/controller.py +++ b/backend/src/api/component/controller.py @@ -114,6 +114,22 @@ def is_subtype(sub: str = Query(), sup: str = Query()) -> bool: return False +@router.post("/subtype-pairs") +def subtype_pairs(names: list[str]) -> list[list[str]]: + """Return all [sub, sup] pairs where sub is a subtype of sup.""" + result: list[list[str]] = [] + for a in names: + for b in names: + if a == b: + continue + try: + if issubclass(_resolve_type(a), _resolve_type(b)): + result.append([a, b]) + except ValueError: + pass + return result + + @router.post("/{component_name}/options") def get_options( component_name: str, diff --git a/backend/src/api/dep.py b/backend/src/api/dep.py index 299ef487..a30200e0 100644 --- a/backend/src/api/dep.py +++ b/backend/src/api/dep.py @@ -2,8 +2,13 @@ from fastapi import Request +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager def get_manager(request: Request) -> GraphManager: return request.app.state.manager + + +def get_ui_bridge(request: Request) -> UIChannelBridge: + return request.app.state.ui_bridge diff --git a/backend/src/api/graph/node/controller.py b/backend/src/api/graph/node/controller.py index 35df8d24..6469e02e 100644 --- a/backend/src/api/graph/node/controller.py +++ b/backend/src/api/graph/node/controller.py @@ -2,7 +2,8 @@ from fastapi import APIRouter, Depends, HTTPException -from src.api.dep import get_manager +from src.api.dep import get_manager, get_ui_bridge +from src.api.ui.bridge import UIChannelBridge from src.api.graph.node.dto import ( NodeInitArgsUpdateRequest, NodeCreateRequest, @@ -79,12 +80,15 @@ def update_node( @router.patch("/nodes/{node_id}/init-args") -def update_node_init_args( +def update_primitive_node_init_args( node_id: str, req: NodeInitArgsUpdateRequest, manager: GraphManager = Depends(get_manager), + ui_bridge: UIChannelBridge = Depends(get_ui_bridge), ) -> NodeResponse: - node = service.update_node_init_args(manager, node_id, req.init_args) + node = service.update_primitive_node_init_args( + manager, ui_bridge, node_id, req.init_args + ) if node is None: raise HTTPException(status_code=404, detail=f"Node not found: {node_id}") return _node_response(node_id, node, manager) diff --git a/backend/src/api/graph/node/service.py b/backend/src/api/graph/node/service.py index 6da83068..017960f3 100644 --- a/backend/src/api/graph/node/service.py +++ b/backend/src/api/graph/node/service.py @@ -4,6 +4,7 @@ from typing import Any from src.api.graph.node.dto import NodeUpdateRequest +from src.api.ui.bridge import UIChannelBridge from src.core.config import PROJECTS_DIR from src.core.graph import Edge, Graph, GraphManager, Node @@ -17,14 +18,14 @@ def get_node(manager: GraphManager, node_id: str) -> Node | None: def create_node( - manager: GraphManager, node_type: str, init_args: dict[str, Any] + manager: GraphManager, type_: str, init_args: dict[str, Any] ) -> tuple[str, Node]: try: - return manager.add_node(node_type, init_args) + return manager.add_primitive_node(type_, init_args) except ValueError: pass # Fallback: try loading as a project - return create_from_project(manager, node_type) + return create_from_project(manager, type_) def update_node( @@ -37,12 +38,17 @@ def delete_node(manager: GraphManager, node_id: str) -> None: manager.delete_node(node_id) -def update_node_init_args( +def update_primitive_node_init_args( manager: GraphManager, + ui_bridge: UIChannelBridge, node_id: str, init_args: dict[str, Any], ) -> Node | None: - return manager.update_node_init_args(node_id, init_args) + node, was_running = manager.update_primitive_node_init_args(node_id, init_args) + if was_running and node is not None: + recv_overrides, send_overrides = ui_bridge.wire(manager) + manager.run(recv_overrides, send_overrides) + return node def create_subgraph( diff --git a/backend/src/api/graph/run/controller.py b/backend/src/api/graph/run/controller.py index 4ab354a2..e50af66b 100644 --- a/backend/src/api/graph/run/controller.py +++ b/backend/src/api/graph/run/controller.py @@ -2,16 +2,20 @@ from fastapi import APIRouter, Depends -from src.api.dep import get_manager +from src.api.dep import get_manager, get_ui_bridge from src.api.graph.run import service +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager router = APIRouter(prefix="/graph") @router.post("/start", status_code=204) -def start_all(manager: GraphManager = Depends(get_manager)) -> None: - service.start_all(manager) +def start_all( + manager: GraphManager = Depends(get_manager), + ui_bridge: UIChannelBridge = Depends(get_ui_bridge), +) -> None: + service.start_all(manager, ui_bridge) @router.post("/stop", status_code=204) diff --git a/backend/src/api/graph/run/service.py b/backend/src/api/graph/run/service.py index 9d526204..b84e5c4b 100644 --- a/backend/src/api/graph/run/service.py +++ b/backend/src/api/graph/run/service.py @@ -1,10 +1,12 @@ from __future__ import annotations +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager -def start_all(manager: GraphManager) -> None: - manager.run() +def start_all(manager: GraphManager, ui_bridge: UIChannelBridge) -> None: + recv_overrides, send_overrides = ui_bridge.wire(manager) + manager.run(recv_overrides, send_overrides) def stop_all(manager: GraphManager) -> None: diff --git a/backend/src/api/metrics/service.py b/backend/src/api/metrics/service.py index 19fc7215..67803d3b 100644 --- a/backend/src/api/metrics/service.py +++ b/backend/src/api/metrics/service.py @@ -9,7 +9,8 @@ ReceiverSnapshot, SenderSnapshot, ) -from src.core.graph import GraphManager, ReceiverKey, SenderKey +from src.core.graph import GraphManager +from src.core.utils import ReceiverKey, SenderKey class MetricsCollector: diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py new file mode 100644 index 00000000..d999f8cd --- /dev/null +++ b/backend/src/api/ui/bridge.py @@ -0,0 +1,264 @@ +"""UI channel bridge: creates UI channels and manages the WebSocket connection. + +The bridge inspects components for UI slots, creates channels and handles, +passes component-side handles as overrides to GraphManager.run(), and +owns the WebSocket lifecycle for bidirectional UI communication. + +Wire format: + Binary frames: 2-byte header length (big-endian) + JSON header + raw payload + Text frames: JSON {"type": "ui_input"|"ui_output", "node_id", "channel", "payload"} +""" + +from __future__ import annotations + +import asyncio +import json +import struct +import threading +from typing import Any, get_args, get_origin + +from fastapi import WebSocket, WebSocketDisconnect +from pydantic import BaseModel +from src.core.channel import Channel, Receiver, Sender, UISender, UIReceiver +from src.core.frames import TextFrame +from src.core.graph import GraphManager +from src.core.utils import ReceiverKey, SenderKey + + +# -- Wire format: encode/decode pairs -- + + +def encode_binary(node_id: str, slot: str, payload: bytes) -> bytes: + """Pack a binary UI message: 2-byte header len + JSON header + payload.""" + header = json.dumps( + {"type": "ui_output", "node_id": node_id, "channel": slot} + ).encode() + return struct.pack(">H", len(header)) + header + payload + + +def decode_binary(buf: bytes) -> tuple[SenderKey | None, bytes]: + """Unpack a binary UI message. Returns (key, payload).""" + if len(buf) < 2: + return None, b"" + header_len = struct.unpack(">H", buf[:2])[0] + header = json.loads(buf[2 : 2 + header_len].decode("utf-8")) + payload = buf[2 + header_len :] + return (header.get("node_id"), header.get("channel")), payload + + +def encode_json(node_id: str, slot: str, item: Any) -> dict[str, Any]: + """Build a JSON UI output envelope.""" + payload: Any + if isinstance(item, BaseModel): + payload = item.model_dump() + elif isinstance(item, TextFrame): + payload = item.text + else: + payload = item + return { + "type": "ui_output", + "node_id": node_id, + "channel": slot, + "payload": payload, + } + + +def decode_json(text: str) -> tuple[SenderKey, Any] | None: + """Parse a JSON UI input message. Returns (key, raw_payload) or None.""" + msg = json.loads(text) + if msg.get("type") != "ui_input": + return None + return (msg["node_id"], msg["channel"]), msg.get("payload", "") + + +def deserialize_payload(payload: Any, inner_type: type | None) -> Any: + """Convert a raw JSON payload to the expected type.""" + if ( + inner_type is not None + and issubclass(inner_type, BaseModel) + and isinstance(payload, dict) + ): + return inner_type.model_validate(payload) + if ( + inner_type is not None + and hasattr(inner_type, "new") + and isinstance(payload, str) + ): + return inner_type.new(text=payload) + return payload + + +# -- Bridge -- + + +class UIChannelBridge: + def __init__(self) -> None: + self._ui_senders: dict[SenderKey, Sender[Any]] = {} + self._ui_receivers: dict[ReceiverKey, Receiver[Any]] = {} + self._manager: GraphManager | None = None + self._ws: WebSocket | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._send_tasks: dict[SenderKey, asyncio.Task[None]] = {} + + def wire( + self, manager: GraphManager + ) -> tuple[ + dict[ReceiverKey, Receiver[Any] | None], + dict[SenderKey, Sender[Any] | None], + ]: + """Create UI channels for all components and return overrides for run(). + + Called from a sync thread (start endpoint). If a WebSocket is + connected, schedules task (re)spawning on the event loop. + """ + self._manager = manager + self._ui_senders.clear() + self._ui_receivers.clear() + + recv_overrides: dict[ReceiverKey, Receiver[Any] | None] = {} + send_overrides: dict[SenderKey, Sender[Any] | None] = {} + + for node_id, comp in manager.components().items(): + stop_event = ( + comp.stop_event if hasattr(comp, "stop_event") else threading.Event() + ) + + for slot, slot_type in comp.get_ui_input_types().items(): + ch: Channel[Any] = Channel() + origin = get_origin(slot_type) or slot_type + recv_overrides[(node_id, slot)] = origin(ch, stop_event) + self._ui_senders[(node_id, slot)] = Sender(ch) + + for slot, slot_type in comp.get_ui_output_types().items(): + ch = Channel() + origin = get_origin(slot_type) or slot_type + send_overrides[(node_id, slot)] = origin(ch) + self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) + + if self._ws is not None and self._loop is not None: + self._loop.call_soon_threadsafe(self._start_send_tasks) + + return recv_overrides, send_overrides + + async def run(self, ws: WebSocket) -> None: + """Own the WebSocket lifecycle. Blocks until disconnect.""" + self._ws = ws + self._loop = asyncio.get_running_loop() + + # If pipeline already started (receivers exist), start tasks now + if self._ui_receivers: + self._start_send_tasks() + + try: + await self._recv_msgs(ws) + finally: + self._ws = None + self._loop = None + for t in self._send_tasks.values(): + t.cancel() + await asyncio.gather(*self._send_tasks.values(), return_exceptions=True) + self._send_tasks.clear() + + # -- Outbound: component → frontend -- + + def _start_send_tasks(self) -> None: + """Spawn one task per UI output receiver. Must run in the event loop.""" + for key in list(self._send_tasks): + self._send_tasks.pop(key).cancel() + ws = self._ws + if ws is None: + return + for key, receiver in self._ui_receivers.items(): + node_id, slot = key + inner_type = self._resolve_ui_output_type(node_id, slot) + self._send_tasks[key] = asyncio.create_task( + self._send_msg_task(ws, node_id, slot, receiver, inner_type) + ) + + async def _send_msg_task( + self, + ws: WebSocket, + node_id: str, + slot: str, + receiver: Receiver[Any], + inner_type: type | None, + ) -> None: + try: + while True: + item = await asyncio.to_thread(next, receiver) + if item is None: + break + if inner_type is not None and issubclass(inner_type, bytes): + await ws.send_bytes(encode_binary(node_id, slot, item)) + else: + await ws.send_json(encode_json(node_id, slot, item)) + except (WebSocketDisconnect, RuntimeError, asyncio.CancelledError): + pass + + # -- Inbound: frontend → component -- + + async def _recv_msgs(self, ws: WebSocket) -> None: + while True: + ws_msg = await ws.receive() + if ws_msg.get("type") == "websocket.disconnect": + break + if "bytes" in ws_msg: + key, payload = decode_binary(ws_msg["bytes"]) + if key is not None: + sender = self._ui_senders.get(key) + if sender is not None: + sender.send(payload) + elif "text" in ws_msg: + result = decode_json(ws_msg["text"]) + if result is not None: + key, payload = result + sender = self._ui_senders.get(key) + if sender is not None: + inner_type = self._resolve_ui_input_type(*key) + sender.send(deserialize_payload(payload, inner_type)) + + # -- Type resolution helpers -- + + def _resolve_ui_output_type(self, node_id: str, slot: str) -> type | None: + if self._manager is None: + return None + comp = self._manager.components().get(node_id) + if comp is None: + return None + slot_type = comp.get_ui_output_types().get(slot) + if slot_type is None: + return None + args = get_args(slot_type) + if args: + return args[0] + origin = get_origin(slot_type) or slot_type + if isinstance(origin, type): + for base in getattr(origin, "__orig_bases__", ()): + base_origin = get_origin(base) + if base_origin is not None and issubclass(base_origin, UISender): + base_args = get_args(base) + if base_args: + return base_args[0] + return None + + def _resolve_ui_input_type(self, node_id: str, slot: str) -> type | None: + if self._manager is None: + return None + comp = self._manager.components().get(node_id) + if comp is None: + return None + slot_type = comp.get_ui_input_types().get(slot) + if slot_type is None: + return None + args = get_args(slot_type) + if args: + return args[0] + origin = get_origin(slot_type) or slot_type + if isinstance(origin, type): + for base in getattr(origin, "__orig_bases__", ()): + base_origin = get_origin(base) + if base_origin is not None and issubclass(base_origin, UIReceiver): + base_args = get_args(base) + if base_args: + return base_args[0] + return None diff --git a/backend/src/api/ui/controller.py b/backend/src/api/ui/controller.py index 0c0861eb..759f5c54 100644 --- a/backend/src/api/ui/controller.py +++ b/backend/src/api/ui/controller.py @@ -1,236 +1,14 @@ -"""Single WebSocket that bridges all UI channels between components and frontend. - -Inbound (text JSON): - {"type": "ui_input", "node_id": "...", "channel": "...", "payload": "..."} - -> finds the Sender for (node_id, channel) and calls .send() - -Outbound: - bytes channels: binary WS frame = 2-byte header length (big-endian) + JSON header + raw bytes - BaseModel channels: JSON {"type": "ui_output", ..., "payload": {model dict}} - Legacy (.get()) channels: JSON {"type": "ui_output", ..., "payload": "string"} -""" - from __future__ import annotations -import asyncio -import itertools -import json -import struct -import threading -from typing import Any, get_args, get_origin +from fastapi import APIRouter, WebSocket -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from pydantic import BaseModel -from src.core.channel import Receiver, UISender, UIReceiver -from src.core.graph import GraphManager +from src.api.ui.bridge import UIChannelBridge router = APIRouter(prefix="/ui") -_sub_id_counter = itertools.count() - - -def _resolve_ui_output_type( - manager: GraphManager, node_id: str, slot: str -) -> type | None: - """Extract the inner T from UISender[T] for the given output slot.""" - comp = manager.components().get(node_id) - if comp is None: - return None - slot_type = comp.get_ui_output_types().get(slot) - if slot_type is None: - return None - # slot_type is e.g. UISender[bytes], UIVideoSender (which is UISender[bytes]), etc. - # Try get_args on the slot_type first (for generic aliases like UISender[MyModel]) - args = get_args(slot_type) - if args: - return args[0] - # For concrete subclasses like UIVideoSender, walk __orig_bases__ - origin = get_origin(slot_type) or slot_type - if isinstance(origin, type): - for base in getattr(origin, "__orig_bases__", ()): - base_origin = get_origin(base) - if base_origin is not None and issubclass(base_origin, UISender): - base_args = get_args(base) - if base_args: - return base_args[0] - return None - - -def _resolve_ui_input_type( - manager: GraphManager, node_id: str, slot: str -) -> type | None: - """Extract the inner T from UIReceiver[T] for the given input slot.""" - comp = manager.components().get(node_id) - if comp is None: - return None - slot_type = comp.get_ui_input_types().get(slot) - if slot_type is None: - return None - args = get_args(slot_type) - if args: - return args[0] - origin = get_origin(slot_type) or slot_type - if isinstance(origin, type): - for base in getattr(origin, "__orig_bases__", ()): - base_origin = get_origin(base) - if base_origin is not None and issubclass(base_origin, UIReceiver): - base_args = get_args(base) - if base_args: - return base_args[0] - return None - - -async def _read_ui_output( - ws: WebSocket, - node_id: str, - slot: str, - receiver: Receiver[Any], - inner_type: type | None, - stop_event: asyncio.Event, -) -> None: - """Async task: reads from a blocking Receiver and sends to the WebSocket.""" - thread_stop = threading.Event() - sub_id = next(_sub_id_counter) - receiver._channel._register(sub_id) - - try: - while not stop_event.is_set(): - item = await asyncio.to_thread(receiver._channel._get, sub_id, thread_stop) - if item is None: - break - try: - if inner_type is not None and issubclass(inner_type, bytes): - # Binary path (video frames, etc.) - header = json.dumps( - {"type": "ui_output", "node_id": node_id, "channel": slot} - ).encode() - prefix = struct.pack(">H", len(header)) - await ws.send_bytes(prefix + header + item) - elif isinstance(item, BaseModel): - # Pydantic model → JSON dict payload - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item.model_dump(), - } - ) - elif hasattr(item, "get"): - # Legacy TextFrame-style with .get() - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item.get(), - } - ) - else: - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item, - } - ) - except (WebSocketDisconnect, RuntimeError): - break - finally: - receiver._channel._unregister(sub_id) - thread_stop.set() - - -async def _watch_ui_channels( - ws: WebSocket, - manager: GraphManager, - stop_event: asyncio.Event, - tasks: dict[tuple[str, str], asyncio.Task[None]], -) -> None: - """Wait for GraphManager.run() to signal new UI channels, then spawn readers.""" - last_version = manager._ui_version - while not stop_event.is_set(): - # Spawn tasks for any receivers we haven't seen yet - for key, receiver in manager.ui_receivers().items(): - if key not in tasks: - node_id, slot = key - inner_type = _resolve_ui_output_type(manager, node_id, slot) - tasks[key] = asyncio.create_task( - _read_ui_output(ws, node_id, slot, receiver, inner_type, stop_event) - ) - - if manager._ui_version == last_version: - # Wait until run() fires the event - try: - await manager._ui_changed.wait() - except asyncio.CancelledError: - return - - # Cancel all stale tasks — even if the same keys exist, the - # underlying Receiver objects point to new channels after a restart. - for key in list(tasks): - tasks.pop(key).cancel() - - last_version = manager._ui_version - @router.websocket("/ws") async def ui_ws(ws: WebSocket) -> None: - print("[ui_ws] Got connection request from:", ws.client) - manager: GraphManager = ws.app.state.manager + bridge: UIChannelBridge = ws.app.state.ui_bridge await ws.accept() - print("[ui_ws] Accepted connection!!!") - - stop_event = asyncio.Event() - tasks: dict[tuple[str, str], asyncio.Task[None]] = {} - - watcher = asyncio.create_task(_watch_ui_channels(ws, manager, stop_event, tasks)) - - try: - while True: - ws_msg = await ws.receive() - if "bytes" in ws_msg: - # Binary frame: 2-byte header length + JSON header + payload - buf = ws_msg["bytes"] - if len(buf) >= 2: - header_len = struct.unpack(">H", buf[:2])[0] - header = json.loads(buf[2 : 2 + header_len].decode("utf-8")) - payload = buf[2 + header_len :] - - node_id = header.get("node_id") - channel = header.get("channel") - sender = manager.ui_senders().get((node_id, channel)) - if sender is not None: - sender.send(payload) - elif "text" in ws_msg: - msg = json.loads(ws_msg["text"]) - if msg.get("type") == "ui_input": - node_id = msg["node_id"] - channel = msg["channel"] - payload = msg.get("payload", "") - sender = manager.ui_senders().get((node_id, channel)) - if sender is not None: - inner_type = _resolve_ui_input_type(manager, node_id, channel) - if ( - inner_type is not None - and issubclass(inner_type, BaseModel) - and isinstance(payload, dict) - ): - sender.send(inner_type.model_validate(payload)) - elif ( - inner_type is not None - and hasattr(inner_type, "new") - and isinstance(payload, str) - ): - sender.send(inner_type.new(text=payload)) - else: - sender.send(payload) - except (WebSocketDisconnect, RuntimeError): - pass - finally: - stop_event.set() - watcher.cancel() - for t in tasks.values(): - t.cancel() - await asyncio.gather(watcher, *tasks.values(), return_exceptions=True) + await bridge.run(ws) diff --git a/backend/src/core/channel.py b/backend/src/core/channel.py index a12474da..d71bb26a 100644 --- a/backend/src/core/channel.py +++ b/backend/src/core/channel.py @@ -116,17 +116,20 @@ def buffer_depth(self) -> int: class Receiver[T]: - """Handle for receiving from a channel. Is itself the iterator.""" + """Handle for receiving from a channel. Is itself the iterator. - def __init__(self, channel: Channel[T]) -> None: + Registers with the channel on construction, unregisters on GC. + """ + + def __init__(self, channel: Channel[T], stop_event: threading.Event) -> None: self._channel = channel + self._stop_event = stop_event + self._sub_id: int = id(self) self._msg_count: int = 0 self._byte_count: int = 0 - self._sub_id: int | None = None - self._stop_event: threading.Event | None = None - self._wired: bool = False self._newest: bool = False self.blocking: bool = True + self._channel._register(self._sub_id) @property def newest(self) -> bool: @@ -137,32 +140,15 @@ def newest(self, value: bool) -> None: if self._newest == value: return self._newest = value - if self._wired: - self._channel._reregister(self._sub_id, newest=value) # type: ignore[arg-type] - - def _wire(self, stop_event: threading.Event) -> None: - """Register with the channel. Called by GraphManager.run().""" - if self._wired: - return - self._sub_id = id(self) - self._stop_event = stop_event - self._channel._register(self._sub_id, newest=self._newest) - self._wired = True - - def _unwire(self) -> None: - """Unregister from the channel. Idempotent.""" - if not self._wired: - return - self._channel._unregister(self._sub_id) # type: ignore[arg-type] - self._wired = False + self._channel._reregister(self._sub_id, newest=value) def __iter__(self) -> Iterator[T | None]: return self def __next__(self) -> T | None: item = self._channel._get( - self._sub_id, # type: ignore[arg-type] - self._stop_event, # type: ignore[arg-type] + self._sub_id, + self._stop_event, blocking=self.blocking, ) if item is not None: @@ -172,18 +158,17 @@ def __next__(self) -> T | None: def __del__(self) -> None: try: - self._unwire() + self._channel._unregister(self._sub_id) except Exception: pass @property def lag(self) -> int: - sub_id = self._sub_id - if sub_id is None or self._newest: + if self._newest: return 0 ch = self._channel with ch._condition: - cursor = ch._cursors.get(sub_id) + cursor = ch._cursors.get(self._sub_id) if cursor is None: return 0 head = ch._offset + len(ch._items) diff --git a/backend/src/core/component.py b/backend/src/core/component.py index 58c1c474..1f7f9667 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -12,6 +12,7 @@ from src.core.channel import Receiver, Sender from src.core.channel import UIReceiver, UISender from src.core.log_capture import get_log_store +from src.core.utils import ReceiverKey, SenderKey if TYPE_CHECKING: from src.core.graph import Graph, GraphManager @@ -342,8 +343,6 @@ def _safe_run(self, inputs: I, outputs: O) -> None: except Exception: traceback.print_exc() finally: - # Registration happens in GraphManager when start() is called. - # Unregister here to ensure buffered partial lines are flushed. get_log_store().unregister_thread() self._status = Status.STOPPED @@ -395,8 +394,15 @@ def emit(self, outputs: E) -> None: ... # --------------------------------------------------------------------------- -class CompositeComponent(Component[Any, Any]): - """A composite morphism: its interface is derived from unmatched ports in the subgraph.""" +class CompositeComponent( + Component[tuple[Receiver[Any] | None, ...], tuple[Sender[Any] | None, ...]] +): + """A composite component wrapping a subgraph. + + Its interface is derived from unmatched ports in the subgraph. + start() receives tuples whose element order matches + get_input_types() / get_output_types() key order. + """ _registerable = False @@ -421,29 +427,69 @@ def type_(self) -> str: def _compute_boundary( self, - ) -> tuple[dict[str, tuple[str, str]], dict[str, tuple[str, str]]]: - connected_inputs: set[tuple[str, str]] = set() - connected_outputs: set[tuple[str, str]] = set() + ) -> tuple[dict[str, ReceiverKey], dict[str, SenderKey]]: + connected_inputs: set[ReceiverKey] = set() + connected_outputs: set[SenderKey] = set() for edge in self._sub_graph.edges: connected_inputs.add((edge.target_node, edge.target_slot)) connected_outputs.add((edge.source_node, edge.source_slot)) classes = PrimitiveComponent.registered_subclasses() - ext_inputs: dict[str, tuple[str, str]] = {} - ext_outputs: dict[str, tuple[str, str]] = {} + + # Collect unconnected (boundary) slots: (node_id, slot, component_type) + raw_inputs: list[tuple[str, str, str]] = [] + raw_outputs: list[tuple[str, str, str]] = [] for node_id, node in self._sub_graph.nodes.items(): cls = classes.get(node.type) if cls is None: - # Node is a CompositeComponent — nested composites not yet supported continue for slot in cls._class_input_types(): if (node_id, slot) not in connected_inputs: - ext_inputs[f"{node_id}.{slot}"] = (node_id, slot) + raw_inputs.append((node_id, slot, node.type)) for slot in cls._class_output_types(): if (node_id, slot) not in connected_outputs: - ext_outputs[f"{node_id}.{slot}"] = (node_id, slot) + raw_outputs.append((node_id, slot, node.type)) + + ext_inputs = self._disambiguate(raw_inputs) + ext_outputs = self._disambiguate(raw_outputs) return ext_inputs, ext_outputs + @staticmethod + def _disambiguate( + slots: list[tuple[str, str, str]], + ) -> dict[str, tuple[str, str]]: + """Assign unique external names to boundary slots. + + Uses the slot name alone if unique, otherwise prefixes with the + component type. If still ambiguous, appends a numeric suffix. + """ + from collections import Counter + + # Count how many times each slot name appears + slot_counts = Counter(slot for _, slot, _ in slots) + + # For duplicates, try "Type.slot" + names: list[str] = [] + for _, slot, comp_type in slots: + if slot_counts[slot] == 1: + names.append(slot) + else: + names.append(f"{comp_type}.{slot}") + + # If still duplicates, append numbers + final_counts: dict[str, int] = Counter(names) + seen: dict[str, int] = {} + result: dict[str, tuple[str, str]] = {} + for i, name in enumerate(names): + node_id, slot, _ = slots[i] + if final_counts[name] > 1: + idx = seen.get(name, 0) + 1 + seen[name] = idx + result[f"{name}.{idx}"] = (node_id, slot) + else: + result[name] = (node_id, slot) + return result + def get_input_types(self) -> dict[str, type]: classes = PrimitiveComponent.registered_subclasses() result: dict[str, type] = {} @@ -474,7 +520,11 @@ def get_ui_input_types(self) -> dict[str, type]: def get_ui_output_types(self) -> dict[str, type]: return {} - def start(self, inputs: Any, outputs: Any) -> None: + def start( + self, + inputs: tuple[Receiver[Any] | None, ...], + outputs: tuple[Sender[Any] | None, ...], + ) -> None: if self.status == Status.RUNNING: return self._status = Status.SETUP @@ -483,20 +533,26 @@ def start(self, inputs: Any, outputs: Any) -> None: self._inner_manager = GraphManager(self._sub_graph) - for ext_name, (node_id, slot) in self._ext_inputs.items(): - outer_receiver = inputs.get(ext_name) if isinstance(inputs, dict) else None - if outer_receiver is not None: - self._inner_manager._receiver_handles[(node_id, slot)] = outer_receiver - - for ext_name, (node_id, slot) in self._ext_outputs.items(): - outer_sender = outputs.get(ext_name) if isinstance(outputs, dict) else None - if outer_sender is not None: - self._inner_manager._sender_handles[(node_id, slot)] = outer_sender + recv_overrides: dict[ReceiverKey, Receiver[Any] | None] = { + self._ext_inputs[ext_name]: recv + for ext_name, recv in zip(self._ext_inputs, inputs) + } + send_overrides: dict[SenderKey, Sender[Any] | None] = { + self._ext_outputs[ext_name]: send + for ext_name, send in zip(self._ext_outputs, outputs) + } - self._inner_manager.run() + self._inner_manager.run(recv_overrides, send_overrides) self._status = Status.RUNNING def stop(self) -> None: super().stop() if self._inner_manager is not None: - self._inner_manager.stop() + # Only signal inner components to stop, don't join — + # the outer GraphManager handles joining. + for node in self._inner_manager._graph.nodes.values(): + for sender in node.senders.values(): + if sender is not None: + sender._stopped = True + for comp in self._inner_manager._components.values(): + comp.stop() diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index 535d9e06..b0549a59 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -1,10 +1,10 @@ from __future__ import annotations -import asyncio +import threading import time import uuid from collections import defaultdict -from typing import Any, get_args, get_origin +from typing import Any from pydantic import BaseModel, Field @@ -17,17 +17,22 @@ from src.core.log_capture import get_log_store -SenderKey = tuple[str, str] # (node_id, slot_name) -ReceiverKey = tuple[str, str] # (node_id, slot_name) +from src.core.utils import SenderKey, ReceiverKey class Node(BaseModel): + model_config = {"arbitrary_types_allowed": True} + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) type: str init_args: dict[str, Any] x: float = 0.0 y: float = 0.0 sub_graph: Graph | None = None + senders: dict[str, Sender[Any] | None] = Field(default_factory=dict, exclude=True) + receivers: dict[str, Receiver[Any] | None] = Field( + default_factory=dict, exclude=True + ) class Edge(BaseModel): @@ -51,25 +56,19 @@ def __init__(self, graph: Graph) -> None: self._graph = Graph(edges=[], nodes={}) self._components: dict[str, Component[Any, Any]] = {} self._channel_map: dict[frozenset[SenderKey], Channel[Any]] = {} - self._sender_handles: dict[SenderKey, Sender[Any]] = {} - self._receiver_handles: dict[ReceiverKey, Receiver[Any]] = {} - # UI channels: keyed by (node_id, slot_name) - self._ui_channels: dict[tuple[str, str], Channel[Any]] = {} - self._ui_senders: dict[tuple[str, str], Sender[Any]] = {} - self._ui_receivers: dict[tuple[str, str], Receiver[Any]] = {} - self._ui_version = 0 - self._ui_changed = asyncio.Event() self.reset(graph) # --- node CRUD --- - def add_node(self, node_type: str, init_args: dict[str, Any]) -> tuple[str, Node]: + def add_primitive_node( + self, type_: str, init_args: dict[str, Any] + ) -> tuple[str, Node]: classes = PrimitiveComponent.registered_subclasses() - cls = classes.get(node_type) + cls = classes.get(type_) if cls is None: - raise ValueError(f"Unknown node type: {node_type}") + raise ValueError(f"Unknown node type: {type_}") comp = cls.from_args(init_args) - node = Node(type=node_type, init_args=init_args) + node = Node(type=type_, init_args=init_args) self._graph.nodes[node.id_] = node self._components[node.id_] = comp return node.id_, node @@ -107,17 +106,22 @@ def update_node(self, node_id: str, x: float, y: float) -> Node | None: node.y = y return node - def update_node_init_args( + def update_primitive_node_init_args( self, node_id: str, init_args: dict[str, Any] - ) -> Node | None: + ) -> tuple[Node | None, bool]: + """Replace a node's init_args and recreate its component. + + Returns (node, was_running). The caller is responsible for + calling run() with UI channel overrides if was_running is True. + """ node = self._graph.nodes.get(node_id) if node is None: - return None + return None, False classes = PrimitiveComponent.registered_subclasses() cls = classes.get(node.type) if cls is None: - return None + return None, False was_running = any( c.status.value == "running" for c in self._components.values() @@ -128,10 +132,7 @@ def update_node_init_args( node.init_args = init_args self._components[node_id] = cls.from_args(init_args) self._reconcile() - - if was_running: - self.run() - return node + return node, was_running def delete_node(self, node_id: str) -> None: node = self._graph.nodes.get(node_id) @@ -142,13 +143,11 @@ def delete_node(self, node_id: str) -> None: if comp is not None: comp.stop() - # Collect connected components that need stopping + # Collect downstream components that need stopping affected: set[str] = set() for edge in self._graph.edges: if edge.source_node == node_id: affected.add(edge.target_node) - if edge.target_node == node_id: - affected.add(edge.source_node) self._graph.edges = [ e @@ -187,18 +186,36 @@ def components(self) -> dict[str, Component[Any, Any]]: return self._components def sender_handles(self) -> dict[SenderKey, Sender[Any]]: - return self._sender_handles + return { + (node_id, slot): sender + for node_id, node in self._graph.nodes.items() + for slot, sender in node.senders.items() + if sender is not None + } def receiver_handles(self) -> dict[ReceiverKey, Receiver[Any]]: - return self._receiver_handles - - def ui_senders(self) -> dict[tuple[str, str], Sender[Any]]: - """Server-side senders that push data into component UIReceiver slots.""" - return self._ui_senders + return { + (node_id, slot): receiver + for node_id, node in self._graph.nodes.items() + for slot, receiver in node.receivers.items() + if receiver is not None + } + + def ui_input_slots(self) -> list[ReceiverKey]: + """All (node_id, slot) pairs for UI input slots across the graph.""" + return [ + (node_id, slot) + for node_id, comp in self._components.items() + for slot in comp.get_ui_input_types() + ] - def ui_receivers(self) -> dict[tuple[str, str], Receiver[Any]]: - """Server-side receivers that read from component UISender slots.""" - return self._ui_receivers + def ui_output_slots(self) -> list[SenderKey]: + """All (node_id, slot) pairs for UI output slots across the graph.""" + return [ + (node_id, slot) + for node_id, comp in self._components.items() + for slot in comp.get_ui_output_types() + ] def get_node_output(self, node_id: str) -> dict[str, type]: return self._components[node_id].get_output_types() @@ -212,11 +229,6 @@ def reset(self, graph: Graph) -> None: self._graph = graph self._components.clear() self._channel_map.clear() - self._sender_handles.clear() - self._receiver_handles.clear() - self._ui_channels.clear() - self._ui_senders.clear() - self._ui_receivers.clear() classes = PrimitiveComponent.registered_subclasses() for node_id, node in self._graph.nodes.items(): @@ -248,8 +260,14 @@ def _group( return dict(groups) - def _reconcile(self) -> None: - """Recompute optimal channel layout and diff against existing.""" + def _reconcile( + self, + ) -> tuple[dict[SenderKey, list[Channel[Any]]], dict[ReceiverKey, Channel[Any]]]: + """Recompute channel topology from the current graph edges. + + Returns (sender_plan, receiver_plan) — the wiring blueprint. + Does not create or touch Sender/Receiver handles — that happens in run(). + """ edges: list[tuple[SenderKey, ReceiverKey]] = [ ((e.source_node, e.source_slot), (e.target_node, e.target_slot)) for e in self._graph.edges @@ -259,133 +277,98 @@ def _reconcile(self) -> None: old_keys = set(self._channel_map.keys()) new_keys = set(groups.keys()) - reuse = old_keys & new_keys - create = new_keys - old_keys - + # step 1: create/remove channels new_channel_map: dict[frozenset[SenderKey], Channel[Any]] = {} - for ckey in reuse: + for ckey in old_keys & new_keys: new_channel_map[ckey] = self._channel_map[ckey] - for ckey in create: + for ckey in new_keys - old_keys: new_channel_map[ckey] = Channel() - self._channel_map = new_channel_map - sender_channels: dict[SenderKey, list[Channel[Any]]] = defaultdict(list) + # step 2: build the wiring plan + sender_plan: dict[SenderKey, list[Channel[Any]]] = defaultdict(list) for sender_set, channel in self._channel_map.items(): for sender_key in sender_set: - sender_channels[sender_key].append(channel) - - new_sender_handles: dict[SenderKey, Sender[Any]] = {} - for skey, channels in sender_channels.items(): - old_sender = self._sender_handles.get(skey) - if old_sender is not None and set(old_sender._channels) == set(channels): - new_sender_handles[skey] = old_sender - else: - new_sender_handles[skey] = Sender(*channels) - - # Ensure every declared output slot has a sender handle, even when no edges - # are connected. This keeps per-output metrics (e.g. last_send_time) visible - # for standalone sources. - for node_id, comp in self._components.items(): - for slot in comp.get_output_types(): - output_key: SenderKey = (node_id, slot) - if output_key in new_sender_handles: - continue - old_sender = self._sender_handles.get(output_key) - if old_sender is not None and len(old_sender._channels) == 0: - new_sender_handles[output_key] = old_sender - else: - new_sender_handles[output_key] = Sender() - self._sender_handles = new_sender_handles + sender_plan[sender_key].append(channel) - new_receiver_handles: dict[ReceiverKey, Receiver[Any]] = {} + receiver_plan: dict[ReceiverKey, Channel[Any]] = {} for sender_set, recv_keys in groups.items(): channel = self._channel_map[sender_set] for recv_key in recv_keys: - old_recv: Receiver[Any] | None = self._receiver_handles.get(recv_key) # type: ignore[assignment] - if old_recv is not None and old_recv._channel is channel: - new_receiver_handles[recv_key] = old_recv - else: - new_receiver_handles[recv_key] = Receiver(channel) - self._receiver_handles = new_receiver_handles + receiver_plan[recv_key] = channel + + return dict(sender_plan), receiver_plan - def run(self) -> None: - """Stop all running components, then start each with wired handles.""" + def run( + self, + receiver_overrides: dict[ReceiverKey, Receiver[Any] | None] | None = None, + sender_overrides: dict[SenderKey, Sender[Any] | None] | None = None, + ) -> None: + """Stop all running components, then start each with fresh handles. + + Optional overrides let callers (e.g. CompositeComponent) inject + pre-built handles for specific slots. + """ self.stop() - for sender in self._sender_handles.values(): - sender._stopped = False - self._ui_channels.clear() - self._ui_senders.clear() - self._ui_receivers.clear() + + sender_plan, receiver_plan = self._reconcile() + _recv_over = receiver_overrides or {} + _send_over = sender_overrides or {} + start_queue: list[tuple[str, Component[Any, Any], Any, Any]] = [] - for node_id in self._graph.nodes: + for node_id, node in self._graph.nodes.items(): comp = self._components[node_id] cls = type(comp) input_type = cls._get_type_param(0) output_type = cls._get_type_param(1) - # Use instance calls so CompositeComponent overrides take effect input_slots = comp.get_input_types() output_slots = comp.get_output_types() ui_input_slots = comp.get_ui_input_types() ui_output_slots = comp.get_ui_output_types() - from src.core.component import CompositeComponent + stop_event = ( + comp.stop_event + if isinstance(comp, ThreadedComponent) + else threading.Event() + ) + + # Create fresh handles from plan, store on node. + # Overrides take priority over the plan. + for slot in input_slots: + if (node_id, slot) in _recv_over: + node.receivers[slot] = _recv_over[(node_id, slot)] + elif (node_id, slot) in receiver_plan: + node.receivers[slot] = Receiver( + receiver_plan[(node_id, slot)], stop_event + ) + else: + node.receivers[slot] = None - is_composite = isinstance(comp, CompositeComponent) + for slot in output_slots: + if (node_id, slot) in _send_over: + node.senders[slot] = _send_over[(node_id, slot)] or Sender() + elif (node_id, slot) in sender_plan: + node.senders[slot] = Sender(*sender_plan[(node_id, slot)]) + else: + node.senders[slot] = Sender() - input_handles: dict[str, Receiver[Any] | None] = {} - for slot, slot_type in input_slots.items(): - rkey: ReceiverKey = (node_id, slot) - if rkey in self._receiver_handles: - input_handles[slot] = self._receiver_handles[rkey] - elif type(None) in get_args(slot_type): - # Optional inputs with no edge get None so NamedTuple - # construction doesn't fail from a missing field. - input_handles[slot] = None + # UI slots: use overrides if provided, otherwise dummy handles + for slot in ui_input_slots: + if (node_id, slot) in _recv_over: + node.receivers[slot] = _recv_over[(node_id, slot)] + else: + node.receivers[slot] = None - output_handles: dict[str, Sender[Any]] = {} - for slot in output_slots: - skey: SenderKey = (node_id, slot) - if skey in self._sender_handles: - output_handles[slot] = self._sender_handles[skey] + for slot in ui_output_slots: + if (node_id, slot) in _send_over: + node.senders[slot] = _send_over[(node_id, slot)] or Sender() else: - # Unconnected output: no-op Sender (sends are discarded) - output_handles[slot] = Sender() - - # Wire UI input channels (frontend -> component) - for slot, slot_type in ui_input_slots.items(): - ch: Channel[Any] = Channel() - self._ui_channels[(node_id, slot)] = ch - # Component gets a UIReceiver to read from - origin = get_origin(slot_type) or slot_type - input_handles[slot] = origin(ch) - # Server keeps a Sender to push data in - self._ui_senders[(node_id, slot)] = Sender(ch) - - # Wire UI output channels (component -> frontend) - for slot, slot_type in ui_output_slots.items(): - ch = Channel() - self._ui_channels[(node_id, slot)] = ch - # Component gets a UISender to write to - origin = get_origin(slot_type) or slot_type - output_handles[slot] = origin(ch) - # Server keeps a Receiver to read from - self._ui_receivers[(node_id, slot)] = Receiver(ch) - - # Wire all receivers (registers cursors before setup) - if isinstance(comp, ThreadedComponent): - for handle in input_handles.values(): - if isinstance(handle, Receiver): - handle._wire(comp.stop_event) + node.senders[slot] = Sender() - if is_composite: - built_inputs: Any = input_handles - built_outputs: Any = output_handles - else: - built_inputs = self._build_tuple(input_type, input_handles) - built_outputs = self._build_tuple(output_type, output_handles) + built_inputs = self._build_tuple(input_type, dict(node.receivers)) + built_outputs = self._build_tuple(output_type, dict(node.senders)) start_queue.append((node_id, comp, built_inputs, built_outputs)) @@ -400,7 +383,6 @@ def run(self) -> None: if isinstance(comp, ThreadedComponent): ident = comp.get_ident() if ident is None: - # Thread ident may lag briefly after start(). for _ in range(10): time.sleep(0.005) ident = comp.get_ident() @@ -409,15 +391,6 @@ def run(self) -> None: if ident is not None: get_log_store().register_thread(node_id=node_id, ident=ident) - # Notify WS listeners that UI channels are ready. - # Save old event, replace with fresh one, *then* wake waiters. - # This avoids a race where a coroutine calling wait() between - # set() and reassignment would block on the now-orphaned event. - self._ui_version += 1 - old_event = self._ui_changed - self._ui_changed = asyncio.Event() - old_event.set() - @staticmethod def _build_tuple(tp: type | None, handles: dict[str, Any]) -> tuple[Any, ...]: """Build a NamedTuple (keyword) or plain tuple (positional) from handles.""" @@ -425,14 +398,15 @@ def _build_tuple(tp: type | None, handles: dict[str, Any]) -> tuple[Any, ...]: return () if hasattr(tp, "_fields"): return tp(**handles) - return tuple(handles[k] for k in sorted(handles.keys())) + # Composite or plain tuple: preserve insertion order of handles dict + return tuple(handles.values()) def stop(self) -> None: """Stop all components and await their threads.""" - for sender in self._sender_handles.values(): - sender._stopped = True - for sender in self._ui_senders.values(): - sender._stopped = True + for node in self._graph.nodes.values(): + for sender in node.senders.values(): + if sender is not None: + sender._stopped = True for comp in self._components.values(): comp.stop() for comp in self._components.values(): diff --git a/backend/src/core/utils.py b/backend/src/core/utils.py index b12aebe3..8a3876ea 100644 --- a/backend/src/core/utils.py +++ b/backend/src/core/utils.py @@ -13,6 +13,9 @@ import numpy as np +SenderKey = tuple[str, str] # (node_id, slot_name) +ReceiverKey = tuple[str, str] # (node_id, slot_name) + _COUNTS: collections.defaultdict[str, itertools.count[int]] = collections.defaultdict( itertools.count ) diff --git a/backend/src/lib/misc/do_nothing.py b/backend/src/lib/misc/do_nothing.py index 1e57fa36..5af795e7 100644 --- a/backend/src/lib/misc/do_nothing.py +++ b/backend/src/lib/misc/do_nothing.py @@ -15,5 +15,6 @@ class DoNothing[T](ThreadedComponent[DoNothingInputs[T], tuple[()]]): description = "Consumes input and discards it" def run(self, inputs: DoNothingInputs[T], outputs: tuple[()]) -> None: - for _ in inputs.input: - pass + for item in inputs.input: + if item is None: + break diff --git a/backend/src/lib/misc/throttle.py b/backend/src/lib/misc/throttle.py index 14514c93..8c1326a9 100644 --- a/backend/src/lib/misc/throttle.py +++ b/backend/src/lib/misc/throttle.py @@ -2,7 +2,6 @@ from __future__ import annotations -import time from typing import NamedTuple from pydantic import BaseModel @@ -40,4 +39,5 @@ def run(self, inputs: ThrottleInputs[T], outputs: ThrottleOutputs[T]) -> None: if item is None: break outputs.data.send(item) - time.sleep(self.config.interval) + if self.stop_event.wait(self.config.interval): + break diff --git a/backend/src/main.py b/backend/src/main.py index 48d6b07c..3ef16f3c 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -16,6 +16,7 @@ from src.api.project.controller import router as project_router from src.api.ui.controller import router as ui_router from src.api.env.controller import router as env_router +from src.api.ui.bridge import UIChannelBridge from src.core.graph import Graph from src.core.config import PROJECTS_DIR, PRESETS_DIR, AppConfig from src.core.graph import GraphManager @@ -47,6 +48,7 @@ async def lifespan(app: FastAPI): app.state.current_project = config.current_project app.state.manager = GraphManager(Graph(edges=[], nodes={})) + app.state.ui_bridge = UIChannelBridge() yield diff --git a/backend/tests/api/test_api_controllers.py b/backend/tests/api/test_api_controllers.py index c7795339..669416af 100644 --- a/backend/tests/api/test_api_controllers.py +++ b/backend/tests/api/test_api_controllers.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import threading import types import pytest @@ -29,7 +30,7 @@ def __init__(self) -> None: node = Node(id_="n1", type="A", init_args={}, x=1, y=2) self.graph = Graph(nodes={"n1": node}, edges=[]) self._sender = Sender(Channel()) - self._receiver = Receiver(Channel()) + self._receiver = Receiver(Channel(), threading.Event()) def component(self, _node_id: str): return types.SimpleNamespace(status=types.SimpleNamespace(value="running")) @@ -126,17 +127,17 @@ def test_node_controller_paths(monkeypatch) -> None: ) monkeypatch.setattr( - node_controller.service, "update_node_init_args", lambda *a, **k: None + node_controller.service, "update_primitive_node_init_args", lambda *a, **k: None ) with pytest.raises(HTTPException): - node_controller.update_node_init_args( + node_controller.update_primitive_node_init_args( "n1", NodeInitArgsUpdateRequest(init_args={}), manager ) monkeypatch.setattr( - node_controller.service, "update_node_init_args", lambda *a, **k: node + node_controller.service, "update_primitive_node_init_args", lambda *a, **k: node ) assert ( - node_controller.update_node_init_args( + node_controller.update_primitive_node_init_args( "n1", NodeInitArgsUpdateRequest(init_args={}), manager ).id == "n1" @@ -195,12 +196,17 @@ def test_run_save_metrics_project_controllers(monkeypatch, tmp_path) -> None: called = {} monkeypatch.setattr( - run_controller.service, "start_all", lambda m: called.setdefault("start", True) + run_controller.service, + "start_all", + lambda m, b: called.setdefault("start", True), ) monkeypatch.setattr( run_controller.service, "stop_all", lambda m: called.setdefault("stop", True) ) - run_controller.start_all(manager) + from src.api.ui.bridge import UIChannelBridge + + bridge = UIChannelBridge() + run_controller.start_all(manager, bridge) run_controller.stop_all(manager) assert called == {"start": True, "stop": True} diff --git a/backend/tests/api/test_api_services.py b/backend/tests/api/test_api_services.py index ca0fb75e..1cd3c9ce 100644 --- a/backend/tests/api/test_api_services.py +++ b/backend/tests/api/test_api_services.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import threading import types import pytest @@ -33,7 +34,7 @@ def __init__(self) -> None: self._reset_called = False channel = Channel() self._sender = Sender(channel) - self._receiver = Receiver(channel) + self._receiver = Receiver(channel, threading.Event()) def get_node(self, node_id: str): return self.graph.nodes.get(node_id) @@ -50,7 +51,7 @@ def add_edge(self, edge: Edge) -> None: def delete_edge(self, edge: Edge) -> None: self.graph.edges.remove(edge) - def run(self) -> None: + def run(self, receiver_overrides=None, sender_overrides=None) -> None: self._run_called = True def stop(self) -> None: @@ -60,9 +61,13 @@ def reset(self, _graph: Graph) -> None: self._reset_called = True def components(self): + _no_ui = lambda: {} # noqa: E731 return { "a": types.SimpleNamespace( - type_="A", status=types.SimpleNamespace(value="running") + type_="A", + status=types.SimpleNamespace(value="running"), + get_ui_input_types=_no_ui, + get_ui_output_types=_no_ui, ) } @@ -198,7 +203,10 @@ def test_edge_run_save_services(tmp_path, monkeypatch) -> None: with pytest.raises(KeyError): edge_service.delete_edge(manager, "a", "out", "b", "in") - run_service.start_all(manager) + from src.api.ui.bridge import UIChannelBridge + + bridge = UIChannelBridge() + run_service.start_all(manager, bridge) run_service.stop_all(manager) assert manager._run_called is True assert manager._stop_called is True @@ -224,9 +232,7 @@ def test_logs_controller(monkeypatch) -> None: def test_metrics_collector_collect() -> None: manager = FakeManager() frame = TextFrame.new(text="hello") - stop_event = types.SimpleNamespace(is_set=lambda: False) manager._receiver.blocking = False - manager._receiver._wire(stop_event) manager._sender.send(frame) next(manager._receiver) @@ -235,7 +241,6 @@ def test_metrics_collector_collect() -> None: second = collector.collect(manager) assert "a" in first.nodes assert second.nodes["a"].senders["out"].msg_count_delta == 0 - manager._receiver._unwire() def test_project_service(monkeypatch, tmp_path) -> None: diff --git a/backend/tests/api/test_node_service.py b/backend/tests/api/test_node_service.py index 4913a457..34ea67e1 100644 --- a/backend/tests/api/test_node_service.py +++ b/backend/tests/api/test_node_service.py @@ -35,10 +35,10 @@ def __init__(self) -> None: self._components = {"n1": _FakeComp(), "n2": _FakeComp()} self.reconciled = 0 - def add_node(self, node_type, init_args): - if node_type != "A": + def add_primitive_node(self, type_, init_args): + if type_ != "A": raise ValueError("bad") - n = Node(id_="n3", type=node_type, init_args=init_args) + n = Node(id_="n3", type=type_, init_args=init_args) self.graph.nodes["n3"] = n self._components["n3"] = _FakeComp() return "n3", n @@ -64,12 +64,12 @@ def delete_node(self, node_id): self.graph.nodes.pop(node_id, None) self._components.pop(node_id, None) - def update_node_init_args(self, node_id, init_args): + def update_primitive_node_init_args(self, node_id, init_args): n = self.graph.nodes.get(node_id) if n is None: - return None + return None, False n.init_args = init_args - return n + return n, False def components(self): return self._components @@ -88,7 +88,11 @@ def test_node_service_basic_crud() -> None: assert updated.x == 9 node_service.delete_node(m, "n2") assert "n2" not in m.graph.nodes - out = node_service.update_node_init_args(m, "n1", {"k": 1}) + from src.api.ui.bridge import UIChannelBridge + + out = node_service.update_primitive_node_init_args( + m, UIChannelBridge(), "n1", {"k": 1} + ) assert out.init_args == {"k": 1} diff --git a/backend/tests/api/test_ui_controller.py b/backend/tests/api/test_ui_controller.py index 0c793d0d..29a93b87 100644 --- a/backend/tests/api/test_ui_controller.py +++ b/backend/tests/api/test_ui_controller.py @@ -2,17 +2,21 @@ import asyncio import json -import struct +import threading import types from fastapi import WebSocketDisconnect from pydantic import BaseModel -from src.api.ui import controller as ui_controller +from src.api.ui.bridge import ( + UIChannelBridge, + encode_binary, + encode_json, + decode_binary, + decode_json, + deserialize_payload, +) from src.core.channel import ( - Channel, - Receiver, - Sender, UIReceiver, UISender, UITextReceiver, @@ -25,15 +29,22 @@ class _Payload(BaseModel): value: int -class _PayloadSender(UISender[_Payload]): +class _PayloadReceiver(UIReceiver[_Payload]): pass -class _PayloadReceiver(UIReceiver[_Payload]): +class _MissingSender(UISender): + pass + + +class _MissingReceiver(UIReceiver): pass class _FakeComponent: + def __init__(self) -> None: + self.stop_event = threading.Event() + def get_ui_output_types(self): return { "video": UIVideoSender, @@ -51,49 +62,17 @@ def get_ui_input_types(self): } -class _MissingSender(UISender): - pass - - -class _MissingReceiver(UIReceiver): - pass - - class _FakeManager: def __init__(self) -> None: self._component = _FakeComponent() - self._ui_version = 0 - self._ui_changed = asyncio.Event() - self._receiver = Receiver(Channel()) - self._ui_receivers = {("node", "video"): self._receiver} - self.sent_payloads: list[object] = [] - self._ui_senders = { - ("node", "payload_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - ("node", "text_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - ("node", "raw_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - } def components(self): return {"node": self._component} - def ui_receivers(self): - return self._ui_receivers - - def ui_senders(self): - return self._ui_senders - class _FakeWebSocket: - def __init__( - self, manager: _FakeManager, messages: list[str] | None = None - ) -> None: - self.app = types.SimpleNamespace(state=types.SimpleNamespace(manager=manager)) + def __init__(self, messages: list[str | bytes] | None = None) -> None: + self.app = types.SimpleNamespace(state=types.SimpleNamespace()) self.client = ("127.0.0.1", 0) self._messages = iter(messages or []) self.accepted = False @@ -103,15 +82,12 @@ def __init__( async def accept(self) -> None: self.accepted = True - async def receive_text(self) -> str: + async def receive(self) -> dict[str, str | bytes]: try: - return next(self._messages) - except StopIteration as exc: - raise WebSocketDisconnect() from exc - - async def receive(self) -> dict[str, str]: - try: - return {"text": next(self._messages)} + msg = next(self._messages) + if isinstance(msg, bytes): + return {"bytes": msg} + return {"text": msg} except StopIteration as exc: raise WebSocketDisconnect() from exc @@ -122,143 +98,99 @@ async def send_bytes(self, payload: bytes) -> None: self.byte_messages.append(payload) -def test_ui_type_resolution_helpers() -> None: - manager = _FakeManager() - manager._component.get_ui_output_types = lambda: { - "video": UIVideoSender, - "payload": UISender[_Payload], - "missing": _MissingSender, - } - manager._component.get_ui_input_types = lambda: { - "payload_in": UIReceiver[_Payload], - "text_in": UITextReceiver, - "missing_args": _MissingReceiver, +def test_wire_format_encode_decode() -> None: + # Binary round-trip + encoded = encode_binary("n1", "video", b"\x00\x01") + key, payload = decode_binary(encoded) + assert key == ("n1", "video") + assert payload == b"\x00\x01" + + # Binary too short + assert decode_binary(b"\x00")[0] is None + + # JSON round-trip + envelope = encode_json("n1", "text", "hello") + assert envelope == { + "type": "ui_output", + "node_id": "n1", + "channel": "text", + "payload": "hello", } - assert ( - ui_controller._resolve_ui_output_type(manager, "missing-node", "video") is None - ) - assert ( - ui_controller._resolve_ui_output_type(manager, "node", "missing-slot") is None - ) - assert ui_controller._resolve_ui_output_type(manager, "node", "video") is bytes - assert ui_controller._resolve_ui_output_type(manager, "node", "payload") is _Payload - assert ui_controller._resolve_ui_output_type(manager, "node", "missing") is None + # JSON with BaseModel + env_model = encode_json("n1", "data", _Payload(value=5)) + assert env_model["payload"] == {"value": 5} - assert ( - ui_controller._resolve_ui_input_type(manager, "missing-node", "payload_in") - is None - ) - assert ui_controller._resolve_ui_input_type(manager, "node", "missing-slot") is None - assert ( - ui_controller._resolve_ui_input_type(manager, "node", "payload_in") is _Payload - ) - assert ui_controller._resolve_ui_input_type(manager, "node", "text_in") is TextFrame - assert ui_controller._resolve_ui_input_type(manager, "node", "missing_args") is None + # JSON with TextFrame + env_text = encode_json("n1", "data", TextFrame.new(text="hi")) + assert env_text["payload"] == "hi" + # decode_json + result = decode_json( + json.dumps( + {"type": "ui_input", "node_id": "n1", "channel": "c", "payload": "x"} + ) + ) + assert result == (("n1", "c"), "x") + assert decode_json(json.dumps({"type": "ignored"})) is None -def test_read_ui_output_variants() -> None: - async def run_case(item: object, inner_type: type | None, failing: bool = False): - stop_event = asyncio.Event() - channel = Channel() - sender = Sender(channel) - receiver = Receiver(channel) - ws = _FakeWebSocket(_FakeManager()) + # deserialize_payload + assert deserialize_payload({"value": 3}, _Payload) == _Payload(value=3) + assert isinstance(deserialize_payload("hi", TextFrame), TextFrame) + assert deserialize_payload(42, None) == 42 - if failing: - async def send_json(_payload: object) -> None: - raise RuntimeError("boom") +def test_type_resolution() -> None: + bridge = UIChannelBridge() + manager = _FakeManager() + bridge._manager = manager - ws.send_json = send_json # type: ignore[method-assign] + assert bridge._resolve_ui_output_type("missing-node", "video") is None + assert bridge._resolve_ui_output_type("node", "missing-slot") is None + assert bridge._resolve_ui_output_type("node", "video") is bytes + assert bridge._resolve_ui_output_type("node", "payload") is _Payload - task = asyncio.create_task( - ui_controller._read_ui_output( - ws, "node", "slot", receiver, inner_type, stop_event - ) - ) - await asyncio.sleep(0) - sender.send(item) - sender.send(None) - await task - return ws, channel - - async def run_all() -> None: - ws_bytes, _ = await run_case(b"\x00\x01", bytes) - header_len = struct.unpack(">H", ws_bytes.byte_messages[0][:2])[0] - header = json.loads(ws_bytes.byte_messages[0][2 : 2 + header_len].decode()) - assert header == {"type": "ui_output", "node_id": "node", "channel": "slot"} - assert ws_bytes.byte_messages[0][2 + header_len :] == b"\x00\x01" - - ws_model, _ = await run_case(_Payload(value=7), _Payload) - assert ws_model.json_messages[0] == { - "type": "ui_output", - "node_id": "node", - "channel": "slot", - "payload": {"value": 7}, - } + assert bridge._resolve_ui_input_type("missing-node", "payload_in") is None + assert bridge._resolve_ui_input_type("node", "missing-slot") is None + assert bridge._resolve_ui_input_type("node", "payload_in") is _Payload + assert bridge._resolve_ui_input_type("node", "text_in") is TextFrame + assert bridge._resolve_ui_input_type("node", "missing_args") is None - ws_legacy, _ = await run_case(TextFrame.new(text="hello"), TextFrame) - assert ws_legacy.json_messages[0]["payload"] == "hello" - ws_plain, _ = await run_case(5, int) - assert ws_plain.json_messages[0]["payload"] == 5 +def test_bridge_wire() -> None: + bridge = UIChannelBridge() + manager = _FakeManager() + recv_overrides, send_overrides = bridge.wire(manager) - ws_fail, channel = await run_case(9, int, failing=True) - assert ws_fail.json_messages == [] - assert channel._cursors == {} + # UI input slots should have receiver overrides + server senders + assert ("node", "payload_in") in recv_overrides + assert ("node", "text_in") in recv_overrides + assert ("node", "payload_in") in bridge._ui_senders + assert ("node", "text_in") in bridge._ui_senders - asyncio.run(run_all()) + # UI output slots should have sender overrides + server receivers + assert ("node", "video") in send_overrides + assert ("node", "payload") in send_overrides + assert ("node", "video") in bridge._ui_receivers + assert ("node", "payload") in bridge._ui_receivers -def test_watch_ui_channels_and_ui_ws(monkeypatch) -> None: - async def run_watch_ui_channels() -> None: +def test_bridge_recv_msgs() -> None: + async def run() -> None: + bridge = UIChannelBridge() manager = _FakeManager() - ws = _FakeWebSocket(manager) - stop_event = asyncio.Event() - tasks: dict[tuple[str, str], asyncio.Task[None]] = {} - calls: list[tuple[str, str, type | None]] = [] + bridge.wire(manager) - async def fake_read_ui_output( - ws, node_id, slot, receiver, inner_type, stop_event - ): - calls.append((node_id, slot, inner_type)) - await stop_event.wait() + sent: list[object] = [] + for sender in bridge._ui_senders.values(): + original_send = sender.send + sender.send = lambda item, _orig=original_send: ( + sent.append(item), + _orig(item), + ) # type: ignore[method-assign] - monkeypatch.setattr(ui_controller, "_read_ui_output", fake_read_ui_output) - - watcher = asyncio.create_task( - ui_controller._watch_ui_channels(ws, manager, stop_event, tasks) - ) - await asyncio.sleep(0.01) - assert calls == [("node", "video", bytes)] - assert ("node", "video") in tasks - - old_event = manager._ui_changed - manager._ui_version = 1 - manager._ui_changed = asyncio.Event() - old_event.set() - await asyncio.sleep(0.01) - assert calls[-1] == ("node", "video", bytes) - assert ("node", "video") in tasks - - stop_event.set() - watcher.cancel() - await asyncio.gather(watcher, return_exceptions=True) - - async def run_ui_ws() -> None: - manager = _FakeManager() ws = _FakeWebSocket( - manager, messages=[ - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "payload_in", - "payload": {"value": 5}, - } - ), json.dumps( { "type": "ui_input", @@ -271,52 +203,41 @@ async def run_ui_ws() -> None: { "type": "ui_input", "node_id": "node", - "channel": "raw_in", - "payload": 3, - } - ), - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "missing", - "payload": "skip", + "channel": "payload_in", + "payload": {"value": 5}, } ), json.dumps({"type": "ignored"}), ], ) - cancelled = {"watcher": False, "task": False} - original_receive = ws.receive - - async def delayed_receive() -> dict[str, str]: - await asyncio.sleep(0) - return await original_receive() - - ws.receive = delayed_receive # type: ignore[method-assign] + bridge._ws = ws + # _recv_msgs raises WebSocketDisconnect when messages run out + try: + await bridge._recv_msgs(ws) # type: ignore[arg-type] + except WebSocketDisconnect: + pass - async def fake_watch_ui_channels(ws, manager, stop_event, tasks): - class _Task: - def cancel(self) -> None: - cancelled["task"] = True + assert any(isinstance(s, TextFrame) and s.text == "hello" for s in sent) + assert any(isinstance(s, _Payload) and s.value == 5 for s in sent) - def __await__(self): - return stop_event.wait().__await__() + asyncio.run(run()) - tasks[("node", "payload")] = _Task() - await stop_event.wait() - monkeypatch.setattr(ui_controller, "_watch_ui_channels", fake_watch_ui_channels) +def test_encode_output_formats() -> None: + # bytes → encode_binary + binary = encode_binary("n1", "video", b"\xff") + key, payload = decode_binary(binary) + assert key == ("n1", "video") + assert payload == b"\xff" - await ui_controller.ui_ws(ws) + # BaseModel → encode_json + env = encode_json("n1", "data", _Payload(value=9)) + assert env["payload"] == {"value": 9} - assert ws.accepted is True - assert isinstance(manager.sent_payloads[0], _Payload) - assert manager.sent_payloads[0].value == 5 - assert isinstance(manager.sent_payloads[1], TextFrame) - assert manager.sent_payloads[1].get() == "hello" - assert manager.sent_payloads[2] == 3 - assert cancelled["task"] is True + # TextFrame → encode_json + env2 = encode_json("n1", "text", TextFrame.new(text="hi")) + assert env2["payload"] == "hi" - asyncio.run(run_watch_ui_channels()) - asyncio.run(run_ui_ws()) + # plain → encode_json + env3 = encode_json("n1", "num", 42) + assert env3["payload"] == 42 diff --git a/backend/tests/conduit/test_think_tool_throttle.py b/backend/tests/conduit/test_think_tool_throttle.py index be8264aa..ee53d216 100644 --- a/backend/tests/conduit/test_think_tool_throttle.py +++ b/backend/tests/conduit/test_think_tool_throttle.py @@ -65,14 +65,12 @@ def test_think_tool_setup_and_run_paths(capsys) -> None: assert all(result.content == "" for result in tool_results) -def test_throttle_forwards_newest_items(monkeypatch) -> None: - sleeps = [] +def test_throttle_forwards_newest_items() -> None: sent = [] + waits = [] throttle = Throttle[int](ThrottleConfig(interval=0.25)) - - monkeypatch.setattr( - "src.lib.misc.throttle.time.sleep", lambda value: sleeps.append(value) - ) + # Mock stop_event.wait to record calls and return False (not stopped) + throttle._stop_event.wait = lambda timeout: (waits.append(timeout), False)[1] # type: ignore[assignment] recv = _FakeRecv([1, 2, None]) throttle.run( @@ -84,4 +82,4 @@ def test_throttle_forwards_newest_items(monkeypatch) -> None: assert recv.newest is True assert sent == [1, 2] - assert sleeps == [0.25, 0.25] + assert waits == [0.25, 0.25] diff --git a/backend/tests/core/test_channel.py b/backend/tests/core/test_channel.py index 56110c08..1f44ba13 100644 --- a/backend/tests/core/test_channel.py +++ b/backend/tests/core/test_channel.py @@ -9,8 +9,7 @@ def test_channel_send_receive_and_gc() -> None: channel = Channel[int]() sender = Sender(channel) - receiver = Receiver(channel) - receiver._wire(threading.Event()) + receiver = Receiver(channel, threading.Event()) sender.send(1) sender.send(2) @@ -18,16 +17,14 @@ def test_channel_send_receive_and_gc() -> None: assert next(receiver) == 2 assert receiver.lag == 0 assert sender.buffer_depth == 0 - receiver._unwire() def test_channel_non_blocking_and_fast_forward() -> None: channel = Channel[int]() - receiver = Receiver(channel) + receiver = Receiver(channel, threading.Event()) sender = Sender(channel) receiver.newest = True receiver.blocking = False - receiver._wire(threading.Event()) sender.send(10) sender.send(20) @@ -35,7 +32,6 @@ def test_channel_non_blocking_and_fast_forward() -> None: assert next(receiver) == 30 assert next(receiver) is None - receiver._unwire() def test_channel_wait_stop_and_unregister_idempotent() -> None: @@ -50,22 +46,18 @@ def test_channel_wait_stop_and_unregister_idempotent() -> None: def test_receiver_iterator_stop_event_finishes() -> None: channel = Channel[str]() - receiver = Receiver(channel) stop_event = threading.Event() - receiver._wire(stop_event) + receiver = Receiver(channel, stop_event) assert iter(receiver) is receiver stop_event.set() assert next(receiver) is None - receiver._unwire() def test_sender_metrics_with_multiple_channels() -> None: c1 = Channel[TextFrame]() c2 = Channel[TextFrame]() - r1 = Receiver(c1) - r2 = Receiver(c2) - r1._wire(threading.Event()) - r2._wire(threading.Event()) + _r1 = Receiver(c1, threading.Event()) # noqa: F841 + _r2 = Receiver(c2, threading.Event()) # noqa: F841 sender = Sender(c1, c2) frame = TextFrame.new(text="x") sender.send(frame) @@ -73,16 +65,10 @@ def test_sender_metrics_with_multiple_channels() -> None: assert sender._byte_count > 0 assert sender._last_send_time > 0 assert sender.buffer_depth == 2 - r1._unwire() - r2._unwire() def test_receiver_lag_without_subscriber_or_cursor() -> None: channel = Channel[int]() - receiver = Receiver(channel) - assert receiver.lag == 0 - receiver._wire(threading.Event()) + receiver = Receiver(channel, threading.Event()) Sender(channel).send(5) assert receiver.lag == 1 - receiver._unwire() - assert receiver.lag == 0 diff --git a/backend/tests/core/test_component_graph.py b/backend/tests/core/test_component_graph.py index e43be576..2c9e4808 100644 --- a/backend/tests/core/test_component_graph.py +++ b/backend/tests/core/test_component_graph.py @@ -116,8 +116,8 @@ def test_composite_component_and_graph_manager(monkeypatch) -> None: assert gm.get_node_output("n1") assert gm.get_node_input("n2") gm.run() - assert gm.ui_senders() == {} - assert gm.ui_receivers() == {} + assert gm.ui_input_slots() == [] + assert gm.ui_output_slots() == [] gm.stop() gm.add_edge( @@ -144,8 +144,8 @@ def test_composite_component_start_stop(monkeypatch) -> None: n1 = Node(id_="a", type="DemoThread", init_args={}) sub = Graph(nodes={"a": n1}, edges=[]) comp = CompositeComponent("C", sub) - assert "a.1" in comp.get_input_types() - assert "a.1" in comp.get_output_types() + assert "1" in comp.get_input_types() + assert "1" in comp.get_output_types() comp.start((), ()) assert comp.status == Status.RUNNING comp.stop() diff --git a/backend/tests/core/test_frames_graph.py b/backend/tests/core/test_frames_graph.py index 6ec1f895..6422e012 100644 --- a/backend/tests/core/test_frames_graph.py +++ b/backend/tests/core/test_frames_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import threading import types from pathlib import Path from typing import NamedTuple @@ -185,19 +186,18 @@ def test_composite_component_additional_paths(monkeypatch) -> None: ) comp = CompositeComponent("Wrap", sub_graph) assert comp.type_ == "Wrap" - assert comp.get_input_types()["n1.maybe"] == Receiver[TextFrame] | None - assert "n1.out" not in comp.get_output_types() + assert comp.get_input_types()["maybe"] == Receiver[TextFrame] | None + assert "out" not in comp.get_output_types() startable = CompositeComponent( "WrapStart", Graph(nodes={"n1": Node(id_="n1", type="Known", init_args={})}, edges=[]), ) - recv = Receiver(Channel()) + recv = Receiver(Channel(), threading.Event()) send = Sender(Channel()) - named_inputs = {"n1.plain": recv, "n1.maybe": recv} - named_outputs = {"n1.out": send} - startable.start(named_inputs, named_outputs) + # Pass as tuple — order matches get_input_types() / get_output_types() + startable.start((recv, recv), (send,)) assert startable.status == Status.RUNNING assert startable._inner_manager is not None assert startable._inner_manager.receiver_handles()[("n1", "plain")] is recv @@ -211,10 +211,10 @@ def test_composite_component_additional_paths(monkeypatch) -> None: ) comp2.start((recv, recv), (send,)) assert comp2._inner_manager is not None - assert ("n1", "plain") not in comp2._inner_manager.receiver_handles() - assert ("n1", "maybe") not in comp2._inner_manager.receiver_handles() + # After run(), inner manager creates fresh handles — outer ones are on the node + assert ("n1", "plain") in comp2._inner_manager.receiver_handles() + assert ("n1", "maybe") in comp2._inner_manager.receiver_handles() assert ("n1", "out") in comp2._inner_manager.sender_handles() - assert comp2._inner_manager.sender_handles()[("n1", "out")] is not send comp2._status = Status.RUNNING current_manager = comp2._inner_manager @@ -241,9 +241,9 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: gm = GraphManager(Graph(nodes={}, edges=[])) with pytest.raises(ValueError): - gm.add_node("Missing", {}) + gm.add_primitive_node("Missing", {}) - added_id, _ = gm.add_node("OutputOnly", {}) + added_id, _ = gm.add_primitive_node("OutputOnly", {}) assert gm.component(added_id).type_ == "_OutputOnlyComp" running_node = Node(id_="running", type="Known", init_args={}) @@ -254,10 +254,11 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: calls: list[str] = [] gm2.stop = lambda: calls.append("stop") # type: ignore[method-assign] gm2.run = lambda: calls.append("run") # type: ignore[method-assign] - updated = gm2.update_node_init_args("running", {}) + updated, was_running = gm2.update_primitive_node_init_args("running", {}) assert updated is running_node - assert calls == ["stop", "run"] - assert gm2.update_node_init_args("missing", {}) is None + assert was_running is True + assert calls == ["stop"] + assert gm2.update_primitive_node_init_args("missing", {}) == (None, False) gm3 = GraphManager( Graph( @@ -270,7 +271,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: "registered_subclasses", classmethod(lambda cls: {"Known": _CompositeInner}), ) - assert gm3.update_node_init_args("running", {}) is None + assert gm3.update_primitive_node_init_args("running", {}) == (None, False) nodes = { "a": Node(id_="a", type="Known", init_args={}), @@ -288,7 +289,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: assert "a" not in gm4.graph.nodes assert gm4.graph.edges == [] assert gm4.components()["b"].stop_event.is_set() is True - assert gm4.components()["c"].stop_event.is_set() is True + assert gm4.components()["c"].stop_event.is_set() is False # upstream, not stopped composite_node = Node( id_="wrap", @@ -301,19 +302,23 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: gm5 = GraphManager(Graph(nodes={"wrap": composite_node}, edges=[])) assert isinstance(gm5.component("wrap"), CompositeComponent) + from src.api.ui.bridge import UIChannelBridge + ui_node = Node(id_="ui", type="Known", init_args={}) gm6 = GraphManager(Graph(nodes={"ui": ui_node}, edges=[])) - gm6._sender_handles.clear() - gm6.run() - assert ("ui", "ui_text") in gm6.ui_senders() - assert ("ui", "ui_text") in gm6.ui_receivers() + gm6._channel_map.clear() + bridge = UIChannelBridge() + recv_over, send_over = bridge.wire(gm6) + gm6.run(recv_over, send_over) + assert ("ui", "ui_text") in bridge._ui_senders + assert ("ui", "ui_text") in bridge._ui_receivers inner = gm6.component("ui") assert isinstance(inner, _CompositeInner) assert inner.started_inputs is not None and inner.started_inputs.maybe is None assert registered and registered[0]["node_id"] == "ui" assert GraphManager._build_tuple(None, {"x": 1}) == () - tuple_out = GraphManager._build_tuple(tuple[int, str], {"2": "b", "1": 1}) + tuple_out = GraphManager._build_tuple(tuple[int, str], {"1": 1, "2": "b"}) assert tuple_out == (1, "b") diff --git a/backend/tests/test_channel_component_graph_more.py b/backend/tests/test_channel_component_graph_more.py index ddb4ee8c..1f418d6a 100644 --- a/backend/tests/test_channel_component_graph_more.py +++ b/backend/tests/test_channel_component_graph_more.py @@ -53,14 +53,12 @@ def test_channel_remaining_paths(monkeypatch) -> None: assert channel._items == [] stop_event = threading.Event() - cursor = Receiver(channel) - cursor._wire(stop_event) + cursor = Receiver(channel, stop_event) cursor.blocking = False assert next(cursor) is None - newest = Receiver(channel) + newest = Receiver(channel, stop_event) newest.newest = True - newest._wire(stop_event) newest.newest = False newest.newest = False assert newest._sub_id in channel._cursors @@ -68,22 +66,20 @@ def test_channel_remaining_paths(monkeypatch) -> None: channel._items = [10] channel._offset = 1 - channel._cursors[cursor._sub_id] = 0 # type: ignore[index] - assert channel._get(cursor._sub_id, stop_event, blocking=False) is None # type: ignore[arg-type] + channel._cursors[cursor._sub_id] = 0 + assert channel._get(cursor._sub_id, stop_event, blocking=False) is None waiting_channel = Channel[int]() - newest_blocking = Receiver(waiting_channel) + newest_blocking = Receiver(waiting_channel, stop_event) newest_blocking.newest = True - newest_blocking._wire(stop_event) monkeypatch.setattr( waiting_channel._condition, "wait", lambda timeout=None: stop_event.set() ) assert next(newest_blocking) is None continue_channel = Channel[int]() - continue_newest = Receiver(continue_channel) + continue_newest = Receiver(continue_channel, threading.Event()) continue_newest.newest = True - continue_newest._wire(threading.Event()) wait_calls: list[float | None] = [] def _wake_with_item(timeout: float | None = None) -> None: @@ -98,11 +94,12 @@ def _wake_with_item(timeout: float | None = None) -> None: sender.send(2) assert sender._msg_count == 1 - bad_receiver = Receiver(channel) - bad_receiver._wired = True - bad_receiver._sub_id = 999 - bad_receiver._unwire = lambda: (_ for _ in ()).throw(RuntimeError("boom")) # type: ignore[method-assign] + # Test __del__ handles exceptions gracefully + bad_receiver = Receiver(channel, threading.Event()) + old_unregister = channel._unregister + channel._unregister = lambda sub_id: (_ for _ in ()).throw(RuntimeError("boom")) # type: ignore[method-assign,assignment] bad_receiver.__del__() + channel._unregister = old_unregister # type: ignore[method-assign] def test_threaded_component_remaining_paths(monkeypatch) -> None: @@ -170,5 +167,5 @@ def test_composite_component_ui_types_and_graph_run_uses_dicts() -> None: manager.run() - assert isinstance(captured["inputs"], dict) - assert isinstance(captured["outputs"], dict) + assert isinstance(captured["inputs"], tuple) + assert isinstance(captured["outputs"], tuple) diff --git a/backend/tests/test_dart_control.py b/backend/tests/test_dart_control.py index f12acbe7..4817f129 100644 --- a/backend/tests/test_dart_control.py +++ b/backend/tests/test_dart_control.py @@ -110,20 +110,16 @@ def test_manual_start_uses_wired_receivers(monkeypatch) -> None: goal_channel = Channel[GoalFrame]() goal_sender = Sender(goal_channel) - goal_receiver = Receiver(goal_channel) + goal_receiver = Receiver(goal_channel, component.stop_event) instruction_channel = Channel[TextFrame]() instruction_sender = Sender(instruction_channel) - instruction_receiver = Receiver(instruction_channel) + instruction_receiver = Receiver(instruction_channel, component.stop_event) motion_channel = Channel() motion_sender = Sender(motion_channel) - motion_receiver = Receiver(motion_channel) - - goal_receiver._wire(component.stop_event) - instruction_receiver._wire(component.stop_event) motion_stop_event = threading.Event() - motion_receiver._wire(motion_stop_event) + motion_receiver = Receiver(motion_channel, motion_stop_event) component.start( DartControlInputs(goal=goal_receiver, instruction=instruction_receiver), diff --git a/backend/tests/test_graph_manager.py b/backend/tests/test_graph_manager.py index abd81901..e3168da8 100644 --- a/backend/tests/test_graph_manager.py +++ b/backend/tests/test_graph_manager.py @@ -7,9 +7,9 @@ Responsibility | Interface method(s) ------------------------- | ------------------------------------------- - Node lifecycle | add_node, delete_node, get_node, update_node + Node lifecycle | add_primitive_node, delete_node, get_node, update_node Edge lifecycle | add_edge, delete_edge - Hot-reload of init args | update_node_init_args + Hot-reload of init args | update_primitive_node_init_args Channel reconciliation | _reconcile (private), _group (static) Pipeline start / stop | run, stop, reset Handle construction | _build_tuple (static) @@ -104,9 +104,9 @@ class TestAddNode: manager's internal dicts must stay in sync. """ - def test_add_node(self): + def test_add_primitive_node(self): gm = GraphManager(_empty_graph()) - node_id, node = gm.add_node("StubComponent", {}) + node_id, node = gm.add_primitive_node("StubComponent", {}) assert node_id in gm.graph.nodes assert gm.graph.nodes[node_id].type == "StubComponent" @@ -126,8 +126,8 @@ class TestAddAndDeleteEdge: def test_add_and_delete_edge(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) edge = Edge( source_node=id_a, @@ -136,18 +136,23 @@ def test_add_and_delete_edge(self): target_slot="data", ) gm.add_edge(edge) + gm.run() - # After adding: receiver handle exists for the target slot + # After run: receiver handle exists for the target slot assert (id_b, "data") in gm.receiver_handles() # Sender handle for source slot should be wired to at least one channel sender = gm.sender_handles().get((id_a, "data")) assert sender is not None assert len(sender._channels) >= 1 + gm.stop() gm.delete_edge(edge) + gm.run() - # After deleting: receiver handle should be gone - assert (id_b, "data") not in gm.receiver_handles() + # After deleting and re-running: sender has no channels (unconnected) + sender = gm.sender_handles().get((id_a, "data")) + assert sender is not None + assert len(sender._channels) == 0 class TestDeleteNodeRemovesEdges: @@ -161,9 +166,9 @@ class TestDeleteNodeRemovesEdges: def test_delete_node_removes_edges(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) - id_c, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) + id_c, _ = gm.add_primitive_node("StubComponent", {}) gm.add_edge( Edge( @@ -204,9 +209,9 @@ class TestReconcileReusesChannels: def test_reconcile_reuses_channels(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) - id_c, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) + id_c, _ = gm.add_primitive_node("StubComponent", {}) edge_ab = Edge( source_node=id_a, @@ -280,7 +285,7 @@ class TestUpdateNodePosition: def test_update_node_position(self): gm = GraphManager(_empty_graph()) - node_id, _ = gm.add_node("StubComponent", {}) + node_id, _ = gm.add_primitive_node("StubComponent", {}) result = gm.update_node(node_id, x=42.0, y=99.0) @@ -335,8 +340,8 @@ class TestRunAndStopLifecycle: def test_run_and_stop(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) gm.add_edge( Edge( source_node=id_a, @@ -374,10 +379,10 @@ class TestAddNodeUnknownType: partially-initialised node. """ - def test_add_node_unknown_type(self): + def test_add_primitive_node_unknown_type(self): gm = GraphManager(_empty_graph()) with pytest.raises(ValueError, match="Unknown node type"): - gm.add_node("NonExistentComponent", {}) + gm.add_primitive_node("NonExistentComponent", {}) # Graph must remain unchanged after the failed add assert len(gm.graph.nodes) == 0 @@ -393,7 +398,7 @@ class TestDeleteNonexistentNode: def test_delete_nonexistent_node(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) # Should not raise gm.delete_node("does-not-exist") @@ -408,7 +413,7 @@ class TestUpdateNonexistentNode: Why: Same rationale as delete — the API must handle stale references gracefully. Returning None lets callers distinguish "not found" from "updated". - Edge case: update_node and update_node_init_args both return None. + Edge case: update_node and update_primitive_node_init_args both return None. """ def test_update_nonexistent_node_position(self): @@ -418,5 +423,8 @@ def test_update_nonexistent_node_position(self): def test_update_nonexistent_node_init_args(self): gm = GraphManager(_empty_graph()) - result = gm.update_node_init_args("ghost", {"key": "value"}) - assert result is None + node, was_running = gm.update_primitive_node_init_args( + "ghost", {"key": "value"} + ) + assert node is None + assert was_running is False diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c9291d0d..3aca4a23 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -444,25 +444,24 @@ function AppInner({ // Wrap node changes — detect removals and call backend const onNodesChange: OnNodesChange = useCallback( (changes) => { - const removals = changes.filter((c) => c.type === "remove"); onNodesChangeRaw(changes); - for (const r of removals) { + for (const r of changes) { if (r.type === "remove") { if (r.id.startsWith("configuring-")) continue; - setEdges((currentEdges) => { - for (const e of currentEdges) { - if (e.source === r.id || e.target === r.id) { - deleteEdgeFromReactFlow(e); - } - } - return currentEdges.filter( - (e) => e.source !== r.id && e.target !== r.id, - ); - }); apiDeleteNode(r.id) - .then(() => { + .then(() => apiFetchEdges()) + .then((backendEdges) => { + setEdges((current) => + current.filter((e) => + backendEdges.some( + (be) => + be.source_node === e.source && + be.target_node === e.target, + ), + ), + ); runTypeCheck(); triggerSave(); }) @@ -473,28 +472,28 @@ function AppInner({ [onNodesChangeRaw, setEdges, triggerSave, runTypeCheck], ); - // Wrap edge changes — detect removals and call backend + // Wrap edge changes — detect removals and call backend. + // Let ReactFlow update local edge state. Backend edge deletion is + // handled by onEdgesDelete (user-initiated) or delete_node (node removal). const onEdgesChange: OnEdgesChange = useCallback( (changes) => { - const hasRemovals = changes.some((c) => c.type === "remove"); - setEdges((currentEdges) => { - for (const c of changes) { - if (c.type === "remove") { - const edge = currentEdges.find((e) => e.id === c.id); - if (edge) { - deleteEdgeFromReactFlow(edge); - } - } - } - return currentEdges; - }); onEdgesChangeRaw(changes); - if (hasRemovals) { + if (changes.some((c) => c.type === "remove")) { runTypeCheck(); triggerSave(); } }, - [onEdgesChangeRaw, setEdges, triggerSave, runTypeCheck], + [onEdgesChangeRaw, triggerSave, runTypeCheck], + ); + + // User explicitly deletes edges (backspace/delete key on selected edges) + const onEdgesDelete = useCallback( + (edges: Edge[]) => { + for (const edge of edges) { + deleteEdgeFromReactFlow(edge); + } + }, + [], ); // Handle new edge connections @@ -730,6 +729,7 @@ function AppInner({ edges={edges} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} + onEdgesDelete={onEdgesDelete} onConnect={onConnect} onDrop={onDrop} onDragOver={onDragOver} diff --git a/frontend/src/components/graph/GraphCanvas.tsx b/frontend/src/components/graph/GraphCanvas.tsx index b10f5201..7b15b5b4 100644 --- a/frontend/src/components/graph/GraphCanvas.tsx +++ b/frontend/src/components/graph/GraphCanvas.tsx @@ -38,6 +38,7 @@ interface GraphCanvasProps { onDragOver?: (event: React.DragEvent) => void; onNodeDragStop?: OnNodeDrag; onSelectionChange?: OnSelectionChangeFunc; + onEdgesDelete?: (edges: Edge[]) => void; onNodeContextMenu?: (event: React.MouseEvent, node: Node) => void; onPaneClick?: () => void; } @@ -51,6 +52,7 @@ export function GraphCanvas({ onDrop, onDragOver, onNodeDragStop, + onEdgesDelete, onSelectionChange, onNodeContextMenu, onPaneClick, @@ -100,6 +102,7 @@ export function GraphCanvas({ onDrop={onDrop} onDragOver={onDragOver} onNodeDragStop={onNodeDragStop} + onEdgesDelete={onEdgesDelete} onSelectionChange={onSelectionChange} onNodeContextMenu={onNodeContextMenu} onPaneClick={onPaneClick} diff --git a/frontend/src/components/graph/NodeSidebar.tsx b/frontend/src/components/graph/NodeSidebar.tsx index f0c2ee90..7d5b4f54 100644 --- a/frontend/src/components/graph/NodeSidebar.tsx +++ b/frontend/src/components/graph/NodeSidebar.tsx @@ -1,7 +1,30 @@ import { useState, useRef } from "react"; -import { Mic, AudioLines, MessageSquareText, Brain, Volume2, Radio, Speaker, Video, Monitor, Play, Camera, Puzzle, FolderOpen, ChevronDown, ChevronRight, Search, X } from "lucide-react"; +import { + Mic, + AudioLines, + MessageSquareText, + Brain, + Volume2, + Radio, + Speaker, + Video, + Monitor, + Play, + Camera, + Puzzle, + FolderOpen, + ChevronDown, + ChevronRight, + Search, + X, +} from "lucide-react"; import { cn } from "@/lib/utils"; -import type { ComponentInfo, IOTag, FunctionalityTag, GPUTag } from "@/lib/types"; +import type { + ComponentInfo, + IOTag, + FunctionalityTag, + GPUTag, +} from "@/lib/types"; const iconMap: Record> = { Mic, @@ -17,6 +40,24 @@ const iconMap: Record> = { VideoStream: Monitor, }; +const funcLabels: Record = { + audio: "Audio", + vision: "Vision", + llm: "LLM", + motion: "Motion", + misc: "Misc", + other: "Other", +}; + +const funcOrder: FunctionalityTag[] = [ + "audio", + "vision", + "llm", + "motion", + "misc", + "other", +]; + const ioLabels: Record = { source: "Sources", conduit: "Conduits", @@ -25,49 +66,94 @@ const ioLabels: Record = { const ioOrder: IOTag[] = ["source", "conduit", "sink"]; -const funcLabels: Record = { - audio: "Audio", - video: "Video", - llm: "LLM", - image: "Image", - movement: "Movement", - misc: "Misc", - other: "Other", -}; - const ioAccent: Record = { source: "text-source", conduit: "text-conduit", sink: "text-sink", }; -const ioTagColors: Record = { - source: { bg: "bg-source/15", text: "text-source", border: "border-source/30" }, - conduit: { bg: "bg-conduit/15", text: "text-conduit", border: "border-conduit/30" }, +const ioTagColors: Record< + IOTag, + { bg: string; text: string; border: string } +> = { + source: { + bg: "bg-source/15", + text: "text-source", + border: "border-source/30", + }, + conduit: { + bg: "bg-conduit/15", + text: "text-conduit", + border: "border-conduit/30", + }, sink: { bg: "bg-sink/15", text: "text-sink", border: "border-sink/30" }, }; -const funcTagColors: Record = { - audio: { bg: "bg-tag-audio/15", text: "text-tag-audio", border: "border-tag-audio/30" }, - video: { bg: "bg-tag-video/15", text: "text-tag-video", border: "border-tag-video/30" }, - llm: { bg: "bg-tag-llm/15", text: "text-tag-llm", border: "border-tag-llm/30" }, - image: { bg: "bg-tag-image/15", text: "text-tag-image", border: "border-tag-image/30" }, - movement: { bg: "bg-tag-movement/15", text: "text-tag-movement", border: "border-tag-movement/30" }, - misc: { bg: "bg-tag-misc/15", text: "text-tag-misc", border: "border-tag-misc/30" }, - other: { bg: "bg-tag-other/15", text: "text-tag-other", border: "border-tag-other/30" }, +const funcTagColors: Record< + FunctionalityTag, + { bg: string; text: string; border: string } +> = { + audio: { + bg: "bg-tag-audio/15", + text: "text-tag-audio", + border: "border-tag-audio/30", + }, + vision: { + bg: "bg-tag-vision/15", + text: "text-tag-vision", + border: "border-tag-vision/30", + }, + llm: { + bg: "bg-tag-llm/15", + text: "text-tag-llm", + border: "border-tag-llm/30", + }, + motion: { + bg: "bg-tag-motion/15", + text: "text-tag-motion", + border: "border-tag-motion/30", + }, + misc: { + bg: "bg-tag-misc/15", + text: "text-tag-misc", + border: "border-tag-misc/30", + }, + other: { + bg: "bg-tag-other/15", + text: "text-tag-other", + border: "border-tag-other/30", + }, }; -const gpuTagColors: Record = { - nvidia: { bg: "bg-tag-nvidia/15", text: "text-tag-nvidia", border: "border-tag-nvidia/30" }, - apple: { bg: "bg-tag-apple/15", text: "text-tag-apple", border: "border-tag-apple/30" }, - intel: { bg: "bg-tag-intel/15", text: "text-tag-intel", border: "border-tag-intel/30" }, - amd: { bg: "bg-tag-amd/15", text: "text-tag-amd", border: "border-tag-amd/30" }, +const gpuTagColors: Record< + string, + { bg: string; text: string; border: string } +> = { + nvidia: { + bg: "bg-tag-nvidia/15", + text: "text-tag-nvidia", + border: "border-tag-nvidia/30", + }, + apple: { + bg: "bg-tag-apple/15", + text: "text-tag-apple", + border: "border-tag-apple/30", + }, + intel: { + bg: "bg-tag-intel/15", + text: "text-tag-intel", + border: "border-tag-intel/30", + }, + amd: { + bg: "bg-tag-amd/15", + text: "text-tag-amd", + border: "border-tag-amd/30", + }, }; /** Render simple inline markdown: **bold**, *italic*, `code` */ function InlineMarkdown({ text }: { text: string }) { const parts: React.ReactNode[] = []; - // Match **bold**, *italic*, `code` const regex = /(\*\*(.+?)\*\*|\*(.+?)\*|`(.+?)`)/g; let lastIndex = 0; let match: RegExpExecArray | null; @@ -78,11 +164,26 @@ function InlineMarkdown({ text }: { text: string }) { parts.push(text.slice(lastIndex, match.index)); } if (match[2]) { - parts.push({match[2]}); + parts.push( + + {match[2]} + , + ); } else if (match[3]) { - parts.push({match[3]}); + parts.push( + + {match[3]} + , + ); } else if (match[4]) { - parts.push({match[4]}); + parts.push( + + {match[4]} + , + ); } lastIndex = match.index + match[0].length; } @@ -97,12 +198,15 @@ interface NodeSidebarProps { currentProject: string; } -/** Group components by IO → Functionality. Components with no tags go into `untagged`. */ +/** Group components by Functionality → IO. */ function groupComponents(components: ComponentInfo[]) { - const groups: Record> = { - source: {} as Record, - conduit: {} as Record, - sink: {} as Record, + const groups: Record> = { + audio: {} as Record, + vision: {} as Record, + llm: {} as Record, + motion: {} as Record, + misc: {} as Record, + other: {} as Record, }; const untagged: ComponentInfo[] = []; @@ -115,10 +219,10 @@ function groupComponents(components: ComponentInfo[]) { continue; } - for (const io of ioTags) { - for (const func of funcTags) { - if (!groups[io][func]) groups[io][func] = []; - groups[io][func].push(comp); + for (const func of funcTags) { + for (const io of ioTags) { + if (!groups[func][io]) groups[func][io] = []; + groups[func][io].push(comp); } } } @@ -126,7 +230,15 @@ function groupComponents(components: ComponentInfo[]) { return { groups, untagged }; } -function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: React.RefObject; y: number }) { +function InfoPanel({ + item, + sidebarRef, + y, +}: { + item: ComponentInfo; + sidebarRef: React.RefObject; + y: number; +}) { const inputs = Object.entries(item.inputs); const outputs = Object.entries(item.outputs); const rect = sidebarRef.current?.getBoundingClientRect(); @@ -143,42 +255,77 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R )} style={{ left: rect.right + 8, top: Math.max(8, rect.top + y - 40) }} > - {/* Type */}
{item.type_}
- {/* Description */} {item.description && (
)} - {/* Tags */}
{item.tags.io.map((t) => { const colors = ioTagColors[t]; return ( - {t} + + {t} + ); })} {item.tags.functionality.map((t) => { const colors = funcTagColors[t]; + if (!colors) return null; return ( - {t} - ); - })} - {item.tags.gpu.filter((g: GPUTag) => g !== "cpu").map((t: GPUTag) => { - const colors = gpuTagColors[t] ?? { bg: "bg-white/[0.06]", text: "text-white/60", border: "border-white/10" }; - return ( - {t} + + {t} + ); })} + {item.tags.gpu + .filter((g: GPUTag) => g !== "cpu") + .map((t: GPUTag) => { + const colors = gpuTagColors[t] ?? { + bg: "bg-white/[0.06]", + text: "text-white/60", + border: "border-white/10", + }; + return ( + + {t} + + ); + })}
- {/* IO */} {inputs.length > 0 && (
-
Inputs
+
+ Inputs +
{inputs.map(([k, v]) => (
{k}: @@ -189,7 +336,9 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R )} {outputs.length > 0 && (
-
Outputs
+
+ Outputs +
{outputs.map(([k, v]) => (
{k}: @@ -199,12 +348,15 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R
)} - {/* Init params */} {Object.keys(item.init).length > 0 && (
-
Config
+
+ Config +
{Object.keys(item.init).map((k) => ( -
{k}
+
+ {k} +
))}
)} @@ -212,9 +364,15 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R ); } -export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { +export function NodeSidebar({ + components, + currentProject, +}: NodeSidebarProps) { const [collapsed, setCollapsed] = useState>({}); - const [hovered, setHovered] = useState<{ item: ComponentInfo; y: number } | null>(null); + const [hovered, setHovered] = useState<{ + item: ComponentInfo; + y: number; + } | null>(null); const [search, setSearch] = useState(""); const hoverTimeout = useRef>(); const sidebarRef = useRef(null); @@ -241,7 +399,8 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { function onItemEnter(e: React.MouseEvent, item: ComponentInfo) { clearTimeout(hoverTimeout.current); const sidebarRect = sidebarRef.current?.getBoundingClientRect(); - const y = e.currentTarget.getBoundingClientRect().top - (sidebarRect?.top ?? 0); + const y = + e.currentTarget.getBoundingClientRect().top - (sidebarRect?.top ?? 0); hoverTimeout.current = setTimeout(() => setHovered({ item, y }), 300); } @@ -251,7 +410,9 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { } const primitives = components.filter((c) => !c.is_composite); - const composites = components.filter((c) => c.is_composite && c.type_ !== currentProject); + const composites = components.filter( + (c) => c.is_composite && c.type_ !== currentProject, + ); const query = search.toLowerCase(); const filtered = query ? primitives.filter((c) => c.type_.toLowerCase().includes(query)) @@ -263,191 +424,215 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { return ( <> -
-

- Components -

-
- - setSearch(e.target.value)} - placeholder="Search..." - className={cn( - "w-full pl-7 pr-7 py-1.5 text-[11px] text-white/90 placeholder:text-muted-foreground", - "bg-white/[0.06] border border-glass-border rounded-lg", - "outline-none focus:border-white/20 transition-colors", - )} - /> - {search && ( - +
-
- {ioOrder.map((io) => { - const funcGroups = groups[io]; - const funcKeys = Object.keys(funcGroups) as FunctionalityTag[]; - if (funcKeys.length === 0) return null; - - const ioKey = `io:${io}`; - const ioCollapsed = collapsed[ioKey]; + > +

+ Components +

+
+ + setSearch(e.target.value)} + placeholder="Search..." + className={cn( + "w-full pl-7 pr-7 py-1.5 text-[11px] text-white/90 placeholder:text-muted-foreground", + "bg-white/[0.06] border border-glass-border rounded-lg", + "outline-none focus:border-white/20 transition-colors", + )} + /> + {search && ( + + )} +
+
+ {funcOrder.map((func) => { + const ioGroups = groups[func]; + const ioKeys = ioOrder.filter((io) => ioGroups[io]?.length > 0); + if (ioKeys.length === 0) return null; + + const funcKey = `func:${func}`; + const funcCollapsed = collapsed[funcKey]; + + return ( +
+ + + {!funcCollapsed && + ioKeys.map((io) => { + const items = ioGroups[io]!; + const ioKey = `${func}:${io}`; + const ioCollapsed = collapsed[ioKey]; + + return ( +
+ + + {!ioCollapsed && + items.map((item) => { + const Icon = iconMap[item.type_] ?? Puzzle; + return ( +
onDragStart(e, item)} + onMouseEnter={(e) => onItemEnter(e, item)} + onMouseLeave={onItemLeave} + className={cn( + "flex items-center gap-2.5 px-3 py-2 ml-2 cursor-grab", + "transition-all duration-200", + "hover:bg-glass-hover", + )} + > + + + {item.type_} + +
+ ); + })} +
+ ); + })} +
+ ); + })} - return ( -
+ {/* Untagged components */} + {untagged.length > 0 && ( +
- - {!ioCollapsed && funcKeys.map((func) => { - const items = funcGroups[func]!; - const funcKey = `${io}:${func}`; - const funcCollapsed = collapsed[funcKey]; - - return ( -
- - - {!funcCollapsed && items.map((item) => { - const Icon = iconMap[item.type_] ?? Puzzle; - return ( -
onDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-2 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} - > - - - {item.type_} - -
- ); - })} -
- ); - })} + + + {item.type_} + +
+ ); + })}
- ); - })} - - {/* Untagged components */} - {untagged.length > 0 && ( -
- - {!collapsed["io:untagged"] && untagged.map((item) => { - const Icon = iconMap[item.type_] ?? Puzzle; - return ( -
onDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} - > - - - {item.type_} - -
- ); - })} -
- )} + )} - {/* Projects section */} - {filteredComposites.length > 0 && ( -
-
- - {!collapsed["projects"] && filteredComposites.map((item) => ( -
onCompositeDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} + {/* Projects section */} + {filteredComposites.length > 0 && ( +
+
+
- ))} -
- )} + {collapsed["projects"] ? ( + + ) : ( + + )} + Projects + + {!collapsed["projects"] && + filteredComposites.map((item) => ( +
onCompositeDragStart(e, item)} + onMouseEnter={(e) => onItemEnter(e, item)} + onMouseLeave={onItemLeave} + className={cn( + "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", + "transition-all duration-200", + "hover:bg-glass-hover", + )} + > + + + {item.type_} + +
+ ))} +
+ )} +
-
- - {/* Hover info panel — rendered outside sidebar to avoid overflow clip */} - {hovered && ( -
clearTimeout(hoverTimeout.current)} - onMouseLeave={onItemLeave} - > - -
- )} - + {/* Hover info panel — rendered outside sidebar to avoid overflow clip */} + {hovered && ( +
clearTimeout(hoverTimeout.current)} + onMouseLeave={onItemLeave} + > + +
+ )} + ); } diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 8aea519d..b8b24589 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -40,6 +40,16 @@ export async function fetchIsSubtype(sub: string, sup: string): Promise return res.json(); } +export async function fetchSubtypePairs(names: string[]): Promise<[string, string][]> { + const res = await fetch(`${API_BASE}/component/subtype-pairs`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(names), + }); + if (!res.ok) throw new Error(`subtype-pairs failed: ${res.status}`); + return res.json(); +} + export async function fetchNodes(): Promise { const res = await fetch(`${API_BASE}/graph/nodes`); if (!res.ok) throw new Error(`Fetch nodes failed: ${res.status}`); diff --git a/frontend/src/lib/typecheck.ts b/frontend/src/lib/typecheck.ts index 4a4358ce..5b9e8901 100644 --- a/frontend/src/lib/typecheck.ts +++ b/frontend/src/lib/typecheck.ts @@ -4,18 +4,18 @@ // https://doi.org/10.1145/3409006 import type { Graph, SlotType } from "./types"; -import { fetchIsSubtype } from "./api"; +import { fetchSubtypePairs } from "./api"; // Module-level subtype cache: "sub:sup" → true for known subtype pairs. -// Populated by warmSubtypeCache() before solving. +// Populated by warmSubtypeCache() with a single batch request. const subtypeSet = new Set(); export async function warmSubtypeCache(concreteNames: Iterable): Promise { const names = [...concreteNames]; - const pairs = names.flatMap((a) => names.filter((b) => a !== b).map((b) => [a, b] as const)); - const results = await Promise.all(pairs.map(async ([a, b]) => [a, b, await fetchIsSubtype(a, b)] as const)); - for (const [a, b, ok] of results) { - if (ok) subtypeSet.add(`${a}:${b}`); + if (names.length === 0) return; + const pairs = await fetchSubtypePairs(names); + for (const [a, b] of pairs) { + subtypeSet.add(`${a}:${b}`); } } diff --git a/frontend/src/styles/globals.css b/frontend/src/styles/globals.css index 50195800..35a72730 100644 --- a/frontend/src/styles/globals.css +++ b/frontend/src/styles/globals.css @@ -42,10 +42,9 @@ --color-handle: var(--handle); --color-handle-border: var(--handle-border); --color-tag-audio: var(--tag-audio); - --color-tag-video: var(--tag-video); + --color-tag-vision: var(--tag-vision); --color-tag-llm: var(--tag-llm); - --color-tag-image: var(--tag-image); - --color-tag-movement: var(--tag-movement); + --color-tag-motion: var(--tag-motion); --color-tag-misc: var(--tag-misc); --color-tag-other: var(--tag-other); --color-tag-nvidia: var(--tag-nvidia); @@ -83,10 +82,9 @@ /* Functionality tag colors */ --tag-audio: #a78bfa; - --tag-video: #fb923c; + --tag-vision: #fb923c; --tag-llm: #22d3ee; - --tag-image: #f472b6; - --tag-movement: #34d399; + --tag-motion: #34d399; --tag-misc: #94a3b8; --tag-other: #6b7280; diff --git a/frontend/tests/App.test.tsx b/frontend/tests/App.test.tsx index 8ad4eb63..59186815 100644 --- a/frontend/tests/App.test.tsx +++ b/frontend/tests/App.test.tsx @@ -159,7 +159,7 @@ vi.mock("@/components/graph/GraphCanvas", () => ({ -