From 8efb590cef55acca1bb91954858d67de239748d2 Mon Sep 17 00:00:00 2001 From: Antonio Bevilacqua Date: Wed, 11 Mar 2026 15:44:08 +0100 Subject: [PATCH 1/6] fix: delete message attachments at edit --- backend/chainlit/socket.py | 32 ++++++++ backend/tests/test_socket.py | 143 +++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 740f0c276a..cadecc84b8 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -1,4 +1,5 @@ import asyncio +from concurrent.futures import thread import json from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union from urllib.parse import unquote @@ -336,6 +337,7 @@ async def edit_message(sid, payload: MessagePayload): if message.id == payload["message"]["id"]: message.content = payload["message"]["output"] await message.update() + await delete_message_children(message.id, session.thread_id) orig_message = message await context.emitter.task_start() @@ -348,6 +350,36 @@ async def edit_message(sid, payload: MessagePayload): finally: await context.emitter.task_end() +async def delete_message_children(message_id: str, thread_id: str): + data_layer = get_data_layer() + if not data_layer: + return + + thread = await data_layer.get_thread(thread_id) + if thread is None: + return + + steps = thread.get("steps", []) + elements = thread.get("elements", []) + + # Collect all descendant step IDs whose root parent is message_id + def collect_descendants(parent_id: str) -> set: + ids = set() + for step in steps: + if step.get("parentId") == parent_id: + ids.add(step["id"]) + ids |= collect_descendants(step["id"]) + return ids + + descendant_ids = collect_descendants(message_id) + + for step in steps: + if step["id"] in descendant_ids: + await data_layer.delete_step(step["id"]) + + for element in elements: + if element.get("forId") in descendant_ids: + await data_layer.delete_element(element["id"]) @sio.on("message_favorite") # pyright: ignore [reportOptionalCall] async def message_favorite(sid, payload: MessagePayload): diff --git a/backend/tests/test_socket.py b/backend/tests/test_socket.py index e45247b744..dd0ac50f84 100644 --- a/backend/tests/test_socket.py +++ b/backend/tests/test_socket.py @@ -10,6 +10,7 @@ _get_token_from_cookie, clean_session, connection_successful, + delete_message_children, load_user_env, persist_user_session, restore_existing_session, @@ -630,3 +631,145 @@ async def test_on_chat_start_not_duplicated_on_fresh_then_reconnect( await connection_successful("sid-1") assert on_chat_start.call_count == 1 + + +class TestDeleteMessageChildren: + """Test suite for delete_message_children function.""" + + @pytest.mark.asyncio + async def test_no_data_layer(self): + """Does nothing when there is no data layer.""" + with patch("chainlit.socket.get_data_layer", return_value=None): + # Should return without error + await delete_message_children("msg_1", "thread_1") + + @pytest.mark.asyncio + async def test_thread_not_found(self): + """Does nothing when the thread does not exist.""" + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = None + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + mock_data_layer.delete_step.assert_not_called() + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_no_children(self): + """Does nothing when the message has no child steps.""" + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "other_msg", "parentId": None}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + mock_data_layer.delete_step.assert_not_called() + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_direct_children_deleted(self): + """Deletes direct children of the message.""" + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "child_2", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + deleted_ids = { + call.args[0] for call in mock_data_layer.delete_step.call_args_list + } + assert deleted_ids == {"child_1", "child_2"} + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_nested_descendants_deleted(self): + """Recursively deletes grandchildren and deeper descendants.""" + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "grandchild_1", "parentId": "child_1"}, + {"id": "great_grandchild_1", "parentId": "grandchild_1"}, + {"id": "unrelated", "parentId": None}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + deleted_ids = { + call.args[0] for call in mock_data_layer.delete_step.call_args_list + } + assert deleted_ids == {"child_1", "grandchild_1", "great_grandchild_1"} + assert "msg_1" not in deleted_ids + assert "unrelated" not in deleted_ids + + @pytest.mark.asyncio + async def test_elements_linked_to_descendants_deleted(self): + """Deletes elements whose forId references a descendant step.""" + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + ], + "elements": [ + {"id": "elem_1", "forId": "child_1"}, + { + "id": "elem_2", + "forId": "msg_1", + }, # forId is the message itself — not deleted + {"id": "elem_3", "forId": "unrelated"}, + ], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + mock_data_layer.delete_step.assert_called_once_with("child_1") + + deleted_element_ids = { + call.args[0] for call in mock_data_layer.delete_element.call_args_list + } + assert deleted_element_ids == {"elem_1"} + + @pytest.mark.asyncio + async def test_message_itself_is_not_deleted(self): + """The root message itself is never deleted, only its descendants.""" + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): + await delete_message_children("msg_1", "thread_1") + + deleted_ids = { + call.args[0] for call in mock_data_layer.delete_step.call_args_list + } + assert "msg_1" not in deleted_ids From b074d56421e13fb50e8ef838e399b440db3e0226 Mon Sep 17 00:00:00 2001 From: Antonio Bevilacqua Date: Tue, 17 Mar 2026 12:41:53 +0100 Subject: [PATCH 2/6] fix: apply review --- backend/chainlit/message.py | 39 ++++++++++ backend/chainlit/socket.py | 33 +------- backend/tests/test_message.py | 143 ++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+), 32 deletions(-) diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 0700f630a7..c9310ae4c5 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -138,10 +138,49 @@ async def remove(self): raise e logger.error(f"Failed to persist message deletion: {e!s}") + await self.remove_children() await context.emitter.delete_step(step_dict) return True + async def remove_children(self): + data_layer = get_data_layer() + if not data_layer: + return + + thread = await data_layer.get_thread(self.thread_id) + if thread is None: + return + + steps = thread.get("steps", []) + elements = thread.get("elements", []) + + def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list: + """Return descendant IDs in post-order (leaves first, parents last).""" + if visited is None: + visited = set() + if parent_id in visited: + return [] + visited.add(parent_id) + result = [] + for step in steps: + if step.get("parentId") == parent_id: + result.extend(collect_descendants(step["id"], visited)) + result.append(step["id"]) + return result + + # Ordered leaves-first so that referential integrity constraints are respected. + ordered_descendant_ids = collect_descendants(self.id) + descendant_id_set = set(ordered_descendant_ids) + + for step_id in ordered_descendant_ids: + await data_layer.delete_step(step_id) + + orphaned_elements = [e for e in elements if e.get("forId") in descendant_id_set] + await asyncio.gather( + *[data_layer.delete_element(e["id"]) for e in orphaned_elements] + ) + async def _create(self): step_dict = self.to_dict() data_layer = get_data_layer() diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index cadecc84b8..07c56edebd 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -337,7 +337,7 @@ async def edit_message(sid, payload: MessagePayload): if message.id == payload["message"]["id"]: message.content = payload["message"]["output"] await message.update() - await delete_message_children(message.id, session.thread_id) + await message.remove_children() orig_message = message await context.emitter.task_start() @@ -350,37 +350,6 @@ async def edit_message(sid, payload: MessagePayload): finally: await context.emitter.task_end() -async def delete_message_children(message_id: str, thread_id: str): - data_layer = get_data_layer() - if not data_layer: - return - - thread = await data_layer.get_thread(thread_id) - if thread is None: - return - - steps = thread.get("steps", []) - elements = thread.get("elements", []) - - # Collect all descendant step IDs whose root parent is message_id - def collect_descendants(parent_id: str) -> set: - ids = set() - for step in steps: - if step.get("parentId") == parent_id: - ids.add(step["id"]) - ids |= collect_descendants(step["id"]) - return ids - - descendant_ids = collect_descendants(message_id) - - for step in steps: - if step["id"] in descendant_ids: - await data_layer.delete_step(step["id"]) - - for element in elements: - if element.get("forId") in descendant_ids: - await data_layer.delete_element(element["id"]) - @sio.on("message_favorite") # pyright: ignore [reportOptionalCall] async def message_favorite(sid, payload: MessagePayload): """Handle a message favorite toggle.""" diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 952f557414..4cc259e632 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -755,3 +755,146 @@ def test_message_to_dict_with_none_metadata(self): result = msg.to_dict() assert result["metadata"] == {} + + +class TestRemoveChildren: + """Test suite for Message.remove_children.""" + + def _make_message(self, msg_id="msg_1", thread_id="thread_1"): + with mock_chainlit_context(): + msg = Message(content="test") + msg.id = msg_id + msg.thread_id = thread_id + return msg + + @pytest.mark.asyncio + async def test_no_data_layer(self): + """Does nothing when there is no data layer.""" + msg = self._make_message() + with patch("chainlit.message.get_data_layer", return_value=None): + await msg.remove_children() + + @pytest.mark.asyncio + async def test_thread_not_found(self): + """Does nothing when the thread does not exist.""" + msg = self._make_message() + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = None + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + mock_data_layer.delete_step.assert_not_called() + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_no_children(self): + """Does nothing when the message has no child steps.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "other_msg", "parentId": None}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + mock_data_layer.delete_step.assert_not_called() + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_direct_children_deleted(self): + """Deletes direct children of the message.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "child_2", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} + assert deleted_ids == {"child_1", "child_2"} + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_nested_descendants_deleted(self): + """Recursively deletes grandchildren and deeper descendants.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "grandchild_1", "parentId": "child_1"}, + {"id": "great_grandchild_1", "parentId": "grandchild_1"}, + {"id": "unrelated", "parentId": None}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} + assert deleted_ids == {"child_1", "grandchild_1", "great_grandchild_1"} + assert "msg_1" not in deleted_ids + assert "unrelated" not in deleted_ids + + @pytest.mark.asyncio + async def test_elements_linked_to_descendants_deleted(self): + """Deletes elements whose forId references a descendant step.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + ], + "elements": [ + {"id": "elem_1", "forId": "child_1"}, + {"id": "elem_2", "forId": "msg_1"}, # forId is the message itself — not deleted + {"id": "elem_3", "forId": "unrelated"}, + ], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + mock_data_layer.delete_step.assert_called_once_with("child_1") + deleted_element_ids = {call.args[0] for call in mock_data_layer.delete_element.call_args_list} + assert deleted_element_ids == {"elem_1"} + + @pytest.mark.asyncio + async def test_message_itself_is_not_deleted(self): + """The root message itself is never deleted, only its descendants.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} + assert "msg_1" not in deleted_ids From 1336d88346902abf8c8db1838ea70320630202ef Mon Sep 17 00:00:00 2001 From: Antonio Bevilacqua Date: Tue, 24 Mar 2026 16:08:21 +0100 Subject: [PATCH 3/6] fix: cubic review --- backend/chainlit/message.py | 8 +++++++- backend/tests/test_message.py | 29 ++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index c9310ae4c5..12c4317499 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -138,7 +138,13 @@ async def remove(self): raise e logger.error(f"Failed to persist message deletion: {e!s}") - await self.remove_children() + try: + await self.remove_children() + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message children deletion: {e!s}") + await context.emitter.delete_step(step_dict) return True diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 4cc259e632..7279bf3a11 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -825,8 +825,10 @@ async def test_direct_children_deleted(self): with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} - assert deleted_ids == {"child_1", "child_2"} + deleted_ids = [ + call.args[0] for call in mock_data_layer.delete_step.call_args_list + ] + assert deleted_ids == ["child_1", "child_2"] mock_data_layer.delete_element.assert_not_called() @pytest.mark.asyncio @@ -849,8 +851,10 @@ async def test_nested_descendants_deleted(self): with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} - assert deleted_ids == {"child_1", "grandchild_1", "great_grandchild_1"} + deleted_ids = [ + call.args[0] for call in mock_data_layer.delete_step.call_args_list + ] + assert deleted_ids == ["child_1", "grandchild_1", "great_grandchild_1"] assert "msg_1" not in deleted_ids assert "unrelated" not in deleted_ids @@ -865,7 +869,10 @@ async def test_elements_linked_to_descendants_deleted(self): ], "elements": [ {"id": "elem_1", "forId": "child_1"}, - {"id": "elem_2", "forId": "msg_1"}, # forId is the message itself — not deleted + { + "id": "elem_2", + "forId": "msg_1", + }, # forId is the message itself — not deleted {"id": "elem_3", "forId": "unrelated"}, ], } @@ -875,9 +882,11 @@ async def test_elements_linked_to_descendants_deleted(self): with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - mock_data_layer.delete_step.assert_called_once_with("child_1") - deleted_element_ids = {call.args[0] for call in mock_data_layer.delete_element.call_args_list} - assert deleted_element_ids == {"elem_1"} + mock_data_layer.delete_step.assert_awaited_once_with("child_1") + deleted_element_ids = [ + call.args[0] for call in mock_data_layer.delete_element.call_args_list + ] + assert deleted_element_ids == ["elem_1"] @pytest.mark.asyncio async def test_message_itself_is_not_deleted(self): @@ -896,5 +905,7 @@ async def test_message_itself_is_not_deleted(self): with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = {call.args[0] for call in mock_data_layer.delete_step.call_args_list} + deleted_ids = [ + call.args[0] for call in mock_data_layer.delete_step.call_args_list + ] assert "msg_1" not in deleted_ids From 7c2e465f01fcaa635fc9c4db169ff1d24fbf5fc7 Mon Sep 17 00:00:00 2001 From: Antonio Bevilacqua Date: Tue, 24 Mar 2026 17:27:13 +0100 Subject: [PATCH 4/6] fix: unnecessary elements delete --- backend/chainlit/message.py | 7 ------- backend/tests/test_message.py | 32 +------------------------------- 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 12c4317499..37abdc0f34 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -159,7 +159,6 @@ async def remove_children(self): return steps = thread.get("steps", []) - elements = thread.get("elements", []) def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list: """Return descendant IDs in post-order (leaves first, parents last).""" @@ -177,16 +176,10 @@ def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list: # Ordered leaves-first so that referential integrity constraints are respected. ordered_descendant_ids = collect_descendants(self.id) - descendant_id_set = set(ordered_descendant_ids) for step_id in ordered_descendant_ids: await data_layer.delete_step(step_id) - orphaned_elements = [e for e in elements if e.get("forId") in descendant_id_set] - await asyncio.gather( - *[data_layer.delete_element(e["id"]) for e in orphaned_elements] - ) - async def _create(self): step_dict = self.to_dict() data_layer = get_data_layer() diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 7279bf3a11..1730cd71ad 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -854,40 +854,10 @@ async def test_nested_descendants_deleted(self): deleted_ids = [ call.args[0] for call in mock_data_layer.delete_step.call_args_list ] - assert deleted_ids == ["child_1", "grandchild_1", "great_grandchild_1"] + assert deleted_ids == ["great_grandchild_1", "grandchild_1", "child_1"] assert "msg_1" not in deleted_ids assert "unrelated" not in deleted_ids - @pytest.mark.asyncio - async def test_elements_linked_to_descendants_deleted(self): - """Deletes elements whose forId references a descendant step.""" - msg = self._make_message() - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "child_1", "parentId": "msg_1"}, - ], - "elements": [ - {"id": "elem_1", "forId": "child_1"}, - { - "id": "elem_2", - "forId": "msg_1", - }, # forId is the message itself — not deleted - {"id": "elem_3", "forId": "unrelated"}, - ], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): - await msg.remove_children() - - mock_data_layer.delete_step.assert_awaited_once_with("child_1") - deleted_element_ids = [ - call.args[0] for call in mock_data_layer.delete_element.call_args_list - ] - assert deleted_element_ids == ["elem_1"] - @pytest.mark.asyncio async def test_message_itself_is_not_deleted(self): """The root message itself is never deleted, only its descendants.""" From be7727025f1238e8731d05ad506d5c16b6eb71a6 Mon Sep 17 00:00:00 2001 From: Allaoua Benchikh Date: Thu, 30 Apr 2026 16:25:47 +0200 Subject: [PATCH 5/6] Delete elements and feedback for decendants steps --- backend/chainlit/message.py | 13 ++++ backend/chainlit/socket.py | 2 +- backend/tests/test_message.py | 51 +++++++++++- backend/tests/test_socket.py | 143 ---------------------------------- 4 files changed, 64 insertions(+), 145 deletions(-) diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 37abdc0f34..4c19ed900e 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -176,6 +176,19 @@ def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list: # Ordered leaves-first so that referential integrity constraints are respected. ordered_descendant_ids = collect_descendants(self.id) + descendant_set = set(ordered_descendant_ids) + + for step in steps: + step_id = step.get("id") + feedback_id = (step.get("feedback") or {}).get("id") + if step_id in descendant_set and feedback_id: + await data_layer.delete_feedback(feedback_id) + + for element in thread.get("elements", []): + for_id = element.get("forId") + element_id = element.get("id") + if for_id in descendant_set and element_id: + await data_layer.delete_element(element_id, self.thread_id) for step_id in ordered_descendant_ids: await data_layer.delete_step(step_id) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 07c56edebd..6da4d6102a 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -1,5 +1,4 @@ import asyncio -from concurrent.futures import thread import json from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union from urllib.parse import unquote @@ -350,6 +349,7 @@ async def edit_message(sid, payload: MessagePayload): finally: await context.emitter.task_end() + @sio.on("message_favorite") # pyright: ignore [reportOptionalCall] async def message_favorite(sid, payload: MessagePayload): """Handle a message favorite toggle.""" diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 1730cd71ad..7dfad0578b 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -1,7 +1,7 @@ import asyncio import json from contextlib import contextmanager -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock, call, patch import pytest @@ -786,6 +786,7 @@ async def test_thread_not_found(self): mock_data_layer.delete_step.assert_not_called() mock_data_layer.delete_element.assert_not_called() + mock_data_layer.delete_feedback.assert_not_called() @pytest.mark.asyncio async def test_no_children(self): @@ -806,6 +807,7 @@ async def test_no_children(self): mock_data_layer.delete_step.assert_not_called() mock_data_layer.delete_element.assert_not_called() + mock_data_layer.delete_feedback.assert_not_called() @pytest.mark.asyncio async def test_direct_children_deleted(self): @@ -830,6 +832,7 @@ async def test_direct_children_deleted(self): ] assert deleted_ids == ["child_1", "child_2"] mock_data_layer.delete_element.assert_not_called() + mock_data_layer.delete_feedback.assert_not_called() @pytest.mark.asyncio async def test_nested_descendants_deleted(self): @@ -857,6 +860,50 @@ async def test_nested_descendants_deleted(self): assert deleted_ids == ["great_grandchild_1", "grandchild_1", "child_1"] assert "msg_1" not in deleted_ids assert "unrelated" not in deleted_ids + mock_data_layer.delete_feedback.assert_not_called() + mock_data_layer.delete_element.assert_not_called() + + @pytest.mark.asyncio + async def test_feedback_and_elements_deleted_before_steps(self): + """Removes feedback and elements for descendants before delete_step.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + { + "id": "child_1", + "parentId": "msg_1", + "feedback": { + "id": "fb_1", + "forId": "child_1", + "value": 1, + "comment": None, + }, + }, + {"id": "child_2", "parentId": "msg_1"}, + ], + "elements": [ + {"id": "el_1", "forId": "child_1", "threadId": "thread_1"}, + {"id": "el_other", "forId": "msg_1", "threadId": "thread_1"}, + {"id": "el_unrelated", "forId": "other_root", "threadId": "thread_1"}, + ], + } + mock_data_layer = AsyncMock() + mock_data_layer.get_thread.return_value = thread + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + mock_data_layer.assert_has_calls( + [ + call.get_thread("thread_1"), + call.delete_feedback("fb_1"), + call.delete_element("el_1", "thread_1"), + call.delete_step("child_1"), + call.delete_step("child_2"), + ], + any_order=False, + ) @pytest.mark.asyncio async def test_message_itself_is_not_deleted(self): @@ -879,3 +926,5 @@ async def test_message_itself_is_not_deleted(self): call.args[0] for call in mock_data_layer.delete_step.call_args_list ] assert "msg_1" not in deleted_ids + mock_data_layer.delete_feedback.assert_not_called() + mock_data_layer.delete_element.assert_not_called() diff --git a/backend/tests/test_socket.py b/backend/tests/test_socket.py index dd0ac50f84..e45247b744 100644 --- a/backend/tests/test_socket.py +++ b/backend/tests/test_socket.py @@ -10,7 +10,6 @@ _get_token_from_cookie, clean_session, connection_successful, - delete_message_children, load_user_env, persist_user_session, restore_existing_session, @@ -631,145 +630,3 @@ async def test_on_chat_start_not_duplicated_on_fresh_then_reconnect( await connection_successful("sid-1") assert on_chat_start.call_count == 1 - - -class TestDeleteMessageChildren: - """Test suite for delete_message_children function.""" - - @pytest.mark.asyncio - async def test_no_data_layer(self): - """Does nothing when there is no data layer.""" - with patch("chainlit.socket.get_data_layer", return_value=None): - # Should return without error - await delete_message_children("msg_1", "thread_1") - - @pytest.mark.asyncio - async def test_thread_not_found(self): - """Does nothing when the thread does not exist.""" - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = None - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - mock_data_layer.delete_step.assert_not_called() - mock_data_layer.delete_element.assert_not_called() - - @pytest.mark.asyncio - async def test_no_children(self): - """Does nothing when the message has no child steps.""" - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "other_msg", "parentId": None}, - ], - "elements": [], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - mock_data_layer.delete_step.assert_not_called() - mock_data_layer.delete_element.assert_not_called() - - @pytest.mark.asyncio - async def test_direct_children_deleted(self): - """Deletes direct children of the message.""" - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "child_1", "parentId": "msg_1"}, - {"id": "child_2", "parentId": "msg_1"}, - ], - "elements": [], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - deleted_ids = { - call.args[0] for call in mock_data_layer.delete_step.call_args_list - } - assert deleted_ids == {"child_1", "child_2"} - mock_data_layer.delete_element.assert_not_called() - - @pytest.mark.asyncio - async def test_nested_descendants_deleted(self): - """Recursively deletes grandchildren and deeper descendants.""" - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "child_1", "parentId": "msg_1"}, - {"id": "grandchild_1", "parentId": "child_1"}, - {"id": "great_grandchild_1", "parentId": "grandchild_1"}, - {"id": "unrelated", "parentId": None}, - ], - "elements": [], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - deleted_ids = { - call.args[0] for call in mock_data_layer.delete_step.call_args_list - } - assert deleted_ids == {"child_1", "grandchild_1", "great_grandchild_1"} - assert "msg_1" not in deleted_ids - assert "unrelated" not in deleted_ids - - @pytest.mark.asyncio - async def test_elements_linked_to_descendants_deleted(self): - """Deletes elements whose forId references a descendant step.""" - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "child_1", "parentId": "msg_1"}, - ], - "elements": [ - {"id": "elem_1", "forId": "child_1"}, - { - "id": "elem_2", - "forId": "msg_1", - }, # forId is the message itself — not deleted - {"id": "elem_3", "forId": "unrelated"}, - ], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - mock_data_layer.delete_step.assert_called_once_with("child_1") - - deleted_element_ids = { - call.args[0] for call in mock_data_layer.delete_element.call_args_list - } - assert deleted_element_ids == {"elem_1"} - - @pytest.mark.asyncio - async def test_message_itself_is_not_deleted(self): - """The root message itself is never deleted, only its descendants.""" - thread = { - "steps": [ - {"id": "msg_1", "parentId": None}, - {"id": "child_1", "parentId": "msg_1"}, - ], - "elements": [], - } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread - - with patch("chainlit.socket.get_data_layer", return_value=mock_data_layer): - await delete_message_children("msg_1", "thread_1") - - deleted_ids = { - call.args[0] for call in mock_data_layer.delete_step.call_args_list - } - assert "msg_1" not in deleted_ids From 6172e78e57218a7d8a52c6e0bfe07115e99844e9 Mon Sep 17 00:00:00 2001 From: Allaoua Benchikh Date: Wed, 6 May 2026 14:57:44 +0200 Subject: [PATCH 6/6] Improved tests --- backend/tests/test_message.py | 101 ++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 46 deletions(-) diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 7dfad0578b..241e58a4ba 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -1,7 +1,7 @@ import asyncio import json from contextlib import contextmanager -from unittest.mock import AsyncMock, Mock, call, patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -767,6 +767,33 @@ def _make_message(self, msg_id="msg_1", thread_id="thread_1"): msg.thread_id = thread_id return msg + def _tracked_data_layer(self, get_thread_result): + """ + Data layer mock whose delete_* / get_thread record only when awaited + (misses missing-await bugs). ``events`` is the strict call sequence. + """ + events: list[tuple] = [] + + async def get_thread(thread_id): + events.append(("get_thread", thread_id)) + return get_thread_result + + async def delete_feedback(feedback_id): + events.append(("delete_feedback", feedback_id)) + + async def delete_element(element_id, thread_id=None): + events.append(("delete_element", element_id, thread_id)) + + async def delete_step(step_id): + events.append(("delete_step", step_id)) + + mock_layer = AsyncMock() + mock_layer.get_thread = AsyncMock(side_effect=get_thread) + mock_layer.delete_feedback = AsyncMock(side_effect=delete_feedback) + mock_layer.delete_element = AsyncMock(side_effect=delete_element) + mock_layer.delete_step = AsyncMock(side_effect=delete_step) + return mock_layer, events + @pytest.mark.asyncio async def test_no_data_layer(self): """Does nothing when there is no data layer.""" @@ -778,15 +805,12 @@ async def test_no_data_layer(self): async def test_thread_not_found(self): """Does nothing when the thread does not exist.""" msg = self._make_message() - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = None + mock_data_layer, events = self._tracked_data_layer(None) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - mock_data_layer.delete_step.assert_not_called() - mock_data_layer.delete_element.assert_not_called() - mock_data_layer.delete_feedback.assert_not_called() + assert events == [("get_thread", "thread_1")] @pytest.mark.asyncio async def test_no_children(self): @@ -799,15 +823,12 @@ async def test_no_children(self): ], "elements": [], } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread + mock_data_layer, events = self._tracked_data_layer(thread) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - mock_data_layer.delete_step.assert_not_called() - mock_data_layer.delete_element.assert_not_called() - mock_data_layer.delete_feedback.assert_not_called() + assert events == [("get_thread", "thread_1")] @pytest.mark.asyncio async def test_direct_children_deleted(self): @@ -821,18 +842,16 @@ async def test_direct_children_deleted(self): ], "elements": [], } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread + mock_data_layer, events = self._tracked_data_layer(thread) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = [ - call.args[0] for call in mock_data_layer.delete_step.call_args_list + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "child_1"), + ("delete_step", "child_2"), ] - assert deleted_ids == ["child_1", "child_2"] - mock_data_layer.delete_element.assert_not_called() - mock_data_layer.delete_feedback.assert_not_called() @pytest.mark.asyncio async def test_nested_descendants_deleted(self): @@ -848,20 +867,17 @@ async def test_nested_descendants_deleted(self): ], "elements": [], } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread + mock_data_layer, events = self._tracked_data_layer(thread) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = [ - call.args[0] for call in mock_data_layer.delete_step.call_args_list + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "great_grandchild_1"), + ("delete_step", "grandchild_1"), + ("delete_step", "child_1"), ] - assert deleted_ids == ["great_grandchild_1", "grandchild_1", "child_1"] - assert "msg_1" not in deleted_ids - assert "unrelated" not in deleted_ids - mock_data_layer.delete_feedback.assert_not_called() - mock_data_layer.delete_element.assert_not_called() @pytest.mark.asyncio async def test_feedback_and_elements_deleted_before_steps(self): @@ -888,22 +904,18 @@ async def test_feedback_and_elements_deleted_before_steps(self): {"id": "el_unrelated", "forId": "other_root", "threadId": "thread_1"}, ], } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread + mock_data_layer, events = self._tracked_data_layer(thread) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - mock_data_layer.assert_has_calls( - [ - call.get_thread("thread_1"), - call.delete_feedback("fb_1"), - call.delete_element("el_1", "thread_1"), - call.delete_step("child_1"), - call.delete_step("child_2"), - ], - any_order=False, - ) + assert events == [ + ("get_thread", "thread_1"), + ("delete_feedback", "fb_1"), + ("delete_element", "el_1", "thread_1"), + ("delete_step", "child_1"), + ("delete_step", "child_2"), + ] @pytest.mark.asyncio async def test_message_itself_is_not_deleted(self): @@ -916,15 +928,12 @@ async def test_message_itself_is_not_deleted(self): ], "elements": [], } - mock_data_layer = AsyncMock() - mock_data_layer.get_thread.return_value = thread + mock_data_layer, events = self._tracked_data_layer(thread) with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): await msg.remove_children() - deleted_ids = [ - call.args[0] for call in mock_data_layer.delete_step.call_args_list + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "child_1"), ] - assert "msg_1" not in deleted_ids - mock_data_layer.delete_feedback.assert_not_called() - mock_data_layer.delete_element.assert_not_called()