From c57140ea035afa354944d5d0677ab831e6dc7341 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 13:02:15 -0400 Subject: [PATCH 01/21] refactor: rename add_node/update_node_init_args, fix edge deletion on node removal - Rename add_node to add_primitive_node, node_type to type_ - Rename update_node_init_args to update_primitive_node_init_args - Backend handles edge cleanup on node deletion (not frontend) - Only stop downstream nodes on delete, not upstream - Use onEdgesDelete for user-initiated edge removal only Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/graph/node/controller.py | 4 +- backend/src/api/graph/node/service.py | 10 ++-- backend/src/core/graph.py | 14 +++--- backend/tests/api/test_api_controllers.py | 8 ++-- backend/tests/api/test_node_service.py | 10 ++-- backend/tests/core/test_frames_graph.py | 10 ++-- backend/tests/test_graph_manager.py | 40 ++++++++-------- frontend/src/App.tsx | 56 +++++++++++------------ 8 files changed, 75 insertions(+), 77 deletions(-) diff --git a/backend/src/api/graph/node/controller.py b/backend/src/api/graph/node/controller.py index 35df8d24..db45ab49 100644 --- a/backend/src/api/graph/node/controller.py +++ b/backend/src/api/graph/node/controller.py @@ -79,12 +79,12 @@ def update_node( @router.patch("/nodes/{node_id}/init-args") -def update_node_init_args( +def update_primitive_node_init_args( node_id: str, req: NodeInitArgsUpdateRequest, manager: GraphManager = Depends(get_manager), ) -> NodeResponse: - node = service.update_node_init_args(manager, node_id, req.init_args) + node = service.update_primitive_node_init_args(manager, node_id, req.init_args) if node is None: raise HTTPException(status_code=404, detail=f"Node not found: {node_id}") return _node_response(node_id, node, manager) diff --git a/backend/src/api/graph/node/service.py b/backend/src/api/graph/node/service.py index 6da83068..dbe52c2a 100644 --- a/backend/src/api/graph/node/service.py +++ b/backend/src/api/graph/node/service.py @@ -17,14 +17,14 @@ def get_node(manager: GraphManager, node_id: str) -> Node | None: def create_node( - manager: GraphManager, node_type: str, init_args: dict[str, Any] + manager: GraphManager, type_: str, init_args: dict[str, Any] ) -> tuple[str, Node]: try: - return manager.add_node(node_type, init_args) + return manager.add_primitive_node(type_, init_args) except ValueError: pass # Fallback: try loading as a project - return create_from_project(manager, node_type) + return create_from_project(manager, type_) def update_node( @@ -37,12 +37,12 @@ def delete_node(manager: GraphManager, node_id: str) -> None: manager.delete_node(node_id) -def update_node_init_args( +def update_primitive_node_init_args( manager: GraphManager, node_id: str, init_args: dict[str, Any], ) -> Node | None: - return manager.update_node_init_args(node_id, init_args) + return manager.update_primitive_node_init_args(node_id, init_args) def create_subgraph( diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index 535d9e06..5211314f 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -63,13 +63,13 @@ def __init__(self, graph: Graph) -> None: # --- node CRUD --- - def add_node(self, node_type: str, init_args: dict[str, Any]) -> tuple[str, Node]: + def add_primitive_node(self, type_: str, init_args: dict[str, Any]) -> tuple[str, Node]: classes = PrimitiveComponent.registered_subclasses() - cls = classes.get(node_type) + cls = classes.get(type_) if cls is None: - raise ValueError(f"Unknown node type: {node_type}") + raise ValueError(f"Unknown node type: {type_}") comp = cls.from_args(init_args) - node = Node(type=node_type, init_args=init_args) + node = Node(type=type_, init_args=init_args) self._graph.nodes[node.id_] = node self._components[node.id_] = comp return node.id_, node @@ -107,7 +107,7 @@ def update_node(self, node_id: str, x: float, y: float) -> Node | None: node.y = y return node - def update_node_init_args( + def update_primitive_node_init_args( self, node_id: str, init_args: dict[str, Any] ) -> Node | None: node = self._graph.nodes.get(node_id) @@ -142,13 +142,11 @@ def delete_node(self, node_id: str) -> None: if comp is not None: comp.stop() - # Collect connected components that need stopping + # Collect downstream components that need stopping affected: set[str] = set() for edge in self._graph.edges: if edge.source_node == node_id: affected.add(edge.target_node) - if edge.target_node == node_id: - affected.add(edge.source_node) self._graph.edges = [ e diff --git a/backend/tests/api/test_api_controllers.py b/backend/tests/api/test_api_controllers.py index c7795339..12cad08f 100644 --- a/backend/tests/api/test_api_controllers.py +++ b/backend/tests/api/test_api_controllers.py @@ -126,17 +126,17 @@ def test_node_controller_paths(monkeypatch) -> None: ) monkeypatch.setattr( - node_controller.service, "update_node_init_args", lambda *a, **k: None + node_controller.service, "update_primitive_node_init_args", lambda *a, **k: None ) with pytest.raises(HTTPException): - node_controller.update_node_init_args( + node_controller.update_primitive_node_init_args( "n1", NodeInitArgsUpdateRequest(init_args={}), manager ) monkeypatch.setattr( - node_controller.service, "update_node_init_args", lambda *a, **k: node + node_controller.service, "update_primitive_node_init_args", lambda *a, **k: node ) assert ( - node_controller.update_node_init_args( + node_controller.update_primitive_node_init_args( "n1", NodeInitArgsUpdateRequest(init_args={}), manager ).id == "n1" diff --git a/backend/tests/api/test_node_service.py b/backend/tests/api/test_node_service.py index 4913a457..d84c7002 100644 --- a/backend/tests/api/test_node_service.py +++ b/backend/tests/api/test_node_service.py @@ -35,10 +35,10 @@ def __init__(self) -> None: self._components = {"n1": _FakeComp(), "n2": _FakeComp()} self.reconciled = 0 - def add_node(self, node_type, init_args): - if node_type != "A": + def add_primitive_node(self, type_, init_args): + if type_ != "A": raise ValueError("bad") - n = Node(id_="n3", type=node_type, init_args=init_args) + n = Node(id_="n3", type=type_, init_args=init_args) self.graph.nodes["n3"] = n self._components["n3"] = _FakeComp() return "n3", n @@ -64,7 +64,7 @@ def delete_node(self, node_id): self.graph.nodes.pop(node_id, None) self._components.pop(node_id, None) - def update_node_init_args(self, node_id, init_args): + def update_primitive_node_init_args(self, node_id, init_args): n = self.graph.nodes.get(node_id) if n is None: return None @@ -88,7 +88,7 @@ def test_node_service_basic_crud() -> None: assert updated.x == 9 node_service.delete_node(m, "n2") assert "n2" not in m.graph.nodes - out = node_service.update_node_init_args(m, "n1", {"k": 1}) + out = node_service.update_primitive_node_init_args(m, "n1", {"k": 1}) assert out.init_args == {"k": 1} diff --git a/backend/tests/core/test_frames_graph.py b/backend/tests/core/test_frames_graph.py index 6ec1f895..10c26279 100644 --- a/backend/tests/core/test_frames_graph.py +++ b/backend/tests/core/test_frames_graph.py @@ -241,9 +241,9 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: gm = GraphManager(Graph(nodes={}, edges=[])) with pytest.raises(ValueError): - gm.add_node("Missing", {}) + gm.add_primitive_node("Missing", {}) - added_id, _ = gm.add_node("OutputOnly", {}) + added_id, _ = gm.add_primitive_node("OutputOnly", {}) assert gm.component(added_id).type_ == "_OutputOnlyComp" running_node = Node(id_="running", type="Known", init_args={}) @@ -254,10 +254,10 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: calls: list[str] = [] gm2.stop = lambda: calls.append("stop") # type: ignore[method-assign] gm2.run = lambda: calls.append("run") # type: ignore[method-assign] - updated = gm2.update_node_init_args("running", {}) + updated = gm2.update_primitive_node_init_args("running", {}) assert updated is running_node assert calls == ["stop", "run"] - assert gm2.update_node_init_args("missing", {}) is None + assert gm2.update_primitive_node_init_args("missing", {}) is None gm3 = GraphManager( Graph( @@ -270,7 +270,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: "registered_subclasses", classmethod(lambda cls: {"Known": _CompositeInner}), ) - assert gm3.update_node_init_args("running", {}) is None + assert gm3.update_primitive_node_init_args("running", {}) is None nodes = { "a": Node(id_="a", type="Known", init_args={}), diff --git a/backend/tests/test_graph_manager.py b/backend/tests/test_graph_manager.py index abd81901..5bdaf6dd 100644 --- a/backend/tests/test_graph_manager.py +++ b/backend/tests/test_graph_manager.py @@ -7,9 +7,9 @@ Responsibility | Interface method(s) ------------------------- | ------------------------------------------- - Node lifecycle | add_node, delete_node, get_node, update_node + Node lifecycle | add_primitive_node, delete_node, get_node, update_node Edge lifecycle | add_edge, delete_edge - Hot-reload of init args | update_node_init_args + Hot-reload of init args | update_primitive_node_init_args Channel reconciliation | _reconcile (private), _group (static) Pipeline start / stop | run, stop, reset Handle construction | _build_tuple (static) @@ -104,9 +104,9 @@ class TestAddNode: manager's internal dicts must stay in sync. """ - def test_add_node(self): + def test_add_primitive_node(self): gm = GraphManager(_empty_graph()) - node_id, node = gm.add_node("StubComponent", {}) + node_id, node = gm.add_primitive_node("StubComponent", {}) assert node_id in gm.graph.nodes assert gm.graph.nodes[node_id].type == "StubComponent" @@ -126,8 +126,8 @@ class TestAddAndDeleteEdge: def test_add_and_delete_edge(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) edge = Edge( source_node=id_a, @@ -161,9 +161,9 @@ class TestDeleteNodeRemovesEdges: def test_delete_node_removes_edges(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) - id_c, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) + id_c, _ = gm.add_primitive_node("StubComponent", {}) gm.add_edge( Edge( @@ -204,9 +204,9 @@ class TestReconcileReusesChannels: def test_reconcile_reuses_channels(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) - id_c, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) + id_c, _ = gm.add_primitive_node("StubComponent", {}) edge_ab = Edge( source_node=id_a, @@ -280,7 +280,7 @@ class TestUpdateNodePosition: def test_update_node_position(self): gm = GraphManager(_empty_graph()) - node_id, _ = gm.add_node("StubComponent", {}) + node_id, _ = gm.add_primitive_node("StubComponent", {}) result = gm.update_node(node_id, x=42.0, y=99.0) @@ -335,8 +335,8 @@ class TestRunAndStopLifecycle: def test_run_and_stop(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) - id_b, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) + id_b, _ = gm.add_primitive_node("StubComponent", {}) gm.add_edge( Edge( source_node=id_a, @@ -374,10 +374,10 @@ class TestAddNodeUnknownType: partially-initialised node. """ - def test_add_node_unknown_type(self): + def test_add_primitive_node_unknown_type(self): gm = GraphManager(_empty_graph()) with pytest.raises(ValueError, match="Unknown node type"): - gm.add_node("NonExistentComponent", {}) + gm.add_primitive_node("NonExistentComponent", {}) # Graph must remain unchanged after the failed add assert len(gm.graph.nodes) == 0 @@ -393,7 +393,7 @@ class TestDeleteNonexistentNode: def test_delete_nonexistent_node(self): gm = GraphManager(_empty_graph()) - id_a, _ = gm.add_node("StubComponent", {}) + id_a, _ = gm.add_primitive_node("StubComponent", {}) # Should not raise gm.delete_node("does-not-exist") @@ -408,7 +408,7 @@ class TestUpdateNonexistentNode: Why: Same rationale as delete — the API must handle stale references gracefully. Returning None lets callers distinguish "not found" from "updated". - Edge case: update_node and update_node_init_args both return None. + Edge case: update_node and update_primitive_node_init_args both return None. """ def test_update_nonexistent_node_position(self): @@ -418,5 +418,5 @@ def test_update_nonexistent_node_position(self): def test_update_nonexistent_node_init_args(self): gm = GraphManager(_empty_graph()) - result = gm.update_node_init_args("ghost", {"key": "value"}) + result = gm.update_primitive_node_init_args("ghost", {"key": "value"}) assert result is None diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c9291d0d..3aca4a23 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -444,25 +444,24 @@ function AppInner({ // Wrap node changes — detect removals and call backend const onNodesChange: OnNodesChange = useCallback( (changes) => { - const removals = changes.filter((c) => c.type === "remove"); onNodesChangeRaw(changes); - for (const r of removals) { + for (const r of changes) { if (r.type === "remove") { if (r.id.startsWith("configuring-")) continue; - setEdges((currentEdges) => { - for (const e of currentEdges) { - if (e.source === r.id || e.target === r.id) { - deleteEdgeFromReactFlow(e); - } - } - return currentEdges.filter( - (e) => e.source !== r.id && e.target !== r.id, - ); - }); apiDeleteNode(r.id) - .then(() => { + .then(() => apiFetchEdges()) + .then((backendEdges) => { + setEdges((current) => + current.filter((e) => + backendEdges.some( + (be) => + be.source_node === e.source && + be.target_node === e.target, + ), + ), + ); runTypeCheck(); triggerSave(); }) @@ -473,28 +472,28 @@ function AppInner({ [onNodesChangeRaw, setEdges, triggerSave, runTypeCheck], ); - // Wrap edge changes — detect removals and call backend + // Wrap edge changes — detect removals and call backend. + // Let ReactFlow update local edge state. Backend edge deletion is + // handled by onEdgesDelete (user-initiated) or delete_node (node removal). const onEdgesChange: OnEdgesChange = useCallback( (changes) => { - const hasRemovals = changes.some((c) => c.type === "remove"); - setEdges((currentEdges) => { - for (const c of changes) { - if (c.type === "remove") { - const edge = currentEdges.find((e) => e.id === c.id); - if (edge) { - deleteEdgeFromReactFlow(edge); - } - } - } - return currentEdges; - }); onEdgesChangeRaw(changes); - if (hasRemovals) { + if (changes.some((c) => c.type === "remove")) { runTypeCheck(); triggerSave(); } }, - [onEdgesChangeRaw, setEdges, triggerSave, runTypeCheck], + [onEdgesChangeRaw, triggerSave, runTypeCheck], + ); + + // User explicitly deletes edges (backspace/delete key on selected edges) + const onEdgesDelete = useCallback( + (edges: Edge[]) => { + for (const edge of edges) { + deleteEdgeFromReactFlow(edge); + } + }, + [], ); // Handle new edge connections @@ -730,6 +729,7 @@ function AppInner({ edges={edges} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} + onEdgesDelete={onEdgesDelete} onConnect={onConnect} onDrop={onDrop} onDragOver={onDragOver} From 883ee8cfda9e2c6ec4297874eb6ba760f5f294d4 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 15:12:38 -0400 Subject: [PATCH 02/21] refactor: separate graph topology, channel topology, and wiring Three-layer separation in GraphManager: - Graph topology: nodes/edges CRUD (unchanged) - Channel topology: _reconcile() returns (sender_plan, receiver_plan) as a pure plan without touching live handles - Wiring: run() creates fresh Sender/Receiver from the plan, stores on Node.senders/Node.receivers. stop() clears them. Also: - Receiver takes (channel, stop_event) in __init__, registers immediately - Remove _wire/_unwire from Receiver - Remove _sender_handles, _receiver_handles, _ui_channels from GraphManager - sender_handles()/receiver_handles() rebuilt from nodes on the fly - CompositeComponent typed properly instead of Component[Any, Any] - CompositeComponent boundary naming uses slot name disambiguation - Only stop downstream nodes on delete_node, not upstream Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/core/channel.py | 43 ++----- backend/src/core/component.py | 83 +++++++++--- backend/src/core/graph.py | 162 ++++++++++-------------- backend/tests/core/test_frames_graph.py | 2 +- 4 files changed, 147 insertions(+), 143 deletions(-) diff --git a/backend/src/core/channel.py b/backend/src/core/channel.py index a12474da..d71bb26a 100644 --- a/backend/src/core/channel.py +++ b/backend/src/core/channel.py @@ -116,17 +116,20 @@ def buffer_depth(self) -> int: class Receiver[T]: - """Handle for receiving from a channel. Is itself the iterator.""" + """Handle for receiving from a channel. Is itself the iterator. - def __init__(self, channel: Channel[T]) -> None: + Registers with the channel on construction, unregisters on GC. + """ + + def __init__(self, channel: Channel[T], stop_event: threading.Event) -> None: self._channel = channel + self._stop_event = stop_event + self._sub_id: int = id(self) self._msg_count: int = 0 self._byte_count: int = 0 - self._sub_id: int | None = None - self._stop_event: threading.Event | None = None - self._wired: bool = False self._newest: bool = False self.blocking: bool = True + self._channel._register(self._sub_id) @property def newest(self) -> bool: @@ -137,32 +140,15 @@ def newest(self, value: bool) -> None: if self._newest == value: return self._newest = value - if self._wired: - self._channel._reregister(self._sub_id, newest=value) # type: ignore[arg-type] - - def _wire(self, stop_event: threading.Event) -> None: - """Register with the channel. Called by GraphManager.run().""" - if self._wired: - return - self._sub_id = id(self) - self._stop_event = stop_event - self._channel._register(self._sub_id, newest=self._newest) - self._wired = True - - def _unwire(self) -> None: - """Unregister from the channel. Idempotent.""" - if not self._wired: - return - self._channel._unregister(self._sub_id) # type: ignore[arg-type] - self._wired = False + self._channel._reregister(self._sub_id, newest=value) def __iter__(self) -> Iterator[T | None]: return self def __next__(self) -> T | None: item = self._channel._get( - self._sub_id, # type: ignore[arg-type] - self._stop_event, # type: ignore[arg-type] + self._sub_id, + self._stop_event, blocking=self.blocking, ) if item is not None: @@ -172,18 +158,17 @@ def __next__(self) -> T | None: def __del__(self) -> None: try: - self._unwire() + self._channel._unregister(self._sub_id) except Exception: pass @property def lag(self) -> int: - sub_id = self._sub_id - if sub_id is None or self._newest: + if self._newest: return 0 ch = self._channel with ch._condition: - cursor = ch._cursors.get(sub_id) + cursor = ch._cursors.get(self._sub_id) if cursor is None: return 0 head = ch._offset + len(ch._items) diff --git a/backend/src/core/component.py b/backend/src/core/component.py index 58c1c474..6f0733a2 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -395,8 +395,13 @@ def emit(self, outputs: E) -> None: ... # --------------------------------------------------------------------------- -class CompositeComponent(Component[Any, Any]): - """A composite morphism: its interface is derived from unmatched ports in the subgraph.""" +class CompositeComponent(Component[tuple[Receiver[Any] | None, ...], tuple[Sender[Any] | None, ...]]): + """A composite component wrapping a subgraph. + + Its interface is derived from unmatched ports in the subgraph. + start() receives tuples whose element order matches + get_input_types() / get_output_types() key order. + """ _registerable = False @@ -429,21 +434,61 @@ def _compute_boundary( connected_outputs.add((edge.source_node, edge.source_slot)) classes = PrimitiveComponent.registered_subclasses() - ext_inputs: dict[str, tuple[str, str]] = {} - ext_outputs: dict[str, tuple[str, str]] = {} + + # Collect unconnected (boundary) slots: (node_id, slot, component_type) + raw_inputs: list[tuple[str, str, str]] = [] + raw_outputs: list[tuple[str, str, str]] = [] for node_id, node in self._sub_graph.nodes.items(): cls = classes.get(node.type) if cls is None: - # Node is a CompositeComponent — nested composites not yet supported continue for slot in cls._class_input_types(): if (node_id, slot) not in connected_inputs: - ext_inputs[f"{node_id}.{slot}"] = (node_id, slot) + raw_inputs.append((node_id, slot, node.type)) for slot in cls._class_output_types(): if (node_id, slot) not in connected_outputs: - ext_outputs[f"{node_id}.{slot}"] = (node_id, slot) + raw_outputs.append((node_id, slot, node.type)) + + ext_inputs = self._disambiguate(raw_inputs) + ext_outputs = self._disambiguate(raw_outputs) return ext_inputs, ext_outputs + @staticmethod + def _disambiguate( + slots: list[tuple[str, str, str]], + ) -> dict[str, tuple[str, str]]: + """Assign unique external names to boundary slots. + + Uses the slot name alone if unique, otherwise prefixes with the + component type. If still ambiguous, appends a numeric suffix. + """ + from collections import Counter + + # Count how many times each slot name appears + slot_counts = Counter(slot for _, slot, _ in slots) + + # For duplicates, try "Type.slot" + names: list[str] = [] + for _, slot, comp_type in slots: + if slot_counts[slot] == 1: + names.append(slot) + else: + names.append(f"{comp_type}.{slot}") + + # If still duplicates, append numbers + final_counts: dict[str, int] = Counter(names) + seen: dict[str, int] = {} + result: dict[str, tuple[str, str]] = {} + for i, name in enumerate(names): + node_id, slot, _ = slots[i] + if final_counts[name] > 1: + idx = seen.get(name, 0) + 1 + seen[name] = idx + result[f"{name}.{idx}"] = (node_id, slot) + else: + result[name] = (node_id, slot) + return result + def get_input_types(self) -> dict[str, type]: classes = PrimitiveComponent.registered_subclasses() result: dict[str, type] = {} @@ -474,7 +519,7 @@ def get_ui_input_types(self) -> dict[str, type]: def get_ui_output_types(self) -> dict[str, type]: return {} - def start(self, inputs: Any, outputs: Any) -> None: + def start(self, inputs: tuple[Receiver[Any] | None, ...], outputs: tuple[Sender[Any] | None, ...]) -> None: if self.status == Status.RUNNING: return self._status = Status.SETUP @@ -483,15 +528,21 @@ def start(self, inputs: Any, outputs: Any) -> None: self._inner_manager = GraphManager(self._sub_graph) - for ext_name, (node_id, slot) in self._ext_inputs.items(): - outer_receiver = inputs.get(ext_name) if isinstance(inputs, dict) else None - if outer_receiver is not None: - self._inner_manager._receiver_handles[(node_id, slot)] = outer_receiver + for ext_name, recv in zip(self._ext_inputs, inputs): + if recv is None: + continue + node_id, slot = self._ext_inputs[ext_name] + inner_node = self._inner_manager.graph.nodes.get(node_id) + if inner_node is not None: + inner_node.receivers[slot] = recv - for ext_name, (node_id, slot) in self._ext_outputs.items(): - outer_sender = outputs.get(ext_name) if isinstance(outputs, dict) else None - if outer_sender is not None: - self._inner_manager._sender_handles[(node_id, slot)] = outer_sender + for ext_name, send in zip(self._ext_outputs, outputs): + if send is None: + continue + node_id, slot = self._ext_outputs[ext_name] + inner_node = self._inner_manager.graph.nodes.get(node_id) + if inner_node is not None: + inner_node.senders[slot] = send self._inner_manager.run() self._status = Status.RUNNING diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index 5211314f..13e4abc4 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -1,10 +1,11 @@ from __future__ import annotations import asyncio +import threading import time import uuid from collections import defaultdict -from typing import Any, get_args, get_origin +from typing import Any, get_origin from pydantic import BaseModel, Field @@ -22,12 +23,16 @@ class Node(BaseModel): + model_config = {"arbitrary_types_allowed": True} + id_: str = Field(default_factory=lambda: str(uuid.uuid4())) type: str init_args: dict[str, Any] x: float = 0.0 y: float = 0.0 sub_graph: Graph | None = None + senders: dict[str, Sender[Any] | None] = Field(default_factory=dict, exclude=True) + receivers: dict[str, Receiver[Any] | None] = Field(default_factory=dict, exclude=True) class Edge(BaseModel): @@ -51,10 +56,7 @@ def __init__(self, graph: Graph) -> None: self._graph = Graph(edges=[], nodes={}) self._components: dict[str, Component[Any, Any]] = {} self._channel_map: dict[frozenset[SenderKey], Channel[Any]] = {} - self._sender_handles: dict[SenderKey, Sender[Any]] = {} - self._receiver_handles: dict[ReceiverKey, Receiver[Any]] = {} # UI channels: keyed by (node_id, slot_name) - self._ui_channels: dict[tuple[str, str], Channel[Any]] = {} self._ui_senders: dict[tuple[str, str], Sender[Any]] = {} self._ui_receivers: dict[tuple[str, str], Receiver[Any]] = {} self._ui_version = 0 @@ -185,10 +187,20 @@ def components(self) -> dict[str, Component[Any, Any]]: return self._components def sender_handles(self) -> dict[SenderKey, Sender[Any]]: - return self._sender_handles + return { + (node_id, slot): sender + for node_id, node in self._graph.nodes.items() + for slot, sender in node.senders.items() + if sender is not None + } def receiver_handles(self) -> dict[ReceiverKey, Receiver[Any]]: - return self._receiver_handles + return { + (node_id, slot): receiver + for node_id, node in self._graph.nodes.items() + for slot, receiver in node.receivers.items() + if receiver is not None + } def ui_senders(self) -> dict[tuple[str, str], Sender[Any]]: """Server-side senders that push data into component UIReceiver slots.""" @@ -210,9 +222,6 @@ def reset(self, graph: Graph) -> None: self._graph = graph self._components.clear() self._channel_map.clear() - self._sender_handles.clear() - self._receiver_handles.clear() - self._ui_channels.clear() self._ui_senders.clear() self._ui_receivers.clear() @@ -246,8 +255,14 @@ def _group( return dict(groups) - def _reconcile(self) -> None: - """Recompute optimal channel layout and diff against existing.""" + def _reconcile( + self, + ) -> tuple[dict[SenderKey, list[Channel[Any]]], dict[ReceiverKey, Channel[Any]]]: + """Recompute channel topology from the current graph edges. + + Returns (sender_plan, receiver_plan) — the wiring blueprint. + Does not create or touch Sender/Receiver handles — that happens in run(). + """ edges: list[tuple[SenderKey, ReceiverKey]] = [ ((e.source_node, e.source_slot), (e.target_node, e.target_slot)) for e in self._graph.edges @@ -257,133 +272,81 @@ def _reconcile(self) -> None: old_keys = set(self._channel_map.keys()) new_keys = set(groups.keys()) - reuse = old_keys & new_keys - create = new_keys - old_keys - + # step 1: create/remove channels new_channel_map: dict[frozenset[SenderKey], Channel[Any]] = {} - for ckey in reuse: + for ckey in old_keys & new_keys: new_channel_map[ckey] = self._channel_map[ckey] - for ckey in create: + for ckey in new_keys - old_keys: new_channel_map[ckey] = Channel() - self._channel_map = new_channel_map - sender_channels: dict[SenderKey, list[Channel[Any]]] = defaultdict(list) + # step 2: build the wiring plan + sender_plan: dict[SenderKey, list[Channel[Any]]] = defaultdict(list) for sender_set, channel in self._channel_map.items(): for sender_key in sender_set: - sender_channels[sender_key].append(channel) + sender_plan[sender_key].append(channel) - new_sender_handles: dict[SenderKey, Sender[Any]] = {} - for skey, channels in sender_channels.items(): - old_sender = self._sender_handles.get(skey) - if old_sender is not None and set(old_sender._channels) == set(channels): - new_sender_handles[skey] = old_sender - else: - new_sender_handles[skey] = Sender(*channels) - - # Ensure every declared output slot has a sender handle, even when no edges - # are connected. This keeps per-output metrics (e.g. last_send_time) visible - # for standalone sources. - for node_id, comp in self._components.items(): - for slot in comp.get_output_types(): - output_key: SenderKey = (node_id, slot) - if output_key in new_sender_handles: - continue - old_sender = self._sender_handles.get(output_key) - if old_sender is not None and len(old_sender._channels) == 0: - new_sender_handles[output_key] = old_sender - else: - new_sender_handles[output_key] = Sender() - self._sender_handles = new_sender_handles - - new_receiver_handles: dict[ReceiverKey, Receiver[Any]] = {} + receiver_plan: dict[ReceiverKey, Channel[Any]] = {} for sender_set, recv_keys in groups.items(): channel = self._channel_map[sender_set] for recv_key in recv_keys: - old_recv: Receiver[Any] | None = self._receiver_handles.get(recv_key) # type: ignore[assignment] - if old_recv is not None and old_recv._channel is channel: - new_receiver_handles[recv_key] = old_recv - else: - new_receiver_handles[recv_key] = Receiver(channel) - self._receiver_handles = new_receiver_handles + receiver_plan[recv_key] = channel + + return dict(sender_plan), receiver_plan def run(self) -> None: - """Stop all running components, then start each with wired handles.""" + """Stop all running components, then start each with fresh handles.""" self.stop() - for sender in self._sender_handles.values(): - sender._stopped = False - self._ui_channels.clear() self._ui_senders.clear() self._ui_receivers.clear() + + sender_plan, receiver_plan = self._reconcile() + start_queue: list[tuple[str, Component[Any, Any], Any, Any]] = [] - for node_id in self._graph.nodes: + for node_id, node in self._graph.nodes.items(): comp = self._components[node_id] cls = type(comp) input_type = cls._get_type_param(0) output_type = cls._get_type_param(1) - # Use instance calls so CompositeComponent overrides take effect input_slots = comp.get_input_types() output_slots = comp.get_output_types() ui_input_slots = comp.get_ui_input_types() ui_output_slots = comp.get_ui_output_types() - from src.core.component import CompositeComponent + stop_event = comp.stop_event if isinstance(comp, ThreadedComponent) else threading.Event() - is_composite = isinstance(comp, CompositeComponent) - - input_handles: dict[str, Receiver[Any] | None] = {} - for slot, slot_type in input_slots.items(): - rkey: ReceiverKey = (node_id, slot) - if rkey in self._receiver_handles: - input_handles[slot] = self._receiver_handles[rkey] - elif type(None) in get_args(slot_type): - # Optional inputs with no edge get None so NamedTuple - # construction doesn't fail from a missing field. - input_handles[slot] = None + # Create fresh handles from plan, store on node. + # Pre-existing handles (e.g. from composite boundary wiring) are kept. + for slot in input_slots: + if (node_id, slot) in receiver_plan: + node.receivers[slot] = Receiver(receiver_plan[(node_id, slot)], stop_event) + else: + node.receivers[slot] = None - output_handles: dict[str, Sender[Any]] = {} for slot in output_slots: - skey: SenderKey = (node_id, slot) - if skey in self._sender_handles: - output_handles[slot] = self._sender_handles[skey] + if (node_id, slot) in sender_plan: + node.senders[slot] = Sender(*sender_plan[(node_id, slot)]) else: - # Unconnected output: no-op Sender (sends are discarded) - output_handles[slot] = Sender() + node.senders[slot] = None # Wire UI input channels (frontend -> component) for slot, slot_type in ui_input_slots.items(): ch: Channel[Any] = Channel() - self._ui_channels[(node_id, slot)] = ch - # Component gets a UIReceiver to read from origin = get_origin(slot_type) or slot_type - input_handles[slot] = origin(ch) - # Server keeps a Sender to push data in + node.receivers[slot] = origin(ch, stop_event) self._ui_senders[(node_id, slot)] = Sender(ch) # Wire UI output channels (component -> frontend) for slot, slot_type in ui_output_slots.items(): ch = Channel() - self._ui_channels[(node_id, slot)] = ch - # Component gets a UISender to write to origin = get_origin(slot_type) or slot_type - output_handles[slot] = origin(ch) - # Server keeps a Receiver to read from - self._ui_receivers[(node_id, slot)] = Receiver(ch) + node.senders[slot] = origin(ch) + self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) - # Wire all receivers (registers cursors before setup) - if isinstance(comp, ThreadedComponent): - for handle in input_handles.values(): - if isinstance(handle, Receiver): - handle._wire(comp.stop_event) - - if is_composite: - built_inputs: Any = input_handles - built_outputs: Any = output_handles - else: - built_inputs = self._build_tuple(input_type, input_handles) - built_outputs = self._build_tuple(output_type, output_handles) + built_inputs = self._build_tuple(input_type, dict(node.receivers)) + built_outputs = self._build_tuple(output_type, {k: v for k, v in node.senders.items() if v is not None}) start_queue.append((node_id, comp, built_inputs, built_outputs)) @@ -423,12 +386,17 @@ def _build_tuple(tp: type | None, handles: dict[str, Any]) -> tuple[Any, ...]: return () if hasattr(tp, "_fields"): return tp(**handles) - return tuple(handles[k] for k in sorted(handles.keys())) + # Composite or plain tuple: preserve insertion order of handles dict + return tuple(handles.values()) def stop(self) -> None: """Stop all components and await their threads.""" - for sender in self._sender_handles.values(): - sender._stopped = True + for node in self._graph.nodes.values(): + for sender in node.senders.values(): + if sender is not None: + sender._stopped = True + node.senders = {} + node.receivers = {} for sender in self._ui_senders.values(): sender._stopped = True for comp in self._components.values(): diff --git a/backend/tests/core/test_frames_graph.py b/backend/tests/core/test_frames_graph.py index 10c26279..55d22f84 100644 --- a/backend/tests/core/test_frames_graph.py +++ b/backend/tests/core/test_frames_graph.py @@ -288,7 +288,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: assert "a" not in gm4.graph.nodes assert gm4.graph.edges == [] assert gm4.components()["b"].stop_event.is_set() is True - assert gm4.components()["c"].stop_event.is_set() is True + assert gm4.components()["c"].stop_event.is_set() is False # upstream, not stopped composite_node = Node( id_="wrap", From 0d8fb6868753d383f05a66a65e5111b4d4407e2e Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 15:45:01 -0400 Subject: [PATCH 03/21] refactor: three-layer separation of graph topology, channel topology, and wiring - _reconcile() returns (sender_plan, receiver_plan) as a pure plan - run() creates fresh Sender/Receiver from the plan, stores on Node - stop() marks senders as stopped, doesn't clear handles - Receiver registers in __init__, unregisters in __del__ - Remove _wire/_unwire, _sender_handles, _receiver_handles, _ui_channels - run() accepts optional overrides for composite boundary wiring - CompositeComponent uses tuple I/O with proper type bounds - Composite boundary slot naming uses disambiguation (slot, Type.slot, Type.slot.N) Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/core/component.py | 27 ++++++---------- backend/src/core/graph.py | 31 +++++++++++++------ backend/tests/api/test_api_controllers.py | 3 +- backend/tests/api/test_api_services.py | 6 ++-- backend/tests/api/test_ui_controller.py | 5 +-- backend/tests/core/test_channel.py | 26 ++++------------ backend/tests/core/test_component_graph.py | 4 +-- backend/tests/core/test_frames_graph.py | 22 ++++++------- .../test_channel_component_graph_more.py | 29 ++++++++--------- backend/tests/test_dart_control.py | 10 ++---- 10 files changed, 74 insertions(+), 89 deletions(-) diff --git a/backend/src/core/component.py b/backend/src/core/component.py index 6f0733a2..c2b78d76 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -524,27 +524,20 @@ def start(self, inputs: tuple[Receiver[Any] | None, ...], outputs: tuple[Sender[ return self._status = Status.SETUP - from src.core.graph import GraphManager + from src.core.graph import GraphManager, ReceiverKey, SenderKey self._inner_manager = GraphManager(self._sub_graph) - for ext_name, recv in zip(self._ext_inputs, inputs): - if recv is None: - continue - node_id, slot = self._ext_inputs[ext_name] - inner_node = self._inner_manager.graph.nodes.get(node_id) - if inner_node is not None: - inner_node.receivers[slot] = recv - - for ext_name, send in zip(self._ext_outputs, outputs): - if send is None: - continue - node_id, slot = self._ext_outputs[ext_name] - inner_node = self._inner_manager.graph.nodes.get(node_id) - if inner_node is not None: - inner_node.senders[slot] = send + recv_overrides: dict[ReceiverKey, Receiver[Any] | None] = { + self._ext_inputs[ext_name]: recv + for ext_name, recv in zip(self._ext_inputs, inputs) + } + send_overrides: dict[SenderKey, Sender[Any] | None] = { + self._ext_outputs[ext_name]: send + for ext_name, send in zip(self._ext_outputs, outputs) + } - self._inner_manager.run() + self._inner_manager.run(recv_overrides, send_overrides) self._status = Status.RUNNING def stop(self) -> None: diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index 13e4abc4..e3928357 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -294,13 +294,23 @@ def _reconcile( return dict(sender_plan), receiver_plan - def run(self) -> None: - """Stop all running components, then start each with fresh handles.""" + def run( + self, + receiver_overrides: dict[ReceiverKey, Receiver[Any] | None] | None = None, + sender_overrides: dict[SenderKey, Sender[Any] | None] | None = None, + ) -> None: + """Stop all running components, then start each with fresh handles. + + Optional overrides let callers (e.g. CompositeComponent) inject + pre-built handles for specific slots. + """ self.stop() self._ui_senders.clear() self._ui_receivers.clear() sender_plan, receiver_plan = self._reconcile() + _recv_over = receiver_overrides or {} + _send_over = sender_overrides or {} start_queue: list[tuple[str, Component[Any, Any], Any, Any]] = [] for node_id, node in self._graph.nodes.items(): @@ -318,18 +328,23 @@ def run(self) -> None: stop_event = comp.stop_event if isinstance(comp, ThreadedComponent) else threading.Event() # Create fresh handles from plan, store on node. - # Pre-existing handles (e.g. from composite boundary wiring) are kept. + # Overrides take priority over the plan. for slot in input_slots: - if (node_id, slot) in receiver_plan: + if (node_id, slot) in _recv_over: + node.receivers[slot] = _recv_over[(node_id, slot)] + elif (node_id, slot) in receiver_plan: node.receivers[slot] = Receiver(receiver_plan[(node_id, slot)], stop_event) else: node.receivers[slot] = None for slot in output_slots: - if (node_id, slot) in sender_plan: + if (node_id, slot) in _send_over: + node.senders[slot] = _send_over[(node_id, slot)] or Sender() + elif (node_id, slot) in sender_plan: node.senders[slot] = Sender(*sender_plan[(node_id, slot)]) else: - node.senders[slot] = None + # Unconnected output: no-op sender (sends are discarded) + node.senders[slot] = Sender() # Wire UI input channels (frontend -> component) for slot, slot_type in ui_input_slots.items(): @@ -346,7 +361,7 @@ def run(self) -> None: self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) built_inputs = self._build_tuple(input_type, dict(node.receivers)) - built_outputs = self._build_tuple(output_type, {k: v for k, v in node.senders.items() if v is not None}) + built_outputs = self._build_tuple(output_type, dict(node.senders)) start_queue.append((node_id, comp, built_inputs, built_outputs)) @@ -395,8 +410,6 @@ def stop(self) -> None: for sender in node.senders.values(): if sender is not None: sender._stopped = True - node.senders = {} - node.receivers = {} for sender in self._ui_senders.values(): sender._stopped = True for comp in self._components.values(): diff --git a/backend/tests/api/test_api_controllers.py b/backend/tests/api/test_api_controllers.py index 12cad08f..4410e558 100644 --- a/backend/tests/api/test_api_controllers.py +++ b/backend/tests/api/test_api_controllers.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import threading import types import pytest @@ -29,7 +30,7 @@ def __init__(self) -> None: node = Node(id_="n1", type="A", init_args={}, x=1, y=2) self.graph = Graph(nodes={"n1": node}, edges=[]) self._sender = Sender(Channel()) - self._receiver = Receiver(Channel()) + self._receiver = Receiver(Channel(), threading.Event()) def component(self, _node_id: str): return types.SimpleNamespace(status=types.SimpleNamespace(value="running")) diff --git a/backend/tests/api/test_api_services.py b/backend/tests/api/test_api_services.py index ca0fb75e..a14e8703 100644 --- a/backend/tests/api/test_api_services.py +++ b/backend/tests/api/test_api_services.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import threading import types import pytest @@ -33,7 +34,7 @@ def __init__(self) -> None: self._reset_called = False channel = Channel() self._sender = Sender(channel) - self._receiver = Receiver(channel) + self._receiver = Receiver(channel, threading.Event()) def get_node(self, node_id: str): return self.graph.nodes.get(node_id) @@ -224,9 +225,7 @@ def test_logs_controller(monkeypatch) -> None: def test_metrics_collector_collect() -> None: manager = FakeManager() frame = TextFrame.new(text="hello") - stop_event = types.SimpleNamespace(is_set=lambda: False) manager._receiver.blocking = False - manager._receiver._wire(stop_event) manager._sender.send(frame) next(manager._receiver) @@ -235,7 +234,6 @@ def test_metrics_collector_collect() -> None: second = collector.collect(manager) assert "a" in first.nodes assert second.nodes["a"].senders["out"].msg_count_delta == 0 - manager._receiver._unwire() def test_project_service(monkeypatch, tmp_path) -> None: diff --git a/backend/tests/api/test_ui_controller.py b/backend/tests/api/test_ui_controller.py index 0c793d0d..90524f84 100644 --- a/backend/tests/api/test_ui_controller.py +++ b/backend/tests/api/test_ui_controller.py @@ -3,6 +3,7 @@ import asyncio import json import struct +import threading import types from fastapi import WebSocketDisconnect @@ -64,7 +65,7 @@ def __init__(self) -> None: self._component = _FakeComponent() self._ui_version = 0 self._ui_changed = asyncio.Event() - self._receiver = Receiver(Channel()) + self._receiver = Receiver(Channel(), threading.Event()) self._ui_receivers = {("node", "video"): self._receiver} self.sent_payloads: list[object] = [] self._ui_senders = { @@ -162,7 +163,7 @@ async def run_case(item: object, inner_type: type | None, failing: bool = False) stop_event = asyncio.Event() channel = Channel() sender = Sender(channel) - receiver = Receiver(channel) + receiver = Receiver(channel, threading.Event()) ws = _FakeWebSocket(_FakeManager()) if failing: diff --git a/backend/tests/core/test_channel.py b/backend/tests/core/test_channel.py index 56110c08..1f44ba13 100644 --- a/backend/tests/core/test_channel.py +++ b/backend/tests/core/test_channel.py @@ -9,8 +9,7 @@ def test_channel_send_receive_and_gc() -> None: channel = Channel[int]() sender = Sender(channel) - receiver = Receiver(channel) - receiver._wire(threading.Event()) + receiver = Receiver(channel, threading.Event()) sender.send(1) sender.send(2) @@ -18,16 +17,14 @@ def test_channel_send_receive_and_gc() -> None: assert next(receiver) == 2 assert receiver.lag == 0 assert sender.buffer_depth == 0 - receiver._unwire() def test_channel_non_blocking_and_fast_forward() -> None: channel = Channel[int]() - receiver = Receiver(channel) + receiver = Receiver(channel, threading.Event()) sender = Sender(channel) receiver.newest = True receiver.blocking = False - receiver._wire(threading.Event()) sender.send(10) sender.send(20) @@ -35,7 +32,6 @@ def test_channel_non_blocking_and_fast_forward() -> None: assert next(receiver) == 30 assert next(receiver) is None - receiver._unwire() def test_channel_wait_stop_and_unregister_idempotent() -> None: @@ -50,22 +46,18 @@ def test_channel_wait_stop_and_unregister_idempotent() -> None: def test_receiver_iterator_stop_event_finishes() -> None: channel = Channel[str]() - receiver = Receiver(channel) stop_event = threading.Event() - receiver._wire(stop_event) + receiver = Receiver(channel, stop_event) assert iter(receiver) is receiver stop_event.set() assert next(receiver) is None - receiver._unwire() def test_sender_metrics_with_multiple_channels() -> None: c1 = Channel[TextFrame]() c2 = Channel[TextFrame]() - r1 = Receiver(c1) - r2 = Receiver(c2) - r1._wire(threading.Event()) - r2._wire(threading.Event()) + _r1 = Receiver(c1, threading.Event()) # noqa: F841 + _r2 = Receiver(c2, threading.Event()) # noqa: F841 sender = Sender(c1, c2) frame = TextFrame.new(text="x") sender.send(frame) @@ -73,16 +65,10 @@ def test_sender_metrics_with_multiple_channels() -> None: assert sender._byte_count > 0 assert sender._last_send_time > 0 assert sender.buffer_depth == 2 - r1._unwire() - r2._unwire() def test_receiver_lag_without_subscriber_or_cursor() -> None: channel = Channel[int]() - receiver = Receiver(channel) - assert receiver.lag == 0 - receiver._wire(threading.Event()) + receiver = Receiver(channel, threading.Event()) Sender(channel).send(5) assert receiver.lag == 1 - receiver._unwire() - assert receiver.lag == 0 diff --git a/backend/tests/core/test_component_graph.py b/backend/tests/core/test_component_graph.py index e43be576..b9bb25ff 100644 --- a/backend/tests/core/test_component_graph.py +++ b/backend/tests/core/test_component_graph.py @@ -144,8 +144,8 @@ def test_composite_component_start_stop(monkeypatch) -> None: n1 = Node(id_="a", type="DemoThread", init_args={}) sub = Graph(nodes={"a": n1}, edges=[]) comp = CompositeComponent("C", sub) - assert "a.1" in comp.get_input_types() - assert "a.1" in comp.get_output_types() + assert "1" in comp.get_input_types() + assert "1" in comp.get_output_types() comp.start((), ()) assert comp.status == Status.RUNNING comp.stop() diff --git a/backend/tests/core/test_frames_graph.py b/backend/tests/core/test_frames_graph.py index 55d22f84..7d9a2b21 100644 --- a/backend/tests/core/test_frames_graph.py +++ b/backend/tests/core/test_frames_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import threading import types from pathlib import Path from typing import NamedTuple @@ -185,19 +186,18 @@ def test_composite_component_additional_paths(monkeypatch) -> None: ) comp = CompositeComponent("Wrap", sub_graph) assert comp.type_ == "Wrap" - assert comp.get_input_types()["n1.maybe"] == Receiver[TextFrame] | None - assert "n1.out" not in comp.get_output_types() + assert comp.get_input_types()["maybe"] == Receiver[TextFrame] | None + assert "out" not in comp.get_output_types() startable = CompositeComponent( "WrapStart", Graph(nodes={"n1": Node(id_="n1", type="Known", init_args={})}, edges=[]), ) - recv = Receiver(Channel()) + recv = Receiver(Channel(), threading.Event()) send = Sender(Channel()) - named_inputs = {"n1.plain": recv, "n1.maybe": recv} - named_outputs = {"n1.out": send} - startable.start(named_inputs, named_outputs) + # Pass as tuple — order matches get_input_types() / get_output_types() + startable.start((recv, recv), (send,)) assert startable.status == Status.RUNNING assert startable._inner_manager is not None assert startable._inner_manager.receiver_handles()[("n1", "plain")] is recv @@ -211,10 +211,10 @@ def test_composite_component_additional_paths(monkeypatch) -> None: ) comp2.start((recv, recv), (send,)) assert comp2._inner_manager is not None - assert ("n1", "plain") not in comp2._inner_manager.receiver_handles() - assert ("n1", "maybe") not in comp2._inner_manager.receiver_handles() + # After run(), inner manager creates fresh handles — outer ones are on the node + assert ("n1", "plain") in comp2._inner_manager.receiver_handles() + assert ("n1", "maybe") in comp2._inner_manager.receiver_handles() assert ("n1", "out") in comp2._inner_manager.sender_handles() - assert comp2._inner_manager.sender_handles()[("n1", "out")] is not send comp2._status = Status.RUNNING current_manager = comp2._inner_manager @@ -303,7 +303,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: ui_node = Node(id_="ui", type="Known", init_args={}) gm6 = GraphManager(Graph(nodes={"ui": ui_node}, edges=[])) - gm6._sender_handles.clear() + gm6._channel_map.clear() gm6.run() assert ("ui", "ui_text") in gm6.ui_senders() assert ("ui", "ui_text") in gm6.ui_receivers() @@ -313,7 +313,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: assert registered and registered[0]["node_id"] == "ui" assert GraphManager._build_tuple(None, {"x": 1}) == () - tuple_out = GraphManager._build_tuple(tuple[int, str], {"2": "b", "1": 1}) + tuple_out = GraphManager._build_tuple(tuple[int, str], {"1": 1, "2": "b"}) assert tuple_out == (1, "b") diff --git a/backend/tests/test_channel_component_graph_more.py b/backend/tests/test_channel_component_graph_more.py index ddb4ee8c..1f418d6a 100644 --- a/backend/tests/test_channel_component_graph_more.py +++ b/backend/tests/test_channel_component_graph_more.py @@ -53,14 +53,12 @@ def test_channel_remaining_paths(monkeypatch) -> None: assert channel._items == [] stop_event = threading.Event() - cursor = Receiver(channel) - cursor._wire(stop_event) + cursor = Receiver(channel, stop_event) cursor.blocking = False assert next(cursor) is None - newest = Receiver(channel) + newest = Receiver(channel, stop_event) newest.newest = True - newest._wire(stop_event) newest.newest = False newest.newest = False assert newest._sub_id in channel._cursors @@ -68,22 +66,20 @@ def test_channel_remaining_paths(monkeypatch) -> None: channel._items = [10] channel._offset = 1 - channel._cursors[cursor._sub_id] = 0 # type: ignore[index] - assert channel._get(cursor._sub_id, stop_event, blocking=False) is None # type: ignore[arg-type] + channel._cursors[cursor._sub_id] = 0 + assert channel._get(cursor._sub_id, stop_event, blocking=False) is None waiting_channel = Channel[int]() - newest_blocking = Receiver(waiting_channel) + newest_blocking = Receiver(waiting_channel, stop_event) newest_blocking.newest = True - newest_blocking._wire(stop_event) monkeypatch.setattr( waiting_channel._condition, "wait", lambda timeout=None: stop_event.set() ) assert next(newest_blocking) is None continue_channel = Channel[int]() - continue_newest = Receiver(continue_channel) + continue_newest = Receiver(continue_channel, threading.Event()) continue_newest.newest = True - continue_newest._wire(threading.Event()) wait_calls: list[float | None] = [] def _wake_with_item(timeout: float | None = None) -> None: @@ -98,11 +94,12 @@ def _wake_with_item(timeout: float | None = None) -> None: sender.send(2) assert sender._msg_count == 1 - bad_receiver = Receiver(channel) - bad_receiver._wired = True - bad_receiver._sub_id = 999 - bad_receiver._unwire = lambda: (_ for _ in ()).throw(RuntimeError("boom")) # type: ignore[method-assign] + # Test __del__ handles exceptions gracefully + bad_receiver = Receiver(channel, threading.Event()) + old_unregister = channel._unregister + channel._unregister = lambda sub_id: (_ for _ in ()).throw(RuntimeError("boom")) # type: ignore[method-assign,assignment] bad_receiver.__del__() + channel._unregister = old_unregister # type: ignore[method-assign] def test_threaded_component_remaining_paths(monkeypatch) -> None: @@ -170,5 +167,5 @@ def test_composite_component_ui_types_and_graph_run_uses_dicts() -> None: manager.run() - assert isinstance(captured["inputs"], dict) - assert isinstance(captured["outputs"], dict) + assert isinstance(captured["inputs"], tuple) + assert isinstance(captured["outputs"], tuple) diff --git a/backend/tests/test_dart_control.py b/backend/tests/test_dart_control.py index f12acbe7..4817f129 100644 --- a/backend/tests/test_dart_control.py +++ b/backend/tests/test_dart_control.py @@ -110,20 +110,16 @@ def test_manual_start_uses_wired_receivers(monkeypatch) -> None: goal_channel = Channel[GoalFrame]() goal_sender = Sender(goal_channel) - goal_receiver = Receiver(goal_channel) + goal_receiver = Receiver(goal_channel, component.stop_event) instruction_channel = Channel[TextFrame]() instruction_sender = Sender(instruction_channel) - instruction_receiver = Receiver(instruction_channel) + instruction_receiver = Receiver(instruction_channel, component.stop_event) motion_channel = Channel() motion_sender = Sender(motion_channel) - motion_receiver = Receiver(motion_channel) - - goal_receiver._wire(component.stop_event) - instruction_receiver._wire(component.stop_event) motion_stop_event = threading.Event() - motion_receiver._wire(motion_stop_event) + motion_receiver = Receiver(motion_channel, motion_stop_event) component.start( DartControlInputs(goal=goal_receiver, instruction=instruction_receiver), From afff429b7487a5c57f6365de199ad60963395cc5 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 17:12:41 -0400 Subject: [PATCH 04/21] refactor: extract UI channels into UIChannelBridge, simplify WS controller - UIChannelBridge owns WebSocket lifecycle via run(ws) - wire() creates UI channels and returns overrides for GraphManager.run() - Paired encode/decode functions for binary and JSON wire formats - One blocking task per UI output using Receiver properly - Controller reduced to accept + bridge.run() - Move SenderKey/ReceiverKey to utils - CompositeComponent uses tuple I/O with proper type bounds Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/dep.py | 5 + backend/src/api/graph/node/controller.py | 6 +- backend/src/api/graph/node/service.py | 8 +- backend/src/api/graph/run/controller.py | 10 +- backend/src/api/graph/run/service.py | 6 +- backend/src/api/metrics/service.py | 3 +- backend/src/api/ui/bridge.py | 232 +++++++++++++++++++++++ backend/src/api/ui/controller.py | 230 +--------------------- backend/src/core/component.py | 3 +- backend/src/core/graph.py | 79 +++----- backend/src/core/utils.py | 3 + backend/src/main.py | 2 + 12 files changed, 297 insertions(+), 290 deletions(-) create mode 100644 backend/src/api/ui/bridge.py diff --git a/backend/src/api/dep.py b/backend/src/api/dep.py index 299ef487..a30200e0 100644 --- a/backend/src/api/dep.py +++ b/backend/src/api/dep.py @@ -2,8 +2,13 @@ from fastapi import Request +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager def get_manager(request: Request) -> GraphManager: return request.app.state.manager + + +def get_ui_bridge(request: Request) -> UIChannelBridge: + return request.app.state.ui_bridge diff --git a/backend/src/api/graph/node/controller.py b/backend/src/api/graph/node/controller.py index db45ab49..49821baa 100644 --- a/backend/src/api/graph/node/controller.py +++ b/backend/src/api/graph/node/controller.py @@ -2,7 +2,8 @@ from fastapi import APIRouter, Depends, HTTPException -from src.api.dep import get_manager +from src.api.dep import get_manager, get_ui_bridge +from src.api.ui.bridge import UIChannelBridge from src.api.graph.node.dto import ( NodeInitArgsUpdateRequest, NodeCreateRequest, @@ -83,8 +84,9 @@ def update_primitive_node_init_args( node_id: str, req: NodeInitArgsUpdateRequest, manager: GraphManager = Depends(get_manager), + ui_bridge: UIChannelBridge = Depends(get_ui_bridge), ) -> NodeResponse: - node = service.update_primitive_node_init_args(manager, node_id, req.init_args) + node = service.update_primitive_node_init_args(manager, ui_bridge, node_id, req.init_args) if node is None: raise HTTPException(status_code=404, detail=f"Node not found: {node_id}") return _node_response(node_id, node, manager) diff --git a/backend/src/api/graph/node/service.py b/backend/src/api/graph/node/service.py index dbe52c2a..017960f3 100644 --- a/backend/src/api/graph/node/service.py +++ b/backend/src/api/graph/node/service.py @@ -4,6 +4,7 @@ from typing import Any from src.api.graph.node.dto import NodeUpdateRequest +from src.api.ui.bridge import UIChannelBridge from src.core.config import PROJECTS_DIR from src.core.graph import Edge, Graph, GraphManager, Node @@ -39,10 +40,15 @@ def delete_node(manager: GraphManager, node_id: str) -> None: def update_primitive_node_init_args( manager: GraphManager, + ui_bridge: UIChannelBridge, node_id: str, init_args: dict[str, Any], ) -> Node | None: - return manager.update_primitive_node_init_args(node_id, init_args) + node, was_running = manager.update_primitive_node_init_args(node_id, init_args) + if was_running and node is not None: + recv_overrides, send_overrides = ui_bridge.wire(manager) + manager.run(recv_overrides, send_overrides) + return node def create_subgraph( diff --git a/backend/src/api/graph/run/controller.py b/backend/src/api/graph/run/controller.py index 4ab354a2..e50af66b 100644 --- a/backend/src/api/graph/run/controller.py +++ b/backend/src/api/graph/run/controller.py @@ -2,16 +2,20 @@ from fastapi import APIRouter, Depends -from src.api.dep import get_manager +from src.api.dep import get_manager, get_ui_bridge from src.api.graph.run import service +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager router = APIRouter(prefix="/graph") @router.post("/start", status_code=204) -def start_all(manager: GraphManager = Depends(get_manager)) -> None: - service.start_all(manager) +def start_all( + manager: GraphManager = Depends(get_manager), + ui_bridge: UIChannelBridge = Depends(get_ui_bridge), +) -> None: + service.start_all(manager, ui_bridge) @router.post("/stop", status_code=204) diff --git a/backend/src/api/graph/run/service.py b/backend/src/api/graph/run/service.py index 9d526204..b84e5c4b 100644 --- a/backend/src/api/graph/run/service.py +++ b/backend/src/api/graph/run/service.py @@ -1,10 +1,12 @@ from __future__ import annotations +from src.api.ui.bridge import UIChannelBridge from src.core.graph import GraphManager -def start_all(manager: GraphManager) -> None: - manager.run() +def start_all(manager: GraphManager, ui_bridge: UIChannelBridge) -> None: + recv_overrides, send_overrides = ui_bridge.wire(manager) + manager.run(recv_overrides, send_overrides) def stop_all(manager: GraphManager) -> None: diff --git a/backend/src/api/metrics/service.py b/backend/src/api/metrics/service.py index 19fc7215..67803d3b 100644 --- a/backend/src/api/metrics/service.py +++ b/backend/src/api/metrics/service.py @@ -9,7 +9,8 @@ ReceiverSnapshot, SenderSnapshot, ) -from src.core.graph import GraphManager, ReceiverKey, SenderKey +from src.core.graph import GraphManager +from src.core.utils import ReceiverKey, SenderKey class MetricsCollector: diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py new file mode 100644 index 00000000..037c6d00 --- /dev/null +++ b/backend/src/api/ui/bridge.py @@ -0,0 +1,232 @@ +"""UI channel bridge: creates UI channels and manages the WebSocket connection. + +The bridge inspects components for UI slots, creates channels and handles, +passes component-side handles as overrides to GraphManager.run(), and +owns the WebSocket lifecycle for bidirectional UI communication. + +Wire format: + Binary frames: 2-byte header length (big-endian) + JSON header + raw payload + Text frames: JSON {"type": "ui_input"|"ui_output", "node_id", "channel", "payload"} +""" + +from __future__ import annotations + +import asyncio +import json +import struct +import threading +from typing import Any, get_args, get_origin + +from fastapi import WebSocket, WebSocketDisconnect +from pydantic import BaseModel +from src.core.channel import Channel, Receiver, Sender, UISender, UIReceiver +from src.core.frames import TextFrame +from src.core.graph import GraphManager +from src.core.utils import ReceiverKey, SenderKey + + +# -- Wire format: encode/decode pairs -- + +def encode_binary(node_id: str, slot: str, payload: bytes) -> bytes: + """Pack a binary UI message: 2-byte header len + JSON header + payload.""" + header = json.dumps({"type": "ui_output", "node_id": node_id, "channel": slot}).encode() + return struct.pack(">H", len(header)) + header + payload + + +def decode_binary(buf: bytes) -> tuple[SenderKey | None, bytes]: + """Unpack a binary UI message. Returns (key, payload).""" + if len(buf) < 2: + return None, b"" + header_len = struct.unpack(">H", buf[:2])[0] + header = json.loads(buf[2 : 2 + header_len].decode("utf-8")) + payload = buf[2 + header_len :] + return (header.get("node_id"), header.get("channel")), payload + + +def encode_json(node_id: str, slot: str, item: Any) -> dict[str, Any]: + """Build a JSON UI output envelope.""" + payload: Any + if isinstance(item, BaseModel): + payload = item.model_dump() + elif isinstance(item, TextFrame): + payload = item.text + else: + payload = item + return {"type": "ui_output", "node_id": node_id, "channel": slot, "payload": payload} + + +def decode_json(text: str) -> tuple[SenderKey, Any] | None: + """Parse a JSON UI input message. Returns (key, raw_payload) or None.""" + msg = json.loads(text) + if msg.get("type") != "ui_input": + return None + return (msg["node_id"], msg["channel"]), msg.get("payload", "") + + +def deserialize_payload(payload: Any, inner_type: type | None) -> Any: + """Convert a raw JSON payload to the expected type.""" + if inner_type is not None and issubclass(inner_type, BaseModel) and isinstance(payload, dict): + return inner_type.model_validate(payload) + if inner_type is not None and hasattr(inner_type, "new") and isinstance(payload, str): + return inner_type.new(text=payload) + return payload + + +# -- Bridge -- + +class UIChannelBridge: + def __init__(self) -> None: + self._ui_senders: dict[SenderKey, Sender[Any]] = {} + self._ui_receivers: dict[ReceiverKey, Receiver[Any]] = {} + self._manager: GraphManager | None = None + self._ws: WebSocket | None = None + self._stop_event: asyncio.Event = asyncio.Event() + self._send_tasks: dict[tuple[str, str], asyncio.Task[None]] = {} + + def wire( + self, manager: GraphManager + ) -> tuple[ + dict[ReceiverKey, Receiver[Any] | None], + dict[SenderKey, Sender[Any] | None], + ]: + """Create UI channels for all components and return overrides for run().""" + self._manager = manager + self._ui_senders.clear() + self._ui_receivers.clear() + + recv_overrides: dict[ReceiverKey, Receiver[Any] | None] = {} + send_overrides: dict[SenderKey, Sender[Any] | None] = {} + + for node_id, comp in manager.components().items(): + stop_event = ( + comp.stop_event + if hasattr(comp, "stop_event") + else threading.Event() + ) + + for slot, slot_type in comp.get_ui_input_types().items(): + ch: Channel[Any] = Channel() + origin = get_origin(slot_type) or slot_type + recv_overrides[(node_id, slot)] = origin(ch, stop_event) + self._ui_senders[(node_id, slot)] = Sender(ch) + + for slot, slot_type in comp.get_ui_output_types().items(): + ch = Channel() + origin = get_origin(slot_type) or slot_type + send_overrides[(node_id, slot)] = origin(ch) + self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) + + # Restart outbound tasks if WS is connected + if self._ws is not None: + self._send_msgs(self._ws) + + return recv_overrides, send_overrides + + async def run(self, ws: WebSocket) -> None: + """Own the WebSocket: read inbound, send outbound via tasks.""" + self._ws = ws + self._stop_event.clear() + self._send_msgs(ws) + try: + await self._recv_msgs(ws) + finally: + self._ws = None + self._stop_event.set() + for t in self._send_tasks.values(): + t.cancel() + await asyncio.gather(*self._send_tasks.values(), return_exceptions=True) + self._send_tasks.clear() + + # -- Outbound: component → frontend -- + + def _send_msgs(self, ws: WebSocket) -> None: + for key in list(self._send_tasks): + self._send_tasks.pop(key).cancel() + for key, receiver in self._ui_receivers.items(): + node_id, slot = key + inner_type = self._resolve_ui_output_type(node_id, slot) + self._send_tasks[key] = asyncio.create_task( + self._send_msg_task(ws, node_id, slot, receiver, inner_type) + ) + + async def _send_msg_task( + self, ws: WebSocket, node_id: str, slot: str, + receiver: Receiver[Any], inner_type: type | None, + ) -> None: + try: + while not self._stop_event.is_set(): + item = await asyncio.to_thread(next, receiver) + if item is None: + break + if inner_type is not None and issubclass(inner_type, bytes): + await ws.send_bytes(encode_binary(node_id, slot, item)) + else: + await ws.send_json(encode_json(node_id, slot, item)) + except (WebSocketDisconnect, RuntimeError, asyncio.CancelledError): + pass + + # -- Inbound: frontend → component -- + + async def _recv_msgs(self, ws: WebSocket) -> None: + while True: + ws_msg = await ws.receive() + if "bytes" in ws_msg: + key, payload = decode_binary(ws_msg["bytes"]) + if key is not None: + sender = self._ui_senders.get(key) + if sender is not None: + sender.send(payload) + elif "text" in ws_msg: + result = decode_json(ws_msg["text"]) + if result is not None: + key, payload = result + sender = self._ui_senders.get(key) + if sender is not None: + inner_type = self._resolve_ui_input_type(*key) + sender.send(deserialize_payload(payload, inner_type)) + + # -- Type resolution helpers -- + + def _resolve_ui_output_type(self, node_id: str, slot: str) -> type | None: + if self._manager is None: + return None + comp = self._manager.components().get(node_id) + if comp is None: + return None + slot_type = comp.get_ui_output_types().get(slot) + if slot_type is None: + return None + args = get_args(slot_type) + if args: + return args[0] + origin = get_origin(slot_type) or slot_type + if isinstance(origin, type): + for base in getattr(origin, "__orig_bases__", ()): + base_origin = get_origin(base) + if base_origin is not None and issubclass(base_origin, UISender): + base_args = get_args(base) + if base_args: + return base_args[0] + return None + + def _resolve_ui_input_type(self, node_id: str, slot: str) -> type | None: + if self._manager is None: + return None + comp = self._manager.components().get(node_id) + if comp is None: + return None + slot_type = comp.get_ui_input_types().get(slot) + if slot_type is None: + return None + args = get_args(slot_type) + if args: + return args[0] + origin = get_origin(slot_type) or slot_type + if isinstance(origin, type): + for base in getattr(origin, "__orig_bases__", ()): + base_origin = get_origin(base) + if base_origin is not None and issubclass(base_origin, UIReceiver): + base_args = get_args(base) + if base_args: + return base_args[0] + return None diff --git a/backend/src/api/ui/controller.py b/backend/src/api/ui/controller.py index 0c0861eb..759f5c54 100644 --- a/backend/src/api/ui/controller.py +++ b/backend/src/api/ui/controller.py @@ -1,236 +1,14 @@ -"""Single WebSocket that bridges all UI channels between components and frontend. - -Inbound (text JSON): - {"type": "ui_input", "node_id": "...", "channel": "...", "payload": "..."} - -> finds the Sender for (node_id, channel) and calls .send() - -Outbound: - bytes channels: binary WS frame = 2-byte header length (big-endian) + JSON header + raw bytes - BaseModel channels: JSON {"type": "ui_output", ..., "payload": {model dict}} - Legacy (.get()) channels: JSON {"type": "ui_output", ..., "payload": "string"} -""" - from __future__ import annotations -import asyncio -import itertools -import json -import struct -import threading -from typing import Any, get_args, get_origin +from fastapi import APIRouter, WebSocket -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from pydantic import BaseModel -from src.core.channel import Receiver, UISender, UIReceiver -from src.core.graph import GraphManager +from src.api.ui.bridge import UIChannelBridge router = APIRouter(prefix="/ui") -_sub_id_counter = itertools.count() - - -def _resolve_ui_output_type( - manager: GraphManager, node_id: str, slot: str -) -> type | None: - """Extract the inner T from UISender[T] for the given output slot.""" - comp = manager.components().get(node_id) - if comp is None: - return None - slot_type = comp.get_ui_output_types().get(slot) - if slot_type is None: - return None - # slot_type is e.g. UISender[bytes], UIVideoSender (which is UISender[bytes]), etc. - # Try get_args on the slot_type first (for generic aliases like UISender[MyModel]) - args = get_args(slot_type) - if args: - return args[0] - # For concrete subclasses like UIVideoSender, walk __orig_bases__ - origin = get_origin(slot_type) or slot_type - if isinstance(origin, type): - for base in getattr(origin, "__orig_bases__", ()): - base_origin = get_origin(base) - if base_origin is not None and issubclass(base_origin, UISender): - base_args = get_args(base) - if base_args: - return base_args[0] - return None - - -def _resolve_ui_input_type( - manager: GraphManager, node_id: str, slot: str -) -> type | None: - """Extract the inner T from UIReceiver[T] for the given input slot.""" - comp = manager.components().get(node_id) - if comp is None: - return None - slot_type = comp.get_ui_input_types().get(slot) - if slot_type is None: - return None - args = get_args(slot_type) - if args: - return args[0] - origin = get_origin(slot_type) or slot_type - if isinstance(origin, type): - for base in getattr(origin, "__orig_bases__", ()): - base_origin = get_origin(base) - if base_origin is not None and issubclass(base_origin, UIReceiver): - base_args = get_args(base) - if base_args: - return base_args[0] - return None - - -async def _read_ui_output( - ws: WebSocket, - node_id: str, - slot: str, - receiver: Receiver[Any], - inner_type: type | None, - stop_event: asyncio.Event, -) -> None: - """Async task: reads from a blocking Receiver and sends to the WebSocket.""" - thread_stop = threading.Event() - sub_id = next(_sub_id_counter) - receiver._channel._register(sub_id) - - try: - while not stop_event.is_set(): - item = await asyncio.to_thread(receiver._channel._get, sub_id, thread_stop) - if item is None: - break - try: - if inner_type is not None and issubclass(inner_type, bytes): - # Binary path (video frames, etc.) - header = json.dumps( - {"type": "ui_output", "node_id": node_id, "channel": slot} - ).encode() - prefix = struct.pack(">H", len(header)) - await ws.send_bytes(prefix + header + item) - elif isinstance(item, BaseModel): - # Pydantic model → JSON dict payload - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item.model_dump(), - } - ) - elif hasattr(item, "get"): - # Legacy TextFrame-style with .get() - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item.get(), - } - ) - else: - await ws.send_json( - { - "type": "ui_output", - "node_id": node_id, - "channel": slot, - "payload": item, - } - ) - except (WebSocketDisconnect, RuntimeError): - break - finally: - receiver._channel._unregister(sub_id) - thread_stop.set() - - -async def _watch_ui_channels( - ws: WebSocket, - manager: GraphManager, - stop_event: asyncio.Event, - tasks: dict[tuple[str, str], asyncio.Task[None]], -) -> None: - """Wait for GraphManager.run() to signal new UI channels, then spawn readers.""" - last_version = manager._ui_version - while not stop_event.is_set(): - # Spawn tasks for any receivers we haven't seen yet - for key, receiver in manager.ui_receivers().items(): - if key not in tasks: - node_id, slot = key - inner_type = _resolve_ui_output_type(manager, node_id, slot) - tasks[key] = asyncio.create_task( - _read_ui_output(ws, node_id, slot, receiver, inner_type, stop_event) - ) - - if manager._ui_version == last_version: - # Wait until run() fires the event - try: - await manager._ui_changed.wait() - except asyncio.CancelledError: - return - - # Cancel all stale tasks — even if the same keys exist, the - # underlying Receiver objects point to new channels after a restart. - for key in list(tasks): - tasks.pop(key).cancel() - - last_version = manager._ui_version - @router.websocket("/ws") async def ui_ws(ws: WebSocket) -> None: - print("[ui_ws] Got connection request from:", ws.client) - manager: GraphManager = ws.app.state.manager + bridge: UIChannelBridge = ws.app.state.ui_bridge await ws.accept() - print("[ui_ws] Accepted connection!!!") - - stop_event = asyncio.Event() - tasks: dict[tuple[str, str], asyncio.Task[None]] = {} - - watcher = asyncio.create_task(_watch_ui_channels(ws, manager, stop_event, tasks)) - - try: - while True: - ws_msg = await ws.receive() - if "bytes" in ws_msg: - # Binary frame: 2-byte header length + JSON header + payload - buf = ws_msg["bytes"] - if len(buf) >= 2: - header_len = struct.unpack(">H", buf[:2])[0] - header = json.loads(buf[2 : 2 + header_len].decode("utf-8")) - payload = buf[2 + header_len :] - - node_id = header.get("node_id") - channel = header.get("channel") - sender = manager.ui_senders().get((node_id, channel)) - if sender is not None: - sender.send(payload) - elif "text" in ws_msg: - msg = json.loads(ws_msg["text"]) - if msg.get("type") == "ui_input": - node_id = msg["node_id"] - channel = msg["channel"] - payload = msg.get("payload", "") - sender = manager.ui_senders().get((node_id, channel)) - if sender is not None: - inner_type = _resolve_ui_input_type(manager, node_id, channel) - if ( - inner_type is not None - and issubclass(inner_type, BaseModel) - and isinstance(payload, dict) - ): - sender.send(inner_type.model_validate(payload)) - elif ( - inner_type is not None - and hasattr(inner_type, "new") - and isinstance(payload, str) - ): - sender.send(inner_type.new(text=payload)) - else: - sender.send(payload) - except (WebSocketDisconnect, RuntimeError): - pass - finally: - stop_event.set() - watcher.cancel() - for t in tasks.values(): - t.cancel() - await asyncio.gather(watcher, *tasks.values(), return_exceptions=True) + await bridge.run(ws) diff --git a/backend/src/core/component.py b/backend/src/core/component.py index c2b78d76..326d625d 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -524,7 +524,8 @@ def start(self, inputs: tuple[Receiver[Any] | None, ...], outputs: tuple[Sender[ return self._status = Status.SETUP - from src.core.graph import GraphManager, ReceiverKey, SenderKey + from src.core.graph import GraphManager + from src.core.utils import ReceiverKey, SenderKey self._inner_manager = GraphManager(self._sub_graph) diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index e3928357..4a06d2b6 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -1,11 +1,10 @@ from __future__ import annotations -import asyncio import threading import time import uuid from collections import defaultdict -from typing import Any, get_origin +from typing import Any from pydantic import BaseModel, Field @@ -18,8 +17,7 @@ from src.core.log_capture import get_log_store -SenderKey = tuple[str, str] # (node_id, slot_name) -ReceiverKey = tuple[str, str] # (node_id, slot_name) +from src.core.utils import SenderKey, ReceiverKey class Node(BaseModel): @@ -56,11 +54,6 @@ def __init__(self, graph: Graph) -> None: self._graph = Graph(edges=[], nodes={}) self._components: dict[str, Component[Any, Any]] = {} self._channel_map: dict[frozenset[SenderKey], Channel[Any]] = {} - # UI channels: keyed by (node_id, slot_name) - self._ui_senders: dict[tuple[str, str], Sender[Any]] = {} - self._ui_receivers: dict[tuple[str, str], Receiver[Any]] = {} - self._ui_version = 0 - self._ui_changed = asyncio.Event() self.reset(graph) # --- node CRUD --- @@ -111,15 +104,20 @@ def update_node(self, node_id: str, x: float, y: float) -> Node | None: def update_primitive_node_init_args( self, node_id: str, init_args: dict[str, Any] - ) -> Node | None: + ) -> tuple[Node | None, bool]: + """Replace a node's init_args and recreate its component. + + Returns (node, was_running). The caller is responsible for + calling run() with UI channel overrides if was_running is True. + """ node = self._graph.nodes.get(node_id) if node is None: - return None + return None, False classes = PrimitiveComponent.registered_subclasses() cls = classes.get(node.type) if cls is None: - return None + return None, False was_running = any( c.status.value == "running" for c in self._components.values() @@ -130,10 +128,7 @@ def update_primitive_node_init_args( node.init_args = init_args self._components[node_id] = cls.from_args(init_args) self._reconcile() - - if was_running: - self.run() - return node + return node, was_running def delete_node(self, node_id: str) -> None: node = self._graph.nodes.get(node_id) @@ -202,13 +197,21 @@ def receiver_handles(self) -> dict[ReceiverKey, Receiver[Any]]: if receiver is not None } - def ui_senders(self) -> dict[tuple[str, str], Sender[Any]]: - """Server-side senders that push data into component UIReceiver slots.""" - return self._ui_senders + def ui_input_slots(self) -> list[ReceiverKey]: + """All (node_id, slot) pairs for UI input slots across the graph.""" + return [ + (node_id, slot) + for node_id, comp in self._components.items() + for slot in comp.get_ui_input_types() + ] - def ui_receivers(self) -> dict[tuple[str, str], Receiver[Any]]: - """Server-side receivers that read from component UISender slots.""" - return self._ui_receivers + def ui_output_slots(self) -> list[SenderKey]: + """All (node_id, slot) pairs for UI output slots across the graph.""" + return [ + (node_id, slot) + for node_id, comp in self._components.items() + for slot in comp.get_ui_output_types() + ] def get_node_output(self, node_id: str) -> dict[str, type]: return self._components[node_id].get_output_types() @@ -222,8 +225,6 @@ def reset(self, graph: Graph) -> None: self._graph = graph self._components.clear() self._channel_map.clear() - self._ui_senders.clear() - self._ui_receivers.clear() classes = PrimitiveComponent.registered_subclasses() for node_id, node in self._graph.nodes.items(): @@ -305,8 +306,6 @@ def run( pre-built handles for specific slots. """ self.stop() - self._ui_senders.clear() - self._ui_receivers.clear() sender_plan, receiver_plan = self._reconcile() _recv_over = receiver_overrides or {} @@ -322,8 +321,6 @@ def run( input_slots = comp.get_input_types() output_slots = comp.get_output_types() - ui_input_slots = comp.get_ui_input_types() - ui_output_slots = comp.get_ui_output_types() stop_event = comp.stop_event if isinstance(comp, ThreadedComponent) else threading.Event() @@ -346,20 +343,6 @@ def run( # Unconnected output: no-op sender (sends are discarded) node.senders[slot] = Sender() - # Wire UI input channels (frontend -> component) - for slot, slot_type in ui_input_slots.items(): - ch: Channel[Any] = Channel() - origin = get_origin(slot_type) or slot_type - node.receivers[slot] = origin(ch, stop_event) - self._ui_senders[(node_id, slot)] = Sender(ch) - - # Wire UI output channels (component -> frontend) - for slot, slot_type in ui_output_slots.items(): - ch = Channel() - origin = get_origin(slot_type) or slot_type - node.senders[slot] = origin(ch) - self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) - built_inputs = self._build_tuple(input_type, dict(node.receivers)) built_outputs = self._build_tuple(output_type, dict(node.senders)) @@ -376,7 +359,6 @@ def run( if isinstance(comp, ThreadedComponent): ident = comp.get_ident() if ident is None: - # Thread ident may lag briefly after start(). for _ in range(10): time.sleep(0.005) ident = comp.get_ident() @@ -385,15 +367,6 @@ def run( if ident is not None: get_log_store().register_thread(node_id=node_id, ident=ident) - # Notify WS listeners that UI channels are ready. - # Save old event, replace with fresh one, *then* wake waiters. - # This avoids a race where a coroutine calling wait() between - # set() and reassignment would block on the now-orphaned event. - self._ui_version += 1 - old_event = self._ui_changed - self._ui_changed = asyncio.Event() - old_event.set() - @staticmethod def _build_tuple(tp: type | None, handles: dict[str, Any]) -> tuple[Any, ...]: """Build a NamedTuple (keyword) or plain tuple (positional) from handles.""" @@ -410,8 +383,6 @@ def stop(self) -> None: for sender in node.senders.values(): if sender is not None: sender._stopped = True - for sender in self._ui_senders.values(): - sender._stopped = True for comp in self._components.values(): comp.stop() for comp in self._components.values(): diff --git a/backend/src/core/utils.py b/backend/src/core/utils.py index b12aebe3..16304369 100644 --- a/backend/src/core/utils.py +++ b/backend/src/core/utils.py @@ -13,6 +13,9 @@ import numpy as np +SenderKey = tuple[str, str] # (node_id, slot_name) +ReceiverKey = tuple[str, str] # (node_id, slot_name) + _COUNTS: collections.defaultdict[str, itertools.count[int]] = collections.defaultdict( itertools.count ) diff --git a/backend/src/main.py b/backend/src/main.py index 48d6b07c..3ef16f3c 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -16,6 +16,7 @@ from src.api.project.controller import router as project_router from src.api.ui.controller import router as ui_router from src.api.env.controller import router as env_router +from src.api.ui.bridge import UIChannelBridge from src.core.graph import Graph from src.core.config import PROJECTS_DIR, PRESETS_DIR, AppConfig from src.core.graph import GraphManager @@ -47,6 +48,7 @@ async def lifespan(app: FastAPI): app.state.current_project = config.current_project app.state.manager = GraphManager(Graph(edges=[], nodes={})) + app.state.ui_bridge = UIChannelBridge() yield From 60daaa30d757ae29ae89fc4b56c439282ee3d76e Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 19:03:59 -0400 Subject: [PATCH 05/21] refactor: extract UI channels into UIChannelBridge, simplify WS controller - UIChannelBridge owns WebSocket lifecycle via run(ws) - wire() creates UI channels and returns overrides for GraphManager.run() - Paired encode/decode functions for binary and JSON wire formats - One blocking task per UI output using Receiver properly - Controller reduced to accept + bridge.run() - Move SenderKey/ReceiverKey to utils - Update all tests for new Receiver(channel, stop_event) signature - Update tests for update_primitive_node_init_args returning (node, was_running) Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/ui/bridge.py | 2 +- backend/src/core/component.py | 8 +- backend/src/core/graph.py | 16 +- backend/tests/api/test_api_controllers.py | 6 +- backend/tests/api/test_api_services.py | 10 +- backend/tests/api/test_node_service.py | 7 +- backend/tests/api/test_ui_controller.py | 348 ++++++++------------- backend/tests/core/test_component_graph.py | 4 +- backend/tests/core/test_frames_graph.py | 19 +- backend/tests/test_graph_manager.py | 16 +- 10 files changed, 185 insertions(+), 251 deletions(-) diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py index 037c6d00..c10cba70 100644 --- a/backend/src/api/ui/bridge.py +++ b/backend/src/api/ui/bridge.py @@ -81,7 +81,7 @@ def __init__(self) -> None: self._manager: GraphManager | None = None self._ws: WebSocket | None = None self._stop_event: asyncio.Event = asyncio.Event() - self._send_tasks: dict[tuple[str, str], asyncio.Task[None]] = {} + self._send_tasks: dict[SenderKey, asyncio.Task[None]] = {} def wire( self, manager: GraphManager diff --git a/backend/src/core/component.py b/backend/src/core/component.py index 326d625d..8e555b46 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -12,6 +12,7 @@ from src.core.channel import Receiver, Sender from src.core.channel import UIReceiver, UISender from src.core.log_capture import get_log_store +from src.core.utils import ReceiverKey, SenderKey if TYPE_CHECKING: from src.core.graph import Graph, GraphManager @@ -426,9 +427,9 @@ def type_(self) -> str: def _compute_boundary( self, - ) -> tuple[dict[str, tuple[str, str]], dict[str, tuple[str, str]]]: - connected_inputs: set[tuple[str, str]] = set() - connected_outputs: set[tuple[str, str]] = set() + ) -> tuple[dict[str, ReceiverKey], dict[str, SenderKey]]: + connected_inputs: set[ReceiverKey] = set() + connected_outputs: set[SenderKey] = set() for edge in self._sub_graph.edges: connected_inputs.add((edge.target_node, edge.target_slot)) connected_outputs.add((edge.source_node, edge.source_slot)) @@ -525,7 +526,6 @@ def start(self, inputs: tuple[Receiver[Any] | None, ...], outputs: tuple[Sender[ self._status = Status.SETUP from src.core.graph import GraphManager - from src.core.utils import ReceiverKey, SenderKey self._inner_manager = GraphManager(self._sub_graph) diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index 4a06d2b6..bed3cf43 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -321,6 +321,8 @@ def run( input_slots = comp.get_input_types() output_slots = comp.get_output_types() + ui_input_slots = comp.get_ui_input_types() + ui_output_slots = comp.get_ui_output_types() stop_event = comp.stop_event if isinstance(comp, ThreadedComponent) else threading.Event() @@ -340,7 +342,19 @@ def run( elif (node_id, slot) in sender_plan: node.senders[slot] = Sender(*sender_plan[(node_id, slot)]) else: - # Unconnected output: no-op sender (sends are discarded) + node.senders[slot] = Sender() + + # UI slots: use overrides if provided, otherwise dummy handles + for slot in ui_input_slots: + if (node_id, slot) in _recv_over: + node.receivers[slot] = _recv_over[(node_id, slot)] + else: + node.receivers[slot] = None + + for slot in ui_output_slots: + if (node_id, slot) in _send_over: + node.senders[slot] = _send_over[(node_id, slot)] or Sender() + else: node.senders[slot] = Sender() built_inputs = self._build_tuple(input_type, dict(node.receivers)) diff --git a/backend/tests/api/test_api_controllers.py b/backend/tests/api/test_api_controllers.py index 4410e558..082bdeba 100644 --- a/backend/tests/api/test_api_controllers.py +++ b/backend/tests/api/test_api_controllers.py @@ -196,12 +196,14 @@ def test_run_save_metrics_project_controllers(monkeypatch, tmp_path) -> None: called = {} monkeypatch.setattr( - run_controller.service, "start_all", lambda m: called.setdefault("start", True) + run_controller.service, "start_all", lambda m, b: called.setdefault("start", True) ) monkeypatch.setattr( run_controller.service, "stop_all", lambda m: called.setdefault("stop", True) ) - run_controller.start_all(manager) + from src.api.ui.bridge import UIChannelBridge + bridge = UIChannelBridge() + run_controller.start_all(manager, bridge) run_controller.stop_all(manager) assert called == {"start": True, "stop": True} diff --git a/backend/tests/api/test_api_services.py b/backend/tests/api/test_api_services.py index a14e8703..5acbc7f8 100644 --- a/backend/tests/api/test_api_services.py +++ b/backend/tests/api/test_api_services.py @@ -51,7 +51,7 @@ def add_edge(self, edge: Edge) -> None: def delete_edge(self, edge: Edge) -> None: self.graph.edges.remove(edge) - def run(self) -> None: + def run(self, receiver_overrides=None, sender_overrides=None) -> None: self._run_called = True def stop(self) -> None: @@ -61,9 +61,11 @@ def reset(self, _graph: Graph) -> None: self._reset_called = True def components(self): + _no_ui = lambda: {} # noqa: E731 return { "a": types.SimpleNamespace( - type_="A", status=types.SimpleNamespace(value="running") + type_="A", status=types.SimpleNamespace(value="running"), + get_ui_input_types=_no_ui, get_ui_output_types=_no_ui ) } @@ -199,7 +201,9 @@ def test_edge_run_save_services(tmp_path, monkeypatch) -> None: with pytest.raises(KeyError): edge_service.delete_edge(manager, "a", "out", "b", "in") - run_service.start_all(manager) + from src.api.ui.bridge import UIChannelBridge + bridge = UIChannelBridge() + run_service.start_all(manager, bridge) run_service.stop_all(manager) assert manager._run_called is True assert manager._stop_called is True diff --git a/backend/tests/api/test_node_service.py b/backend/tests/api/test_node_service.py index d84c7002..c06ef478 100644 --- a/backend/tests/api/test_node_service.py +++ b/backend/tests/api/test_node_service.py @@ -67,9 +67,9 @@ def delete_node(self, node_id): def update_primitive_node_init_args(self, node_id, init_args): n = self.graph.nodes.get(node_id) if n is None: - return None + return None, False n.init_args = init_args - return n + return n, False def components(self): return self._components @@ -88,7 +88,8 @@ def test_node_service_basic_crud() -> None: assert updated.x == 9 node_service.delete_node(m, "n2") assert "n2" not in m.graph.nodes - out = node_service.update_primitive_node_init_args(m, "n1", {"k": 1}) + from src.api.ui.bridge import UIChannelBridge + out = node_service.update_primitive_node_init_args(m, UIChannelBridge(), "n1", {"k": 1}) assert out.init_args == {"k": 1} diff --git a/backend/tests/api/test_ui_controller.py b/backend/tests/api/test_ui_controller.py index 90524f84..dc993228 100644 --- a/backend/tests/api/test_ui_controller.py +++ b/backend/tests/api/test_ui_controller.py @@ -2,18 +2,21 @@ import asyncio import json -import struct import threading import types from fastapi import WebSocketDisconnect from pydantic import BaseModel -from src.api.ui import controller as ui_controller +from src.api.ui.bridge import ( + UIChannelBridge, + encode_binary, + encode_json, + decode_binary, + decode_json, + deserialize_payload, +) from src.core.channel import ( - Channel, - Receiver, - Sender, UIReceiver, UISender, UITextReceiver, @@ -26,15 +29,22 @@ class _Payload(BaseModel): value: int -class _PayloadSender(UISender[_Payload]): +class _PayloadReceiver(UIReceiver[_Payload]): pass -class _PayloadReceiver(UIReceiver[_Payload]): +class _MissingSender(UISender): + pass + + +class _MissingReceiver(UIReceiver): pass class _FakeComponent: + def __init__(self) -> None: + self.stop_event = threading.Event() + def get_ui_output_types(self): return { "video": UIVideoSender, @@ -52,49 +62,17 @@ def get_ui_input_types(self): } -class _MissingSender(UISender): - pass - - -class _MissingReceiver(UIReceiver): - pass - - class _FakeManager: def __init__(self) -> None: self._component = _FakeComponent() - self._ui_version = 0 - self._ui_changed = asyncio.Event() - self._receiver = Receiver(Channel(), threading.Event()) - self._ui_receivers = {("node", "video"): self._receiver} - self.sent_payloads: list[object] = [] - self._ui_senders = { - ("node", "payload_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - ("node", "text_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - ("node", "raw_in"): types.SimpleNamespace( - send=lambda value: self.sent_payloads.append(value) - ), - } def components(self): return {"node": self._component} - def ui_receivers(self): - return self._ui_receivers - - def ui_senders(self): - return self._ui_senders - class _FakeWebSocket: - def __init__( - self, manager: _FakeManager, messages: list[str] | None = None - ) -> None: - self.app = types.SimpleNamespace(state=types.SimpleNamespace(manager=manager)) + def __init__(self, messages: list[str | bytes] | None = None) -> None: + self.app = types.SimpleNamespace(state=types.SimpleNamespace()) self.client = ("127.0.0.1", 0) self._messages = iter(messages or []) self.accepted = False @@ -104,15 +82,12 @@ def __init__( async def accept(self) -> None: self.accepted = True - async def receive_text(self) -> str: - try: - return next(self._messages) - except StopIteration as exc: - raise WebSocketDisconnect() from exc - - async def receive(self) -> dict[str, str]: + async def receive(self) -> dict[str, str | bytes]: try: - return {"text": next(self._messages)} + msg = next(self._messages) + if isinstance(msg, bytes): + return {"bytes": msg} + return {"text": msg} except StopIteration as exc: raise WebSocketDisconnect() from exc @@ -123,201 +98,128 @@ async def send_bytes(self, payload: bytes) -> None: self.byte_messages.append(payload) -def test_ui_type_resolution_helpers() -> None: - manager = _FakeManager() - manager._component.get_ui_output_types = lambda: { - "video": UIVideoSender, - "payload": UISender[_Payload], - "missing": _MissingSender, - } - manager._component.get_ui_input_types = lambda: { - "payload_in": UIReceiver[_Payload], - "text_in": UITextReceiver, - "missing_args": _MissingReceiver, +def test_wire_format_encode_decode() -> None: + # Binary round-trip + encoded = encode_binary("n1", "video", b"\x00\x01") + key, payload = decode_binary(encoded) + assert key == ("n1", "video") + assert payload == b"\x00\x01" + + # Binary too short + assert decode_binary(b"\x00")[0] is None + + # JSON round-trip + envelope = encode_json("n1", "text", "hello") + assert envelope == { + "type": "ui_output", "node_id": "n1", "channel": "text", "payload": "hello" } - assert ( - ui_controller._resolve_ui_output_type(manager, "missing-node", "video") is None - ) - assert ( - ui_controller._resolve_ui_output_type(manager, "node", "missing-slot") is None - ) - assert ui_controller._resolve_ui_output_type(manager, "node", "video") is bytes - assert ui_controller._resolve_ui_output_type(manager, "node", "payload") is _Payload - assert ui_controller._resolve_ui_output_type(manager, "node", "missing") is None - - assert ( - ui_controller._resolve_ui_input_type(manager, "missing-node", "payload_in") - is None - ) - assert ui_controller._resolve_ui_input_type(manager, "node", "missing-slot") is None - assert ( - ui_controller._resolve_ui_input_type(manager, "node", "payload_in") is _Payload - ) - assert ui_controller._resolve_ui_input_type(manager, "node", "text_in") is TextFrame - assert ui_controller._resolve_ui_input_type(manager, "node", "missing_args") is None - - -def test_read_ui_output_variants() -> None: - async def run_case(item: object, inner_type: type | None, failing: bool = False): - stop_event = asyncio.Event() - channel = Channel() - sender = Sender(channel) - receiver = Receiver(channel, threading.Event()) - ws = _FakeWebSocket(_FakeManager()) - - if failing: - - async def send_json(_payload: object) -> None: - raise RuntimeError("boom") - - ws.send_json = send_json # type: ignore[method-assign] - - task = asyncio.create_task( - ui_controller._read_ui_output( - ws, "node", "slot", receiver, inner_type, stop_event - ) - ) - await asyncio.sleep(0) - sender.send(item) - sender.send(None) - await task - return ws, channel - - async def run_all() -> None: - ws_bytes, _ = await run_case(b"\x00\x01", bytes) - header_len = struct.unpack(">H", ws_bytes.byte_messages[0][:2])[0] - header = json.loads(ws_bytes.byte_messages[0][2 : 2 + header_len].decode()) - assert header == {"type": "ui_output", "node_id": "node", "channel": "slot"} - assert ws_bytes.byte_messages[0][2 + header_len :] == b"\x00\x01" - - ws_model, _ = await run_case(_Payload(value=7), _Payload) - assert ws_model.json_messages[0] == { - "type": "ui_output", - "node_id": "node", - "channel": "slot", - "payload": {"value": 7}, - } + # JSON with BaseModel + env_model = encode_json("n1", "data", _Payload(value=5)) + assert env_model["payload"] == {"value": 5} - ws_legacy, _ = await run_case(TextFrame.new(text="hello"), TextFrame) - assert ws_legacy.json_messages[0]["payload"] == "hello" + # JSON with TextFrame + env_text = encode_json("n1", "data", TextFrame.new(text="hi")) + assert env_text["payload"] == "hi" - ws_plain, _ = await run_case(5, int) - assert ws_plain.json_messages[0]["payload"] == 5 + # decode_json + result = decode_json(json.dumps({"type": "ui_input", "node_id": "n1", "channel": "c", "payload": "x"})) + assert result == (("n1", "c"), "x") + assert decode_json(json.dumps({"type": "ignored"})) is None - ws_fail, channel = await run_case(9, int, failing=True) - assert ws_fail.json_messages == [] - assert channel._cursors == {} + # deserialize_payload + assert deserialize_payload({"value": 3}, _Payload) == _Payload(value=3) + assert isinstance(deserialize_payload("hi", TextFrame), TextFrame) + assert deserialize_payload(42, None) == 42 - asyncio.run(run_all()) +def test_type_resolution() -> None: + bridge = UIChannelBridge() + manager = _FakeManager() + bridge._manager = manager -def test_watch_ui_channels_and_ui_ws(monkeypatch) -> None: - async def run_watch_ui_channels() -> None: - manager = _FakeManager() - ws = _FakeWebSocket(manager) - stop_event = asyncio.Event() - tasks: dict[tuple[str, str], asyncio.Task[None]] = {} - calls: list[tuple[str, str, type | None]] = [] + assert bridge._resolve_ui_output_type("missing-node", "video") is None + assert bridge._resolve_ui_output_type("node", "missing-slot") is None + assert bridge._resolve_ui_output_type("node", "video") is bytes + assert bridge._resolve_ui_output_type("node", "payload") is _Payload - async def fake_read_ui_output( - ws, node_id, slot, receiver, inner_type, stop_event - ): - calls.append((node_id, slot, inner_type)) - await stop_event.wait() + assert bridge._resolve_ui_input_type("missing-node", "payload_in") is None + assert bridge._resolve_ui_input_type("node", "missing-slot") is None + assert bridge._resolve_ui_input_type("node", "payload_in") is _Payload + assert bridge._resolve_ui_input_type("node", "text_in") is TextFrame + assert bridge._resolve_ui_input_type("node", "missing_args") is None - monkeypatch.setattr(ui_controller, "_read_ui_output", fake_read_ui_output) - watcher = asyncio.create_task( - ui_controller._watch_ui_channels(ws, manager, stop_event, tasks) - ) - await asyncio.sleep(0.01) - assert calls == [("node", "video", bytes)] - assert ("node", "video") in tasks - - old_event = manager._ui_changed - manager._ui_version = 1 - manager._ui_changed = asyncio.Event() - old_event.set() - await asyncio.sleep(0.01) - assert calls[-1] == ("node", "video", bytes) - assert ("node", "video") in tasks - - stop_event.set() - watcher.cancel() - await asyncio.gather(watcher, return_exceptions=True) - - async def run_ui_ws() -> None: +def test_bridge_wire() -> None: + bridge = UIChannelBridge() + manager = _FakeManager() + recv_overrides, send_overrides = bridge.wire(manager) + + # UI input slots should have receiver overrides + server senders + assert ("node", "payload_in") in recv_overrides + assert ("node", "text_in") in recv_overrides + assert ("node", "payload_in") in bridge._ui_senders + assert ("node", "text_in") in bridge._ui_senders + + # UI output slots should have sender overrides + server receivers + assert ("node", "video") in send_overrides + assert ("node", "payload") in send_overrides + assert ("node", "video") in bridge._ui_receivers + assert ("node", "payload") in bridge._ui_receivers + + +def test_bridge_recv_msgs() -> None: + async def run() -> None: + bridge = UIChannelBridge() manager = _FakeManager() + bridge.wire(manager) + + sent: list[object] = [] + for sender in bridge._ui_senders.values(): + original_send = sender.send + sender.send = lambda item, _orig=original_send: (sent.append(item), _orig(item)) # type: ignore[method-assign] + ws = _FakeWebSocket( - manager, messages=[ - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "payload_in", - "payload": {"value": 5}, - } - ), - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "text_in", - "payload": "hello", - } - ), - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "raw_in", - "payload": 3, - } - ), - json.dumps( - { - "type": "ui_input", - "node_id": "node", - "channel": "missing", - "payload": "skip", - } - ), + json.dumps({ + "type": "ui_input", "node_id": "node", + "channel": "text_in", "payload": "hello", + }), + json.dumps({ + "type": "ui_input", "node_id": "node", + "channel": "payload_in", "payload": {"value": 5}, + }), json.dumps({"type": "ignored"}), ], ) - cancelled = {"watcher": False, "task": False} - original_receive = ws.receive - - async def delayed_receive() -> dict[str, str]: - await asyncio.sleep(0) - return await original_receive() - - ws.receive = delayed_receive # type: ignore[method-assign] + bridge._ws = ws + # _recv_msgs raises WebSocketDisconnect when messages run out + try: + await bridge._recv_msgs(ws) # type: ignore[arg-type] + except WebSocketDisconnect: + pass - async def fake_watch_ui_channels(ws, manager, stop_event, tasks): - class _Task: - def cancel(self) -> None: - cancelled["task"] = True + assert any(isinstance(s, TextFrame) and s.text == "hello" for s in sent) + assert any(isinstance(s, _Payload) and s.value == 5 for s in sent) - def __await__(self): - return stop_event.wait().__await__() + asyncio.run(run()) - tasks[("node", "payload")] = _Task() - await stop_event.wait() - monkeypatch.setattr(ui_controller, "_watch_ui_channels", fake_watch_ui_channels) +def test_encode_output_formats() -> None: + # bytes → encode_binary + binary = encode_binary("n1", "video", b"\xff") + key, payload = decode_binary(binary) + assert key == ("n1", "video") + assert payload == b"\xff" - await ui_controller.ui_ws(ws) + # BaseModel → encode_json + env = encode_json("n1", "data", _Payload(value=9)) + assert env["payload"] == {"value": 9} - assert ws.accepted is True - assert isinstance(manager.sent_payloads[0], _Payload) - assert manager.sent_payloads[0].value == 5 - assert isinstance(manager.sent_payloads[1], TextFrame) - assert manager.sent_payloads[1].get() == "hello" - assert manager.sent_payloads[2] == 3 - assert cancelled["task"] is True + # TextFrame → encode_json + env2 = encode_json("n1", "text", TextFrame.new(text="hi")) + assert env2["payload"] == "hi" - asyncio.run(run_watch_ui_channels()) - asyncio.run(run_ui_ws()) + # plain → encode_json + env3 = encode_json("n1", "num", 42) + assert env3["payload"] == 42 diff --git a/backend/tests/core/test_component_graph.py b/backend/tests/core/test_component_graph.py index b9bb25ff..2c9e4808 100644 --- a/backend/tests/core/test_component_graph.py +++ b/backend/tests/core/test_component_graph.py @@ -116,8 +116,8 @@ def test_composite_component_and_graph_manager(monkeypatch) -> None: assert gm.get_node_output("n1") assert gm.get_node_input("n2") gm.run() - assert gm.ui_senders() == {} - assert gm.ui_receivers() == {} + assert gm.ui_input_slots() == [] + assert gm.ui_output_slots() == [] gm.stop() gm.add_edge( diff --git a/backend/tests/core/test_frames_graph.py b/backend/tests/core/test_frames_graph.py index 7d9a2b21..6422e012 100644 --- a/backend/tests/core/test_frames_graph.py +++ b/backend/tests/core/test_frames_graph.py @@ -254,10 +254,11 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: calls: list[str] = [] gm2.stop = lambda: calls.append("stop") # type: ignore[method-assign] gm2.run = lambda: calls.append("run") # type: ignore[method-assign] - updated = gm2.update_primitive_node_init_args("running", {}) + updated, was_running = gm2.update_primitive_node_init_args("running", {}) assert updated is running_node - assert calls == ["stop", "run"] - assert gm2.update_primitive_node_init_args("missing", {}) is None + assert was_running is True + assert calls == ["stop"] + assert gm2.update_primitive_node_init_args("missing", {}) == (None, False) gm3 = GraphManager( Graph( @@ -270,7 +271,7 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: "registered_subclasses", classmethod(lambda cls: {"Known": _CompositeInner}), ) - assert gm3.update_primitive_node_init_args("running", {}) is None + assert gm3.update_primitive_node_init_args("running", {}) == (None, False) nodes = { "a": Node(id_="a", type="Known", init_args={}), @@ -301,12 +302,16 @@ def test_graph_manager_additional_paths(monkeypatch) -> None: gm5 = GraphManager(Graph(nodes={"wrap": composite_node}, edges=[])) assert isinstance(gm5.component("wrap"), CompositeComponent) + from src.api.ui.bridge import UIChannelBridge + ui_node = Node(id_="ui", type="Known", init_args={}) gm6 = GraphManager(Graph(nodes={"ui": ui_node}, edges=[])) gm6._channel_map.clear() - gm6.run() - assert ("ui", "ui_text") in gm6.ui_senders() - assert ("ui", "ui_text") in gm6.ui_receivers() + bridge = UIChannelBridge() + recv_over, send_over = bridge.wire(gm6) + gm6.run(recv_over, send_over) + assert ("ui", "ui_text") in bridge._ui_senders + assert ("ui", "ui_text") in bridge._ui_receivers inner = gm6.component("ui") assert isinstance(inner, _CompositeInner) assert inner.started_inputs is not None and inner.started_inputs.maybe is None diff --git a/backend/tests/test_graph_manager.py b/backend/tests/test_graph_manager.py index 5bdaf6dd..78fa268e 100644 --- a/backend/tests/test_graph_manager.py +++ b/backend/tests/test_graph_manager.py @@ -136,18 +136,23 @@ def test_add_and_delete_edge(self): target_slot="data", ) gm.add_edge(edge) + gm.run() - # After adding: receiver handle exists for the target slot + # After run: receiver handle exists for the target slot assert (id_b, "data") in gm.receiver_handles() # Sender handle for source slot should be wired to at least one channel sender = gm.sender_handles().get((id_a, "data")) assert sender is not None assert len(sender._channels) >= 1 + gm.stop() gm.delete_edge(edge) + gm.run() - # After deleting: receiver handle should be gone - assert (id_b, "data") not in gm.receiver_handles() + # After deleting and re-running: sender has no channels (unconnected) + sender = gm.sender_handles().get((id_a, "data")) + assert sender is not None + assert len(sender._channels) == 0 class TestDeleteNodeRemovesEdges: @@ -418,5 +423,6 @@ def test_update_nonexistent_node_position(self): def test_update_nonexistent_node_init_args(self): gm = GraphManager(_empty_graph()) - result = gm.update_primitive_node_init_args("ghost", {"key": "value"}) - assert result is None + node, was_running = gm.update_primitive_node_init_args("ghost", {"key": "value"}) + assert node is None + assert was_running is False From 8534dc33c735f58fc9578230b3f408600693c73e Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sat, 4 Apr 2026 19:28:40 -0400 Subject: [PATCH 06/21] style: ruff format Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/graph/node/controller.py | 4 ++- backend/src/api/ui/bridge.py | 37 ++++++++++++++++----- backend/src/core/component.py | 10 ++++-- backend/src/core/graph.py | 18 +++++++--- backend/src/core/utils.py | 2 +- backend/tests/api/test_api_controllers.py | 5 ++- backend/tests/api/test_api_services.py | 7 ++-- backend/tests/api/test_node_service.py | 5 ++- backend/tests/api/test_ui_controller.py | 40 ++++++++++++++++------- backend/tests/test_graph_manager.py | 4 ++- 10 files changed, 99 insertions(+), 33 deletions(-) diff --git a/backend/src/api/graph/node/controller.py b/backend/src/api/graph/node/controller.py index 49821baa..6469e02e 100644 --- a/backend/src/api/graph/node/controller.py +++ b/backend/src/api/graph/node/controller.py @@ -86,7 +86,9 @@ def update_primitive_node_init_args( manager: GraphManager = Depends(get_manager), ui_bridge: UIChannelBridge = Depends(get_ui_bridge), ) -> NodeResponse: - node = service.update_primitive_node_init_args(manager, ui_bridge, node_id, req.init_args) + node = service.update_primitive_node_init_args( + manager, ui_bridge, node_id, req.init_args + ) if node is None: raise HTTPException(status_code=404, detail=f"Node not found: {node_id}") return _node_response(node_id, node, manager) diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py index c10cba70..54e74910 100644 --- a/backend/src/api/ui/bridge.py +++ b/backend/src/api/ui/bridge.py @@ -27,9 +27,12 @@ # -- Wire format: encode/decode pairs -- + def encode_binary(node_id: str, slot: str, payload: bytes) -> bytes: """Pack a binary UI message: 2-byte header len + JSON header + payload.""" - header = json.dumps({"type": "ui_output", "node_id": node_id, "channel": slot}).encode() + header = json.dumps( + {"type": "ui_output", "node_id": node_id, "channel": slot} + ).encode() return struct.pack(">H", len(header)) + header + payload @@ -52,7 +55,12 @@ def encode_json(node_id: str, slot: str, item: Any) -> dict[str, Any]: payload = item.text else: payload = item - return {"type": "ui_output", "node_id": node_id, "channel": slot, "payload": payload} + return { + "type": "ui_output", + "node_id": node_id, + "channel": slot, + "payload": payload, + } def decode_json(text: str) -> tuple[SenderKey, Any] | None: @@ -65,15 +73,24 @@ def decode_json(text: str) -> tuple[SenderKey, Any] | None: def deserialize_payload(payload: Any, inner_type: type | None) -> Any: """Convert a raw JSON payload to the expected type.""" - if inner_type is not None and issubclass(inner_type, BaseModel) and isinstance(payload, dict): + if ( + inner_type is not None + and issubclass(inner_type, BaseModel) + and isinstance(payload, dict) + ): return inner_type.model_validate(payload) - if inner_type is not None and hasattr(inner_type, "new") and isinstance(payload, str): + if ( + inner_type is not None + and hasattr(inner_type, "new") + and isinstance(payload, str) + ): return inner_type.new(text=payload) return payload # -- Bridge -- + class UIChannelBridge: def __init__(self) -> None: self._ui_senders: dict[SenderKey, Sender[Any]] = {} @@ -99,9 +116,7 @@ def wire( for node_id, comp in manager.components().items(): stop_event = ( - comp.stop_event - if hasattr(comp, "stop_event") - else threading.Event() + comp.stop_event if hasattr(comp, "stop_event") else threading.Event() ) for slot, slot_type in comp.get_ui_input_types().items(): @@ -150,8 +165,12 @@ def _send_msgs(self, ws: WebSocket) -> None: ) async def _send_msg_task( - self, ws: WebSocket, node_id: str, slot: str, - receiver: Receiver[Any], inner_type: type | None, + self, + ws: WebSocket, + node_id: str, + slot: str, + receiver: Receiver[Any], + inner_type: type | None, ) -> None: try: while not self._stop_event.is_set(): diff --git a/backend/src/core/component.py b/backend/src/core/component.py index 8e555b46..e56ee5ea 100644 --- a/backend/src/core/component.py +++ b/backend/src/core/component.py @@ -396,7 +396,9 @@ def emit(self, outputs: E) -> None: ... # --------------------------------------------------------------------------- -class CompositeComponent(Component[tuple[Receiver[Any] | None, ...], tuple[Sender[Any] | None, ...]]): +class CompositeComponent( + Component[tuple[Receiver[Any] | None, ...], tuple[Sender[Any] | None, ...]] +): """A composite component wrapping a subgraph. Its interface is derived from unmatched ports in the subgraph. @@ -520,7 +522,11 @@ def get_ui_input_types(self) -> dict[str, type]: def get_ui_output_types(self) -> dict[str, type]: return {} - def start(self, inputs: tuple[Receiver[Any] | None, ...], outputs: tuple[Sender[Any] | None, ...]) -> None: + def start( + self, + inputs: tuple[Receiver[Any] | None, ...], + outputs: tuple[Sender[Any] | None, ...], + ) -> None: if self.status == Status.RUNNING: return self._status = Status.SETUP diff --git a/backend/src/core/graph.py b/backend/src/core/graph.py index bed3cf43..b0549a59 100644 --- a/backend/src/core/graph.py +++ b/backend/src/core/graph.py @@ -30,7 +30,9 @@ class Node(BaseModel): y: float = 0.0 sub_graph: Graph | None = None senders: dict[str, Sender[Any] | None] = Field(default_factory=dict, exclude=True) - receivers: dict[str, Receiver[Any] | None] = Field(default_factory=dict, exclude=True) + receivers: dict[str, Receiver[Any] | None] = Field( + default_factory=dict, exclude=True + ) class Edge(BaseModel): @@ -58,7 +60,9 @@ def __init__(self, graph: Graph) -> None: # --- node CRUD --- - def add_primitive_node(self, type_: str, init_args: dict[str, Any]) -> tuple[str, Node]: + def add_primitive_node( + self, type_: str, init_args: dict[str, Any] + ) -> tuple[str, Node]: classes = PrimitiveComponent.registered_subclasses() cls = classes.get(type_) if cls is None: @@ -324,7 +328,11 @@ def run( ui_input_slots = comp.get_ui_input_types() ui_output_slots = comp.get_ui_output_types() - stop_event = comp.stop_event if isinstance(comp, ThreadedComponent) else threading.Event() + stop_event = ( + comp.stop_event + if isinstance(comp, ThreadedComponent) + else threading.Event() + ) # Create fresh handles from plan, store on node. # Overrides take priority over the plan. @@ -332,7 +340,9 @@ def run( if (node_id, slot) in _recv_over: node.receivers[slot] = _recv_over[(node_id, slot)] elif (node_id, slot) in receiver_plan: - node.receivers[slot] = Receiver(receiver_plan[(node_id, slot)], stop_event) + node.receivers[slot] = Receiver( + receiver_plan[(node_id, slot)], stop_event + ) else: node.receivers[slot] = None diff --git a/backend/src/core/utils.py b/backend/src/core/utils.py index 16304369..8a3876ea 100644 --- a/backend/src/core/utils.py +++ b/backend/src/core/utils.py @@ -13,7 +13,7 @@ import numpy as np -SenderKey = tuple[str, str] # (node_id, slot_name) +SenderKey = tuple[str, str] # (node_id, slot_name) ReceiverKey = tuple[str, str] # (node_id, slot_name) _COUNTS: collections.defaultdict[str, itertools.count[int]] = collections.defaultdict( diff --git a/backend/tests/api/test_api_controllers.py b/backend/tests/api/test_api_controllers.py index 082bdeba..669416af 100644 --- a/backend/tests/api/test_api_controllers.py +++ b/backend/tests/api/test_api_controllers.py @@ -196,12 +196,15 @@ def test_run_save_metrics_project_controllers(monkeypatch, tmp_path) -> None: called = {} monkeypatch.setattr( - run_controller.service, "start_all", lambda m, b: called.setdefault("start", True) + run_controller.service, + "start_all", + lambda m, b: called.setdefault("start", True), ) monkeypatch.setattr( run_controller.service, "stop_all", lambda m: called.setdefault("stop", True) ) from src.api.ui.bridge import UIChannelBridge + bridge = UIChannelBridge() run_controller.start_all(manager, bridge) run_controller.stop_all(manager) diff --git a/backend/tests/api/test_api_services.py b/backend/tests/api/test_api_services.py index 5acbc7f8..1cd3c9ce 100644 --- a/backend/tests/api/test_api_services.py +++ b/backend/tests/api/test_api_services.py @@ -64,8 +64,10 @@ def components(self): _no_ui = lambda: {} # noqa: E731 return { "a": types.SimpleNamespace( - type_="A", status=types.SimpleNamespace(value="running"), - get_ui_input_types=_no_ui, get_ui_output_types=_no_ui + type_="A", + status=types.SimpleNamespace(value="running"), + get_ui_input_types=_no_ui, + get_ui_output_types=_no_ui, ) } @@ -202,6 +204,7 @@ def test_edge_run_save_services(tmp_path, monkeypatch) -> None: edge_service.delete_edge(manager, "a", "out", "b", "in") from src.api.ui.bridge import UIChannelBridge + bridge = UIChannelBridge() run_service.start_all(manager, bridge) run_service.stop_all(manager) diff --git a/backend/tests/api/test_node_service.py b/backend/tests/api/test_node_service.py index c06ef478..34ea67e1 100644 --- a/backend/tests/api/test_node_service.py +++ b/backend/tests/api/test_node_service.py @@ -89,7 +89,10 @@ def test_node_service_basic_crud() -> None: node_service.delete_node(m, "n2") assert "n2" not in m.graph.nodes from src.api.ui.bridge import UIChannelBridge - out = node_service.update_primitive_node_init_args(m, UIChannelBridge(), "n1", {"k": 1}) + + out = node_service.update_primitive_node_init_args( + m, UIChannelBridge(), "n1", {"k": 1} + ) assert out.init_args == {"k": 1} diff --git a/backend/tests/api/test_ui_controller.py b/backend/tests/api/test_ui_controller.py index dc993228..29a93b87 100644 --- a/backend/tests/api/test_ui_controller.py +++ b/backend/tests/api/test_ui_controller.py @@ -111,7 +111,10 @@ def test_wire_format_encode_decode() -> None: # JSON round-trip envelope = encode_json("n1", "text", "hello") assert envelope == { - "type": "ui_output", "node_id": "n1", "channel": "text", "payload": "hello" + "type": "ui_output", + "node_id": "n1", + "channel": "text", + "payload": "hello", } # JSON with BaseModel @@ -123,7 +126,11 @@ def test_wire_format_encode_decode() -> None: assert env_text["payload"] == "hi" # decode_json - result = decode_json(json.dumps({"type": "ui_input", "node_id": "n1", "channel": "c", "payload": "x"})) + result = decode_json( + json.dumps( + {"type": "ui_input", "node_id": "n1", "channel": "c", "payload": "x"} + ) + ) assert result == (("n1", "c"), "x") assert decode_json(json.dumps({"type": "ignored"})) is None @@ -177,18 +184,29 @@ async def run() -> None: sent: list[object] = [] for sender in bridge._ui_senders.values(): original_send = sender.send - sender.send = lambda item, _orig=original_send: (sent.append(item), _orig(item)) # type: ignore[method-assign] + sender.send = lambda item, _orig=original_send: ( + sent.append(item), + _orig(item), + ) # type: ignore[method-assign] ws = _FakeWebSocket( messages=[ - json.dumps({ - "type": "ui_input", "node_id": "node", - "channel": "text_in", "payload": "hello", - }), - json.dumps({ - "type": "ui_input", "node_id": "node", - "channel": "payload_in", "payload": {"value": 5}, - }), + json.dumps( + { + "type": "ui_input", + "node_id": "node", + "channel": "text_in", + "payload": "hello", + } + ), + json.dumps( + { + "type": "ui_input", + "node_id": "node", + "channel": "payload_in", + "payload": {"value": 5}, + } + ), json.dumps({"type": "ignored"}), ], ) diff --git a/backend/tests/test_graph_manager.py b/backend/tests/test_graph_manager.py index 78fa268e..e3168da8 100644 --- a/backend/tests/test_graph_manager.py +++ b/backend/tests/test_graph_manager.py @@ -423,6 +423,8 @@ def test_update_nonexistent_node_position(self): def test_update_nonexistent_node_init_args(self): gm = GraphManager(_empty_graph()) - node, was_running = gm.update_primitive_node_init_args("ghost", {"key": "value"}) + node, was_running = gm.update_primitive_node_init_args( + "ghost", {"key": "value"} + ) assert node is None assert was_running is False From 5e3619f06ba472cafd070922e2d8b6c4e53a213d Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:12:19 -0400 Subject: [PATCH 07/21] fix: bridge.run() stores event loop, wire() schedules tasks via call_soon_threadsafe - run(ws) blocks for WS lifetime, stores loop for cross-thread scheduling - wire() creates channels from sync thread, schedules task spawning on event loop - Handles all orderings: WS before start, start before WS, config change while running - WS disconnect cleans up all tasks in run()'s finally block Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/ui/bridge.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py index 54e74910..cb8bde80 100644 --- a/backend/src/api/ui/bridge.py +++ b/backend/src/api/ui/bridge.py @@ -97,7 +97,7 @@ def __init__(self) -> None: self._ui_receivers: dict[ReceiverKey, Receiver[Any]] = {} self._manager: GraphManager | None = None self._ws: WebSocket | None = None - self._stop_event: asyncio.Event = asyncio.Event() + self._loop: asyncio.AbstractEventLoop | None = None self._send_tasks: dict[SenderKey, asyncio.Task[None]] = {} def wire( @@ -106,7 +106,11 @@ def wire( dict[ReceiverKey, Receiver[Any] | None], dict[SenderKey, Sender[Any] | None], ]: - """Create UI channels for all components and return overrides for run().""" + """Create UI channels for all components and return overrides for run(). + + Called from a sync thread (start endpoint). If a WebSocket is + connected, schedules task (re)spawning on the event loop. + """ self._manager = manager self._ui_senders.clear() self._ui_receivers.clear() @@ -131,22 +135,25 @@ def wire( send_overrides[(node_id, slot)] = origin(ch) self._ui_receivers[(node_id, slot)] = Receiver(ch, threading.Event()) - # Restart outbound tasks if WS is connected - if self._ws is not None: - self._send_msgs(self._ws) + if self._ws is not None and self._loop is not None: + self._loop.call_soon_threadsafe(self._start_send_tasks) return recv_overrides, send_overrides async def run(self, ws: WebSocket) -> None: - """Own the WebSocket: read inbound, send outbound via tasks.""" + """Own the WebSocket lifecycle. Blocks until disconnect.""" self._ws = ws - self._stop_event.clear() - self._send_msgs(ws) + self._loop = asyncio.get_running_loop() + + # If pipeline already started (receivers exist), start tasks now + if self._ui_receivers: + self._start_send_tasks() + try: await self._recv_msgs(ws) finally: self._ws = None - self._stop_event.set() + self._loop = None for t in self._send_tasks.values(): t.cancel() await asyncio.gather(*self._send_tasks.values(), return_exceptions=True) @@ -154,9 +161,13 @@ async def run(self, ws: WebSocket) -> None: # -- Outbound: component → frontend -- - def _send_msgs(self, ws: WebSocket) -> None: + def _start_send_tasks(self) -> None: + """Spawn one task per UI output receiver. Must run in the event loop.""" for key in list(self._send_tasks): self._send_tasks.pop(key).cancel() + ws = self._ws + if ws is None: + return for key, receiver in self._ui_receivers.items(): node_id, slot = key inner_type = self._resolve_ui_output_type(node_id, slot) @@ -173,7 +184,7 @@ async def _send_msg_task( inner_type: type | None, ) -> None: try: - while not self._stop_event.is_set(): + while True: item = await asyncio.to_thread(next, receiver) if item is None: break From ec43e41547f11bf2b8fd9e1803fe6a28f9dc5a3b Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:27:36 -0400 Subject: [PATCH 08/21] fix: handle WS disconnect in recv loop, fix event loop issue in wire() Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/src/api/ui/bridge.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/src/api/ui/bridge.py b/backend/src/api/ui/bridge.py index cb8bde80..d999f8cd 100644 --- a/backend/src/api/ui/bridge.py +++ b/backend/src/api/ui/bridge.py @@ -200,6 +200,8 @@ async def _send_msg_task( async def _recv_msgs(self, ws: WebSocket) -> None: while True: ws_msg = await ws.receive() + if ws_msg.get("type") == "websocket.disconnect": + break if "bytes" in ws_msg: key, payload = decode_binary(ws_msg["bytes"]) if key is not None: From 47897b16d3614045ad1b9ed935fdab4a6a5ce54d Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:30:04 -0400 Subject: [PATCH 09/21] fix: sidebar groups by functionality then IO, fix tag names (vision/motion) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Sidebar now groups by functionality (Audio/Vision/LLM/Motion/Misc) at top level - IO tags (Sources/Conduits/Sinks) shown as subsections within each group - Fix tag names: video→vision, image/movement→motion to match backend - Add null guard for unknown functionality tags in InfoPanel Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/src/components/graph/NodeSidebar.tsx | 661 +++++++++++------- frontend/src/styles/globals.css | 10 +- 2 files changed, 427 insertions(+), 244 deletions(-) diff --git a/frontend/src/components/graph/NodeSidebar.tsx b/frontend/src/components/graph/NodeSidebar.tsx index f0c2ee90..7d5b4f54 100644 --- a/frontend/src/components/graph/NodeSidebar.tsx +++ b/frontend/src/components/graph/NodeSidebar.tsx @@ -1,7 +1,30 @@ import { useState, useRef } from "react"; -import { Mic, AudioLines, MessageSquareText, Brain, Volume2, Radio, Speaker, Video, Monitor, Play, Camera, Puzzle, FolderOpen, ChevronDown, ChevronRight, Search, X } from "lucide-react"; +import { + Mic, + AudioLines, + MessageSquareText, + Brain, + Volume2, + Radio, + Speaker, + Video, + Monitor, + Play, + Camera, + Puzzle, + FolderOpen, + ChevronDown, + ChevronRight, + Search, + X, +} from "lucide-react"; import { cn } from "@/lib/utils"; -import type { ComponentInfo, IOTag, FunctionalityTag, GPUTag } from "@/lib/types"; +import type { + ComponentInfo, + IOTag, + FunctionalityTag, + GPUTag, +} from "@/lib/types"; const iconMap: Record> = { Mic, @@ -17,6 +40,24 @@ const iconMap: Record> = { VideoStream: Monitor, }; +const funcLabels: Record = { + audio: "Audio", + vision: "Vision", + llm: "LLM", + motion: "Motion", + misc: "Misc", + other: "Other", +}; + +const funcOrder: FunctionalityTag[] = [ + "audio", + "vision", + "llm", + "motion", + "misc", + "other", +]; + const ioLabels: Record = { source: "Sources", conduit: "Conduits", @@ -25,49 +66,94 @@ const ioLabels: Record = { const ioOrder: IOTag[] = ["source", "conduit", "sink"]; -const funcLabels: Record = { - audio: "Audio", - video: "Video", - llm: "LLM", - image: "Image", - movement: "Movement", - misc: "Misc", - other: "Other", -}; - const ioAccent: Record = { source: "text-source", conduit: "text-conduit", sink: "text-sink", }; -const ioTagColors: Record = { - source: { bg: "bg-source/15", text: "text-source", border: "border-source/30" }, - conduit: { bg: "bg-conduit/15", text: "text-conduit", border: "border-conduit/30" }, +const ioTagColors: Record< + IOTag, + { bg: string; text: string; border: string } +> = { + source: { + bg: "bg-source/15", + text: "text-source", + border: "border-source/30", + }, + conduit: { + bg: "bg-conduit/15", + text: "text-conduit", + border: "border-conduit/30", + }, sink: { bg: "bg-sink/15", text: "text-sink", border: "border-sink/30" }, }; -const funcTagColors: Record = { - audio: { bg: "bg-tag-audio/15", text: "text-tag-audio", border: "border-tag-audio/30" }, - video: { bg: "bg-tag-video/15", text: "text-tag-video", border: "border-tag-video/30" }, - llm: { bg: "bg-tag-llm/15", text: "text-tag-llm", border: "border-tag-llm/30" }, - image: { bg: "bg-tag-image/15", text: "text-tag-image", border: "border-tag-image/30" }, - movement: { bg: "bg-tag-movement/15", text: "text-tag-movement", border: "border-tag-movement/30" }, - misc: { bg: "bg-tag-misc/15", text: "text-tag-misc", border: "border-tag-misc/30" }, - other: { bg: "bg-tag-other/15", text: "text-tag-other", border: "border-tag-other/30" }, +const funcTagColors: Record< + FunctionalityTag, + { bg: string; text: string; border: string } +> = { + audio: { + bg: "bg-tag-audio/15", + text: "text-tag-audio", + border: "border-tag-audio/30", + }, + vision: { + bg: "bg-tag-vision/15", + text: "text-tag-vision", + border: "border-tag-vision/30", + }, + llm: { + bg: "bg-tag-llm/15", + text: "text-tag-llm", + border: "border-tag-llm/30", + }, + motion: { + bg: "bg-tag-motion/15", + text: "text-tag-motion", + border: "border-tag-motion/30", + }, + misc: { + bg: "bg-tag-misc/15", + text: "text-tag-misc", + border: "border-tag-misc/30", + }, + other: { + bg: "bg-tag-other/15", + text: "text-tag-other", + border: "border-tag-other/30", + }, }; -const gpuTagColors: Record = { - nvidia: { bg: "bg-tag-nvidia/15", text: "text-tag-nvidia", border: "border-tag-nvidia/30" }, - apple: { bg: "bg-tag-apple/15", text: "text-tag-apple", border: "border-tag-apple/30" }, - intel: { bg: "bg-tag-intel/15", text: "text-tag-intel", border: "border-tag-intel/30" }, - amd: { bg: "bg-tag-amd/15", text: "text-tag-amd", border: "border-tag-amd/30" }, +const gpuTagColors: Record< + string, + { bg: string; text: string; border: string } +> = { + nvidia: { + bg: "bg-tag-nvidia/15", + text: "text-tag-nvidia", + border: "border-tag-nvidia/30", + }, + apple: { + bg: "bg-tag-apple/15", + text: "text-tag-apple", + border: "border-tag-apple/30", + }, + intel: { + bg: "bg-tag-intel/15", + text: "text-tag-intel", + border: "border-tag-intel/30", + }, + amd: { + bg: "bg-tag-amd/15", + text: "text-tag-amd", + border: "border-tag-amd/30", + }, }; /** Render simple inline markdown: **bold**, *italic*, `code` */ function InlineMarkdown({ text }: { text: string }) { const parts: React.ReactNode[] = []; - // Match **bold**, *italic*, `code` const regex = /(\*\*(.+?)\*\*|\*(.+?)\*|`(.+?)`)/g; let lastIndex = 0; let match: RegExpExecArray | null; @@ -78,11 +164,26 @@ function InlineMarkdown({ text }: { text: string }) { parts.push(text.slice(lastIndex, match.index)); } if (match[2]) { - parts.push({match[2]}); + parts.push( + + {match[2]} + , + ); } else if (match[3]) { - parts.push({match[3]}); + parts.push( + + {match[3]} + , + ); } else if (match[4]) { - parts.push({match[4]}); + parts.push( + + {match[4]} + , + ); } lastIndex = match.index + match[0].length; } @@ -97,12 +198,15 @@ interface NodeSidebarProps { currentProject: string; } -/** Group components by IO → Functionality. Components with no tags go into `untagged`. */ +/** Group components by Functionality → IO. */ function groupComponents(components: ComponentInfo[]) { - const groups: Record> = { - source: {} as Record, - conduit: {} as Record, - sink: {} as Record, + const groups: Record> = { + audio: {} as Record, + vision: {} as Record, + llm: {} as Record, + motion: {} as Record, + misc: {} as Record, + other: {} as Record, }; const untagged: ComponentInfo[] = []; @@ -115,10 +219,10 @@ function groupComponents(components: ComponentInfo[]) { continue; } - for (const io of ioTags) { - for (const func of funcTags) { - if (!groups[io][func]) groups[io][func] = []; - groups[io][func].push(comp); + for (const func of funcTags) { + for (const io of ioTags) { + if (!groups[func][io]) groups[func][io] = []; + groups[func][io].push(comp); } } } @@ -126,7 +230,15 @@ function groupComponents(components: ComponentInfo[]) { return { groups, untagged }; } -function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: React.RefObject; y: number }) { +function InfoPanel({ + item, + sidebarRef, + y, +}: { + item: ComponentInfo; + sidebarRef: React.RefObject; + y: number; +}) { const inputs = Object.entries(item.inputs); const outputs = Object.entries(item.outputs); const rect = sidebarRef.current?.getBoundingClientRect(); @@ -143,42 +255,77 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R )} style={{ left: rect.right + 8, top: Math.max(8, rect.top + y - 40) }} > - {/* Type */}
{item.type_}
- {/* Description */} {item.description && (
)} - {/* Tags */}
{item.tags.io.map((t) => { const colors = ioTagColors[t]; return ( - {t} + + {t} + ); })} {item.tags.functionality.map((t) => { const colors = funcTagColors[t]; + if (!colors) return null; return ( - {t} - ); - })} - {item.tags.gpu.filter((g: GPUTag) => g !== "cpu").map((t: GPUTag) => { - const colors = gpuTagColors[t] ?? { bg: "bg-white/[0.06]", text: "text-white/60", border: "border-white/10" }; - return ( - {t} + + {t} + ); })} + {item.tags.gpu + .filter((g: GPUTag) => g !== "cpu") + .map((t: GPUTag) => { + const colors = gpuTagColors[t] ?? { + bg: "bg-white/[0.06]", + text: "text-white/60", + border: "border-white/10", + }; + return ( + + {t} + + ); + })}
- {/* IO */} {inputs.length > 0 && (
-
Inputs
+
+ Inputs +
{inputs.map(([k, v]) => (
{k}: @@ -189,7 +336,9 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R )} {outputs.length > 0 && (
-
Outputs
+
+ Outputs +
{outputs.map(([k, v]) => (
{k}: @@ -199,12 +348,15 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R
)} - {/* Init params */} {Object.keys(item.init).length > 0 && (
-
Config
+
+ Config +
{Object.keys(item.init).map((k) => ( -
{k}
+
+ {k} +
))}
)} @@ -212,9 +364,15 @@ function InfoPanel({ item, sidebarRef, y }: { item: ComponentInfo; sidebarRef: R ); } -export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { +export function NodeSidebar({ + components, + currentProject, +}: NodeSidebarProps) { const [collapsed, setCollapsed] = useState>({}); - const [hovered, setHovered] = useState<{ item: ComponentInfo; y: number } | null>(null); + const [hovered, setHovered] = useState<{ + item: ComponentInfo; + y: number; + } | null>(null); const [search, setSearch] = useState(""); const hoverTimeout = useRef>(); const sidebarRef = useRef(null); @@ -241,7 +399,8 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { function onItemEnter(e: React.MouseEvent, item: ComponentInfo) { clearTimeout(hoverTimeout.current); const sidebarRect = sidebarRef.current?.getBoundingClientRect(); - const y = e.currentTarget.getBoundingClientRect().top - (sidebarRect?.top ?? 0); + const y = + e.currentTarget.getBoundingClientRect().top - (sidebarRect?.top ?? 0); hoverTimeout.current = setTimeout(() => setHovered({ item, y }), 300); } @@ -251,7 +410,9 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { } const primitives = components.filter((c) => !c.is_composite); - const composites = components.filter((c) => c.is_composite && c.type_ !== currentProject); + const composites = components.filter( + (c) => c.is_composite && c.type_ !== currentProject, + ); const query = search.toLowerCase(); const filtered = query ? primitives.filter((c) => c.type_.toLowerCase().includes(query)) @@ -263,191 +424,215 @@ export function NodeSidebar({ components, currentProject }: NodeSidebarProps) { return ( <> -
-

- Components -

-
- - setSearch(e.target.value)} - placeholder="Search..." - className={cn( - "w-full pl-7 pr-7 py-1.5 text-[11px] text-white/90 placeholder:text-muted-foreground", - "bg-white/[0.06] border border-glass-border rounded-lg", - "outline-none focus:border-white/20 transition-colors", - )} - /> - {search && ( - +
-
- {ioOrder.map((io) => { - const funcGroups = groups[io]; - const funcKeys = Object.keys(funcGroups) as FunctionalityTag[]; - if (funcKeys.length === 0) return null; - - const ioKey = `io:${io}`; - const ioCollapsed = collapsed[ioKey]; + > +

+ Components +

+
+ + setSearch(e.target.value)} + placeholder="Search..." + className={cn( + "w-full pl-7 pr-7 py-1.5 text-[11px] text-white/90 placeholder:text-muted-foreground", + "bg-white/[0.06] border border-glass-border rounded-lg", + "outline-none focus:border-white/20 transition-colors", + )} + /> + {search && ( + + )} +
+
+ {funcOrder.map((func) => { + const ioGroups = groups[func]; + const ioKeys = ioOrder.filter((io) => ioGroups[io]?.length > 0); + if (ioKeys.length === 0) return null; + + const funcKey = `func:${func}`; + const funcCollapsed = collapsed[funcKey]; + + return ( +
+ + + {!funcCollapsed && + ioKeys.map((io) => { + const items = ioGroups[io]!; + const ioKey = `${func}:${io}`; + const ioCollapsed = collapsed[ioKey]; + + return ( +
+ + + {!ioCollapsed && + items.map((item) => { + const Icon = iconMap[item.type_] ?? Puzzle; + return ( +
onDragStart(e, item)} + onMouseEnter={(e) => onItemEnter(e, item)} + onMouseLeave={onItemLeave} + className={cn( + "flex items-center gap-2.5 px-3 py-2 ml-2 cursor-grab", + "transition-all duration-200", + "hover:bg-glass-hover", + )} + > + + + {item.type_} + +
+ ); + })} +
+ ); + })} +
+ ); + })} - return ( -
+ {/* Untagged components */} + {untagged.length > 0 && ( +
- - {!ioCollapsed && funcKeys.map((func) => { - const items = funcGroups[func]!; - const funcKey = `${io}:${func}`; - const funcCollapsed = collapsed[funcKey]; - - return ( -
- - - {!funcCollapsed && items.map((item) => { - const Icon = iconMap[item.type_] ?? Puzzle; - return ( -
onDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-2 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} - > - - - {item.type_} - -
- ); - })} -
- ); - })} + + + {item.type_} + +
+ ); + })}
- ); - })} - - {/* Untagged components */} - {untagged.length > 0 && ( -
- - {!collapsed["io:untagged"] && untagged.map((item) => { - const Icon = iconMap[item.type_] ?? Puzzle; - return ( -
onDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} - > - - - {item.type_} - -
- ); - })} -
- )} + )} - {/* Projects section */} - {filteredComposites.length > 0 && ( -
-
- - {!collapsed["projects"] && filteredComposites.map((item) => ( -
onCompositeDragStart(e, item)} - onMouseEnter={(e) => onItemEnter(e, item)} - onMouseLeave={onItemLeave} - className={cn( - "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", - "transition-all duration-200", - "hover:bg-glass-hover", - )} + {/* Projects section */} + {filteredComposites.length > 0 && ( +
+
+
- ))} -
- )} + {collapsed["projects"] ? ( + + ) : ( + + )} + Projects + + {!collapsed["projects"] && + filteredComposites.map((item) => ( +
onCompositeDragStart(e, item)} + onMouseEnter={(e) => onItemEnter(e, item)} + onMouseLeave={onItemLeave} + className={cn( + "flex items-center gap-2.5 px-3 py-2 ml-4 cursor-grab", + "transition-all duration-200", + "hover:bg-glass-hover", + )} + > + + + {item.type_} + +
+ ))} +
+ )} +
-
- - {/* Hover info panel — rendered outside sidebar to avoid overflow clip */} - {hovered && ( -
clearTimeout(hoverTimeout.current)} - onMouseLeave={onItemLeave} - > - -
- )} - + {/* Hover info panel — rendered outside sidebar to avoid overflow clip */} + {hovered && ( +
clearTimeout(hoverTimeout.current)} + onMouseLeave={onItemLeave} + > + +
+ )} + ); } diff --git a/frontend/src/styles/globals.css b/frontend/src/styles/globals.css index 50195800..35a72730 100644 --- a/frontend/src/styles/globals.css +++ b/frontend/src/styles/globals.css @@ -42,10 +42,9 @@ --color-handle: var(--handle); --color-handle-border: var(--handle-border); --color-tag-audio: var(--tag-audio); - --color-tag-video: var(--tag-video); + --color-tag-vision: var(--tag-vision); --color-tag-llm: var(--tag-llm); - --color-tag-image: var(--tag-image); - --color-tag-movement: var(--tag-movement); + --color-tag-motion: var(--tag-motion); --color-tag-misc: var(--tag-misc); --color-tag-other: var(--tag-other); --color-tag-nvidia: var(--tag-nvidia); @@ -83,10 +82,9 @@ /* Functionality tag colors */ --tag-audio: #a78bfa; - --tag-video: #fb923c; + --tag-vision: #fb923c; --tag-llm: #22d3ee; - --tag-image: #f472b6; - --tag-movement: #34d399; + --tag-motion: #34d399; --tag-misc: #94a3b8; --tag-other: #6b7280; From c9baf40735dc4853693714736e2acc5f431780f5 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:39:54 -0400 Subject: [PATCH 10/21] fix: pass onEdgesDelete through GraphCanvas to ReactFlow Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/src/components/graph/GraphCanvas.tsx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/src/components/graph/GraphCanvas.tsx b/frontend/src/components/graph/GraphCanvas.tsx index b10f5201..7b15b5b4 100644 --- a/frontend/src/components/graph/GraphCanvas.tsx +++ b/frontend/src/components/graph/GraphCanvas.tsx @@ -38,6 +38,7 @@ interface GraphCanvasProps { onDragOver?: (event: React.DragEvent) => void; onNodeDragStop?: OnNodeDrag; onSelectionChange?: OnSelectionChangeFunc; + onEdgesDelete?: (edges: Edge[]) => void; onNodeContextMenu?: (event: React.MouseEvent, node: Node) => void; onPaneClick?: () => void; } @@ -51,6 +52,7 @@ export function GraphCanvas({ onDrop, onDragOver, onNodeDragStop, + onEdgesDelete, onSelectionChange, onNodeContextMenu, onPaneClick, @@ -100,6 +102,7 @@ export function GraphCanvas({ onDrop={onDrop} onDragOver={onDragOver} onNodeDragStop={onNodeDragStop} + onEdgesDelete={onEdgesDelete} onSelectionChange={onSelectionChange} onNodeContextMenu={onNodeContextMenu} onPaneClick={onPaneClick} From 3f887f10b4ffd401651142443a711e8bea453f15 Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:47:50 -0400 Subject: [PATCH 11/21] fix: update profiling scripts for Receiver(channel, stop_event) signature Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/profiling/pipeline_hop_test.py | 8 ++++---- backend/profiling/ttfa_profile.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backend/profiling/pipeline_hop_test.py b/backend/profiling/pipeline_hop_test.py index bfcca22f..fc49d386 100644 --- a/backend/profiling/pipeline_hop_test.py +++ b/backend/profiling/pipeline_hop_test.py @@ -208,20 +208,20 @@ def run_test(model: str): # Wire: FileSource -> VAD -> ASR -> Adapter -> LLM -> NullSink ch1: Channel[AudioFrame] = Channel() s1 = Sender(ch1) - r1 = Receiver(ch1) + r1 = Receiver(ch1, threading.Event()) ch2: Channel[AudioFrame] = Channel() s2 = Sender(ch2) - r2 = Receiver(ch2) + r2 = Receiver(ch2, threading.Event()) ch3: Channel[TextFrame] = Channel() ch4: Channel[list[MessageFrame]] = Channel() s4 = Sender(ch4) - r4 = Receiver(ch4) + r4 = Receiver(ch4, threading.Event()) # Adapter thread def adapter(): comp = _Adapter() comp._stop_event = threading.Event() - for text_frame in Receiver(ch3)(comp): + for text_frame in Receiver(ch3, threading.Event()): if text_frame is None: break if comp.stop_event.is_set(): diff --git a/backend/profiling/ttfa_profile.py b/backend/profiling/ttfa_profile.py index 2977e3f3..1f825956 100644 --- a/backend/profiling/ttfa_profile.py +++ b/backend/profiling/ttfa_profile.py @@ -585,7 +585,7 @@ def wrap_audio(item): def adapter3(): comp = _Stub() comp._stop_event = threading.Event() - for tf in Receiver(ch3)(comp): + for tf in Receiver(ch3, threading.Event()): if tf is None: break msgs = [ @@ -607,19 +607,19 @@ def adapter3(): vad_2 = VAD(config=VADConfig()) file_source_2 = FileSource(wav_path=wav_path) - null_sink_2.start(NullSinkInputs(audio=Receiver(ch6)), ()) + null_sink_2.start(NullSinkInputs(audio=Receiver(ch6, threading.Event())), ()) tts_2.start( - tts_mod.TTSInputs(text=Receiver(ch5)), + tts_mod.TTSInputs(text=Receiver(ch5, threading.Event())), tts_mod.TTSOutputs(audio=Sender(ch6), text=Sender()), ) llm_2.start( - llm_mod.LLMInputs(messages=Receiver(ch4)), llm_mod.LLMOutputs(token=Sender(ch5)) + llm_mod.LLMInputs(messages=Receiver(ch4, threading.Event())), llm_mod.LLMOutputs(token=Sender(ch5)) ) asr_2.start( - asr_mod.ASRInputs(audio=Receiver(ch2)), asr_mod.ASROutputs(text=Sender(ch3)) + asr_mod.ASRInputs(audio=Receiver(ch2, threading.Event())), asr_mod.ASROutputs(text=Sender(ch3)) ) vad_2.start( - vad_mod.VADInputs(audio=Receiver(ch1)), vad_mod.VADOutputs(audio=Sender(ch2)) + vad_mod.VADInputs(audio=Receiver(ch1, threading.Event())), vad_mod.VADOutputs(audio=Sender(ch2)) ) file_source_2.start((), FileSourceOutputs(audio=Sender(ch1))) From 4347722c814cf53256c90e56328cebb320ec21bf Mon Sep 17 00:00:00 2001 From: xingjianll <4396kevinliu@gmail.com> Date: Sun, 5 Apr 2026 00:49:32 -0400 Subject: [PATCH 12/21] fix: update frontend tests for sidebar grouping and edge deletion changes Co-Authored-By: Claude Opus 4.6 (1M context) --- frontend/tests/App.test.tsx | 2 +- .../tests/components/graph/NodeSidebar.test.tsx | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/frontend/tests/App.test.tsx b/frontend/tests/App.test.tsx index 8ad4eb63..0c44a85d 100644 --- a/frontend/tests/App.test.tsx +++ b/frontend/tests/App.test.tsx @@ -159,7 +159,7 @@ vi.mock("@/components/graph/GraphCanvas", () => ({ -