1616#include < functional>
1717#include < future>
1818#include < memory>
19+ #include < mutex>
1920#include < string>
2021#include < thread>
2122#include < unordered_map>
@@ -44,6 +45,28 @@ class ITransport
4445 virtual fastmcpp::Json request (const std::string& route, const fastmcpp::Json& payload) = 0;
4546};
4647
48+ // / Optional transport interface: some transports support explicit session reset/disconnect.
49+ class IResettableTransport
50+ {
51+ public:
52+ virtual ~IResettableTransport () = default ;
53+
54+ // / Reset connection/session state. Semantics are transport-specific.
55+ // / @param full If true, reset any additional internal state beyond the session identifier.
56+ virtual void reset (bool full = false ) = 0;
57+ };
58+
59+ using ServerRequestHandler =
60+ std::function<fastmcpp::Json(const std::string& method, const fastmcpp::Json& params)>;
61+
62+ // / Optional transport interface: some transports can accept server-initiated requests and send responses.
63+ class IServerRequestTransport
64+ {
65+ public:
66+ virtual ~IServerRequestTransport () = default ;
67+ virtual void set_server_request_handler (ServerRequestHandler handler) = 0;
68+ };
69+
4770// / Loopback transport for in-process server testing
4871class LoopbackTransport : public ITransport
4972{
@@ -146,17 +169,24 @@ struct CallToolOptions
146169// / @endcode
147170class Client
148171{
172+ struct CallbackState ;
173+
149174 public:
150- Client () = default ;
175+ Client () : callbacks_(std::make_shared<CallbackState>()) {}
151176 explicit Client (std::unique_ptr<ITransport> t)
152- : transport_(std::shared_ptr<ITransport>(std::move(t)))
177+ : transport_(std::shared_ptr<ITransport>(std::move(t))),
178+ callbacks_(std::make_shared<CallbackState>())
153179 {
180+ configure_transport_callbacks ();
154181 }
155182
156183 // / Set the transport (for deferred initialization)
157184 void set_transport (std::unique_ptr<ITransport> t)
158185 {
159186 transport_ = std::shared_ptr<ITransport>(std::move (t));
187+ if (!callbacks_)
188+ callbacks_ = std::make_shared<CallbackState>();
189+ configure_transport_callbacks ();
160190 }
161191
162192 // / Check if transport is connected
@@ -512,8 +542,20 @@ class Client
512542 // / Initialize the session with the server
513543 InitializeResult initialize (std::chrono::milliseconds timeout = std::chrono::milliseconds{0 })
514544 {
545+ fastmcpp::Json caps = fastmcpp::Json::object ();
546+ if (get_sampling_callback ())
547+ {
548+ caps[" sampling" ] = fastmcpp::Json::object ();
549+ // Optimistically advertise tools support when a sampling callback is present.
550+ caps[" sampling" ][" tools" ] = fastmcpp::Json::object ();
551+ }
552+ if (get_elicitation_callback ())
553+ caps[" elicitation" ] = fastmcpp::Json::object ();
554+ if (get_roots_callback ())
555+ caps[" roots" ] = fastmcpp::Json::object ();
556+
515557 fastmcpp::Json payload = {{" protocolVersion" , " 2024-11-05" },
516- {" capabilities" , fastmcpp::Json::object ( )},
558+ {" capabilities" , std::move (caps )},
517559 {" clientInfo" , {{" name" , " fastmcpp" }, {" version" , " 2.14.0" }}}};
518560
519561 auto response = call (" initialize" , payload);
@@ -543,6 +585,15 @@ class Client
543585 call (" notifications/cancelled" , payload);
544586 }
545587
588+ // / Reset transport session/connection state when supported (best-effort).
589+ void disconnect (bool full = false )
590+ {
591+ if (!transport_)
592+ return ;
593+ if (auto * resettable = dynamic_cast <IResettableTransport*>(transport_.get ()))
594+ resettable->reset (full);
595+ }
596+
546597 // / Send a progress notification
547598 void progress (const std::string& progress_token, float progress_value,
548599 std::optional<float > total = std::nullopt , const std::string& message = " " )
@@ -567,20 +618,33 @@ class Client
567618 void send_roots_list_changed ()
568619 {
569620 fastmcpp::Json payload = fastmcpp::Json::object ();
570- if (roots_callback_)
571- payload[" roots" ] = roots_callback_ ();
621+ auto cb = get_roots_callback ();
622+ if (cb)
623+ payload[" roots" ] = cb ();
572624 call (" roots/list_changed" , payload);
573625 }
574626
575627 // / Handle server notifications that target client callbacks (sampling/elicitation/roots)
576628 fastmcpp::Json handle_notification (const std::string& method, const fastmcpp::Json& params)
577629 {
578- if (method == " sampling/request" && sampling_callback_)
579- return sampling_callback_ (params);
580- if (method == " elicitation/request" && elicitation_callback_)
581- return elicitation_callback_ (params);
582- if (method == " roots/list" && roots_callback_)
583- return roots_callback_ ();
630+ if (method == " sampling/request" )
631+ {
632+ auto cb = get_sampling_callback ();
633+ if (cb)
634+ return cb (params);
635+ }
636+ if (method == " elicitation/request" )
637+ {
638+ auto cb = get_elicitation_callback ();
639+ if (cb)
640+ return cb (params);
641+ }
642+ if (method == " roots/list" )
643+ {
644+ auto cb = get_roots_callback ();
645+ if (cb)
646+ return cb ();
647+ }
584648 throw fastmcpp::Error (" Unsupported notification method: " + method);
585649 }
586650
@@ -589,7 +653,7 @@ class Client
589653 {
590654 if (!transport_)
591655 throw fastmcpp::Error (" Cannot clone client without transport" );
592- return Client (transport_, true );
656+ return Client (transport_, callbacks_, true );
593657 }
594658
595659 // / Python-friendly alias for cloning
@@ -599,17 +663,17 @@ class Client
599663 }
600664
601665 // / Register roots/sampling/elicitation callbacks (placeholders for parity)
602- void set_roots_callback (const std::function<fastmcpp::Json()>& cb)
666+ void set_roots_callback (const std::function<fastmcpp::Json()>& cb)
603667 {
604- roots_callback_ = cb ;
668+ set_roots_callback_impl (cb) ;
605669 }
606670 void set_sampling_callback (const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
607671 {
608- sampling_callback_ = cb ;
672+ set_sampling_callback_impl (cb) ;
609673 }
610674 void set_elicitation_callback (const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
611675 {
612- elicitation_callback_ = cb ;
676+ set_elicitation_callback_impl (cb) ;
613677 }
614678
615679 // / Poll server notifications and dispatch to callbacks (sampling/elicitation/roots)
@@ -641,13 +705,117 @@ class Client
641705 friend class ResourceTask ;
642706
643707 std::shared_ptr<ITransport> transport_;
644- std::function<fastmcpp::Json()> roots_callback_;
645- std::function<fastmcpp::Json(const fastmcpp::Json&)> sampling_callback_;
646- std::function<fastmcpp::Json(const fastmcpp::Json&)> elicitation_callback_;
647- std::unordered_map<std::string, fastmcpp::Json> tool_output_schemas_;
708+ struct CallbackState
709+ {
710+ std::mutex mutex;
711+ std::function<fastmcpp::Json()> roots_callback;
712+ std::function<fastmcpp::Json(const fastmcpp::Json&)> sampling_callback;
713+ std::function<fastmcpp::Json(const fastmcpp::Json&)> elicitation_callback;
714+ };
715+
716+ std::shared_ptr<CallbackState> callbacks_;
717+ std::unordered_map<std::string, fastmcpp::Json> tool_output_schemas_;
718+
719+ std::function<fastmcpp::Json()> get_roots_callback () const
720+ {
721+ if (!callbacks_)
722+ return {};
723+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
724+ return callbacks_->roots_callback ;
725+ }
726+ std::function<fastmcpp::Json(const fastmcpp::Json&)> get_sampling_callback () const
727+ {
728+ if (!callbacks_)
729+ return {};
730+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
731+ return callbacks_->sampling_callback ;
732+ }
733+ std::function<fastmcpp::Json(const fastmcpp::Json&)> get_elicitation_callback () const
734+ {
735+ if (!callbacks_)
736+ return {};
737+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
738+ return callbacks_->elicitation_callback ;
739+ }
740+
741+ void set_roots_callback_impl (const std::function<fastmcpp::Json()>& cb)
742+ {
743+ if (!callbacks_)
744+ callbacks_ = std::make_shared<CallbackState>();
745+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
746+ callbacks_->roots_callback = cb;
747+ }
748+ void set_sampling_callback_impl (
749+ const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
750+ {
751+ if (!callbacks_)
752+ callbacks_ = std::make_shared<CallbackState>();
753+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
754+ callbacks_->sampling_callback = cb;
755+ }
756+ void set_elicitation_callback_impl (
757+ const std::function<fastmcpp::Json(const fastmcpp::Json&)>& cb)
758+ {
759+ if (!callbacks_)
760+ callbacks_ = std::make_shared<CallbackState>();
761+ std::lock_guard<std::mutex> lock (callbacks_->mutex );
762+ callbacks_->elicitation_callback = cb;
763+ }
764+
765+ void configure_transport_callbacks ()
766+ {
767+ if (!transport_ || !callbacks_)
768+ return ;
769+ if (auto * req_transport = dynamic_cast <IServerRequestTransport*>(transport_.get ()))
770+ {
771+ std::weak_ptr<CallbackState> weak = callbacks_;
772+ req_transport->set_server_request_handler (
773+ [weak](const std::string& method, const fastmcpp::Json& params) -> fastmcpp::Json
774+ {
775+ auto state = weak.lock ();
776+ if (!state)
777+ throw fastmcpp::Error (" Client callbacks expired" );
778+
779+ std::function<fastmcpp::Json ()> roots_cb;
780+ std::function<fastmcpp::Json (const fastmcpp::Json&)> sampling_cb;
781+ std::function<fastmcpp::Json (const fastmcpp::Json&)> elicitation_cb;
782+ {
783+ std::lock_guard<std::mutex> lock (state->mutex );
784+ roots_cb = state->roots_callback ;
785+ sampling_cb = state->sampling_callback ;
786+ elicitation_cb = state->elicitation_callback ;
787+ }
788+
789+ if (method == " sampling/createMessage" )
790+ {
791+ if (!sampling_cb)
792+ throw fastmcpp::Error (" No sampling handler configured" );
793+ return sampling_cb (params);
794+ }
795+ if (method == " elicitation/request" )
796+ {
797+ if (!elicitation_cb)
798+ throw fastmcpp::Error (" No elicitation handler configured" );
799+ return elicitation_cb (params);
800+ }
801+ if (method == " roots/list" )
802+ {
803+ if (!roots_cb)
804+ throw fastmcpp::Error (" No roots handler configured" );
805+ return roots_cb ();
806+ }
807+
808+ throw fastmcpp::Error (" Unsupported server request method: " + method);
809+ });
810+ }
811+ }
648812
649813 // Internal constructor for cloning
650- Client (std::shared_ptr<ITransport> t, bool /* internal*/ ) : transport_(std::move(t)) {}
814+ Client (std::shared_ptr<ITransport> t, std::shared_ptr<CallbackState> callbacks, bool /* internal*/ )
815+ : transport_(std::move(t)), callbacks_(std::move(callbacks))
816+ {
817+ configure_transport_callbacks ();
818+ }
651819
652820 // ==========================================================================
653821 // Response Parsers
0 commit comments