Skip to content

Commit 6122a3a

Browse files
committed
Exposing conn_id to hooks to recognize user sessions
1 parent c9f6b13 commit 6122a3a

10 files changed

Lines changed: 56 additions & 22 deletions

rust/loro-websocket-server/src/lib.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,28 @@ type LoadFuture<DocCtx> =
104104
type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
105105
type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
106106
type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;
107+
108+
/// Arguments provided to `authenticate`.
109+
pub struct AuthArgs {
110+
pub room: String,
111+
pub crdt: CrdtType,
112+
pub auth: Vec<u8>,
113+
pub conn_id: u64,
114+
}
115+
107116
type AuthFuture =
108117
Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
109-
type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
118+
type AuthFn = Arc<dyn Fn(AuthArgs) -> AuthFuture + Send + Sync>;
119+
120+
/// Arguments provided to `handshake_auth`.
121+
pub struct HandshakeAuthArgs<'a> {
122+
pub workspace: &'a str,
123+
pub token: Option<&'a str>,
124+
pub request: &'a tungstenite::handshake::server::Request,
125+
pub conn_id: u64,
126+
}
110127

111-
type HandshakeAuthFn = dyn Fn(&str, Option<&str>, &tungstenite::handshake::server::Request) -> bool + Send + Sync;
128+
type HandshakeAuthFn = dyn Fn(HandshakeAuthArgs) -> bool + Send + Sync;
112129

113130
#[derive(Clone)]
114131
pub struct ServerConfig<DocCtx = ()> {
@@ -123,6 +140,7 @@ pub struct ServerConfig<DocCtx = ()> {
123140
/// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing)
124141
/// - `token`: `token` query parameter if present
125142
/// - `request`: the full HTTP request (headers, uri, etc)
143+
/// - `conn_id`: the connection id
126144
///
127145
/// Return true to accept, false to reject with 401.
128146
pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
@@ -885,12 +903,17 @@ async fn handle_conn<DocCtx>(
885903
where
886904
DocCtx: Clone + Send + Sync + 'static,
887905
{
906+
907+
// Generate a connection id
908+
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
909+
888910
// Capture config outside of non-async closure
889911
let handshake_auth = registry.config.handshake_auth.clone();
890912
let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
891913
Arc::new(std::sync::Mutex::new(None));
892914
let workspace_holder_c = workspace_holder.clone();
893915

916+
894917
let ws = accept_hdr_async(
895918
stream,
896919
move |req: &tungstenite::handshake::server::Request,
@@ -926,7 +949,12 @@ where
926949
None
927950
});
928951

929-
let allowed = (check)(workspace_id, token, req);
952+
let allowed = (check)(HandshakeAuthArgs {
953+
workspace: workspace_id,
954+
token,
955+
request: req,
956+
conn_id,
957+
});
930958
if !allowed {
931959
warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
932960
// Build a 401 Unauthorized response
@@ -972,7 +1000,6 @@ where
9721000
}
9731001
});
9741002

975-
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
9761003
let mut joined_rooms: HashSet<RoomKey> = HashSet::new();
9771004

9781005
while let Some(msg) = stream.next().await {
@@ -1002,7 +1029,14 @@ where
10021029
let mut permission = h.config.default_permission;
10031030
if let Some(auth_fn) = &h.config.authenticate {
10041031
let room_str = room.room.clone();
1005-
match (auth_fn)(room_str, room.crdt, auth.clone()).await {
1032+
match (auth_fn)(AuthArgs {
1033+
room: room_str,
1034+
crdt: room.crdt,
1035+
auth: auth.clone(),
1036+
conn_id,
1037+
})
1038+
.await
1039+
{
10061040
Ok(Some(p)) => {
10071041
permission = p;
10081042
}

rust/loro-websocket-server/tests/e2e.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async fn e2e_sync_two_clients_docupdate_roundtrip() {
1414
let addr = listener.local_addr().unwrap();
1515
let server_task = tokio::spawn(async move {
1616
let cfg: Cfg = server::ServerConfig {
17-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
17+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
1818
..Default::default()
1919
};
2020
server::serve_incoming_with_config(listener, cfg)
@@ -65,7 +65,7 @@ async fn workspaces_are_isolated() {
6565
let addr = listener.local_addr().unwrap();
6666
let server_task = tokio::spawn(async move {
6767
let cfg: Cfg = server::ServerConfig {
68-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
68+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
6969
..Default::default()
7070
};
7171
server::serve_incoming_with_config(listener, cfg)
@@ -104,7 +104,7 @@ async fn e2e_sync_two_clients_loro_adaptor_roundtrip() {
104104
let addr = listener.local_addr().unwrap();
105105
let server_task = tokio::spawn(async move {
106106
let cfg: Cfg = server::ServerConfig {
107-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
107+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
108108
..Default::default()
109109
};
110110
server::serve_incoming_with_config(listener, cfg)
@@ -154,7 +154,7 @@ async fn e2e_sync_two_clients_elo_adaptor_roundtrip() {
154154
let addr = listener.local_addr().unwrap();
155155
let server_task = tokio::spawn(async move {
156156
let cfg: Cfg = server::ServerConfig {
157-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
157+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
158158
..Default::default()
159159
};
160160
server::serve_incoming_with_config(listener, cfg)

rust/loro-websocket-server/tests/elo_accept_broadcast.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async fn elo_accepts_join_and_broadcasts_updates() {
1111
let addr = listener.local_addr().unwrap();
1212
let server_task = tokio::spawn(async move {
1313
let cfg: server::ServerConfig<()> = server::ServerConfig {
14-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
14+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
1515
..Default::default()
1616
};
1717
server::serve_incoming_with_config(listener, cfg)

rust/loro-websocket-server/tests/elo_fragment_reassembly.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async fn elo_fragment_reassembly_broadcasts_original_frames() {
1919
let addr = listener.local_addr().unwrap();
2020
let server_task = tokio::spawn(async move {
2121
let cfg: server::ServerConfig<()> = server::ServerConfig {
22-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
22+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
2323
..Default::default()
2424
};
2525
server::serve_incoming_with_config(listener, cfg)

rust/loro-websocket-server/tests/handshake_auth.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ async fn handshake_rejects_invalid_token_with_401() {
99
let addr = listener.local_addr().unwrap();
1010
let server_task = tokio::spawn(async move {
1111
let cfg: server::ServerConfig<()> = server::ServerConfig {
12-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
12+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
1313
..Default::default()
1414
};
1515
server::serve_incoming_with_config(listener, cfg)

rust/loro-websocket-server/tests/handshake_cookies.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ async fn handshake_auth_can_read_cookies() {
1010
let addr = listener.local_addr().unwrap();
1111
let server_task = tokio::spawn(async move {
1212
let cfg: server::ServerConfig<()> = server::ServerConfig {
13-
handshake_auth: Some(Arc::new(|_ws, _token, req| {
14-
if let Some(header) = req.headers().get("Cookie") {
13+
handshake_auth: Some(Arc::new(|args| {
14+
if let Some(header) = args.request.headers().get("Cookie") {
1515
if let Ok(s) = header.to_str() {
1616
for cookie in cookie::Cookie::split_parse(s) {
1717
if let Ok(c) = cookie {

rust/loro-websocket-server/tests/join_denied.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ async fn join_denied_returns_error() {
1111

1212
// Server with auth that always denies
1313
let cfg: server::ServerConfig<()> = server::ServerConfig {
14-
authenticate: Some(Arc::new(|_room, _crdt, _auth| Box::pin(async { Ok(None) }))),
14+
authenticate: Some(Arc::new(|_args| Box::pin(async { Ok(None) }))),
1515
default_permission: Permission::Write,
16-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
16+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
1717
..Default::default()
1818
};
1919
let server_task = tokio::spawn(async move {

rust/loro-websocket-server/tests/join_snapshot_load.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async fn join_sends_snapshot_from_loader() {
2424
})
2525
})
2626
})),
27-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
27+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
2828
..Default::default()
2929
};
3030
let server_task = tokio::spawn(async move {

rust/loro-websocket-server/tests/readonly_receive.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@ async fn readonly_receives_updates_writer_sends() {
1212
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1313
let addr = listener.local_addr().unwrap();
1414
let cfg: server::ServerConfig<()> = server::ServerConfig {
15-
authenticate: Some(Arc::new(|_room, _crdt, auth| {
15+
authenticate: Some(Arc::new(|args| {
1616
Box::pin(async move {
17-
if auth == b"writer" {
17+
if args.auth == b"writer" {
1818
Ok(Some(Permission::Write))
19-
} else if auth == b"reader" {
19+
} else if args.auth == b"reader" {
2020
Ok(Some(Permission::Read))
2121
} else {
2222
Ok(None)
2323
}
2424
})
2525
})),
2626
default_permission: Permission::Write,
27-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
27+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
2828
..Default::default()
2929
};
3030
let server_task = tokio::spawn(async move {

rust/loro-websocket-server/tests/reject_update_without_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ async fn reject_update_without_join() {
99
let addr = listener.local_addr().unwrap();
1010
let server_task = tokio::spawn(async move {
1111
let cfg: server::ServerConfig<()> = server::ServerConfig {
12-
handshake_auth: Some(Arc::new(|_ws, token, _req| token == Some("secret"))),
12+
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
1313
..Default::default()
1414
};
1515
server::serve_incoming_with_config(listener, cfg)

0 commit comments

Comments
 (0)