Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4939e5d
refactor(service): remove uses of tokio::spawn
AadamZ5 Feb 28, 2026
3b79366
refactor(operation-processor): remove uses of tokio::spawn
AadamZ5 Feb 28, 2026
8d6655b
refactor(progress): remove need for spawning on drop
AadamZ5 Mar 1, 2026
dc4204c
refactor(child-process): wip experiment with new child process transport
AadamZ5 Mar 1, 2026
1db16c8
refactor(child-process): implement tokio child process and use in test
AadamZ5 Mar 1, 2026
d0bd6ca
refactor(child-process): continue to build command abstraction
AadamZ5 Mar 1, 2026
02b53cd
refactor(child-process): add env to command, move builder to separate…
AadamZ5 Mar 1, 2026
86977b6
refactor(example): fix example compilation
AadamZ5 Mar 1, 2026
b746384
refactor(child-process): rename module back to "child-process"
AadamZ5 Mar 1, 2026
0adab25
refactor(test): re-introduce tests for child process dropping
AadamZ5 Mar 1, 2026
007dd92
refactor: revert some unnecessary module visibility changes
AadamZ5 Mar 1, 2026
319a77a
refactor(tests): update calls to serve in all unit tests
AadamZ5 Mar 3, 2026
aee0d4f
refactor(test,examples): update remaining calls to `serve(...)`
AadamZ5 Mar 3, 2026
5482545
refactor(docs): update docs to new call convention for serving
AadamZ5 Mar 3, 2026
a35067a
refactor(http): change to futures unordered
AadamZ5 Mar 4, 2026
a0724d3
refactor(worker): remove spawn from worker
AadamZ5 Mar 4, 2026
d85b6df
refactor(http): explicitly bubble work task up to be spawned
AadamZ5 Mar 5, 2026
7fa7a3f
fix(docs): fix doc examples so they compile
AadamZ5 Mar 5, 2026
d526a9b
refactor(http): use different timeout API for futures
AadamZ5 Mar 5, 2026
595cb6f
Merge branch 'main' into dev/remove-tokio-rt
AadamZ5 Mar 5, 2026
4732493
chore(examples): cleanup examples so they compile
AadamZ5 Mar 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async-trait = "0.1.89"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "2"
tokio = { version = "1", features = ["sync", "macros", "rt", "time"] }
tokio = { version = "1", features = ["sync", "macros", "time"] }
futures = "0.3"
tracing = { version = "0.1" }
tokio-util = { version = "0.7" }
Expand Down Expand Up @@ -56,7 +56,7 @@ process-wrap = { version = "9.0", features = ["tokio1"], optional = true }

# for http-server transport
rand = { version = "0.10", optional = true }
tokio-stream = { version = "0.1", optional = true }
tokio-stream = { version = "0.1", optional = true, features = ["sync"] }
uuid = { version = "1", features = ["v4"], optional = true }
http-body = { version = "1", optional = true }
http-body-util = { version = "0.1", optional = true }
Expand All @@ -74,10 +74,17 @@ chrono = { version = "0.4.38", default-features = false, features = [
"oldtime",
] }

[target.'cfg(test)']

[features]
default = ["base64", "macros", "server"]
client = ["dep:tokio-stream"]
server = ["transport-async-rw", "dep:schemars", "dep:pastey"]
server = [
"transport-async-rw",
"dep:schemars",
"dep:pastey",
"dep:tokio-stream",
]
macros = ["dep:rmcp-macros", "dep:pastey"]
elicitation = ["dep:url"]

Expand Down Expand Up @@ -109,14 +116,18 @@ client-side-sse = ["dep:sse-stream", "dep:http"]

# Streamable HTTP client
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"]
transport-streamable-http-client-reqwest = [
"transport-streamable-http-client",
"__reqwest",
]

transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
transport-async-rw = ["tokio/io-util", "tokio-util/codec", "tokio-util/compat"]
transport-io = ["transport-async-rw", "tokio/io-std"]
transport-child-process = [
transport-child-process = ["transport-async-rw", "tokio/process"]
transport-child-process-tokio = [
"transport-async-rw",
"tokio/process",
"dep:process-wrap",
"tokio/rt",
]
transport-streamable-http-server = [
"transport-streamable-http-server-session",
Expand All @@ -135,7 +146,10 @@ schemars = ["dep:schemars"]
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
schemars = { version = "1.1.0", features = ["chrono04"] }
axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] }
axum = { version = "0.8", default-features = false, features = [
"http1",
"tokio",
] }
anyhow = "1.0"
tracing-subscriber = { version = "0.3", features = [
"env-filter",
Expand All @@ -155,6 +169,7 @@ required-features = [
"server",
"client",
"transport-child-process",
"transport-child-process-tokio",
]
path = "tests/test_with_python.rs"

Expand All @@ -164,6 +179,7 @@ required-features = [
"server",
"client",
"transport-child-process",
"transport-child-process-tokio",
"transport-streamable-http-server",
"transport-streamable-http-client",
"__reqwest",
Expand Down Expand Up @@ -207,12 +223,22 @@ path = "tests/test_task.rs"

[[test]]
name = "test_streamable_http_priming"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"reqwest",
]
path = "tests/test_streamable_http_priming.rs"

[[test]]
name = "test_streamable_http_json_response"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"reqwest",
]
path = "tests/test_streamable_http_json_response.rs"


Expand Down Expand Up @@ -249,5 +275,11 @@ path = "tests/test_custom_headers.rs"

[[test]]
name = "test_sse_concurrent_streams"
required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"]
required-features = [
"server",
"client",
"transport-streamable-http-server",
"transport-streamable-http-client",
"reqwest",
]
path = "tests/test_sse_concurrent_streams.rs"
2 changes: 0 additions & 2 deletions crates/rmcp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ pub enum RmcpError {
#[cfg(feature = "server")]
#[error("Server initialization error: {0}")]
ServerInitialize(#[from] crate::service::ServerInitializeError),
#[error("Runtime error: {0}")]
Runtime(#[from] tokio::task::JoinError),
#[error("Transport creation error: {error}")]
// TODO: Maybe we can introduce something like `TryIntoTransport` to auto wrap transport type,
// but it could be an breaking change, so we could do it in the future.
Expand Down
153 changes: 112 additions & 41 deletions crates/rmcp/src/handler/client/progress.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
use std::{collections::HashMap, sync::Arc};

use futures::{Stream, StreamExt};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;

use crate::model::{ProgressNotificationParam, ProgressToken};
type Dispatcher =
Arc<RwLock<HashMap<ProgressToken, tokio::sync::mpsc::Sender<ProgressNotificationParam>>>>;
use crate::{
model::{ProgressNotificationParam, ProgressToken},
util::PinnedStream,
};

/// A dispatcher for progress notifications.
#[derive(Debug, Clone, Default)]
///
/// See [ProgressNotificationParam] and [ProgressToken] for more details on
/// how progress is dispatched to a particular listener.
#[derive(Debug, Clone)]
pub struct ProgressDispatcher {
pub(crate) dispatcher: Dispatcher,
/// A channel of any progress notification. Subscribers will filter
/// on this channel.
pub(crate) any_progress_notification_tx: broadcast::Sender<ProgressNotificationParam>,
pub(crate) unsubscribe_tx: broadcast::Sender<ProgressToken>,
pub(crate) unsubscribe_all_tx: broadcast::Sender<()>,
}

impl ProgressDispatcher {
const CHANNEL_SIZE: usize = 16;
pub fn new() -> Self {
Self::default()
// Note that channel size is per-receiver for broadcast channel. It is up to the receiver to
// keep up with the notifications to avoid missing any (via propper polling)
let (any_progress_notification_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
let (unsubscribe_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
let (unsubscribe_all_tx, _) = broadcast::channel(Self::CHANNEL_SIZE);
Self {
any_progress_notification_tx,
unsubscribe_tx,
unsubscribe_all_tx,
}
}

/// Handle a progress notification by sending it to the appropriate subscriber
pub async fn handle_notification(&self, notification: ProgressNotificationParam) {
let token = &notification.progress_token;
if let Some(sender) = self.dispatcher.read().await.get(token).cloned() {
let send_result = sender.send(notification).await;
if let Err(e) = send_result {
tracing::warn!("Failed to send progress notification: {e}");
// Broadcast the notification to all subscribers. Interested subscribers
// will filter on their end.
// ! Note that this implementaiton is very stateless and simple, we cannot
// ! easily inspect which subscribers are interested in which notifications.
// ! However, the stateless-ness and simplicity is also a plus!
// ! Cleanup becomes much easier. Just drop the `ProgressSubscriber`.
match self.any_progress_notification_tx.send(notification) {
Ok(_) => {}
Err(_) => {
// This error only happens if there are no active receivers of the `broadcast` channel.
// Silent error.
}
}
}
Expand All @@ -35,35 +56,97 @@ impl ProgressDispatcher {
///
/// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token.
pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber {
let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE);
self.dispatcher
.write()
.await
.insert(progress_token.clone(), sender);
let receiver = ReceiverStream::new(receiver);
// First, set up the unsubscribe listeners. This will fuse the notifiaction stream below.
let progress_token_clone = progress_token.clone();
let unsub_this_token_rx = BroadcastStream::new(self.unsubscribe_tx.subscribe()).filter_map(
move |token| {
let progress_token_clone = progress_token_clone.clone();
async move {
match token {
Ok(token) => {
if token == progress_token_clone {
Some(())
} else {
None
}
}
Err(e) => {
// An error here means the broadcast stream did not receive values quick enough and
// and we missed some notification. This implies there are notifications
// we missed, but we cannot assume they were for us :(
tracing::warn!(
"Error receiving unsubscribe notification from broadcast channel: {e}"
);
None
}
}
}
},
);
let unsub_any_token_tx =
BroadcastStream::new(self.unsubscribe_all_tx.subscribe()).map(|_| {
// Any reception of a result here indicates we should unsubscribe,
// regardless of if we received an `Ok(())` or an `Err(_)` (which
// indicates the broadcast receiver lagged behind)
()
});
let unsub_fut = futures::stream::select(unsub_this_token_rx, unsub_any_token_tx)
.boxed()
.into_future(); // If the unsub streams end, this will cause unsubscription from the subscriber below.

// Now setup the notification stream. We will receive all notifications and only forward progress notifications
// for the token we're interested in.
let progress_token_clone = progress_token.clone();
let receiver = BroadcastStream::new(self.any_progress_notification_tx.subscribe())
.filter_map(move |notification| {
let progress_token_clone = progress_token_clone.clone();
async move {
// We need to kneed-out the broadcast receive error type here.
match notification {
Ok(notification) => {
let token = notification.progress_token.clone();
if token == progress_token_clone {
Some(notification)
} else {
None
}
}
Err(e) => {
tracing::warn!(
"Error receiving progress notification from broadcast channel: {e}"
);
None
}
}
}
})
// Fuse this stream so it stops once we receive an unsubscribe notification from the stream
// created above
.take_until(unsub_fut)
.boxed();

ProgressSubscriber {
progress_token,
receiver,
dispatcher: self.dispatcher.clone(),
}
}

/// Unsubscribe from progress notifications for a specific token.
pub async fn unsubscribe(&self, token: &ProgressToken) {
self.dispatcher.write().await.remove(token);
pub fn unsubscribe(&self, token: ProgressToken) {
// The only error defined is if there are no listeners, which is fine. Ignore the result.
let _ = self.unsubscribe_tx.send(token);
}

/// Clear all dispatcher.
pub async fn clear(&self) {
let mut dispatcher = self.dispatcher.write().await;
dispatcher.clear();
pub fn clear(&self) {
// The only error defined is if there are no listeners, which is fine. Ignore the result.
let _ = self.unsubscribe_all_tx.send(());
}
}

pub struct ProgressSubscriber {
pub(crate) progress_token: ProgressToken,
pub(crate) receiver: ReceiverStream<ProgressNotificationParam>,
pub(crate) dispatcher: Dispatcher,
pub(crate) receiver: PinnedStream<'static, ProgressNotificationParam>,
}

impl ProgressSubscriber {
Expand All @@ -86,15 +169,3 @@ impl Stream for ProgressSubscriber {
self.receiver.size_hint()
}
}

impl Drop for ProgressSubscriber {
fn drop(&mut self) {
let token = self.progress_token.clone();
self.receiver.close();
let dispatcher = self.dispatcher.clone();
tokio::spawn(async move {
let mut dispatcher = dispatcher.write_owned().await;
dispatcher.remove(&token);
});
}
}
2 changes: 2 additions & 0 deletions crates/rmcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#![doc = include_str!("../README.md")]

mod error;
mod util;

#[allow(deprecated)]
pub use error::{Error, ErrorData, RmcpError};

Expand Down
Loading