From 123b1d896539b1c687666cff025376dbea434acd Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Thu, 2 Apr 2026 14:07:42 +0200 Subject: [PATCH] Remove redundant Box::pin calls from async code Yesterday, Clippy complained a lot about large futures. I couldn't figure out what the root cause was and just added them everywhere clippy complained. However, when adding to Zed I was getting stack overflows and was able to trace it to the Builder::into_connection_and_future where the futures with the connection were actually the root cause. Adding Box::pin there means we can remove them from all of the tests/examples/etc, which makes things look nice and beautiful again! --- Cargo.toml | 2 +- .../src/conductor/mcp_bridge/actor.rs | 6 +- .../src/lib.rs | 18 +- .../tests/arrow_proxy_eliza.rs | 22 +- .../tests/empty_conductor_eliza.rs | 22 +- .../tests/initialization_sequence.rs | 81 +- .../tests/mcp-integration.rs | 138 ++-- .../tests/mcp_server_handler_chain.rs | 22 +- .../tests/nested_arrow_proxy.rs | 26 +- .../tests/nested_conductor.rs | 44 +- .../tests/scoped_mcp_server.rs | 8 +- .../tests/standalone_mcp_server.rs | 20 +- .../tests/test_mcp_tool_output_types.rs | 8 +- .../tests/test_session_id_in_mcp_tools.rs | 8 +- .../tests/test_tool_enable_disable.rs | 24 +- .../tests/test_tool_fn.rs | 4 +- .../tests/trace_client_mcp_server.rs | 92 ++- .../tests/trace_generation.rs | 26 +- .../tests/trace_mcp_tool_call.rs | 99 ++- .../tests/trace_snapshot.rs | 24 +- .../examples/simple_agent.rs | 6 +- .../examples/yolo_one_shot_client.rs | 6 +- src/agent-client-protocol-core/src/jsonrpc.rs | 3 +- .../src/mcp_server/server.rs | 6 +- .../tests/jsonrpc_advanced.rs | 337 ++++----- .../tests/jsonrpc_connection_builder.rs | 694 +++++++++--------- .../tests/jsonrpc_edge_cases.rs | 174 ++--- .../tests/jsonrpc_error_handling.rs | 292 ++++---- .../tests/jsonrpc_hello.rs | 450 ++++++------ .../examples/with_mcp_server.rs | 6 +- .../examples/arrow_proxy.rs | 5 +- .../src/bin/testy.rs | 4 +- .../tests/debug_logging.rs | 6 +- src/yopo/src/main.rs | 10 +- 34 files changed, 1338 insertions(+), 1355 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7718d3d..c7958e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,7 @@ async-trait = "0.1" boxfnonce = "0.1.1" chrono = "0.4" derive_more = { version = "2", features = ["from"] } -futures = "0.3.31" +futures = "0.3.32" futures-concurrency = "7.6.3" fxhash = "0.2.1" jsonrpcmsg = "0.1.2" diff --git a/src/agent-client-protocol-conductor/src/conductor/mcp_bridge/actor.rs b/src/agent-client-protocol-conductor/src/conductor/mcp_bridge/actor.rs index b1a4c71..2c15ad9 100644 --- a/src/agent-client-protocol-conductor/src/conductor/mcp_bridge/actor.rs +++ b/src/agent-client-protocol-conductor/src/conductor/mcp_bridge/actor.rs @@ -43,7 +43,7 @@ impl McpBridgeConnectionActor { to_mcp_client_rx, } = self; - let client = mcp::Client + let result = mcp::Client .builder() .name(format!("mpc-client-to-conductor({connection_id})")) // When we receive a message from the MCP client, forward it to the conductor @@ -70,8 +70,8 @@ impl McpBridgeConnectionActor { mcp_connection_to_client.send_proxied_message(message)?; } Ok(()) - }); - let result = Box::pin(client).await; + }) + .await; conductor_tx .send(ConductorMessage::McpConnectionDisconnected { diff --git a/src/agent-client-protocol-conductor/src/lib.rs b/src/agent-client-protocol-conductor/src/lib.rs index 5c8f0c5..f1ecba0 100644 --- a/src/agent-client-protocol-conductor/src/lib.rs +++ b/src/agent-client-protocol-conductor/src/lib.rs @@ -319,12 +319,10 @@ impl ConductorArgs { (None, false) => (None, None), }; - Box::pin( - self.run(debug_logger.as_ref(), trace_writer) - .instrument(tracing::info_span!("conductor", pid = %pid, cwd = %cwd)), - ) - .await - .map_err(|err| anyhow::anyhow!("{err}")) + self.run(debug_logger.as_ref(), trace_writer) + .instrument(tracing::info_span!("conductor", pid = %pid, cwd = %cwd)) + .await + .map_err(|err| anyhow::anyhow!("{err}")) } async fn run( @@ -334,23 +332,23 @@ impl ConductorArgs { ) -> Result<(), agent_client_protocol_core::Error> { match self.command { ConductorCommand::Agent { name, components } => { - Box::pin(initialize_conductor( + initialize_conductor( debug_logger, trace_writer, name, components, ConductorImpl::new_agent, - )) + ) .await } ConductorCommand::Proxy { name, proxies } => { - Box::pin(initialize_conductor( + initialize_conductor( debug_logger, trace_writer, name, proxies, ConductorImpl::new_proxy, - )) + ) .await } ConductorCommand::Mcp { port } => mcp_bridge::run_mcp_bridge(port).await, diff --git a/src/agent-client-protocol-conductor/tests/arrow_proxy_eliza.rs b/src/agent-client-protocol-conductor/tests/arrow_proxy_eliza.rs index d2231a4..5f6328d 100644 --- a/src/agent-client-protocol-conductor/tests/arrow_proxy_eliza.rs +++ b/src/agent-client-protocol-conductor/tests/arrow_proxy_eliza.rs @@ -29,29 +29,27 @@ async fn test_conductor_with_arrow_proxy_and_test_agent() // Spawn the conductor let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(test_agent).proxy(arrow_proxy_agent), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(test_agent).proxy(arrow_proxy_agent), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Wait for editor to complete and get the result let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; tracing::debug!(?result, "Received response from arrow proxy chain"); diff --git a/src/agent-client-protocol-conductor/tests/empty_conductor_eliza.rs b/src/agent-client-protocol-conductor/tests/empty_conductor_eliza.rs index a1f500e..495faf4 100644 --- a/src/agent-client-protocol-conductor/tests/empty_conductor_eliza.rs +++ b/src/agent-client-protocol-conductor/tests/empty_conductor_eliza.rs @@ -54,29 +54,27 @@ async fn test_conductor_with_empty_conductor_and_test_agent() // Spawn the conductor let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "outer-conductor".to_string(), - ProxiesAndAgent::new(Testy::new()).proxy(MockEmptyConductor), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "outer-conductor".to_string(), + ProxiesAndAgent::new(Testy::new()).proxy(MockEmptyConductor), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Wait for editor to complete and get the result let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; tracing::debug!(?result, "Received response from empty conductor chain"); diff --git a/src/agent-client-protocol-conductor/tests/initialization_sequence.rs b/src/agent-client-protocol-conductor/tests/initialization_sequence.rs index d7c1126..004c7ad 100644 --- a/src/agent-client-protocol-conductor/tests/initialization_sequence.rs +++ b/src/agent-client-protocol-conductor/tests/initialization_sequence.rs @@ -134,17 +134,15 @@ async fn run_test_with_components( .builder() .name("editor-to-connector") .with_spawned(|_cx| async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(Testy::new()).proxies(proxies), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_out.compat_write(), - conductor_in.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(Testy::new()).proxies(proxies), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_out.compat_write(), + conductor_in.compat(), + )) .await }) .connect_with(transport, editor_task) @@ -156,22 +154,19 @@ async fn test_single_component_gets_initialize_request() -> Result<(), agent_client_protocol_core::Error> { // Single component (agent) should receive InitializeRequest - we use ElizaAgent // which properly handles InitializeRequest - Box::pin(run_test_with_components( - vec![], - async |connection_to_editor| { - let init_response = recv( - connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)), - ) - .await; - - assert!( - init_response.is_ok(), - "Initialize should succeed: {init_response:?}" - ); - - Ok::<(), agent_client_protocol_core::Error>(()) - }, - )) + run_test_with_components(vec![], async |connection_to_editor| { + let init_response = recv( + connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)), + ) + .await; + + assert!( + init_response.is_ok(), + "Initialize should succeed: {init_response:?}" + ); + + Ok::<(), agent_client_protocol_core::Error>(()) + }) .await?; Ok(()) @@ -184,7 +179,7 @@ async fn test_two_components_proxy_gets_initialize_proxy() // Second component (agent, ElizaAgent) gets InitializeRequest let component1 = InitConfig::new(); - Box::pin(run_test_with_components( + run_test_with_components( vec![InitComponent::new(&component1)], async |connection_to_editor| { let init_response = recv( @@ -199,7 +194,7 @@ async fn test_two_components_proxy_gets_initialize_proxy() Ok::<(), agent_client_protocol_core::Error>(()) }, - )) + ) .await?; // First component (proxy) should receive InitializeProxyRequest @@ -222,7 +217,7 @@ async fn test_three_components_all_proxies_get_initialize_proxy() let component1 = InitConfig::new(); let component2 = InitConfig::new(); - Box::pin(run_test_with_components( + run_test_with_components( vec![ InitComponent::new(&component1), InitComponent::new(&component2), @@ -240,7 +235,7 @@ async fn test_three_components_all_proxies_get_initialize_proxy() Ok::<(), agent_client_protocol_core::Error>(()) }, - )) + ) .await?; // First two components (proxies) should receive InitializeProxyRequest @@ -307,17 +302,15 @@ async fn run_bad_proxy_test( .builder() .name("editor-to-connector") .with_spawned(|_cx| async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(agent).proxies(proxies), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_out.compat_write(), - conductor_in.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(agent).proxies(proxies), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_out.compat_write(), + conductor_in.compat(), + )) .await }) .connect_with(transport, editor_task) @@ -329,7 +322,7 @@ async fn test_conductor_rejects_initialize_proxy_forwarded_to_agent() -> Result<(), agent_client_protocol_core::Error> { // BadProxy incorrectly forwards InitializeProxyRequest to the agent. // The conductor should reject this with an error. - let result = Box::pin(run_bad_proxy_test( + let result = run_bad_proxy_test( vec![DynConnectTo::new(BadProxy)], DynConnectTo::new(Testy::new()), async |connection_to_editor| { @@ -347,7 +340,7 @@ async fn test_conductor_rejects_initialize_proxy_forwarded_to_agent() Ok::<(), agent_client_protocol_core::Error>(()) }, - )) + ) .await; match result { @@ -368,7 +361,7 @@ async fn test_conductor_rejects_initialize_proxy_forwarded_to_proxy() -> Result<(), agent_client_protocol_core::Error> { // BadProxy incorrectly forwards InitializeProxyRequest to another proxy. // The conductor should reject this with an error. - let result = Box::pin(run_bad_proxy_test( + let result = run_bad_proxy_test( vec![ DynConnectTo::new(BadProxy), DynConnectTo::new(InitComponent::new(&InitConfig::new())), // This proxy will receive the bad request @@ -390,7 +383,7 @@ async fn test_conductor_rejects_initialize_proxy_forwarded_to_proxy() Ok::<(), agent_client_protocol_core::Error>(()) }, - )) + ) .await; // The error might bubble up through run_test_with_components instead diff --git a/src/agent-client-protocol-conductor/tests/mcp-integration.rs b/src/agent-client-protocol-conductor/tests/mcp-integration.rs index 425ba29..2ab5407 100644 --- a/src/agent-client-protocol-conductor/tests/mcp-integration.rs +++ b/src/agent-client-protocol-conductor/tests/mcp-integration.rs @@ -64,15 +64,12 @@ async fn run_test_with_mode( .builder() .name("editor-to-connector") .with_spawned(|_cx| async move { - Box::pin( - ConductorImpl::new_agent("conductor".to_string(), components, mode).run( - agent_client_protocol_core::ByteStreams::new( - conductor_out.compat_write(), - conductor_in.compat(), - ), - ), - ) - .await + ConductorImpl::new_agent("conductor".to_string(), components, mode) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_out.compat_write(), + conductor_in.compat(), + )) + .await }) .connect_with(transport, editor_task) .await @@ -81,7 +78,7 @@ async fn run_test_with_mode( /// Test that proxy-provided MCP tools work with stdio bridge mode #[tokio::test] async fn test_proxy_provides_mcp_tools_stdio() -> Result<(), agent_client_protocol_core::Error> { - Box::pin(run_test_with_mode( + run_test_with_mode( McpBridgeMode::Stdio { conductor_command: conductor_command(), }, @@ -116,7 +113,7 @@ async fn test_proxy_provides_mcp_tools_stdio() -> Result<(), agent_client_protoc Ok(()) }, - )) + ) .await?; Ok(()) @@ -125,7 +122,7 @@ async fn test_proxy_provides_mcp_tools_stdio() -> Result<(), agent_client_protoc /// Test that proxy-provided MCP tools work with HTTP bridge mode #[tokio::test] async fn test_proxy_provides_mcp_tools_http() -> Result<(), agent_client_protocol_core::Error> { - Box::pin(run_test_with_mode( + run_test_with_mode( McpBridgeMode::Http, ProxiesAndAgent::new(Testy::new()).proxy(mcp_integration::proxy::ProxyComponent), async |connection_to_editor| { @@ -158,7 +155,7 @@ async fn test_proxy_provides_mcp_tools_http() -> Result<(), agent_client_protoco Ok(()) }, - )) + ) .await?; Ok(()) @@ -183,29 +180,26 @@ async fn test_agent_handles_prompt() -> Result<(), agent_client_protocol_core::E // Spawn the conductor in a background task let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "mcp-integration-conductor".to_string(), - ProxiesAndAgent::new(Testy::new()).proxy(mcp_integration::proxy::ProxyComponent), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "mcp-integration-conductor".to_string(), + ProxiesAndAgent::new(Testy::new()).proxy(mcp_integration::proxy::ProxyComponent), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Run the client - let result = Box::pin( - agent_client_protocol_core::Client - .builder() - .name("editor-to-connector") - .on_receive_notification( - { - let mut log_tx = log_tx.clone(); - async move |notification: SessionNotification, + let result = agent_client_protocol_core::Client + .builder() + .name("editor-to-connector") + .on_receive_notification( + { + let mut log_tx = log_tx.clone(); + async move |notification: SessionNotification, _cx: agent_client_protocol_core::ConnectionTo| { // Log the notification in debug format log_tx @@ -213,54 +207,52 @@ async fn test_agent_handles_prompt() -> Result<(), agent_client_protocol_core::E .await .map_err(|_| agent_client_protocol_core::Error::internal_error()) } - }, - agent_client_protocol_core::on_receive_notification!(), - ) - .connect_with( - agent_client_protocol_core::ByteStreams::new( - client_write.compat_write(), - client_read.compat(), - ), - async |connection_to_editor| { - // Initialize - recv( - connection_to_editor - .send_request(InitializeRequest::new(ProtocolVersion::LATEST)), - ) - .await?; - - // Create session - let session = recv( - connection_to_editor - .send_request(NewSessionRequest::new(std::path::PathBuf::from("/"))), - ) - .await?; - - tracing::debug!(session_id = %session.session_id.0, "Session created"); - - // Send a prompt to call the echo tool - let prompt_response = - recv(connection_to_editor.send_request(PromptRequest::new( - session.session_id.clone(), - vec![ContentBlock::Text(TextContent::new(TestyCommand::CallTool { + }, + agent_client_protocol_core::on_receive_notification!(), + ) + .connect_with( + agent_client_protocol_core::ByteStreams::new( + client_write.compat_write(), + client_read.compat(), + ), + async |connection_to_editor| { + // Initialize + recv( + connection_to_editor + .send_request(InitializeRequest::new(ProtocolVersion::LATEST)), + ) + .await?; + + // Create session + let session = recv( + connection_to_editor + .send_request(NewSessionRequest::new(std::path::PathBuf::from("/"))), + ) + .await?; + + tracing::debug!(session_id = %session.session_id.0, "Session created"); + + // Send a prompt to call the echo tool + let prompt_response = recv(connection_to_editor.send_request(PromptRequest::new( + session.session_id.clone(), + vec![ContentBlock::Text(TextContent::new(TestyCommand::CallTool { server: "test".to_string(), tool: "echo".to_string(), params: serde_json::json!({"message": "Hello from the test!"}), }.to_prompt()))], - ))) - .await?; + ))) + .await?; - // Log the response - log_tx - .send(format!("{prompt_response:?}")) - .await - .map_err(|_| agent_client_protocol_core::Error::internal_error())?; + // Log the response + log_tx + .send(format!("{prompt_response:?}")) + .await + .map_err(|_| agent_client_protocol_core::Error::internal_error())?; - Ok(()) - }, - ), - ) - .await; + Ok(()) + }, + ) + .await; conductor_handle.abort(); result?; diff --git a/src/agent-client-protocol-conductor/tests/mcp_server_handler_chain.rs b/src/agent-client-protocol-conductor/tests/mcp_server_handler_chain.rs index 78557bd..3ddd232 100644 --- a/src/agent-client-protocol-conductor/tests/mcp_server_handler_chain.rs +++ b/src/agent-client-protocol-conductor/tests/mcp_server_handler_chain.rs @@ -169,17 +169,15 @@ async fn run_test( .builder() .name("editor-to-conductor") .with_spawned(|_cx| async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(agent).proxies(proxies), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_out.compat_write(), - conductor_in.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(agent).proxies(proxies), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_out.compat_write(), + conductor_in.compat(), + )) .await }) .connect_with(transport, editor_task) @@ -198,7 +196,7 @@ async fn test_new_session_handler_invoked_with_mcp_server() }); let agent = DynConnectTo::::new(SimpleAgent); - Box::pin(run_test(vec![proxy], agent, async |connection_to_editor| { + run_test(vec![proxy], agent, async |connection_to_editor| { // Initialize first let _init_response = recv( connection_to_editor.send_request(InitializeRequest::new(ProtocolVersion::LATEST)), @@ -217,7 +215,7 @@ async fn test_new_session_handler_invoked_with_mcp_server() ); Ok::<(), agent_client_protocol_core::Error>(()) - })) + }) .await?; // THE KEY ASSERTION: verify the handler was actually called diff --git a/src/agent-client-protocol-conductor/tests/nested_arrow_proxy.rs b/src/agent-client-protocol-conductor/tests/nested_arrow_proxy.rs index 8a9395b..3bb91ed 100644 --- a/src/agent-client-protocol-conductor/tests/nested_arrow_proxy.rs +++ b/src/agent-client-protocol-conductor/tests/nested_arrow_proxy.rs @@ -36,31 +36,29 @@ async fn test_conductor_with_two_external_arrow_proxies() // Spawn the conductor with three components let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "test-conductor".to_string(), - ProxiesAndAgent::new(agent) - .proxy(arrow_proxy1) - .proxy(arrow_proxy2), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "test-conductor".to_string(), + ProxiesAndAgent::new(agent) + .proxy(arrow_proxy1) + .proxy(arrow_proxy2), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Wait for editor to complete and get the result let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; expect_test::expect![[r#" diff --git a/src/agent-client-protocol-conductor/tests/nested_conductor.rs b/src/agent-client-protocol-conductor/tests/nested_conductor.rs index 9de033e..1a53620 100644 --- a/src/agent-client-protocol-conductor/tests/nested_conductor.rs +++ b/src/agent-client-protocol-conductor/tests/nested_conductor.rs @@ -91,29 +91,27 @@ async fn test_nested_conductor_with_arrow_proxies() -> Result<(), agent_client_p // Spawn the outer conductor with the inner conductor and eliza let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "outer-conductor".to_string(), - ProxiesAndAgent::new(Testy::new()).proxy(MockInnerConductor::new(2)), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "outer-conductor".to_string(), + ProxiesAndAgent::new(Testy::new()).proxy(MockInnerConductor::new(2)), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Wait for editor to complete and get the result let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; tracing::debug!(?result, "Received response from nested conductor chain"); @@ -162,29 +160,27 @@ async fn test_nested_conductor_with_external_arrow_proxies() // Spawn the outer conductor with the inner conductor and eliza as external processes let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "outer-conductor".to_string(), - ProxiesAndAgent::new(agent).proxy(inner_conductor), - McpBridgeMode::default(), - ) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "outer-conductor".to_string(), + ProxiesAndAgent::new(agent).proxy(inner_conductor), + McpBridgeMode::default(), ) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Wait for editor to complete and get the result let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; tracing::debug!(?result, "Received response from nested conductor chain"); diff --git a/src/agent-client-protocol-conductor/tests/scoped_mcp_server.rs b/src/agent-client-protocol-conductor/tests/scoped_mcp_server.rs index 6be1552..dc2836a 100644 --- a/src/agent-client-protocol-conductor/tests/scoped_mcp_server.rs +++ b/src/agent-client-protocol-conductor/tests/scoped_mcp_server.rs @@ -24,7 +24,7 @@ async fn test_scoped_mcp_server_through_proxy() -> Result<(), agent_client_proto McpBridgeMode::default(), ); - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( conductor, TestyCommand::CallTool { server: "test".to_string(), @@ -32,7 +32,7 @@ async fn test_scoped_mcp_server_through_proxy() -> Result<(), agent_client_proto params: serde_json::json!({"elements": ["Hello", "world"]}), } .to_prompt(), - )) + ) .await?; expect_test::expect![[r#" @@ -50,7 +50,7 @@ async fn test_scoped_mcp_server_through_proxy() -> Result<(), agent_client_proto #[tokio::test] async fn test_scoped_mcp_server_through_session() -> Result<(), agent_client_protocol_core::Error> { // Run the client - Box::pin(agent_client_protocol_core::Client.builder() + agent_client_protocol_core::Client.builder() .connect_with( ConductorImpl::new_agent( "conductor".to_string(), @@ -87,7 +87,7 @@ async fn test_scoped_mcp_server_through_session() -> Result<(), agent_client_pro Ok(()) }, - )) + ) .await?; Ok(()) diff --git a/src/agent-client-protocol-conductor/tests/standalone_mcp_server.rs b/src/agent-client-protocol-conductor/tests/standalone_mcp_server.rs index 42c0712..77a7b90 100644 --- a/src/agent-client-protocol-conductor/tests/standalone_mcp_server.rs +++ b/src/agent-client-protocol-conductor/tests/standalone_mcp_server.rs @@ -76,7 +76,7 @@ async fn test_standalone_server_list_tools() -> Result<(), agent_client_protocol // Wrap client side as ByteStreams (this is what the MCP server will talk to) let client_as_component = ByteStreams::new(client_write.compat_write(), client_read.compat()); - Box::pin(run_until( + run_until( ConnectTo::::connect_to(server, client_as_component), async move { // Create rmcp client on the server side of the duplex (the "other end") @@ -106,7 +106,7 @@ async fn test_standalone_server_list_tools() -> Result<(), agent_client_protocol .map_err(agent_client_protocol_core::util::internal_error)?; Ok(()) }, - )) + ) .await } @@ -119,7 +119,7 @@ async fn test_standalone_server_call_echo_tool() -> Result<(), agent_client_prot let server = create_test_server(); let client_as_component = ByteStreams::new(client_write.compat_write(), client_read.compat()); - Box::pin(run_until( + run_until( ConnectTo::::connect_to(server, client_as_component), async move { let client = MinimalClientHandler @@ -156,7 +156,7 @@ async fn test_standalone_server_call_echo_tool() -> Result<(), agent_client_prot .map_err(agent_client_protocol_core::util::internal_error)?; Ok(()) }, - )) + ) .await } @@ -169,7 +169,7 @@ async fn test_standalone_server_call_add_tool() -> Result<(), agent_client_proto let server = create_test_server(); let client_as_component = ByteStreams::new(client_write.compat_write(), client_read.compat()); - Box::pin(run_until( + run_until( ConnectTo::::connect_to(server, client_as_component), async move { let client = MinimalClientHandler @@ -209,7 +209,7 @@ async fn test_standalone_server_call_add_tool() -> Result<(), agent_client_proto .map_err(agent_client_protocol_core::util::internal_error)?; Ok(()) }, - )) + ) .await } @@ -222,7 +222,7 @@ async fn test_standalone_server_tool_not_found() -> Result<(), agent_client_prot let server = create_test_server(); let client_as_component = ByteStreams::new(client_write.compat_write(), client_read.compat()); - Box::pin(run_until( + run_until( ConnectTo::::connect_to(server, client_as_component), async move { let client = MinimalClientHandler @@ -244,7 +244,7 @@ async fn test_standalone_server_tool_not_found() -> Result<(), agent_client_prot .map_err(agent_client_protocol_core::util::internal_error)?; Ok(()) }, - )) + ) .await } @@ -278,7 +278,7 @@ async fn test_standalone_server_with_disabled_tools() let client_as_component = ByteStreams::new(client_write.compat_write(), client_read.compat()); - Box::pin(run_until( + run_until( ConnectTo::::connect_to(server, client_as_component), async move { let client = MinimalClientHandler @@ -314,6 +314,6 @@ async fn test_standalone_server_with_disabled_tools() .map_err(agent_client_protocol_core::util::internal_error)?; Ok(()) }, - )) + ) .await } diff --git a/src/agent-client-protocol-conductor/tests/test_mcp_tool_output_types.rs b/src/agent-client-protocol-conductor/tests/test_mcp_tool_output_types.rs index 3efd32c..54c3846 100644 --- a/src/agent-client-protocol-conductor/tests/test_mcp_tool_output_types.rs +++ b/src/agent-client-protocol-conductor/tests/test_mcp_tool_output_types.rs @@ -57,7 +57,7 @@ impl + 'static + Send> ConnectTo #[tokio::test] async fn test_tool_returning_string() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_test_proxy()), @@ -69,7 +69,7 @@ async fn test_tool_returning_string() -> Result<(), agent_client_protocol_core:: params: serde_json::json!({}), } .to_prompt(), - )) + ) .await?; // The result should contain "hello world" somewhere @@ -83,7 +83,7 @@ async fn test_tool_returning_string() -> Result<(), agent_client_protocol_core:: #[tokio::test] async fn test_tool_returning_integer() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_test_proxy()), @@ -95,7 +95,7 @@ async fn test_tool_returning_integer() -> Result<(), agent_client_protocol_core: params: serde_json::json!({}), } .to_prompt(), - )) + ) .await?; // The result should contain "42" somewhere diff --git a/src/agent-client-protocol-conductor/tests/test_session_id_in_mcp_tools.rs b/src/agent-client-protocol-conductor/tests/test_session_id_in_mcp_tools.rs index 0a308f4..dc464fb 100644 --- a/src/agent-client-protocol-conductor/tests/test_session_id_in_mcp_tools.rs +++ b/src/agent-client-protocol-conductor/tests/test_session_id_in_mcp_tools.rs @@ -71,7 +71,7 @@ impl + 'static + Send> ConnectTo async fn test_list_tools_from_mcp_server() -> Result<(), agent_client_protocol_core::Error> { use expect_test::expect; - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_echo_proxy()), @@ -81,7 +81,7 @@ async fn test_list_tools_from_mcp_server() -> Result<(), agent_client_protocol_c server: "echo_server".to_string(), } .to_prompt(), - )) + ) .await?; // Check the response using expect_test @@ -95,7 +95,7 @@ async fn test_list_tools_from_mcp_server() -> Result<(), agent_client_protocol_c #[tokio::test] async fn test_session_id_delivered_to_mcp_tools() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_echo_proxy()), @@ -107,7 +107,7 @@ async fn test_session_id_delivered_to_mcp_tools() -> Result<(), agent_client_pro params: serde_json::json!({}), } .to_prompt(), - )) + ) .await?; let pattern = regex::Regex::new(r#""acp_url":\s*String\("acp:[0-9a-f-]+"\)"#).unwrap(); diff --git a/src/agent-client-protocol-conductor/tests/test_tool_enable_disable.rs b/src/agent-client-protocol-conductor/tests/test_tool_enable_disable.rs index 48b5bb4..60d8ac1 100644 --- a/src/agent-client-protocol-conductor/tests/test_tool_enable_disable.rs +++ b/src/agent-client-protocol-conductor/tests/test_tool_enable_disable.rs @@ -109,7 +109,7 @@ impl + 'static + Send> ConnectTo fo #[tokio::test] async fn test_list_tools_excludes_disabled() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_disabled_tool()?), @@ -119,7 +119,7 @@ async fn test_list_tools_excludes_disabled() -> Result<(), agent_client_protocol server: "test_server".to_string(), } .to_prompt(), - )) + ) .await?; // Should contain echo and greet, but NOT secret @@ -135,7 +135,7 @@ async fn test_list_tools_excludes_disabled() -> Result<(), agent_client_protocol #[tokio::test] async fn test_enabled_tool_can_be_called() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_disabled_tool()?), @@ -147,7 +147,7 @@ async fn test_enabled_tool_can_be_called() -> Result<(), agent_client_protocol_c params: serde_json::json!({"message": "hello"}), } .to_prompt(), - )) + ) .await?; assert!( @@ -160,7 +160,7 @@ async fn test_enabled_tool_can_be_called() -> Result<(), agent_client_protocol_c #[tokio::test] async fn test_disabled_tool_returns_not_found() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_disabled_tool()?), @@ -172,7 +172,7 @@ async fn test_disabled_tool_returns_not_found() -> Result<(), agent_client_proto params: serde_json::json!({}), } .to_prompt(), - )) + ) .await?; // Should get an error about tool not found @@ -191,7 +191,7 @@ async fn test_disabled_tool_returns_not_found() -> Result<(), agent_client_proto #[tokio::test] async fn test_allowlist_only_shows_enabled_tools() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_allowlist()?), @@ -201,7 +201,7 @@ async fn test_allowlist_only_shows_enabled_tools() -> Result<(), agent_client_pr server: "allowlist_server".to_string(), } .to_prompt(), - )) + ) .await?; // Should only contain echo @@ -220,7 +220,7 @@ async fn test_allowlist_only_shows_enabled_tools() -> Result<(), agent_client_pr #[tokio::test] async fn test_allowlist_enabled_tool_works() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_allowlist()?), @@ -232,7 +232,7 @@ async fn test_allowlist_enabled_tool_works() -> Result<(), agent_client_protocol params: serde_json::json!({"message": "allowed"}), } .to_prompt(), - )) + ) .await?; assert!( @@ -246,7 +246,7 @@ async fn test_allowlist_enabled_tool_works() -> Result<(), agent_client_protocol #[tokio::test] async fn test_allowlist_non_enabled_tool_returns_not_found() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_proxy_with_allowlist()?), @@ -258,7 +258,7 @@ async fn test_allowlist_non_enabled_tool_returns_not_found() params: serde_json::json!({"name": "World"}), } .to_prompt(), - )) + ) .await?; // greet is registered but not enabled, should error diff --git a/src/agent-client-protocol-conductor/tests/test_tool_fn.rs b/src/agent-client-protocol-conductor/tests/test_tool_fn.rs index 26eff4a..e74fd7e 100644 --- a/src/agent-client-protocol-conductor/tests/test_tool_fn.rs +++ b/src/agent-client-protocol-conductor/tests/test_tool_fn.rs @@ -55,7 +55,7 @@ impl + 'static + Send> ConnectTo #[tokio::test] async fn test_tool_fn_greet() -> Result<(), agent_client_protocol_core::Error> { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( ConductorImpl::new_agent( "test-conductor".to_string(), ProxiesAndAgent::new(Testy::new()).proxy(create_greet_proxy()), @@ -67,7 +67,7 @@ async fn test_tool_fn_greet() -> Result<(), agent_client_protocol_core::Error> { params: serde_json::json!({"name": "World"}), } .to_prompt(), - )) + ) .await?; expect_test::expect![[r#" diff --git a/src/agent-client-protocol-conductor/tests/trace_client_mcp_server.rs b/src/agent-client-protocol-conductor/tests/trace_client_mcp_server.rs index 964e765..77527c1 100644 --- a/src/agent-client-protocol-conductor/tests/trace_client_mcp_server.rs +++ b/src/agent-client-protocol-conductor/tests/trace_client_mcp_server.rs @@ -227,69 +227,65 @@ async fn test_trace_client_mcp_server() -> Result<(), agent_client_protocol_core // Spawn the conductor with ElizaAgent (no proxies - simple setup) let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(Testy::new()), - McpBridgeMode::default(), - ) - .trace_to(trace_tx) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(Testy::new()), + McpBridgeMode::default(), ) + .trace_to(trace_tx) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Run the client with a client-hosted MCP server let test_result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - Box::pin( - agent_client_protocol_core::Client - .builder() - .name("test-client") - .connect_with( - agent_client_protocol_core::ByteStreams::new( - client_write.compat_write(), - client_read.compat(), - ), - async |cx| { - // Initialize - cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) - .block_task() - .await?; + agent_client_protocol_core::Client + .builder() + .name("test-client") + .connect_with( + agent_client_protocol_core::ByteStreams::new( + client_write.compat_write(), + client_read.compat(), + ), + async |cx| { + // Initialize + cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST)) + .block_task() + .await?; - // Stack-local state that the MCP tool will modify - let call_count = Mutex::new(0usize); + // Stack-local state that the MCP tool will modify + let call_count = Mutex::new(0usize); - // Build session with client-hosted MCP server - let result = cx - .build_session(".") - .with_mcp_server(make_echo_mcp_server::(&call_count))? - .block_task() - .run_until(async |mut session| { - // Send prompt that triggers MCP tool call - // The tool call will travel: agent → conductor → client - session.send_prompt(TestyCommand::CallTool { + // Build session with client-hosted MCP server + let result = cx + .build_session(".") + .with_mcp_server(make_echo_mcp_server::(&call_count))? + .block_task() + .run_until(async |mut session| { + // Send prompt that triggers MCP tool call + // The tool call will travel: agent → conductor → client + session.send_prompt(TestyCommand::CallTool { server: "echo-server".to_string(), tool: "echo".to_string(), params: serde_json::json!({"message": "Hello from client test!"}), }.to_prompt())?; - session.read_to_string().await - }) - .await?; + session.read_to_string().await + }) + .await?; - // Verify the tool was called - assert_eq!(*call_count.lock().unwrap(), 1); + // Verify the tool was called + assert_eq!(*call_count.lock().unwrap(), 1); - // Verify the response contains our echo - assert!(result.contains("Client echoes: Hello from client test!")); + // Verify the response contains our echo + assert!(result.contains("Client echoes: Hello from client test!")); - Ok(()) - }, - ), - ) - .await + Ok(()) + }, + ) + .await }) .await .expect("Test timed out"); diff --git a/src/agent-client-protocol-conductor/tests/trace_generation.rs b/src/agent-client-protocol-conductor/tests/trace_generation.rs index 9ba3e32..6b93d22 100644 --- a/src/agent-client-protocol-conductor/tests/trace_generation.rs +++ b/src/agent-client-protocol-conductor/tests/trace_generation.rs @@ -40,31 +40,29 @@ async fn test_trace_generation() -> Result<(), agent_client_protocol_core::Error // Spawn the conductor with tracing enabled let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(eliza_agent).proxy(arrow_proxy_agent), - McpBridgeMode::default(), - ) - .trace_to_path(&trace_path_clone) - .expect("Failed to create trace writer") - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(eliza_agent).proxy(arrow_proxy_agent), + McpBridgeMode::default(), ) + .trace_to_path(&trace_path_clone) + .expect("Failed to create trace writer") + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Run a simple prompt through the conductor let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - let result = Box::pin(yopo::prompt( + let result = yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await?; Ok::(result) diff --git a/src/agent-client-protocol-conductor/tests/trace_mcp_tool_call.rs b/src/agent-client-protocol-conductor/tests/trace_mcp_tool_call.rs index 83d7e14..f2ea4af 100644 --- a/src/agent-client-protocol-conductor/tests/trace_mcp_tool_call.rs +++ b/src/agent-client-protocol-conductor/tests/trace_mcp_tool_call.rs @@ -212,71 +212,66 @@ async fn test_trace_mcp_tool_call() -> Result<(), agent_client_protocol_core::Er // - ProxyComponent that provides the "test" MCP server with echo tool // - Tracing enabled to capture events let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(Testy::new()).proxy(mcp_integration::proxy::ProxyComponent), - McpBridgeMode::default(), - ) - .trace_to(trace_tx) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(Testy::new()).proxy(mcp_integration::proxy::ProxyComponent), + McpBridgeMode::default(), ) + .trace_to(trace_tx) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Run the client interaction let test_result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - Box::pin( - agent_client_protocol_core::Client - .builder() - .name("test-client") - .on_receive_notification( - { - let mut notif_tx = notif_tx; - async move |notification: SessionNotification, _cx| { - notif_tx - .send(notification) - .await - .map_err(|_| agent_client_protocol_core::Error::internal_error()) - } - }, - agent_client_protocol_core::on_receive_notification!(), - ) - .connect_with( - agent_client_protocol_core::ByteStreams::new( - client_write.compat_write(), - client_read.compat(), - ), - async |cx| { - // Initialize - recv(cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST))) - .await?; + agent_client_protocol_core::Client + .builder() + .name("test-client") + .on_receive_notification( + { + let mut notif_tx = notif_tx; + async move |notification: SessionNotification, _cx| { + notif_tx + .send(notification) + .await + .map_err(|_| agent_client_protocol_core::Error::internal_error()) + } + }, + agent_client_protocol_core::on_receive_notification!(), + ) + .connect_with( + agent_client_protocol_core::ByteStreams::new( + client_write.compat_write(), + client_read.compat(), + ), + async |cx| { + // Initialize + recv(cx.send_request(InitializeRequest::new(ProtocolVersion::LATEST))).await?; - // Create session - let session = recv( - cx.send_request(NewSessionRequest::new(std::path::PathBuf::from("/"))), - ) - .await?; + // Create session + let session = recv( + cx.send_request(NewSessionRequest::new(std::path::PathBuf::from("/"))), + ) + .await?; - // Send prompt that triggers MCP tool call - recv(cx.send_request(PromptRequest::new( - session.session_id.clone(), - vec![ContentBlock::Text(TextContent::new(TestyCommand::CallTool { + // Send prompt that triggers MCP tool call + recv(cx.send_request(PromptRequest::new( + session.session_id.clone(), + vec![ContentBlock::Text(TextContent::new(TestyCommand::CallTool { server: "test".to_string(), tool: "echo".to_string(), params: serde_json::json!({"message": "Hello from trace test!"}), }.to_prompt()))], - ))) - .await?; + ))) + .await?; - Ok(()) - }, - ), - ) - .await + Ok(()) + }, + ) + .await }) .await .expect("Test timed out"); diff --git a/src/agent-client-protocol-conductor/tests/trace_snapshot.rs b/src/agent-client-protocol-conductor/tests/trace_snapshot.rs index b57ea6b..5c90706 100644 --- a/src/agent-client-protocol-conductor/tests/trace_snapshot.rs +++ b/src/agent-client-protocol-conductor/tests/trace_snapshot.rs @@ -144,30 +144,28 @@ async fn test_trace_snapshot() -> Result<(), agent_client_protocol_core::Error> // Spawn the conductor with tracing to the channel let conductor_handle = tokio::spawn(async move { - Box::pin( - ConductorImpl::new_agent( - "conductor".to_string(), - ProxiesAndAgent::new(eliza_agent).proxy(arrow_proxy_agent), - McpBridgeMode::default(), - ) - .trace_to(tx) - .run(agent_client_protocol_core::ByteStreams::new( - conductor_write.compat_write(), - conductor_read.compat(), - )), + ConductorImpl::new_agent( + "conductor".to_string(), + ProxiesAndAgent::new(eliza_agent).proxy(arrow_proxy_agent), + McpBridgeMode::default(), ) + .trace_to(tx) + .run(agent_client_protocol_core::ByteStreams::new( + conductor_write.compat_write(), + conductor_read.compat(), + )) .await }); // Run a simple prompt through the conductor let result = tokio::time::timeout(std::time::Duration::from_secs(30), async move { - Box::pin(yopo::prompt( + yopo::prompt( agent_client_protocol_core::ByteStreams::new( editor_write.compat_write(), editor_read.compat(), ), TestyCommand::Greet.to_prompt(), - )) + ) .await }) .await diff --git a/src/agent-client-protocol-core/examples/simple_agent.rs b/src/agent-client-protocol-core/examples/simple_agent.rs index 6dc9e5b..e40ec9c 100644 --- a/src/agent-client-protocol-core/examples/simple_agent.rs +++ b/src/agent-client-protocol-core/examples/simple_agent.rs @@ -6,7 +6,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; #[tokio::main] async fn main() -> Result<(), agent_client_protocol_core::Error> { - let agent = Agent + Agent .builder() .name("my-agent") // for debugging .on_receive_request( @@ -32,6 +32,6 @@ async fn main() -> Result<(), agent_client_protocol_core::Error> { .connect_to(agent_client_protocol_core::ByteStreams::new( tokio::io::stdout().compat_write(), tokio::io::stdin().compat(), - )); - Box::pin(agent).await + )) + .await } diff --git a/src/agent-client-protocol-core/examples/yolo_one_shot_client.rs b/src/agent-client-protocol-core/examples/yolo_one_shot_client.rs index 2acbc5a..85f3759 100644 --- a/src/agent-client-protocol-core/examples/yolo_one_shot_client.rs +++ b/src/agent-client-protocol-core/examples/yolo_one_shot_client.rs @@ -88,7 +88,7 @@ async fn main() -> Result<(), Box> { ); // Run the client - let client = Client + Client .builder() .on_receive_notification( async move |notification: SessionNotification, _cx| { @@ -152,8 +152,8 @@ async fn main() -> Result<(), Box> { eprintln!("Stop reason: {:?}", prompt_response.stop_reason); Ok(()) - }); - Box::pin(client).await?; + }) + .await?; // Kill the child process when done drop(child.kill().await); diff --git a/src/agent-client-protocol-core/src/jsonrpc.rs b/src/agent-client-protocol-core/src/jsonrpc.rs index a6ea537..e6ac431 100644 --- a/src/agent-client-protocol-core/src/jsonrpc.rs +++ b/src/agent-client-protocol-core/src/jsonrpc.rs @@ -1239,7 +1239,8 @@ impl< Ok(()) }; - crate::util::run_until(background, main_fn(connection.clone())).await + crate::util::run_until(Box::pin(background), Box::pin(main_fn(connection.clone()))) + .await } }); diff --git a/src/agent-client-protocol-core/src/mcp_server/server.rs b/src/agent-client-protocol-core/src/mcp_server/server.rs index bdac9f8..5d5fd3f 100644 --- a/src/agent-client-protocol-core/src/mcp_server/server.rs +++ b/src/agent-client-protocol-core/src/mcp_server/server.rs @@ -233,7 +233,7 @@ where connection: connection_to_client.clone(), }); - let client = role::mcp::Client + role::mcp::Client .builder() .on_receive_dispatch( async |message_from_server: Dispatch, _| { @@ -247,8 +247,8 @@ where connection_to_server.send_proxied_message(message_from_client)?; } Ok(()) - }); - Box::pin(client).await + }) + .await }) .connect_to(client) .await diff --git a/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs b/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs index eeee00b..cdec8e9 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_advanced.rs @@ -171,53 +171,54 @@ async fn test_bidirectional_communication() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - // Set up two connections that are symmetric - both can send and receive - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let side_a_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let side_a = UntypedRole.builder().on_receive_request( - async |request: PingRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(PongResponse { - value: request.value + 1, - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let side_b_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - - // Spawn side_a as server - tokio::task::spawn_local(async move { - Box::pin(side_a.connect_to(side_a_transport)).await.ok(); - }); - - // Use side_b as client - let result = UntypedRole - .builder() - .connect_with( - side_b_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - let request = PingRequest { value: 10 }; - let response_future = recv(cx.send_request(request)); - let response: Result = response_future.await; - - assert!(response.is_ok()); - if let Ok(resp) = response { - assert_eq!(resp.value, 11); - } - Ok(()) + local + .run_until(async { + // Set up two connections that are symmetric - both can send and receive + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let side_a_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let side_a = UntypedRole.builder().on_receive_request( + async |request: PingRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(PongResponse { + value: request.value + 1, + }) }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let side_b_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + + // Spawn side_a as server + tokio::task::spawn_local(async move { + side_a.connect_to(side_a_transport).await.ok(); + }); + + // Use side_b as client + let result = UntypedRole + .builder() + .connect_with( + side_b_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let request = PingRequest { value: 10 }; + let response_future = recv(cx.send_request(request)); + let response: Result = response_future.await; + + assert!(response.is_ok()); + if let Ok(resp) = response { + assert_eq!(resp.value, 11); + } + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -230,60 +231,61 @@ async fn test_request_ids() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |request: PingRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(PongResponse { - value: request.value + 1, - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - // Send multiple requests and verify responses match - let req1 = PingRequest { value: 1 }; - let req2 = PingRequest { value: 2 }; - let req3 = PingRequest { value: 3 }; - - let resp1_future = recv(cx.send_request(req1)); - let resp2_future = recv(cx.send_request(req2)); - let resp3_future = recv(cx.send_request(req3)); - - let resp1: Result = resp1_future.await; - let resp2: Result = resp2_future.await; - let resp3: Result = resp3_future.await; - - // Verify each response corresponds to its request - assert_eq!(resp1.unwrap().value, 2); // 1 + 1 - assert_eq!(resp2.unwrap().value, 3); // 2 + 1 - assert_eq!(resp3.unwrap().value, 4); // 3 + 1 - - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: PingRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(PongResponse { + value: request.value + 1, + }) }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + // Send multiple requests and verify responses match + let req1 = PingRequest { value: 1 }; + let req2 = PingRequest { value: 2 }; + let req3 = PingRequest { value: 3 }; + + let resp1_future = recv(cx.send_request(req1)); + let resp2_future = recv(cx.send_request(req2)); + let resp3_future = recv(cx.send_request(req3)); + + let resp1: Result = resp1_future.await; + let resp2: Result = resp2_future.await; + let resp3: Result = resp3_future.await; + + // Verify each response corresponds to its request + assert_eq!(resp1.unwrap().value, 2); // 1 + 1 + assert_eq!(resp2.unwrap().value, 3); // 2 + 1 + assert_eq!(resp3.unwrap().value, 4); // 3 + 1 + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -296,73 +298,74 @@ async fn test_out_of_order_responses() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |request: SlowRequest, - responder: Responder, - _connection: ConnectionTo| { - // Simulate delay - tokio::time::sleep(tokio::time::Duration::from_millis(request.delay_ms)).await; - responder.respond(SlowResponse { id: request.id }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - // Send requests with different delays - // Request 1: 100ms delay - // Request 2: 50ms delay - // Request 3: 10ms delay - // Responses should arrive in order: 3, 2, 1 - - let req1 = SlowRequest { - delay_ms: 100, - id: 1, - }; - let req2 = SlowRequest { - delay_ms: 50, - id: 2, - }; - let req3 = SlowRequest { - delay_ms: 10, - id: 3, - }; - - let resp1_future = recv(cx.send_request(req1)); - let resp2_future = recv(cx.send_request(req2)); - let resp3_future = recv(cx.send_request(req3)); - - // Wait for all responses - let resp1: Result = resp1_future.await; - let resp2: Result = resp2_future.await; - let resp3: Result = resp3_future.await; - - // Verify each future got the correct response despite out-of-order arrival - assert_eq!(resp1.unwrap().id, 1); - assert_eq!(resp2.unwrap().id, 2); - assert_eq!(resp3.unwrap().id, 3); - - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: SlowRequest, + responder: Responder, + _connection: ConnectionTo| { + // Simulate delay + tokio::time::sleep(tokio::time::Duration::from_millis(request.delay_ms)).await; + responder.respond(SlowResponse { id: request.id }) }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + // Send requests with different delays + // Request 1: 100ms delay + // Request 2: 50ms delay + // Request 3: 10ms delay + // Responses should arrive in order: 3, 2, 1 + + let req1 = SlowRequest { + delay_ms: 100, + id: 1, + }; + let req2 = SlowRequest { + delay_ms: 50, + id: 2, + }; + let req3 = SlowRequest { + delay_ms: 10, + id: 3, + }; + + let resp1_future = recv(cx.send_request(req1)); + let resp2_future = recv(cx.send_request(req2)); + let resp3_future = recv(cx.send_request(req3)); + + // Wait for all responses + let resp1: Result = resp1_future.await; + let resp2: Result = resp2_future.await; + let resp3: Result = resp3_future.await; + + // Verify each future got the correct response despite out-of-order arrival + assert_eq!(resp1.unwrap().id, 1); + assert_eq!(resp2.unwrap().id, 2); + assert_eq!(resp3.unwrap().id, 3); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs b/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs index a0fb458..8a093bf 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_connection_builder.rs @@ -151,86 +151,91 @@ async fn test_multiple_handlers_different_methods() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (client_writer, server_reader) = tokio::io::duplex(1024); - let (server_writer, client_reader) = tokio::io::duplex(1024); - - let server_reader = server_reader.compat(); - let server_writer = server_writer.compat_write(); - let client_reader = client_reader.compat(); - let client_writer = client_writer.compat_write(); - - // Chain both handlers - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole - .builder() - .on_receive_request( - async |request: FooRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(FooResponse { - result: format!("foo: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ) - .on_receive_request( - async |request: BarRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(BarResponse { - result: format!("bar: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Test foo request - let foo_response = recv(cx.send_request(FooRequest { - value: "test1".to_string(), - })) - .await - .map_err(|e| -> agent_client_protocol_core::Error { - agent_client_protocol_core::util::internal_error(format!( - "Foo request failed: {e:?}" - )) - })?; - assert_eq!(foo_response.result, "foo: test1"); - - // Test bar request - let bar_response = recv(cx.send_request(BarRequest { - value: "test2".to_string(), - })) - .await - .map_err(|e| -> agent_client_protocol_core::Error { - agent_client_protocol_core::util::internal_error(format!( - "Bar request failed: {e:?}" - )) - })?; - assert_eq!(bar_response.result, "bar: test2"); - - Ok(()) - }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + local + .run_until(async { + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); + + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); + + // Chain both handlers + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async |request: FooRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(FooResponse { + result: format!("foo: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ) + .on_receive_request( + async |request: BarRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(BarResponse { + result: format!("bar: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Test foo request + let foo_response = recv(cx.send_request(FooRequest { + value: "test1".to_string(), + })) + .await + .map_err( + |e| -> agent_client_protocol_core::Error { + agent_client_protocol_core::util::internal_error(format!( + "Foo request failed: {e:?}" + )) + }, + )?; + assert_eq!(foo_response.result, "foo: test1"); + + // Test bar request + let bar_response = recv(cx.send_request(BarRequest { + value: "test2".to_string(), + })) + .await + .map_err( + |e| -> agent_client_protocol_core::Error { + agent_client_protocol_core::util::internal_error(format!( + "Bar request failed: {e:?}" + )) + }, + )?; + assert_eq!(bar_response.result, "bar: test2"); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -278,86 +283,87 @@ async fn test_handler_priority_ordering() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let handled = Arc::new(Mutex::new(Vec::new())); - - let (client_writer, server_reader) = tokio::io::duplex(1024); - let (server_writer, client_reader) = tokio::io::duplex(1024); - - let server_reader = server_reader.compat(); - let server_writer = server_writer.compat_write(); - let client_reader = client_reader.compat(); - let client_writer = client_writer.compat_write(); - - // First handler in chain should get first chance - let handled_clone1 = handled.clone(); - let handled_clone2 = handled.clone(); - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole - .builder() - .on_receive_request( - async move |request: TrackRequest, - responder: Responder, - _connection: ConnectionTo| { - handled_clone1.lock().unwrap().push("handler1".to_string()); - responder.respond(FooResponse { - result: format!("handler1: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ) - .on_receive_request( - async move |request: TrackRequest, - responder: Responder, - _connection: ConnectionTo| { - handled_clone2.lock().unwrap().push("handler2".to_string()); - responder.respond(FooResponse { - result: format!("handler2: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - let response = recv(cx.send_request(TrackRequest { - value: "test".to_string(), - })) - .await - .map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Track request failed: {e:?}" - )) - })?; - - // First handler should have handled it - assert_eq!(response.result, "handler1: test"); + local + .run_until(async { + let handled = Arc::new(Mutex::new(Vec::new())); - Ok(()) - }, - ) - .await; + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); - assert!(result.is_ok(), "Test failed: {result:?}"); + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); - // Verify only handler1 was invoked - let handled_by = handled.lock().unwrap(); - assert_eq!(handled_by.len(), 1); - assert_eq!(handled_by[0], "handler1"); - })) - .await; + // First handler in chain should get first chance + let handled_clone1 = handled.clone(); + let handled_clone2 = handled.clone(); + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async move |request: TrackRequest, + responder: Responder, + _connection: ConnectionTo| { + handled_clone1.lock().unwrap().push("handler1".to_string()); + responder.respond(FooResponse { + result: format!("handler1: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ) + .on_receive_request( + async move |request: TrackRequest, + responder: Responder, + _connection: ConnectionTo| { + handled_clone2.lock().unwrap().push("handler2".to_string()); + responder.respond(FooResponse { + result: format!("handler2: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + let response = recv(cx.send_request(TrackRequest { + value: "test".to_string(), + })) + .await + .map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Track request failed: {e:?}" + )) + })?; + + // First handler should have handled it + assert_eq!(response.result, "handler1: test"); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + + // Verify only handler1 was invoked + let handled_by = handled.lock().unwrap(); + assert_eq!(handled_by.len(), 1); + assert_eq!(handled_by[0], "handler1"); + }) + .await; } // ============================================================================ @@ -440,86 +446,87 @@ async fn test_fallthrough_behavior() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let handled = Arc::new(Mutex::new(Vec::new())); - - let (client_writer, server_reader) = tokio::io::duplex(1024); - let (server_writer, client_reader) = tokio::io::duplex(1024); - - let server_reader = server_reader.compat(); - let server_writer = server_writer.compat_write(); - let client_reader = client_reader.compat(); - let client_writer = client_writer.compat_write(); - - // Handler1 only handles "method1", Handler2 only handles "method2" - let handled_clone1 = handled.clone(); - let handled_clone2 = handled.clone(); - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole - .builder() - .on_receive_request( - async move |request: Method1Request, - responder: Responder, - _connection: ConnectionTo| { - handled_clone1.lock().unwrap().push("method1".to_string()); - responder.respond(FooResponse { - result: format!("method1: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ) - .on_receive_request( - async move |request: Method2Request, - responder: Responder, - _connection: ConnectionTo| { - handled_clone2.lock().unwrap().push("method2".to_string()); - responder.respond(FooResponse { - result: format!("method2: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Send method2 - should fallthrough handler1 to handler2 - let response = recv(cx.send_request(Method2Request { - value: "fallthrough".to_string(), - })) - .await - .map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Method2 request failed: {e:?}" - )) - })?; - - assert_eq!(response.result, "method2: fallthrough"); + local + .run_until(async { + let handled = Arc::new(Mutex::new(Vec::new())); - Ok(()) - }, - ) - .await; + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); - assert!(result.is_ok(), "Test failed: {result:?}"); + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); - // Verify only method2 was handled (handler1 passed through) - let handled_methods = handled.lock().unwrap(); - assert_eq!(handled_methods.len(), 1); - assert_eq!(handled_methods[0], "method2"); - })) - .await; + // Handler1 only handles "method1", Handler2 only handles "method2" + let handled_clone1 = handled.clone(); + let handled_clone2 = handled.clone(); + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole + .builder() + .on_receive_request( + async move |request: Method1Request, + responder: Responder, + _connection: ConnectionTo| { + handled_clone1.lock().unwrap().push("method1".to_string()); + responder.respond(FooResponse { + result: format!("method1: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ) + .on_receive_request( + async move |request: Method2Request, + responder: Responder, + _connection: ConnectionTo| { + handled_clone2.lock().unwrap().push("method2".to_string()); + responder.respond(FooResponse { + result: format!("method2: {}", request.value), + }) + }, + agent_client_protocol_core::on_receive_request!(), + ); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Send method2 - should fallthrough handler1 to handler2 + let response = recv(cx.send_request(Method2Request { + value: "fallthrough".to_string(), + })) + .await + .map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Method2 request failed: {e:?}" + )) + })?; + + assert_eq!(response.result, "method2: fallthrough"); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + + // Verify only method2 was handled (handler1 passed through) + let handled_methods = handled.lock().unwrap(); + assert_eq!(handled_methods.len(), 1); + assert_eq!(handled_methods[0], "method2"); + }) + .await; } // ============================================================================ @@ -532,59 +539,60 @@ async fn test_no_handler_claims() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (client_writer, server_reader) = tokio::io::duplex(1024); - let (server_writer, client_reader) = tokio::io::duplex(1024); - - let server_reader = server_reader.compat(); - let server_writer = server_writer.compat_write(); - let client_reader = client_reader.compat(); - let client_writer = client_writer.compat_write(); - - // Handler that only handles "foo" - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |request: FooRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(FooResponse { - result: format!("foo: {}", request.value), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Send "bar" request which no handler claims - let response_result = recv(cx.send_request(BarRequest { - value: "unclaimed".to_string(), - })) - .await; + local + .run_until(async { + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); - // Should get an error (method not found) - assert!(response_result.is_err()); + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); - Ok(()) + // Handler that only handles "foo" + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: FooRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(FooResponse { + result: format!("foo: {}", request.value), + }) }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Send "bar" request which no handler claims + let response_result = recv(cx.send_request(BarRequest { + value: "unclaimed".to_string(), + })) + .await; + + // Should get an error (method not found) + assert!(response_result.is_err()); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -630,66 +638,68 @@ async fn test_handler_claims_notification() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let events = Arc::new(Mutex::new(Vec::new())); - - let (client_writer, server_reader) = tokio::io::duplex(1024); - let (server_writer, client_reader) = tokio::io::duplex(1024); - - let server_reader = server_reader.compat(); - let server_writer = server_writer.compat_write(); - let client_reader = client_reader.compat(); - let client_writer = client_writer.compat_write(); - - // EventHandler claims notifications - let events_clone = events.clone(); - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_notification( - async move |notification: EventNotification, _connection: ConnectionTo| { - events_clone.lock().unwrap().push(notification.event); - Ok(()) - }, - agent_client_protocol_core::on_receive_notification!(), - ); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - cx.send_notification(EventNotification { - event: "test_event".to_string(), - }) - .map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Failed to send notification: {e:?}" - )) - })?; + local + .run_until(async { + let events = Arc::new(Mutex::new(Vec::new())); + + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); - // Give server time to process - tokio::time::sleep(Duration::from_millis(100)).await; + let server_reader = server_reader.compat(); + let server_writer = server_writer.compat_write(); + let client_reader = client_reader.compat(); + let client_writer = client_writer.compat_write(); + // EventHandler claims notifications + let events_clone = events.clone(); + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_notification( + async move |notification: EventNotification, + _connection: ConnectionTo| { + events_clone.lock().unwrap().push(notification.event); Ok(()) }, - ) - .await; - - assert!(result.is_ok(), "Test failed: {result:?}"); - - let received_events = events.lock().unwrap(); - assert_eq!(received_events.len(), 1); - assert_eq!(received_events[0], "test_event"); - })) - .await; + agent_client_protocol_core::on_receive_notification!(), + ); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + cx.send_notification(EventNotification { + event: "test_event".to_string(), + }) + .map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Failed to send notification: {e:?}" + )) + })?; + + // Give server time to process + tokio::time::sleep(Duration::from_millis(100)).await; + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + + let received_events = events.lock().unwrap(); + assert_eq!(received_events.len(), 1); + assert_eq!(received_events[0], "test_event"); + }) + .await; } // ============================================================================ @@ -726,7 +736,7 @@ async fn test_connection_builder_as_component() -> Result<(), agent_client_proto ); // Use Builder as a Component via run_until - Box::pin(run_until( + run_until( // This uses Component::serve on Builder ConnectTo::::connect_to(server_builder, server_transport), async move { @@ -744,6 +754,6 @@ async fn test_connection_builder_as_component() -> Result<(), agent_client_proto }) .await }, - )) + ) .await } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs b/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs index 052722d..785e583 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_edge_cases.rs @@ -149,51 +149,53 @@ async fn test_empty_request() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |_request: EmptyRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(SimpleResponse { - result: "Got empty request".to_string(), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - let request = EmptyRequest; - - let result: Result = recv(cx.send_request(request)).await; - - // Should succeed - assert!(result.is_ok()); - if let Ok(response) = result { - assert_eq!(response.result, "Got empty request"); - } - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: EmptyRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: "Got empty request".to_string(), + }) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let request = EmptyRequest; + + let result: Result = + recv(cx.send_request(request)).await; + + // Should succeed + assert!(result.is_ok()); + if let Ok(response) = result { + assert_eq!(response.result, "Got empty request"); + } + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -206,48 +208,50 @@ async fn test_null_params() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |_request: OptionalParamsRequest, - responder: Responder, - _connection: ConnectionTo| { - responder.respond(SimpleResponse { - result: "Has params: true".to_string(), - }) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - let request = OptionalParamsRequest { value: None }; - - let result: Result = recv(cx.send_request(request)).await; - - // Should succeed - handler should handle null/missing params - assert!(result.is_ok()); - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: OptionalParamsRequest, + responder: Responder, + _connection: ConnectionTo| { + responder.respond(SimpleResponse { + result: "Has params: true".to_string(), + }) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let request = OptionalParamsRequest { value: None }; + + let result: Result = + recv(cx.send_request(request)).await; + + // Should succeed - handler should handle null/missing params + assert!(result.is_ok()); + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -282,7 +286,7 @@ async fn test_server_shutdown() { let client = UntypedRole.builder(); let server_handle = tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); + server.connect_to(server_transport).await.ok(); }); let client_result = tokio::task::spawn_local(async move { @@ -355,7 +359,7 @@ async fn test_client_disconnect() { ); tokio::task::spawn_local(async move { - drop(Box::pin(server.connect_to(server_transport)).await); + drop(server.connect_to(server_transport).await); }); // Send partial request and then disconnect diff --git a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs index 8f9154b..11ecbfe 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_error_handling.rs @@ -203,50 +203,52 @@ async fn test_unknown_method() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - // No handlers - all requests will be "method not found" - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder(); - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - // Spawn server - tokio::task::spawn_local(async move { - server.connect_to(server_transport).await.ok(); - }); - - // Send request from client - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - let request = SimpleRequest { - message: "test".to_string(), - }; - - let result: Result = recv(cx.send_request(request)).await; - - // Should get an error because no handler claims the method - assert!(result.is_err()); - if let Err(err) = result { - // Should be "method not found" or similar error - assert!(matches!( - err.code, - agent_client_protocol_core::ErrorCode::MethodNotFound - )); - } - Ok(()) - }, - ) - .await; + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + // No handlers - all requests will be "method not found" + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder(); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + // Spawn server + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + // Send request from client + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let request = SimpleRequest { + message: "test".to_string(), + }; + + let result: Result = + recv(cx.send_request(request)).await; + + // Should get an error because no handler claims the method + assert!(result.is_err()); + if let Err(err) = result { + // Should be "method not found" or similar error + assert!(matches!( + err.code, + agent_client_protocol_core::ErrorCode::MethodNotFound + )); + } + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -294,55 +296,58 @@ async fn test_handler_returns_error() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |_request: ErrorRequest, - responder: Responder, - _connection: ConnectionTo| { - // Explicitly return an error - responder.respond_with_error(agent_client_protocol_core::Error::internal_error()) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - let request = ErrorRequest { - value: "trigger error".to_string(), - }; - - let result: Result = recv(cx.send_request(request)).await; - - // Should get the error the handler returned - assert!(result.is_err()); - if let Err(err) = result { - assert!(matches!( - err.code, - agent_client_protocol_core::ErrorCode::InternalError - )); - } - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: ErrorRequest, + responder: Responder, + _connection: ConnectionTo| { + // Explicitly return an error + responder + .respond_with_error(agent_client_protocol_core::Error::internal_error()) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + let request = ErrorRequest { + value: "trigger error".to_string(), + }; + + let result: Result = + recv(cx.send_request(request)).await; + + // Should get the error the handler returned + assert!(result.is_err()); + if let Err(err) = result { + assert!(matches!( + err.code, + agent_client_protocol_core::ErrorCode::InternalError + )); + } + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } // ============================================================================ @@ -388,58 +393,61 @@ async fn test_missing_required_params() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - // Handler that validates params - since EmptyRequest has no params but we're checking - // against SimpleRequest which requires a message field, this will fail - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |_request: EmptyRequest, - responder: Responder, - _connection: ConnectionTo| { - // This will be called, but EmptyRequest parsing already succeeded - // The test is actually checking if EmptyRequest (no params) fails to parse as SimpleRequest - // But with the new API, EmptyRequest parses successfully since it expects no params - // We need to manually check - but actually the parse_request for EmptyRequest - // accepts anything for "strict_method", so the error must come from somewhere else - responder.respond_with_error(agent_client_protocol_core::Error::invalid_params()) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - Box::pin(server.connect_to(server_transport)).await.ok(); - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> Result<(), agent_client_protocol_core::Error> { - // Send request with no params (EmptyRequest has no fields) - let request = EmptyRequest; - - let result: Result = recv(cx.send_request(request)).await; - - // Should get invalid_params error - assert!(result.is_err()); - if let Err(err) = result { - assert!(matches!( - err.code, - agent_client_protocol_core::ErrorCode::InvalidParams - )); // JSONRPC_INVALID_PARAMS - } - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + // Handler that validates params - since EmptyRequest has no params but we're checking + // against SimpleRequest which requires a message field, this will fail + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |_request: EmptyRequest, + responder: Responder, + _connection: ConnectionTo| { + // This will be called, but EmptyRequest parsing already succeeded + // The test is actually checking if EmptyRequest (no params) fails to parse as SimpleRequest + // But with the new API, EmptyRequest parses successfully since it expects no params + // We need to manually check - but actually the parse_request for EmptyRequest + // accepts anything for "strict_method", so the error must come from somewhere else + responder + .respond_with_error(agent_client_protocol_core::Error::invalid_params()) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); + + tokio::task::spawn_local(async move { + server.connect_to(server_transport).await.ok(); + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> Result<(), agent_client_protocol_core::Error> { + // Send request with no params (EmptyRequest has no fields) + let request = EmptyRequest; + + let result: Result = + recv(cx.send_request(request)).await; + + // Should get invalid_params error + assert!(result.is_err()); + if let Err(err) = result { + assert!(matches!( + err.code, + agent_client_protocol_core::ErrorCode::InvalidParams + )); // JSONRPC_INVALID_PARAMS + } + Ok(()) + }, + ) + .await; - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } diff --git a/src/agent-client-protocol-core/tests/jsonrpc_hello.rs b/src/agent-client-protocol-core/tests/jsonrpc_hello.rs index 4f8d333..aa62eb6 100644 --- a/src/agent-client-protocol-core/tests/jsonrpc_hello.rs +++ b/src/agent-client-protocol-core/tests/jsonrpc_hello.rs @@ -110,59 +110,60 @@ async fn test_hello_world() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async move |request: PingRequest, - responder: Responder, - _connection: ConnectionTo| { - let pong = PongResponse { - echo: format!("pong: {}", request.message), - }; - responder.respond(pong) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - // Spawn the server in the background - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - // Use the client to send a ping and wait for a pong - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - let request = PingRequest { - message: "hello world".to_string(), + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async move |request: PingRequest, + responder: Responder, + _connection: ConnectionTo| { + let pong = PongResponse { + echo: format!("pong: {}", request.message), }; + responder.respond(pong) + }, + agent_client_protocol_core::on_receive_request!(), + ); - let response = recv(cx.send_request(request)).await.map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Request failed: {e:?}" - )) - })?; + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); - assert_eq!(response.echo, "pong: hello world"); + // Spawn the server in the background + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); - Ok(()) - }, - ) - .await; + // Use the client to send a ping and wait for a pong + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + let request = PingRequest { + message: "hello world".to_string(), + }; - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + let response = recv(cx.send_request(request)).await.map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Request failed: {e:?}" + )) + })?; + + assert_eq!(response.echo, "pong: hello world"); + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } /// A simple notification message @@ -205,74 +206,75 @@ async fn test_notification() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let logs = Arc::new(Mutex::new(Vec::new())); - let logs_clone = logs.clone(); + local + .run_until(async { + let logs = Arc::new(Mutex::new(Vec::new())); + let logs_clone = logs.clone(); + + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_notification( + { + let logs = logs_clone.clone(); + async move |notification: LogNotification, _cx: ConnectionTo| { + logs.lock().unwrap().push(notification.message); + Ok(()) + } + }, + agent_client_protocol_core::on_receive_notification!(), + ); - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_notification( - { - let logs = logs_clone.clone(); - async move |notification: LogNotification, _cx: ConnectionTo| { - logs.lock().unwrap().push(notification.message); - Ok(()) + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); } - }, - agent_client_protocol_core::on_receive_notification!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Send a notification (no response expected) - cx.send_notification(LogNotification { - message: "test log 1".to_string(), - }) - .map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Failed to send notification: {e:?}" - )) - })?; - - cx.send_notification(LogNotification { - message: "test log 2".to_string(), - }) - .map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Failed to send notification: {e:?}" - )) - })?; - - // Give the server time to process notifications - tokio::time::sleep(Duration::from_millis(100)).await; - - Ok(()) - }, - ) - .await; + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Send a notification (no response expected) + cx.send_notification(LogNotification { + message: "test log 1".to_string(), + }) + .map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Failed to send notification: {e:?}" + )) + })?; - assert!(result.is_ok(), "Test failed: {result:?}"); + cx.send_notification(LogNotification { + message: "test log 2".to_string(), + }) + .map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Failed to send notification: {e:?}" + )) + })?; + + // Give the server time to process notifications + tokio::time::sleep(Duration::from_millis(100)).await; + + Ok(()) + }, + ) + .await; - let received_logs = logs.lock().unwrap(); - assert_eq!(received_logs.len(), 2); - assert_eq!(received_logs[0], "test log 1"); - assert_eq!(received_logs[1], "test log 2"); - })) - .await; + assert!(result.is_ok(), "Test failed: {result:?}"); + + let received_logs = logs.lock().unwrap(); + assert_eq!(received_logs.len(), 2); + assert_eq!(received_logs[0], "test log 1"); + assert_eq!(received_logs[1], "test log 2"); + }) + .await; } #[tokio::test(flavor = "current_thread")] @@ -281,60 +283,61 @@ async fn test_multiple_sequential_requests() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |request: PingRequest, - responder: Responder, - _connection: ConnectionTo| { - let pong = PongResponse { - echo: format!("pong: {}", request.message), - }; - responder.respond(pong) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Send multiple requests sequentially - for i in 1..=5 { - let request = PingRequest { - message: format!("message {i}"), - }; - - let response = recv(cx.send_request(request)).await.map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Request {i} failed: {e:?}" - )) - })?; - - assert_eq!(response.echo, format!("pong: message {i}")); - } - - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: PingRequest, + responder: Responder, + _connection: ConnectionTo| { + let pong = PongResponse { + echo: format!("pong: {}", request.message), + }; + responder.respond(pong) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Send multiple requests sequentially + for i in 1..=5 { + let request = PingRequest { + message: format!("message {i}"), + }; + + let response = recv(cx.send_request(request)).await.map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Request {i} failed: {e:?}" + )) + })?; + + assert_eq!(response.echo, format!("pong: message {i}")); + } + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } #[tokio::test(flavor = "current_thread")] @@ -343,66 +346,67 @@ async fn test_concurrent_requests() { let local = LocalSet::new(); - Box::pin(local.run_until(async { - let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); - - let server_transport = - agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); - let server = UntypedRole.builder().on_receive_request( - async |request: PingRequest, - responder: Responder, - _connection: ConnectionTo| { - let pong = PongResponse { - echo: format!("pong: {}", request.message), - }; - responder.respond(pong) - }, - agent_client_protocol_core::on_receive_request!(), - ); - - let client_transport = - agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); - let client = UntypedRole.builder(); - - tokio::task::spawn_local(async move { - if let Err(e) = Box::pin(server.connect_to(server_transport)).await { - eprintln!("Server error: {e:?}"); - } - }); - - let result = client - .connect_with( - client_transport, - async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { - // Send multiple requests concurrently - let mut responses = Vec::new(); - - for i in 1..=5 { - let request = PingRequest { - message: format!("concurrent message {i}"), - }; - - // Start all requests without awaiting - responses.push((i, cx.send_request(request))); - } - - // Now await all responses - for (i, response_future) in responses { - let response = recv(response_future).await.map_err(|e| { - agent_client_protocol_core::util::internal_error(format!( - "Request {i} failed: {e:?}" - )) - })?; - - assert_eq!(response.echo, format!("pong: concurrent message {i}")); - } - - Ok(()) + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = + agent_client_protocol_core::ByteStreams::new(server_writer, server_reader); + let server = UntypedRole.builder().on_receive_request( + async |request: PingRequest, + responder: Responder, + _connection: ConnectionTo| { + let pong = PongResponse { + echo: format!("pong: {}", request.message), + }; + responder.respond(pong) }, - ) - .await; + agent_client_protocol_core::on_receive_request!(), + ); + + let client_transport = + agent_client_protocol_core::ByteStreams::new(client_writer, client_reader); + let client = UntypedRole.builder(); - assert!(result.is_ok(), "Test failed: {result:?}"); - })) - .await; + tokio::task::spawn_local(async move { + if let Err(e) = server.connect_to(server_transport).await { + eprintln!("Server error: {e:?}"); + } + }); + + let result = client + .connect_with( + client_transport, + async |cx| -> std::result::Result<(), agent_client_protocol_core::Error> { + // Send multiple requests concurrently + let mut responses = Vec::new(); + + for i in 1..=5 { + let request = PingRequest { + message: format!("concurrent message {i}"), + }; + + // Start all requests without awaiting + responses.push((i, cx.send_request(request))); + } + + // Now await all responses + for (i, response_future) in responses { + let response = recv(response_future).await.map_err(|e| { + agent_client_protocol_core::util::internal_error(format!( + "Request {i} failed: {e:?}" + )) + })?; + + assert_eq!(response.echo, format!("pong: concurrent message {i}")); + } + + Ok(()) + }, + ) + .await; + + assert!(result.is_ok(), "Test failed: {result:?}"); + }) + .await; } diff --git a/src/agent-client-protocol-rmcp/examples/with_mcp_server.rs b/src/agent-client-protocol-rmcp/examples/with_mcp_server.rs index 171d36a..7b8139c 100644 --- a/src/agent-client-protocol-rmcp/examples/with_mcp_server.rs +++ b/src/agent-client-protocol-rmcp/examples/with_mcp_server.rs @@ -87,7 +87,7 @@ async fn main() -> Result<(), Box> { // Set up the proxy connection with our MCP server // ProxyToConductor already has proxy behavior built into its default_message_handler - let proxy = Proxy + Proxy .builder() .name("mcp-server-proxy") // Register the MCP server as a handler @@ -96,8 +96,8 @@ async fn main() -> Result<(), Box> { .connect_to(agent_client_protocol_core::ByteStreams::new( tokio::io::stdout().compat_write(), tokio::io::stdin().compat(), - )); - Box::pin(proxy).await?; + )) + .await?; Ok(()) } diff --git a/src/agent-client-protocol-test/examples/arrow_proxy.rs b/src/agent-client-protocol-test/examples/arrow_proxy.rs index 73fc6ab..3b7b2bb 100644 --- a/src/agent-client-protocol-test/examples/arrow_proxy.rs +++ b/src/agent-client-protocol-test/examples/arrow_proxy.rs @@ -20,10 +20,7 @@ async fn main() -> Result<(), Box> { let stdout = tokio::io::stdout().compat_write(); // Run the arrow proxy - Box::pin(run_arrow_proxy( - agent_client_protocol_core::ByteStreams::new(stdout, stdin), - )) - .await?; + run_arrow_proxy(agent_client_protocol_core::ByteStreams::new(stdout, stdin)).await?; Ok(()) } diff --git a/src/agent-client-protocol-test/src/bin/testy.rs b/src/agent-client-protocol-test/src/bin/testy.rs index 3a03c44..72834f9 100644 --- a/src/agent-client-protocol-test/src/bin/testy.rs +++ b/src/agent-client-protocol-test/src/bin/testy.rs @@ -6,6 +6,8 @@ async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_writer(std::io::stderr) .init(); - Box::pin(Testy::new().connect_to(agent_client_protocol_tokio::Stdio::new())).await?; + Testy::new() + .connect_to(agent_client_protocol_tokio::Stdio::new()) + .await?; Ok(()) } diff --git a/src/agent-client-protocol-tokio/tests/debug_logging.rs b/src/agent-client-protocol-tokio/tests/debug_logging.rs index a56490b..277d212 100644 --- a/src/agent-client-protocol-tokio/tests/debug_logging.rs +++ b/src/agent-client-protocol-tokio/tests/debug_logging.rs @@ -60,7 +60,7 @@ async fn test_acp_agent_debug_callback() -> Result<(), Box Result<(), Box Result<(), Box> { eprintln!("🚀 Spawning agent and running prompt..."); // Use the library function with callback to print progressively - Box::pin(yopo::prompt_with_callback( - agent, - prompt.as_str(), - |block| async move { - print!("{}", yopo::content_block_to_string(&block)); - }, - )) + yopo::prompt_with_callback(agent, prompt.as_str(), |block| async move { + print!("{}", yopo::content_block_to_string(&block)); + }) .await?; println!(); // Final newline