Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
on_message,
on_settings_edit,
on_settings_update,
on_shared_thread_access_allowed,
on_shared_thread_view,
on_slack_reaction_added,
on_stop,
Expand Down Expand Up @@ -205,6 +206,7 @@ def acall(self):
"on_message",
"on_settings_edit",
"on_settings_update",
"on_shared_thread_access_allowed",
"on_shared_thread_view",
"on_slack_reaction_added",
"on_stop",
Expand Down
15 changes: 15 additions & 0 deletions backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,18 @@ def on_shared_thread_view(
"""
config.code.on_shared_thread_view = wrap_user_function(func)
return func


def on_shared_thread_access_allowed(
func: Callable[[ThreadDict, Optional[User]], Awaitable[bool]],
) -> Callable[[ThreadDict, Optional[User]], Awaitable[bool]]:
"""Hook to add extra permission check for viewing a shared thread.

Unlike on_shared_thread_view, this callback can only deny access further.
If defined and returns False, the viewer is blocked regardless of other checks.
If undefined or returns True, normal authorization flow proceeds.

Signature: async (thread: ThreadDict, viewer: Optional[User]) -> bool
"""
config.code.on_shared_thread_access_allowed = wrap_user_function(func)
return func
3 changes: 3 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ class CodeSettings(BaseModel):
on_shared_thread_view: Optional[
Callable[["ThreadDict", Optional["User"]], Awaitable[bool]]
] = None
on_shared_thread_access_allowed: Optional[
Callable[["ThreadDict", Optional["User"]], Awaitable[bool]]
] = None
# Auth callbacks
password_auth_callback: Optional[
Callable[[str, str], Awaitable[Optional["User"]]]
Expand Down
14 changes: 12 additions & 2 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,13 +1024,23 @@ async def get_shared_thread(
)
except Exception:
user_can_view = False

is_shared = bool(metadata.get("is_shared"))

# Proceed only raise an error if both conditions are False.
if (not user_can_view) and (not is_shared):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0: Authorization bypass: is_shared=True grants access even when on_shared_thread_view callback denies access, undermining the stated security goal

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At backend/chainlit/server.py, line 1029:

<comment>Authorization bypass: is_shared=True grants access even when on_shared_thread_view callback denies access, undermining the stated security goal</comment>

<file context>
@@ -1025,11 +1024,23 @@ async def get_shared_thread(
 
-    # Proceed only raise an error if user_can_view return False or exception
-    if not user_can_view:
+    if (not user_can_view) and (not is_shared):
         raise HTTPException(status_code=404, detail="Thread not found")
 
</file context>

raise HTTPException(status_code=404, detail="Thread not found")
Comment on lines 1025 to 1030

if getattr(config.code, "on_shared_thread_access_allowed", None):
try:
access_allowed = await config.code.on_shared_thread_access_allowed(
thread, current_user
)
if not access_allowed:
raise HTTPException(status_code=404, detail="Thread not found")
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=404, detail="Thread not found")

metadata.pop("chat_profile", None)
metadata.pop("chat_settings", None)
metadata.pop("env", None)
Expand Down
137 changes: 137 additions & 0 deletions backend/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,3 +1152,140 @@ def test_health_check(test_client: TestClient):
response = test_client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}


@pytest.mark.parametrize(
("is_shared", "on_shared_thread_view_result", "allowed"),
[
(True, True, True),
(True, False, True),
(True, ValueError("error"), True),
Comment on lines +1161 to +1162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: Callback-denied and callback-error shared-thread cases are incorrectly expected to succeed, so this test would not catch a regression where on_shared_thread_view is ignored.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At backend/tests/test_server.py, line 1161:

<comment>Callback-denied and callback-error shared-thread cases are incorrectly expected to succeed, so this test would not catch a regression where `on_shared_thread_view` is ignored.</comment>

<file context>
@@ -1154,13 +1154,16 @@ def test_health_check(test_client: TestClient):
+    ("is_shared", "on_shared_thread_view_result", "allowed"),
+    [
+        (True, True, True),
+        (True, False, True),
+        (True, ValueError("error"), True),
+        (False, True, True),
</file context>
Suggested change
(True, False, True),
(True, ValueError("error"), True),
(True, False, False),
(True, ValueError("error"), False),

(False, True, True),
(False, False, False),
],
)
def test_get_shared_thread_access(
test_client: TestClient,
test_config: ChainlitConfig,
is_shared: bool,
on_shared_thread_view_result: bool | Exception,
allowed: bool,
):
"""Check if shared thread access is allowed based on is_shared and on_shared_thread_view result."""
import chainlit.data as data_mod
from chainlit.server import app as _app, get_current_user as _get_current_user

viewer = PersistedUser(
id="viewer1",
createdAt=datetime.datetime.now().isoformat(),
identifier="viewer",
)
_app.dependency_overrides[_get_current_user] = lambda: viewer

dl = AsyncMock()
dl.get_thread.return_value = {
"id": "shared-thread-1",
"name": "Shared Thread",
"userIdentifier": "author",
"metadata": {"is_shared": is_shared, "chat_profile": "pro"},
}
dl.get_thread_author.return_value = "author"
dl.build_debug_url.return_value = ""

data_mod._data_layer = dl
data_mod._data_layer_initialized = True

async def deny_cb(thread, user):
if isinstance(on_shared_thread_view_result, Exception):
raise on_shared_thread_view_result
return on_shared_thread_view_result

test_config.code.on_shared_thread_view = deny_cb

r = test_client.get("/project/share/shared-thread-1")

Comment on lines +1201 to +1206
if allowed:
assert r.status_code == 200
assert r.json()["id"] == "shared-thread-1"
else:
assert r.status_code == 404
assert r.json() == {"detail": "Thread not found"}

# Cleanup
del _app.dependency_overrides[_get_current_user]
data_mod._data_layer = None
data_mod._data_layer_initialized = False
test_config.code.on_shared_thread_view = None
Comment on lines +1183 to +1218


@pytest.mark.parametrize(
("is_shared", "access_allowed_result", "allowed"),
[
(True, None, True), # No callback → falls through to is_shared (=200)
(True, True, True), # Callback allows → proceeds to return thread (=200)
(True, False, False), # Callback denies → 404
(True, ValueError("err"), False), # Callback raises → 404
(False, True, False), # is_shared=False blocks
],
)
def test_get_shared_thread_access_allowed(
test_client: TestClient,
test_config: ChainlitConfig,
is_shared: bool,
access_allowed_result: bool | Exception | None,
allowed: bool,
):
"""Check if shared thread access respects on_shared_thread_access_allowed callback."""
import chainlit.data as data_mod
from chainlit.server import app as _app, get_current_user as _get_current_user

viewer = PersistedUser(
id="viewer1",
createdAt=datetime.datetime.now().isoformat(),
identifier="viewer",
)
_app.dependency_overrides[_get_current_user] = lambda: viewer

dl = AsyncMock()
dl.get_thread.return_value = {
"id": "shared-thread-1",
"name": "Shared Thread",
"userIdentifier": "author",
"metadata": {"is_shared": is_shared, "chat_profile": "pro"},
}
dl.get_thread_author.return_value = "author"
dl.build_debug_url.return_value = ""

data_mod._data_layer = dl
data_mod._data_layer_initialized = True

# Ensure on_shared_thread_view is not set (Tier 1+2 relies on is_shared fallback)
test_config.code.on_shared_thread_view = None

if access_allowed_result is not None:

async def access_cb(thread, user):
if isinstance(access_allowed_result, Exception):
raise access_allowed_result
return access_allowed_result

test_config.code.on_shared_thread_access_allowed = access_cb
else:
# Callback not defined — Tier 3 is skipped
test_config.code.on_shared_thread_access_allowed = None

r = test_client.get("/project/share/shared-thread-1")

if allowed:
assert r.status_code == 200
assert r.json()["id"] == "shared-thread-1"
else:
assert r.status_code == 404
assert r.json() == {"detail": "Thread not found"}

# Cleanup
del _app.dependency_overrides[_get_current_user]
data_mod._data_layer = None
data_mod._data_layer_initialized = False
test_config.code.on_shared_thread_view = None
test_config.code.on_shared_thread_access_allowed = None