Skip to content

Commit 4d2c503

Browse files
committed
feat: SSE server-initiated requests, sampling helpers, and CLI tasks help
1 parent 34b1f0c commit 4d2c503

10 files changed

Lines changed: 707 additions & 52 deletions

File tree

CMakeLists.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ if(FASTMCPP_BUILD_TESTS)
171171
add_test(NAME fastmcpp_integration COMMAND fastmcpp_integration)
172172

173173
add_test(NAME fastmcpp_cli_sum COMMAND fastmcpp client sum 2 3)
174+
add_test(NAME fastmcpp_cli_tasks_help COMMAND fastmcpp tasks --help)
174175

175-
add_executable(fastmcpp_http_integration tests/server/http_integration.cpp)
176-
target_link_libraries(fastmcpp_http_integration PRIVATE fastmcpp_core)
177-
add_test(NAME fastmcpp_http_integration COMMAND fastmcpp_http_integration)
176+
add_executable(fastmcpp_http_integration tests/server/http_integration.cpp)
177+
target_link_libraries(fastmcpp_http_integration PRIVATE fastmcpp_core)
178+
add_test(NAME fastmcpp_http_integration COMMAND fastmcpp_http_integration)
178179

179180
add_executable(fastmcpp_json_schema tests/schema/json_schema.cpp)
180181
target_link_libraries(fastmcpp_json_schema PRIVATE fastmcpp_core)
@@ -324,6 +325,11 @@ if(FASTMCPP_BUILD_TESTS)
324325
target_link_libraries(fastmcpp_server_sse_http_integration PRIVATE fastmcpp_core)
325326
add_test(NAME fastmcpp_server_sse_http_integration COMMAND fastmcpp_server_sse_http_integration)
326327

328+
# SSE bidirectional server-initiated requests (ServerSession -> client -> response)
329+
add_executable(fastmcpp_server_sse_bidirectional_requests tests/server/sse_bidirectional_requests.cpp)
330+
target_link_libraries(fastmcpp_server_sse_bidirectional_requests PRIVATE fastmcpp_core)
331+
add_test(NAME fastmcpp_server_sse_bidirectional_requests COMMAND fastmcpp_server_sse_bidirectional_requests)
332+
327333
# Streamable HTTP integration (MCP spec 2025-03-26)
328334
add_executable(fastmcpp_server_streamable_http_integration tests/server/streamable_http_integration.cpp)
329335
target_link_libraries(fastmcpp_server_streamable_http_integration PRIVATE fastmcpp_core)

include/fastmcpp/client/client.hpp

Lines changed: 189 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <functional>
1717
#include <future>
1818
#include <memory>
19+
#include <mutex>
1920
#include <string>
2021
#include <thread>
2122
#include <unordered_map>
@@ -44,6 +45,28 @@ class ITransport
4445
virtual fastmcpp::Json request(const std::string& route, const fastmcpp::Json& payload) = 0;
4546
};
4647

48+
/// Optional transport interface: some transports support explicit session reset/disconnect.
49+
class IResettableTransport
50+
{
51+
public:
52+
virtual ~IResettableTransport() = default;
53+
54+
/// Reset connection/session state. Semantics are transport-specific.
55+
/// @param full If true, reset any additional internal state beyond the session identifier.
56+
virtual void reset(bool full = false) = 0;
57+
};
58+
59+
using ServerRequestHandler =
60+
std::function<fastmcpp::Json(const std::string& method, const fastmcpp::Json& params)>;
61+
62+
/// Optional transport interface: some transports can accept server-initiated requests and send responses.
63+
class IServerRequestTransport
64+
{
65+
public:
66+
virtual ~IServerRequestTransport() = default;
67+
virtual void set_server_request_handler(ServerRequestHandler handler) = 0;
68+
};
69+
4770
/// Loopback transport for in-process server testing
4871
class LoopbackTransport : public ITransport
4972
{
@@ -146,17 +169,24 @@ struct CallToolOptions
146169
/// @endcode
147170
class Client
148171
{
172+
struct CallbackState;
173+
149174
public:
150-
Client() = default;
175+
Client() : callbacks_(std::make_shared<CallbackState>()) {}
151176
explicit Client(std::unique_ptr<ITransport> t)
152-
: transport_(std::shared_ptr<ITransport>(std::move(t)))
177+
: transport_(std::shared_ptr<ITransport>(std::move(t))),
178+
callbacks_(std::make_shared<CallbackState>())
153179
{
180+
configure_transport_callbacks();
154181
}
155182

156183
/// Set the transport (for deferred initialization)
157184
void set_transport(std::unique_ptr<ITransport> t)
158185
{
159186
transport_ = std::shared_ptr<ITransport>(std::move(t));
187+
if (!callbacks_)
188+
callbacks_ = std::make_shared<CallbackState>();
189+
configure_transport_callbacks();
160190
}
161191

162192
/// Check if transport is connected
@@ -512,8 +542,20 @@ class Client
512542
/// Initialize the session with the server
513543
InitializeResult initialize(std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
514544
{
545+
fastmcpp::Json caps = fastmcpp::Json::object();
546+
if (get_sampling_callback())
547+
{
548+
caps["sampling"] = fastmcpp::Json::object();
549+
// Optimistically advertise tools support when a sampling callback is present.
550+
caps["sampling"]["tools"] = fastmcpp::Json::object();
551+
}
552+
if (get_elicitation_callback())
553+
caps["elicitation"] = fastmcpp::Json::object();
554+
if (get_roots_callback())
555+
caps["roots"] = fastmcpp::Json::object();
556+
515557
fastmcpp::Json payload = {{"protocolVersion", "2024-11-05"},
516-
{"capabilities", fastmcpp::Json::object()},
558+
{"capabilities", std::move(caps)},
517559
{"clientInfo", {{"name", "fastmcpp"}, {"version", "2.14.0"}}}};
518560

519561
auto response = call("initialize", payload);
@@ -543,6 +585,15 @@ class Client
543585
call("notifications/cancelled", payload);
544586
}
545587

588+
/// Reset transport session/connection state when supported (best-effort).
589+
void disconnect(bool full = false)
590+
{
591+
if (!transport_)
592+
return;
593+
if (auto* resettable = dynamic_cast<IResettableTransport*>(transport_.get()))
594+
resettable->reset(full);
595+
}
596+
546597
/// Send a progress notification
547598
void progress(const std::string& progress_token, float progress_value,
548599
std::optional<float> total = std::nullopt, const std::string& message = "")
@@ -567,20 +618,33 @@ class Client
567618
void send_roots_list_changed()
568619
{
569620
fastmcpp::Json payload = fastmcpp::Json::object();
570-
if (roots_callback_)
571-
payload["roots"] = roots_callback_();
621+
auto cb = get_roots_callback();
622+
if (cb)
623+
payload["roots"] = cb();
572624
call("roots/list_changed", payload);
573625
}
574626

575627
/// Handle server notifications that target client callbacks (sampling/elicitation/roots)
576628
fastmcpp::Json handle_notification(const std::string& method, const fastmcpp::Json& params)
577629
{
578-
if (method == "sampling/request" && sampling_callback_)
579-
return sampling_callback_(params);
580-
if (method == "elicitation/request" && elicitation_callback_)
581-
return elicitation_callback_(params);
582-
if (method == "roots/list" && roots_callback_)
583-
return roots_callback_();
630+
if (method == "sampling/request")
631+
{
632+
auto cb = get_sampling_callback();
633+
if (cb)
634+
return cb(params);
635+
}
636+
if (method == "elicitation/request")
637+
{
638+
auto cb = get_elicitation_callback();
639+
if (cb)
640+
return cb(params);
641+
}
642+
if (method == "roots/list")
643+
{
644+
auto cb = get_roots_callback();
645+
if (cb)
646+
return cb();
647+
}
584648
throw fastmcpp::Error("Unsupported notification method: " + method);
585649
}
586650

@@ -589,7 +653,7 @@ class Client
589653
{
590654
if (!transport_)
591655
throw fastmcpp::Error("Cannot clone client without transport");
592-
return Client(transport_, true);
656+
return Client(transport_, callbacks_, true);
593657
}
594658

595659
/// Python-friendly alias for cloning
@@ -599,17 +663,17 @@ class Client
599663
}
600664

601665
/// Register roots/sampling/elicitation callbacks (placeholders for parity)
602-
void set_roots_callback(const std::function<fastmcpp::Json()>& cb)
666+
void set_roots_callback(const std::function<fastmcpp::Json()>& cb)
603667
{
604-
roots_callback_ = cb;
668+
set_roots_callback_impl(cb);
605669
}
606670
void set_sampling_callback(const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
607671
{
608-
sampling_callback_ = cb;
672+
set_sampling_callback_impl(cb);
609673
}
610674
void set_elicitation_callback(const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
611675
{
612-
elicitation_callback_ = cb;
676+
set_elicitation_callback_impl(cb);
613677
}
614678

615679
/// Poll server notifications and dispatch to callbacks (sampling/elicitation/roots)
@@ -641,13 +705,117 @@ class Client
641705
friend class ResourceTask;
642706

643707
std::shared_ptr<ITransport> transport_;
644-
std::function<fastmcpp::Json()> roots_callback_;
645-
std::function<fastmcpp::Json(const fastmcpp::Json&)> sampling_callback_;
646-
std::function<fastmcpp::Json(const fastmcpp::Json&)> elicitation_callback_;
647-
std::unordered_map<std::string, fastmcpp::Json> tool_output_schemas_;
708+
struct CallbackState
709+
{
710+
std::mutex mutex;
711+
std::function<fastmcpp::Json()> roots_callback;
712+
std::function<fastmcpp::Json(const fastmcpp::Json&)> sampling_callback;
713+
std::function<fastmcpp::Json(const fastmcpp::Json&)> elicitation_callback;
714+
};
715+
716+
std::shared_ptr<CallbackState> callbacks_;
717+
std::unordered_map<std::string, fastmcpp::Json> tool_output_schemas_;
718+
719+
std::function<fastmcpp::Json()> get_roots_callback() const
720+
{
721+
if (!callbacks_)
722+
return {};
723+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
724+
return callbacks_->roots_callback;
725+
}
726+
std::function<fastmcpp::Json(const fastmcpp::Json&)> get_sampling_callback() const
727+
{
728+
if (!callbacks_)
729+
return {};
730+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
731+
return callbacks_->sampling_callback;
732+
}
733+
std::function<fastmcpp::Json(const fastmcpp::Json&)> get_elicitation_callback() const
734+
{
735+
if (!callbacks_)
736+
return {};
737+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
738+
return callbacks_->elicitation_callback;
739+
}
740+
741+
void set_roots_callback_impl(const std::function<fastmcpp::Json()>& cb)
742+
{
743+
if (!callbacks_)
744+
callbacks_ = std::make_shared<CallbackState>();
745+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
746+
callbacks_->roots_callback = cb;
747+
}
748+
void set_sampling_callback_impl(
749+
const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
750+
{
751+
if (!callbacks_)
752+
callbacks_ = std::make_shared<CallbackState>();
753+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
754+
callbacks_->sampling_callback = cb;
755+
}
756+
void set_elicitation_callback_impl(
757+
const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
758+
{
759+
if (!callbacks_)
760+
callbacks_ = std::make_shared<CallbackState>();
761+
std::lock_guard<std::mutex> lock(callbacks_->mutex);
762+
callbacks_->elicitation_callback = cb;
763+
}
764+
765+
void configure_transport_callbacks()
766+
{
767+
if (!transport_ || !callbacks_)
768+
return;
769+
if (auto* req_transport = dynamic_cast<IServerRequestTransport*>(transport_.get()))
770+
{
771+
std::weak_ptr<CallbackState> weak = callbacks_;
772+
req_transport->set_server_request_handler(
773+
[weak](const std::string& method, const fastmcpp::Json& params) -> fastmcpp::Json
774+
{
775+
auto state = weak.lock();
776+
if (!state)
777+
throw fastmcpp::Error("Client callbacks expired");
778+
779+
std::function<fastmcpp::Json()> roots_cb;
780+
std::function<fastmcpp::Json(const fastmcpp::Json&)> sampling_cb;
781+
std::function<fastmcpp::Json(const fastmcpp::Json&)> elicitation_cb;
782+
{
783+
std::lock_guard<std::mutex> lock(state->mutex);
784+
roots_cb = state->roots_callback;
785+
sampling_cb = state->sampling_callback;
786+
elicitation_cb = state->elicitation_callback;
787+
}
788+
789+
if (method == "sampling/createMessage")
790+
{
791+
if (!sampling_cb)
792+
throw fastmcpp::Error("No sampling handler configured");
793+
return sampling_cb(params);
794+
}
795+
if (method == "elicitation/request")
796+
{
797+
if (!elicitation_cb)
798+
throw fastmcpp::Error("No elicitation handler configured");
799+
return elicitation_cb(params);
800+
}
801+
if (method == "roots/list")
802+
{
803+
if (!roots_cb)
804+
throw fastmcpp::Error("No roots handler configured");
805+
return roots_cb();
806+
}
807+
808+
throw fastmcpp::Error("Unsupported server request method: " + method);
809+
});
810+
}
811+
}
648812

649813
// Internal constructor for cloning
650-
Client(std::shared_ptr<ITransport> t, bool /*internal*/) : transport_(std::move(t)) {}
814+
Client(std::shared_ptr<ITransport> t, std::shared_ptr<CallbackState> callbacks, bool /*internal*/)
815+
: transport_(std::move(t)), callbacks_(std::move(callbacks))
816+
{
817+
configure_transport_callbacks();
818+
}
651819

652820
// ==========================================================================
653821
// Response Parsers
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/// @file client/sampling.hpp
2+
/// @brief Small helpers for MCP sampling/createMessage client callbacks.
3+
4+
#pragma once
5+
6+
#include "fastmcpp/types.hpp"
7+
8+
#include <functional>
9+
#include <string>
10+
#include <utility>
11+
#include <variant>
12+
13+
namespace fastmcpp::client::sampling
14+
{
15+
16+
/// Result type a sampling handler can return.
17+
/// - std::string: treated as an assistant text message
18+
/// - fastmcpp::Json: treated as a full MCP CreateMessageResult(+WithTools) object
19+
using SamplingHandlerResult = std::variant<std::string, fastmcpp::Json>;
20+
21+
/// Handler signature used by create_sampling_callback().
22+
using SamplingHandler = std::function<SamplingHandlerResult(const fastmcpp::Json& params)>;
23+
24+
/// Build a minimal MCP CreateMessageResult with a single text content block.
25+
inline fastmcpp::Json make_text_result(std::string text,
26+
std::string model = "fastmcpp-client",
27+
std::string role = "assistant")
28+
{
29+
return fastmcpp::Json{
30+
{"role", std::move(role)},
31+
{"model", std::move(model)},
32+
{"content", fastmcpp::Json::array(
33+
{fastmcpp::Json{{"type", "text"}, {"text", std::move(text)}}})},
34+
};
35+
}
36+
37+
/// Wrap a handler so it can be registered via Client::set_sampling_callback.
38+
/// Exceptions propagate and are converted into JSON-RPC errors by the transport.
39+
inline std::function<fastmcpp::Json(const fastmcpp::Json&)>
40+
create_sampling_callback(SamplingHandler handler)
41+
{
42+
return [handler = std::move(handler)](const fastmcpp::Json& params) -> fastmcpp::Json
43+
{
44+
SamplingHandlerResult result = handler(params);
45+
if (std::holds_alternative<std::string>(result))
46+
return make_text_result(std::get<std::string>(std::move(result)));
47+
return std::get<fastmcpp::Json>(std::move(result));
48+
};
49+
}
50+
51+
} // namespace fastmcpp::client::sampling
52+

0 commit comments

Comments
 (0)