Skip to content

Commit 53c3386

Browse files
Locking cleanup (#429)
* Rename _lock to _ws_clients_lock for clarity * Rename _lock to _queue_lock for clarity and fixed the wrong usage of _queue_lock _queue_lock is a lock guarding the queues and should not be used to guard the internal client pointer * Null check _client ptr before using it * Protect _client ptr usage in cases where a concurrent call to onDisconnnect could set it to null
1 parent 5e902c7 commit 53c3386

2 files changed

Lines changed: 74 additions & 60 deletions

File tree

src/AsyncWebSocket.cpp

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncClient *client, AsyncWebSocket *
297297

298298
AsyncWebSocketClient::~AsyncWebSocketClient() {
299299
{
300-
asyncsrv::lock_guard_type lock(_lock);
300+
asyncsrv::lock_guard_type lock(_queue_lock);
301301
_messageQueue.clear();
302302
_controlQueue.clear();
303303
}
@@ -313,7 +313,7 @@ void AsyncWebSocketClient::_clearQueue() {
313313
void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
314314
_lastMessageTime = millis();
315315

316-
asyncsrv::unique_lock_type lock(_lock);
316+
asyncsrv::unique_lock_type lock(_queue_lock);
317317

318318
async_ws_log_v("[%s][%" PRIu32 "] START ACK(%u, %" PRIu32 ") Q:%u", _server->url(), _clientId, len, time, _messageQueue.size());
319319

@@ -325,14 +325,12 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
325325
_controlQueue.pop_front();
326326
_status = WS_DISCONNECTED;
327327
async_ws_log_v("[%s][%" PRIu32 "] ACK WS_DISCONNECTED", _server->url(), _clientId);
328-
if (_client) {
329-
/*
330-
Unlocking has to be called before return execution otherwise std::unique_lock ::~unique_lock() will get an exception pthread_mutex_unlock.
331-
Due to _client->close() shall call the callback function _onDisconnect()
332-
The calling flow _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient()
333-
*/
328+
// Capture _client before unlocking: _client->close() triggers the _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient() chain,
329+
// so we must not access any member after unlock.
330+
AsyncClient *c = _client;
331+
if (c) {
334332
lock.unlock();
335-
_client->close();
333+
c->close();
336334
}
337335
return;
338336
}
@@ -357,7 +355,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) {
357355
}
358356

359357
void AsyncWebSocketClient::_onPoll() {
360-
asyncsrv::unique_lock_type lock(_lock);
358+
asyncsrv::unique_lock_type lock(_queue_lock);
361359

362360
if (!_client) {
363361
return;
@@ -430,22 +428,22 @@ void AsyncWebSocketClient::_runQueue() {
430428
}
431429

432430
bool AsyncWebSocketClient::queueIsFull() const {
433-
asyncsrv::lock_guard_type lock(_lock);
431+
asyncsrv::lock_guard_type lock(_queue_lock);
434432
return (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED);
435433
}
436434

437435
size_t AsyncWebSocketClient::queueLen() const {
438-
asyncsrv::lock_guard_type lock(_lock);
436+
asyncsrv::lock_guard_type lock(_queue_lock);
439437
return _messageQueue.size();
440438
}
441439

442440
bool AsyncWebSocketClient::canSend() const {
443-
asyncsrv::lock_guard_type lock(_lock);
441+
asyncsrv::lock_guard_type lock(_queue_lock);
444442
return _messageQueue.size() < WS_MAX_QUEUED_MESSAGES;
445443
}
446444

447445
bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) {
448-
asyncsrv::lock_guard_type lock(_lock);
446+
asyncsrv::lock_guard_type lock(_queue_lock);
449447

450448
if (!_client) {
451449
return false;
@@ -462,7 +460,7 @@ bool AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
462460
}
463461

464462
bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint8_t opcode, bool mask) {
465-
asyncsrv::unique_lock_type lock(_lock);
463+
asyncsrv::unique_lock_type lock(_queue_lock);
466464

467465
if (!_client || !buffer || buffer->empty() || _status != WS_CONNECTED) {
468466
return false;
@@ -472,18 +470,16 @@ bool AsyncWebSocketClient::_queueMessage(AsyncWebSocketSharedBuffer buffer, uint
472470
if (closeWhenFull) {
473471
_status = WS_DISCONNECTED;
474472

475-
if (_client) {
476-
/*
477-
Unlocking has to be called before return execution otherwise std::unique_lock ::~unique_lock() will get an exception pthread_mutex_unlock.
478-
Due to _client->close() shall call the callback function _onDisconnect()
479-
The calling flow _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient()
480-
*/
473+
async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: closing connection", _server->url(), _clientId);
474+
475+
// Capture _client before unlocking: _client->close() triggers the _onDisconnect() --> _handleDisconnect() --> ~AsyncWebSocketClient() chain,
476+
// so we must not access any member after unlock.
477+
AsyncClient *c = _client;
478+
if (c) {
481479
lock.unlock();
482-
_client->close();
480+
c->close();
483481
}
484482

485-
async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: closing connection", _server->url(), _clientId);
486-
487483
} else {
488484
async_ws_log_w("[%s][%" PRIu32 "] Too many messages queued: discarding new message", _server->url(), _clientId);
489485
}
@@ -531,7 +527,14 @@ void AsyncWebSocketClient::close(uint16_t code, const char *message) {
531527
return;
532528
} else {
533529
async_ws_log_e("Failed to allocate");
534-
_client->abort();
530+
// Reads _client, then dereference it without any lock.
531+
// A concurrent _onDisconnect could null + delete the client between the check and the use.
532+
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
533+
// (TOCTOU)
534+
AsyncClient *c = _client;
535+
if (c) {
536+
c->abort();
537+
}
535538
}
536539
}
537540
_queueControl(WS_DISCONNECT);
@@ -546,19 +549,27 @@ void AsyncWebSocketClient::_onError(int8_t err) {
546549
}
547550

548551
void AsyncWebSocketClient::_onTimeout(uint32_t time) {
549-
asyncsrv::lock_guard_type lock(_lock);
550-
if (!_client) {
552+
// Reads _client, then dereference it without any lock.
553+
// A concurrent _onDisconnect could null + delete the client between the check and the use.
554+
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
555+
// (TOCTOU)
556+
AsyncClient *c = _client;
557+
if (!c) {
551558
return;
552559
}
553560
async_ws_log_v("[%s][%" PRIu32 "] TIMEOUT %" PRIu32, _server->url(), _clientId, time);
554-
_client->close();
561+
c->close();
555562
}
556563

557564
void AsyncWebSocketClient::_onDisconnect() {
558-
asyncsrv::lock_guard_type lock(_lock);
559565
async_ws_log_v("[%s][%" PRIu32 "] DISCONNECT", _server->url(), _clientId);
560566
_status = WS_DISCONNECTED;
561-
_client = nullptr;
567+
{
568+
// Every queue method (_queueControl, _queueMessage, _runQueue, _onPoll, _onAck) reads _client while holding _queue_lock.
569+
// For those guarded reads to be meaningful, the write must also be synchronized. This doesn't change _queue_lock's purpose — it still guards queue integrity — but ensures the "is client alive?" checks that protect queue operations see a consistent value.
570+
asyncsrv::lock_guard_type lock(_queue_lock);
571+
_client = nullptr;
572+
}
562573
_server->_handleDisconnect(this);
563574
}
564575

@@ -951,23 +962,27 @@ bool AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) {
951962
#endif
952963

953964
IPAddress AsyncWebSocketClient::remoteIP() const {
954-
asyncsrv::lock_guard_type lock(_lock);
955-
956-
if (!_client) {
965+
// Reads _client, then dereference it without any lock.
966+
// A concurrent _onDisconnect could null + delete the client between the check and the use.
967+
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
968+
// (TOCTOU)
969+
AsyncClient *c = _client;
970+
if (!c) {
957971
return IPAddress((uint32_t)0U);
958972
}
959-
960-
return _client->remoteIP();
973+
return c->remoteIP();
961974
}
962975

963976
uint16_t AsyncWebSocketClient::remotePort() const {
964-
asyncsrv::lock_guard_type lock(_lock);
965-
966-
if (!_client) {
977+
// Reads _client, then dereference it without any lock.
978+
// A concurrent _onDisconnect could null + delete the client between the check and the use.
979+
// Local capture ensures the pointer is read exactly once, eliminating the null-dereference.
980+
// (TOCTOU)
981+
AsyncClient *c = _client;
982+
if (!c) {
967983
return 0;
968984
}
969-
970-
return _client->remotePort();
985+
return c->remotePort();
971986
}
972987

973988
/*
@@ -981,7 +996,7 @@ void AsyncWebSocket::_handleEvent(AsyncWebSocketClient *client, AwsEventType typ
981996
}
982997

983998
AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) {
984-
asyncsrv::lock_guard_type lock(_lock);
999+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
9851000
_clients.emplace_back(request, this);
9861001
// we've just detached AsyncTCP client from AsyncWebServerRequest
9871002
_handleEvent(&_clients.back(), WS_EVT_CONNECT, request, NULL, 0);
@@ -991,7 +1006,7 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
9911006
}
9921007

9931008
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
994-
asyncsrv::lock_guard_type lock(_lock);
1009+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
9951010
const auto client_id = client->id();
9961011
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [client_id](const AsyncWebSocketClient &c) {
9971012
return c.id() == client_id;
@@ -1002,14 +1017,14 @@ void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient *client) {
10021017
}
10031018

10041019
bool AsyncWebSocket::availableForWriteAll() {
1005-
asyncsrv::lock_guard_type lock(_lock);
1020+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10061021
return std::none_of(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
10071022
return c.queueIsFull();
10081023
});
10091024
}
10101025

10111026
bool AsyncWebSocket::availableForWrite(uint32_t id) {
1012-
asyncsrv::lock_guard_type lock(_lock);
1027+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10131028
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), [id](const AsyncWebSocketClient &c) {
10141029
return c.id() == id;
10151030
});
@@ -1020,14 +1035,14 @@ bool AsyncWebSocket::availableForWrite(uint32_t id) {
10201035
}
10211036

10221037
size_t AsyncWebSocket::count() const {
1023-
asyncsrv::lock_guard_type lock(_lock);
1038+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10241039
return std::count_if(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c) {
10251040
return c.status() == WS_CONNECTED;
10261041
});
10271042
}
10281043

10291044
AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
1030-
asyncsrv::lock_guard_type lock(_lock);
1045+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10311046
const auto iter = std::find_if(_clients.begin(), _clients.end(), [id](const AsyncWebSocketClient &c) {
10321047
return c.id() == id && c.status() == WS_CONNECTED;
10331048
});
@@ -1039,14 +1054,14 @@ AsyncWebSocketClient *AsyncWebSocket::client(uint32_t id) {
10391054
}
10401055

10411056
void AsyncWebSocket::close(uint32_t id, uint16_t code, const char *message) {
1042-
asyncsrv::lock_guard_type lock(_lock);
1057+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10431058
if (AsyncWebSocketClient *c = client(id)) {
10441059
c->close(code, message);
10451060
}
10461061
}
10471062

10481063
void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
1049-
asyncsrv::lock_guard_type lock(_lock);
1064+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10501065
for (auto &c : _clients) {
10511066
if (c.status() == WS_CONNECTED) {
10521067
c.close(code, message);
@@ -1055,7 +1070,7 @@ void AsyncWebSocket::closeAll(uint16_t code, const char *message) {
10551070
}
10561071

10571072
void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
1058-
asyncsrv::lock_guard_type lock(_lock);
1073+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10591074
const size_t c = count();
10601075
if (c > maxClients) {
10611076
async_ws_log_v("[%s] CLEANUP %" PRIu32 " (%u/%" PRIu16 ")", _url.c_str(), _clients.front().id(), c, maxClients);
@@ -1071,13 +1086,13 @@ void AsyncWebSocket::cleanupClients(uint16_t maxClients) {
10711086
}
10721087

10731088
bool AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) {
1074-
asyncsrv::lock_guard_type lock(_lock);
1089+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10751090
AsyncWebSocketClient *c = client(id);
10761091
return c && c->ping(data, len);
10771092
}
10781093

10791094
AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t len) {
1080-
asyncsrv::lock_guard_type lock(_lock);
1095+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10811096
size_t hit = 0;
10821097
size_t miss = 0;
10831098
for (auto &c : _clients) {
@@ -1091,7 +1106,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::pingAll(const uint8_t *data, size_t l
10911106
}
10921107

10931108
bool AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) {
1094-
asyncsrv::lock_guard_type lock(_lock);
1109+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
10951110
AsyncWebSocketClient *c = client(id);
10961111
return c && c->text(makeSharedBuffer(message, len));
10971112
}
@@ -1138,7 +1153,7 @@ bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
11381153
return enqueued;
11391154
}
11401155
bool AsyncWebSocket::text(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
1141-
asyncsrv::lock_guard_type lock(_lock);
1156+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
11421157
AsyncWebSocketClient *c = client(id);
11431158
return c && c->text(buffer);
11441159
}
@@ -1188,7 +1203,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer *
11881203
}
11891204

11901205
AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer buffer) {
1191-
asyncsrv::lock_guard_type lock(_lock);
1206+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
11921207
size_t hit = 0;
11931208
size_t miss = 0;
11941209
for (auto &c : _clients) {
@@ -1202,7 +1217,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::textAll(AsyncWebSocketSharedBuffer bu
12021217
}
12031218

12041219
bool AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) {
1205-
asyncsrv::lock_guard_type lock(_lock);
1220+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
12061221
AsyncWebSocketClient *c = client(id);
12071222
return c && c->binary(makeSharedBuffer(message, len));
12081223
}
@@ -1239,7 +1254,7 @@ bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketMessageBuffer *buffer) {
12391254
return enqueued;
12401255
}
12411256
bool AsyncWebSocket::binary(uint32_t id, AsyncWebSocketSharedBuffer buffer) {
1242-
asyncsrv::lock_guard_type lock(_lock);
1257+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
12431258
AsyncWebSocketClient *c = client(id);
12441259
return c && c->binary(buffer);
12451260
}
@@ -1280,7 +1295,7 @@ AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer
12801295
return status;
12811296
}
12821297
AsyncWebSocket::SendStatus AsyncWebSocket::binaryAll(AsyncWebSocketSharedBuffer buffer) {
1283-
asyncsrv::lock_guard_type lock(_lock);
1298+
asyncsrv::lock_guard_type lock(_ws_clients_lock);
12841299
size_t hit = 0;
12851300
size_t miss = 0;
12861301
for (auto &c : _clients) {

src/AsyncWebSocket.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class AsyncWebSocketClient {
222222
uint8_t _pstate;
223223
uint32_t _lastMessageTime;
224224
uint32_t _keepAlivePeriod;
225-
mutable asyncsrv::mutex_type _lock;
225+
mutable asyncsrv::mutex_type _queue_lock;
226226
std::deque<AsyncWebSocketControl> _controlQueue;
227227
std::deque<AsyncWebSocketMessage> _messageQueue;
228228
bool closeWhenFull = true;
@@ -303,7 +303,6 @@ class AsyncWebSocketClient {
303303
uint16_t remotePort() const;
304304

305305
bool shouldBeDeleted() const {
306-
asyncsrv::lock_guard_type lock(_lock);
307306
return !_client;
308307
}
309308

@@ -371,7 +370,7 @@ class AsyncWebSocket : public AsyncWebHandler {
371370
AwsEventHandler _eventHandler;
372371
AwsHandshakeHandler _handshakeHandler;
373372
bool _enabled;
374-
mutable asyncsrv::mutex_type _lock;
373+
mutable asyncsrv::mutex_type _ws_clients_lock;
375374

376375
public:
377376
typedef enum {

0 commit comments

Comments
 (0)