Skip to content

Commit 01954ea

Browse files
committed
Async WSocketServer can bind to multiple URLs
a single instance of WSocketServer can serve multiple websocket URLs connection URL is hashed to 32 bit and kept as a member of respective WSocketClient struct a set of methods are provided to get/set/check server and client's URL
1 parent d8d95b6 commit 01954ea

2 files changed

Lines changed: 228 additions & 38 deletions

File tree

src/AsyncWSocket.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// We target C++17 capable toolchain
77
#if __cplusplus >= 201703L
88
#include "AsyncWSocket.h"
9+
#if defined(ESP32) && (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0))
910
#include "literals.h"
1011

1112
#define WS_MAX_HEADER_SIZE 16
@@ -170,6 +171,9 @@ WSocketClient::WSocketClient(uint32_t id, AsyncWebServerRequest *request, WSocke
170171
_client->onData( [](void *r, AsyncClient *c, void *buf, size_t len) { (void)c; reinterpret_cast<WSocketClient*>(r)->_onData(buf, len); }, this );
171172
_client->onPoll( [](void *r, AsyncClient *c) { (void)c; reinterpret_cast<WSocketClient*>(r)->_keepalive(); reinterpret_cast<WSocketClient*>(r)->_clientSend(); }, this );
172173
_client->onError( [](void *r, AsyncClient *c, int8_t error) { (void)c; log_e("err:%d", error); }, this );
174+
// bind URL hash
175+
setURLHash(request->url().c_str());
176+
173177
delete request;
174178
}
175179

@@ -232,7 +236,7 @@ void WSocketClient::_clientSend(size_t acked_bytes){
232236
return;
233237

234238
auto sock_space = _client->space();
235-
log_d("no ack infl:%u, space:%u, data pending:%u", _in_flight, sock_space, (uint32_t)(_outFrame.len - _outFrame.index));
239+
//log_d("no ack infl:%u, space:%u, data pending:%u", _in_flight, sock_space, (uint32_t)(_outFrame.len - _outFrame.index));
236240
//
237241
if (!sock_space)
238242
return;
@@ -702,10 +706,11 @@ bool WSocketServer::newClient(AsyncWebServerRequest *request){
702706
else
703707
c->dequeueMessage(); }, // silently discard incoming messages when there is no callback set
704708
msgsize, qcap);
709+
_clients.back().setOverflowPolicy(_overflow_policy);
710+
_clients.back().setKeepAlive(_keepAlivePeriod);
705711
}
706-
_clients.back().setOverflowPolicy(_overflow_policy);
707-
_clients.back().setKeepAlive(_keepAlivePeriod);
708-
if (eventHandler) eventHandler(&_clients.back(), WSocketClient::event_t::connect);
712+
if (eventHandler)
713+
eventHandler(&_clients.back(), WSocketClient::event_t::connect);
709714
return true;
710715
}
711716

@@ -714,6 +719,7 @@ void WSocketServer::handleRequest(AsyncWebServerRequest *request) {
714719
request->send(400);
715720
return;
716721
}
722+
717723
if (_handshakeHandler != nullptr) {
718724
if (!_handshakeHandler(request)) {
719725
request->send(401);
@@ -741,9 +747,19 @@ void WSocketServer::handleRequest(AsyncWebServerRequest *request) {
741747
// ToDo: check protocol
742748
response->addHeader(WS_STR_PROTOCOL, protocol->value());
743749
}
750+
744751
request->send(response);
745752
}
746753

754+
bool WSocketServer::canHandle(AsyncWebServerRequest *request) const {
755+
if (request->isWebSocketUpgrade()){
756+
auto url = request->url().c_str();
757+
auto i = std::find_if(_urlhashes.cbegin(), _urlhashes.cend(), [url](auto const &h){ return h == asyncsrv::hash_djb2a(url); });
758+
return (i != _urlhashes.cend());
759+
}
760+
return false;
761+
};
762+
747763
WSocketClient* WSocketServer::getClient(uint32_t id) {
748764
auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const WSocketClient &c) { return c.id == id; });
749765
if (iter != std::end(_clients))
@@ -784,8 +800,8 @@ WSocketServer::msgall_err_t WSocketServer::pingAll(const char *data, size_t len)
784800
return cnt == _clients.size() ? msgall_err_t::ok : msgall_err_t::partial;
785801
}
786802

787-
WSocketClient::err_t WSocketServer::message(uint32_t id, WSMessagePtr m){
788-
if (WSocketClient *c = getClient(id))
803+
WSocketClient::err_t WSocketServer::message(uint32_t clientid, WSMessagePtr m){
804+
if (WSocketClient *c = getClient(clientid))
789805
return c->enqueueMessage(std::move(m));
790806
else
791807
return WSocketClient::err_t::disconnected;
@@ -802,6 +818,20 @@ WSocketServer::msgall_err_t WSocketServer::messageAll(WSMessagePtr m){
802818
return cnt == _clients.size() ? msgall_err_t::ok : msgall_err_t::partial;
803819
}
804820

821+
WSocketServer::msgall_err_t WSocketServer::messageToEndpoint(uint32_t hash, WSMessagePtr m){
822+
size_t cnt{0}, cntt{0};
823+
for (auto &c : _clients){
824+
if (c.getURLHash() == hash){
825+
++cntt;
826+
if ( c.enqueueMessage(m) == WSocketClient::err_t::ok)
827+
++cnt;
828+
}
829+
}
830+
if (!cnt)
831+
return msgall_err_t::none;
832+
return cnt == cntt ? msgall_err_t::ok : msgall_err_t::partial;
833+
}
834+
805835
void WSocketServer::_purgeClients(){
806836
log_d("purging clients");
807837
std::lock_guard lock(clientslock);
@@ -810,8 +840,16 @@ void WSocketServer::_purgeClients(){
810840
}
811841

812842
size_t WSocketServer::activeClientsCount() const {
813-
return std::count_if(std::begin(_clients), std::end(_clients), [](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected; });
814-
};
843+
return std::count_if(std::begin(_clients), std::end(_clients),
844+
[](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected; }
845+
);
846+
}
847+
848+
size_t WSocketServer::activeEndpointClientsCount(uint32_t hash) const {
849+
return std::count_if(std::begin(_clients), std::end(_clients),
850+
[hash](const WSocketClient &c) { return c.connection() == WSocketClient::conn_state_t::connected && c.getURLHash() == hash; }
851+
);
852+
}
815853

816854
void WSocketServer::serverEcho(WSocketClient *c){
817855
if (!_serverEcho) return;
@@ -826,6 +864,10 @@ void WSocketServer::serverEcho(WSocketClient *c){
826864
}
827865
}
828866

867+
void WSocketServer::removeURLendpoint(std::string_view url){
868+
_urlhashes.erase(remove_if(_urlhashes.begin(), _urlhashes.end(), [url](auto const &v){ return v == asyncsrv::hash_djb2a(url); }), _urlhashes.end());
869+
}
870+
829871

830872
// ***** WSMessageClose implementation *****
831873

@@ -851,13 +893,14 @@ bool WSocketServerWorker::newClient(AsyncWebServerRequest *request){
851893
if (_task_hndlr) xTaskNotifyGive(_task_hndlr);
852894
},
853895
msgsize, qcap);
854-
}
855896

856-
// create events group where we'll pick events
857-
_clients.back().createEventGroupHandle();
858-
_clients.back().setOverflowPolicy(getOverflowPolicy());
859-
_clients.back().setKeepAlive(_keepAlivePeriod);
860-
xEventGroupSetBits(_clients.back().getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect));
897+
// create events group where we'll pick events
898+
_clients.back().createEventGroupHandle();
899+
_clients.back().setOverflowPolicy(getOverflowPolicy());
900+
_clients.back().setKeepAlive(_keepAlivePeriod);
901+
_clients.back().setURLHash(request->url().c_str());
902+
xEventGroupSetBits(_clients.back().getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect));
903+
}
861904
if (_task_hndlr)
862905
xTaskNotifyGive(_task_hndlr);
863906
return true;
@@ -895,19 +938,19 @@ void WSocketServerWorker::_taskRunner(){
895938
// check if this a new client
896939
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::connect) );
897940
if ( uxBits & enum2uint32(WSocketClient::event_t::connect) ){
898-
_ecb(WSocketClient::event_t::connect, it->id);
941+
_ecb(&(*it), WSocketClient::event_t::connect);
899942
}
900943

901944
// check if 'inbound Q full' flag set
902945
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::inQfull) );
903946
if ( uxBits & enum2uint32(WSocketClient::event_t::inQfull) ){
904-
_ecb(WSocketClient::event_t::inQfull, it->id);
947+
_ecb(&(*it), WSocketClient::event_t::inQfull);
905948
}
906949

907950
// check for dropped messages flag
908951
uxBits = xEventGroupClearBits(it->getEventGroupHandle(), enum2uint32(WSocketClient::event_t::msgDropped) );
909952
if ( uxBits & enum2uint32(WSocketClient::event_t::msgDropped) ){
910-
_ecb(WSocketClient::event_t::msgDropped, it->id);
953+
_ecb(&(*it), WSocketClient::event_t::msgDropped);
911954
}
912955

913956
// process all the messages from inbound Q
@@ -918,15 +961,14 @@ void WSocketServerWorker::_taskRunner(){
918961

919962
// check for disconnected client - do not care for group bits, cause if it's deleted, we will destruct the client object
920963
if (it->connection() == WSocketClient::conn_state_t::disconnected){
921-
auto id = it->id;
964+
// run a callback
965+
_ecb(&(*it), WSocketClient::event_t::disconnect);
922966
{
923967
#ifdef ESP32
924968
std::lock_guard<std::mutex> lock (clientslock);
925969
#endif
926970
it = _clients.erase(it);
927971
}
928-
// run a callback
929-
_ecb(WSocketClient::event_t::disconnect, id);
930972
} else {
931973
// advance iterator
932974
++it;
@@ -941,4 +983,5 @@ void WSocketServerWorker::_taskRunner(){
941983
vTaskDelete(NULL);
942984
}
943985

986+
#endif // ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0)
944987
#endif // __cplusplus >= 201703L

0 commit comments

Comments
 (0)