Skip to content

Commit 81d5c17

Browse files
committed
Async WSocketServer can bind to multiple URLs
a single instance of WSocketServer can serve multiple websocket URLs connection URL is hashed to 32 bit and kept as a member of respective WSocketClient struct a set of methods are provided to get/set/check server and client's URL
1 parent 8c77eef commit 81d5c17

2 files changed

Lines changed: 228 additions & 38 deletions

File tree

src/AsyncWSocket.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// We target C++17 capable toolchain
77
#if __cplusplus >= 201703L
88
#include "AsyncWSocket.h"
9+
#if defined(ESP32) && (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0))
910
#include "literals.h"
1011

1112
#define WS_MAX_HEADER_SIZE 16
@@ -169,6 +170,9 @@ WSocketClient::WSocketClient(uint32_t id, AsyncWebServerRequest *request, WSocke
169170
_client->onData( [](void *r, AsyncClient *c, void *buf, size_t len) { (void)c; reinterpret_cast<WSocketClient*>(r)->_onData(buf, len); }, this );
170171
_client->onPoll( [](void *r, AsyncClient *c) { (void)c; reinterpret_cast<WSocketClient*>(r)->_keepalive(); reinterpret_cast<WSocketClient*>(r)->_clientSend(); }, this );
171172
_client->onError( [](void *r, AsyncClient *c, int8_t error) { (void)c; log_e("err:%d", error); }, this );
173+
// bind URL hash
174+
setURLHash(request->url().c_str());
175+
172176
delete request;
173177
}
174178

@@ -231,7 +235,7 @@ void WSocketClient::_clientSend(size_t acked_bytes){
231235
return;
232236

233237
auto sock_space = _client->space();
234-
log_d("no ack infl:%u, space:%u, data pending:%u", _in_flight, sock_space, (uint32_t)(_outFrame.len - _outFrame.index));
238+
//log_d("no ack infl:%u, space:%u, data pending:%u", _in_flight, sock_space, (uint32_t)(_outFrame.len - _outFrame.index));
235239
//
236240
if (!sock_space)
237241
return;
@@ -700,10 +704,11 @@ bool WSocketServer::newClient(AsyncWebServerRequest *request){
700704
else
701705
c->dequeueMessage(); }, // silently discard incoming messages when there is no callback set
702706
msgsize, qcap);
707+
_clients.back().setOverflowPolicy(_overflow_policy);
708+
_clients.back().setKeepAlive(_keepAlivePeriod);
703709
}
704-
_clients.back().setOverflowPolicy(_overflow_policy);
705-
_clients.back().setKeepAlive(_keepAlivePeriod);
706-
if (eventHandler) eventHandler(&_clients.back(), WSocketClient::event_t::connect);
710+
if (eventHandler)
711+
eventHandler(&_clients.back(), WSocketClient::event_t::connect);
707712
return true;
708713
}
709714

@@ -712,6 +717,7 @@ void WSocketServer::handleRequest(AsyncWebServerRequest *request) {
712717
request->send(400);
713718
return;
714719
}
720+
715721
if (_handshakeHandler != nullptr) {
716722
if (!_handshakeHandler(request)) {
717723
request->send(401);
@@ -739,9 +745,19 @@ void WSocketServer::handleRequest(AsyncWebServerRequest *request) {
739745
// ToDo: check protocol
740746
response->addHeader(WS_STR_PROTOCOL, protocol->value());
741747
}
748+
742749
request->send(response);
743750
}
744751

752+
bool WSocketServer::canHandle(AsyncWebServerRequest *request) const {
753+
if (request->isWebSocketUpgrade()){
754+
auto url = request->url().c_str();
755+
auto i = std::find_if(_urlhashes.cbegin(), _urlhashes.cend(), [url](auto const &h){ return h == asyncsrv::hash_djb2a(url); });
756+
return (i != _urlhashes.cend());
757+
}
758+
return false;
759+
};
760+
745761
WSocketClient* WSocketServer::getClient(uint32_t id) {
746762
auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const WSocketClient &c) { return c.id == id; });
747763
if (iter != std::end(_clients))
@@ -782,8 +798,8 @@ WSocketServer::msgall_err_t WSocketServer::pingAll(const char *data, size_t len)
782798
return cnt == _clients.size() ? msgall_err_t::ok : msgall_err_t::partial;
783799
}
784800

785-
WSocketClient::err_t WSocketServer::message(uint32_t id, WSMessagePtr m){
786-
if (WSocketClient *c = getClient(id))
801+
WSocketClient::err_t WSocketServer::message(uint32_t clientid, WSMessagePtr m){
802+
if (WSocketClient *c = getClient(clientid))
787803
return c->enqueueMessage(std::move(m));
788804
else
789805
return WSocketClient::err_t::disconnected;
@@ -800,6 +816,20 @@ WSocketServer::msgall_err_t WSocketServer::messageAll(WSMessagePtr m){
800816
return cnt == _clients.size() ? msgall_err_t::ok : msgall_err_t::partial;
801817
}
802818

819+
WSocketServer::msgall_err_t WSocketServer::messageToEndpoint(uint32_t hash, WSMessagePtr m){
820+
size_t cnt{0}, cntt{0};
821+
for (auto &c : _clients){
822+
if (c.getURLHash() == hash){
823+
++cntt;
824+
if ( c.enqueueMessage(m) == WSocketClient::err_t::ok)
825+
++cnt;
826+
}
827+
}
828+
if (!cnt)
829+
return msgall_err_t::none;
830+
return cnt == cntt ? msgall_err_t::ok : msgall_err_t::partial;
831+
}
832+
803833
void WSocketServer::_purgeClients(){
804834
log_d("purging clients");
805835
std::lock_guard lock(clientslock);
@@ -808,8 +838,16 @@ void WSocketServer::_purgeClients(){
808838
}
809839

810840
size_t WSocketServer::activeClientsCount() const {
811-
return std::count_if(std::begin(_clients), std::end(_clients), [](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected; });
812-
};
841+
return std::count_if(std::begin(_clients), std::end(_clients),
842+
[](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected; }
843+
);
844+
}
845+
846+
size_t WSocketServer::activeEndpointClientsCount(uint32_t hash) const {
847+
return std::count_if(std::begin(_clients), std::end(_clients),
848+
[hash](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected && c.getURLHash() == hash; }
849+
);
850+
}
813851

814852
void WSocketServer::serverEcho(WSocketClient *c){
815853
if (!_serverEcho) return;
@@ -824,6 +862,10 @@ void WSocketServer::serverEcho(WSocketClient *c){
824862
}
825863
}
826864

865+
void WSocketServer::removeURLendpoint(std::string_view url){
866+
_urlhashes.erase(remove_if(_urlhashes.begin(), _urlhashes.end(), [url](auto const &v){ return v == asyncsrv::hash_djb2a(url); }), _urlhashes.end());
867+
}
868+
827869

828870
// ***** WSMessageClose implementation *****
829871

@@ -849,13 +891,14 @@ bool WSocketServerWorker::newClient(AsyncWebServerRequest *request){
849891
if (_task_hndlr) xTaskNotifyGive(_task_hndlr);
850892
},
851893
msgsize, qcap);
852-
}
853894

854-
// create events group where we'll pick events
855-
_clients.back().createEventGroupHandle();
856-
_clients.back().setOverflowPolicy(getOverflowPolicy());
857-
_clients.back().setKeepAlive(_keepAlivePeriod);
858-
xEventGroupSetBits(_clients.back().getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect));
895+
// create events group where we'll pick events
896+
_clients.back().createEventGroupHandle();
897+
_clients.back().setOverflowPolicy(getOverflowPolicy());
898+
_clients.back().setKeepAlive(_keepAlivePeriod);
899+
_clients.back().setURLHash(request->url().c_str());
900+
xEventGroupSetBits(_clients.back().getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect));
901+
}
859902
if (_task_hndlr)
860903
xTaskNotifyGive(_task_hndlr);
861904
return true;
@@ -893,19 +936,19 @@ void WSocketServerWorker::_taskRunner(){
893936
// check if this a new client
894937
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect) );
895938
if ( uxBits & enum2uint32(WSocketClient::event_t::connect) ){
896-
_ecb(WSocketClient::event_t::connect, it->id);
939+
_ecb(&(*it), WSocketClient::event_t::connect);
897940
}
898941

899942
// check if 'inbound Q full' flag set
900943
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::inQfull) );
901944
if ( uxBits & enum2uint32(WSocketClient::event_t::inQfull) ){
902-
_ecb(WSocketClient::event_t::inQfull, it->id);
945+
_ecb(&(*it), WSocketClient::event_t::inQfull);
903946
}
904947

905948
// check for dropped messages flag
906949
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::msgDropped) );
907950
if ( uxBits & enum2uint32(WSocketClient::event_t::msgDropped) ){
908-
_ecb(WSocketClient::event_t::msgDropped, it->id);
951+
_ecb(&(*it), WSocketClient::event_t::msgDropped);
909952
}
910953

911954
// process all the messages from inbound Q
@@ -916,15 +959,14 @@ void WSocketServerWorker::_taskRunner(){
916959

917960
// check for disconnected client - do not care for group bits, cause if it's deleted, we will destruct the client object
918961
if (it->connection() == WSocketClient::conn_state_t::disconnected){
919-
auto id = it->id;
962+
// run a callback
963+
_ecb(&(*it), WSocketClient::event_t::disconnect);
920964
{
921965
#ifdef ESP32
922966
std::lock_guard<std::mutex> lock (clientslock);
923967
#endif
924968
it = _clients.erase(it);
925969
}
926-
// run a callback
927-
_ecb(WSocketClient::event_t::disconnect, id);
928970
} else {
929971
// advance iterator
930972
++it;
@@ -939,4 +981,5 @@ void WSocketServerWorker::_taskRunner(){
939981
vTaskDelete(NULL);
940982
}
941983

984+
#endif // ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0)
942985
#endif // __cplusplus >= 201703L

0 commit comments

Comments
 (0)