diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 99dd9d15b2..8b83cb0a5d 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -84,6 +84,7 @@ on_message, on_settings_edit, on_settings_update, + on_shared_thread_access_allowed, on_shared_thread_view, on_slack_reaction_added, on_stop, @@ -205,6 +206,7 @@ def acall(self): "on_message", "on_settings_edit", "on_settings_update", + "on_shared_thread_access_allowed", "on_shared_thread_view", "on_slack_reaction_added", "on_stop", diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index 0cdd209d3b..de021f443d 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -551,3 +551,18 @@ def on_shared_thread_view( """ config.code.on_shared_thread_view = wrap_user_function(func) return func + + +def on_shared_thread_access_allowed( + func: Callable[[ThreadDict, Optional[User]], Awaitable[bool]], +) -> Callable[[ThreadDict, Optional[User]], Awaitable[bool]]: + """Hook to add extra permission check for viewing a shared thread. + + Unlike on_shared_thread_view, this callback can only deny access further. + If defined and returns False, the viewer is blocked regardless of other checks. + If undefined or returns True, normal authorization flow proceeds. + + Signature: async (thread: ThreadDict, viewer: Optional[User]) -> bool + """ + config.code.on_shared_thread_access_allowed = wrap_user_function(func) + return func diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index a540752b33..b33950a6fe 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -422,6 +422,9 @@ class CodeSettings(BaseModel): on_shared_thread_view: Optional[ Callable[["ThreadDict", Optional["User"]], Awaitable[bool]] ] = None + on_shared_thread_access_allowed: Optional[ + Callable[["ThreadDict", Optional["User"]], Awaitable[bool]] + ] = None # Auth callbacks password_auth_callback: Optional[ Callable[[str, str], Awaitable[Optional["User"]]] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index cab8fef9e6..cd7aa4ce55 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1024,13 +1024,23 @@ async def get_shared_thread( ) except Exception: user_can_view = False - is_shared = bool(metadata.get("is_shared")) - # Proceed only raise an error if both conditions are False. if (not user_can_view) and (not is_shared): raise HTTPException(status_code=404, detail="Thread not found") + if getattr(config.code, "on_shared_thread_access_allowed", None): + try: + access_allowed = await config.code.on_shared_thread_access_allowed( + thread, current_user + ) + if not access_allowed: + raise HTTPException(status_code=404, detail="Thread not found") + except HTTPException: + raise + except Exception: + raise HTTPException(status_code=404, detail="Thread not found") + metadata.pop("chat_profile", None) metadata.pop("chat_settings", None) metadata.pop("env", None) diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 1441a3d8b2..468550d173 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1152,3 +1152,140 @@ def test_health_check(test_client: TestClient): response = test_client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "ok"} + + +@pytest.mark.parametrize( + ("is_shared", "on_shared_thread_view_result", "allowed"), + [ + (True, True, True), + (True, False, True), + (True, ValueError("error"), True), + (False, True, True), + (False, False, False), + ], +) +def test_get_shared_thread_access( + test_client: TestClient, + test_config: ChainlitConfig, + is_shared: bool, + on_shared_thread_view_result: bool | Exception, + allowed: bool, +): + """Check if shared thread access is allowed based on is_shared and on_shared_thread_view result.""" + import chainlit.data as data_mod + from chainlit.server import app as _app, get_current_user as _get_current_user + + viewer = PersistedUser( + id="viewer1", + createdAt=datetime.datetime.now().isoformat(), + identifier="viewer", + ) + _app.dependency_overrides[_get_current_user] = lambda: viewer + + dl = AsyncMock() + dl.get_thread.return_value = { + "id": "shared-thread-1", + "name": "Shared Thread", + "userIdentifier": "author", + "metadata": {"is_shared": is_shared, "chat_profile": "pro"}, + } + dl.get_thread_author.return_value = "author" + dl.build_debug_url.return_value = "" + + data_mod._data_layer = dl + data_mod._data_layer_initialized = True + + async def deny_cb(thread, user): + if isinstance(on_shared_thread_view_result, Exception): + raise on_shared_thread_view_result + return on_shared_thread_view_result + + test_config.code.on_shared_thread_view = deny_cb + + r = test_client.get("/project/share/shared-thread-1") + + if allowed: + assert r.status_code == 200 + assert r.json()["id"] == "shared-thread-1" + else: + assert r.status_code == 404 + assert r.json() == {"detail": "Thread not found"} + + # Cleanup + del _app.dependency_overrides[_get_current_user] + data_mod._data_layer = None + data_mod._data_layer_initialized = False + test_config.code.on_shared_thread_view = None + + +@pytest.mark.parametrize( + ("is_shared", "access_allowed_result", "allowed"), + [ + (True, None, True), # No callback → falls through to is_shared (=200) + (True, True, True), # Callback allows → proceeds to return thread (=200) + (True, False, False), # Callback denies → 404 + (True, ValueError("err"), False), # Callback raises → 404 + (False, True, False), # is_shared=False blocks + ], +) +def test_get_shared_thread_access_allowed( + test_client: TestClient, + test_config: ChainlitConfig, + is_shared: bool, + access_allowed_result: bool | Exception | None, + allowed: bool, +): + """Check if shared thread access respects on_shared_thread_access_allowed callback.""" + import chainlit.data as data_mod + from chainlit.server import app as _app, get_current_user as _get_current_user + + viewer = PersistedUser( + id="viewer1", + createdAt=datetime.datetime.now().isoformat(), + identifier="viewer", + ) + _app.dependency_overrides[_get_current_user] = lambda: viewer + + dl = AsyncMock() + dl.get_thread.return_value = { + "id": "shared-thread-1", + "name": "Shared Thread", + "userIdentifier": "author", + "metadata": {"is_shared": is_shared, "chat_profile": "pro"}, + } + dl.get_thread_author.return_value = "author" + dl.build_debug_url.return_value = "" + + data_mod._data_layer = dl + data_mod._data_layer_initialized = True + + # Ensure on_shared_thread_view is not set (Tier 1+2 relies on is_shared fallback) + test_config.code.on_shared_thread_view = None + + if access_allowed_result is not None: + + async def access_cb(thread, user): + if isinstance(access_allowed_result, Exception): + raise access_allowed_result + return access_allowed_result + + test_config.code.on_shared_thread_access_allowed = access_cb + else: + # Callback not defined — Tier 3 is skipped + test_config.code.on_shared_thread_access_allowed = None + + r = test_client.get("/project/share/shared-thread-1") + + if allowed: + assert r.status_code == 200 + assert r.json()["id"] == "shared-thread-1" + else: + assert r.status_code == 404 + assert r.json() == {"detail": "Thread not found"} + + # Cleanup + del _app.dependency_overrides[_get_current_user] + data_mod._data_layer = None + data_mod._data_layer_initialized = False + test_config.code.on_shared_thread_view = None + test_config.code.on_shared_thread_access_allowed = None