Skip to content

Commit dc8f2e6

Browse files
author
lhh
committed
Fix span lifecycle with smart pointers to prevent use-after-free in async RPC callbacks (#3068)
1 parent 0565d8d commit dc8f2e6

32 files changed

Lines changed: 684 additions & 283 deletions

src/brpc/builtin/rpcz_service.cpp

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,35 @@ static void PrintElapse(std::ostream& os, int64_t cur_time,
185185

186186
static void PrintAnnotations(
187187
std::ostream& os, int64_t cur_time, int64_t* last_time,
188-
SpanInfoExtractor** extractors, int num_extr) {
188+
SpanInfoExtractor** extractors, int num_extr, const RpczSpan* span) {
189189
int64_t anno_time;
190190
std::string a;
191+
const char* span_type_str = "Span";
192+
if (span) {
193+
switch (span->type()) {
194+
case SPAN_TYPE_SERVER:
195+
span_type_str = "ServerSpan";
196+
break;
197+
case SPAN_TYPE_CLIENT:
198+
span_type_str = "ClientSpan";
199+
break;
200+
case SPAN_TYPE_BTHREAD:
201+
span_type_str = "BthreadSpan";
202+
break;
203+
}
204+
}
205+
191206
// TODO: Going through all extractors is not strictly correct because
192207
// later extractors may have earlier annotations.
193208
for (int i = 0; i < num_extr; ++i) {
194209
while (extractors[i]->PopAnnotation(cur_time, &anno_time, &a)) {
195210
PrintRealTime(os, anno_time);
196211
PrintElapse(os, anno_time, last_time);
197-
os << ' ' << WebEscape(a);
212+
os << ' ';
213+
if (span) {
214+
os << '[' << span_type_str << ' ' << SPAN_ID_STR << '=' << Hex(span->span_id()) << "] ";
215+
}
216+
os << WebEscape(a);
198217
if (a.empty() || butil::back_char(a) != '\n') {
199218
os << '\n';
200219
}
@@ -204,12 +223,12 @@ static void PrintAnnotations(
204223

205224
static bool PrintAnnotationsAndRealTimeSpan(
206225
std::ostream& os, int64_t cur_time, int64_t* last_time,
207-
SpanInfoExtractor** extr, int num_extr) {
226+
SpanInfoExtractor** extr, int num_extr, const RpczSpan* span) {
208227
if (cur_time == 0) {
209228
// the field was not set.
210229
return false;
211230
}
212-
PrintAnnotations(os, cur_time, last_time, extr, num_extr);
231+
PrintAnnotations(os, cur_time, last_time, extr, num_extr, span);
213232
PrintRealTime(os, cur_time);
214233
PrintElapse(os, cur_time, last_time);
215234
return true;
@@ -239,9 +258,10 @@ static void PrintClientSpan(
239258
extr[num_extr++] = server_extr;
240259
}
241260
extr[num_extr++] = &client_extr;
242-
// start_send_us is always set for client spans.
243-
CHECK(PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
244-
last_time, extr, num_extr));
261+
if (!PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
262+
last_time, extr, num_extr, &span)) {
263+
os << " start_send_real_us:not-set";
264+
}
245265
const Protocol* protocol = FindProtocol(span.protocol());
246266
const char* protocol_name = (protocol ? protocol->name : "Unknown");
247267
const butil::EndPoint remote_side(butil::int2ip(span.remote_ip()), span.remote_port());
@@ -271,12 +291,12 @@ static void PrintClientSpan(
271291
os << std::endl;
272292

273293
if (PrintAnnotationsAndRealTimeSpan(os, span.sent_real_us(),
274-
last_time, extr, num_extr)) {
275-
os << " Requested(" << span.request_size() << ") [1]" << std::endl;
294+
last_time, extr, num_extr, &span)) {
295+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Requested(" << span.request_size() << ") [1]" << std::endl;
276296
}
277297
if (PrintAnnotationsAndRealTimeSpan(os, span.received_real_us(),
278-
last_time, extr, num_extr)) {
279-
os << " Received response(" << span.response_size() << ")";
298+
last_time, extr, num_extr, &span)) {
299+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Received response(" << span.response_size() << ")";
280300
if (span.base_cid() != 0 && span.ending_cid() != 0) {
281301
int64_t ver = span.ending_cid() - span.base_cid();
282302
if (ver >= 1) {
@@ -289,18 +309,18 @@ static void PrintClientSpan(
289309
}
290310

291311
if (PrintAnnotationsAndRealTimeSpan(os, span.start_parse_real_us(),
292-
last_time, extr, num_extr)) {
293-
os << " Processing the response in a new bthread" << std::endl;
312+
last_time, extr, num_extr, &span)) {
313+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Processing the response in a new bthread" << std::endl;
294314
}
295315

296316
if (PrintAnnotationsAndRealTimeSpan(
297317
os, span.start_callback_real_us(),
298-
last_time, extr, num_extr)) {
299-
os << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
318+
last_time, extr, num_extr, &span)) {
319+
os << " [ClientSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] " << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
300320
}
301321

302322
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
303-
last_time, extr, num_extr);
323+
last_time, extr, num_extr, &span);
304324
}
305325

306326
static void PrintClientSpan(std::ostream& os,const RpczSpan& span,
@@ -318,7 +338,15 @@ static void PrintBthreadSpan(std::ostream& os, const RpczSpan& span, int64_t* la
318338
extr[num_extr++] = server_extr;
319339
}
320340
extr[num_extr++] = &client_extr;
321-
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr);
341+
342+
// Print span id for bthread span context identification
343+
os << " [BthreadSpan " << SPAN_ID_STR << '=' << Hex(span.span_id());
344+
if (span.parent_span_id() != 0) {
345+
os << " parent_span=" << Hex(span.parent_span_id());
346+
}
347+
os << "] ";
348+
349+
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr, &span);
322350
}
323351

324352
static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
@@ -348,16 +376,16 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
348376
os << std::endl;
349377
if (PrintAnnotationsAndRealTimeSpan(
350378
os, span.start_parse_real_us(),
351-
&last_time, extr, ARRAY_SIZE(extr))) {
352-
os << " Processing the request in a new bthread" << std::endl;
379+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
380+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Processing the request in a new bthread" << std::endl;
353381
}
354382

355383
bool entered_user_method = false;
356384
if (PrintAnnotationsAndRealTimeSpan(
357385
os, span.start_callback_real_us(),
358-
&last_time, extr, ARRAY_SIZE(extr))) {
386+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
359387
entered_user_method = true;
360-
os << " Enter " << WebEscape(span.full_method_name()) << std::endl;
388+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Enter " << WebEscape(span.full_method_name()) << std::endl;
361389
}
362390

363391
const int nclient = span.client_spans_size();
@@ -372,22 +400,22 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
372400

373401
if (PrintAnnotationsAndRealTimeSpan(
374402
os, span.start_send_real_us(),
375-
&last_time, extr, ARRAY_SIZE(extr))) {
403+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
376404
if (entered_user_method) {
377-
os << " Leave " << WebEscape(span.full_method_name()) << std::endl;
405+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Leave " << WebEscape(span.full_method_name()) << std::endl;
378406
} else {
379-
os << " Responding" << std::endl;
407+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Responding" << std::endl;
380408
}
381409
}
382410

383411
if (PrintAnnotationsAndRealTimeSpan(
384412
os, span.sent_real_us(),
385-
&last_time, extr, ARRAY_SIZE(extr))) {
386-
os << " Responded(" << span.response_size() << ')' << std::endl;
413+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
414+
os << " [ServerSpan " << SPAN_ID_STR << '=' << Hex(span.span_id()) << "] Responded(" << span.response_size() << ')' << std::endl;
387415
}
388416

389417
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
390-
&last_time, extr, ARRAY_SIZE(extr));
418+
&last_time, extr, ARRAY_SIZE(extr), &span);
391419
}
392420

393421
class RpczSpanFilter : public SpanFilter {

src/brpc/channel.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "brpc/rdma/rdma_helper.h"
3939
#include "brpc/policy/esp_authenticator.h"
4040
#include "brpc/transport_factory.h"
41+
#include "brpc/details/controller_private_accessor.h"
4142

4243
namespace brpc {
4344

@@ -502,7 +503,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
502503
}
503504
cntl->set_used_by_rpc();
504505

505-
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent())) {
506+
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent().get())) {
506507
const int64_t start_send_us = butil::cpuwide_time_us();
507508
std::string method_name;
508509
if (_get_method_name) {
@@ -513,13 +514,16 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
513514
const static std::string NULL_METHOD_STR = "null-method";
514515
method_name = NULL_METHOD_STR;
515516
}
516-
Span* span = Span::CreateClientSpan(
517-
method_name, start_send_real_us - start_send_us);
518-
span->set_log_id(cntl->log_id());
519-
span->set_base_cid(correlation_id);
520-
span->set_protocol(_options.protocol);
521-
span->set_start_send_us(start_send_us);
522-
cntl->_span = span;
517+
std::shared_ptr<Span> span = Span::CreateClientSpan(
518+
*method_name, start_send_real_us - start_send_us);
519+
if (span) {
520+
ControllerPrivateAccessor accessor(cntl);
521+
span->set_log_id(cntl->log_id());
522+
span->set_base_cid(correlation_id);
523+
span->set_protocol(_options.protocol);
524+
span->set_start_send_us(start_send_us);
525+
accessor.set_span(span);
526+
}
523527
}
524528
// Override some options if they haven't been set by Controller
525529
if (cntl->timeout_ms() == UNSET_MAGIC_NUM) {
@@ -620,9 +624,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
620624
// be woken up by callback when RPC finishes (succeeds or still
621625
// fails after retry)
622626
Join(correlation_id);
623-
if (cntl->_span) {
624-
cntl->SubmitSpan();
625-
}
627+
cntl->SubmitSpan();
626628
cntl->OnRPCEnd(butil::gettimeofday_us());
627629
}
628630
}

src/brpc/controller.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ static void CreateIgnoreAllRead() { s_ignore_all_read = new IgnoreAllRead; }
183183
// you don't have to set the fields to initial state after deletion since
184184
// they'll be set uniformly after this method is called.
185185
void Controller::ResetNonPods() {
186-
if (_span) {
187-
Span::Submit(_span, butil::cpuwide_time_us());
186+
if (auto span = _span.lock()) {
187+
Span::Submit(span, butil::cpuwide_time_us());
188188
}
189189
_error_text.clear();
190190
_remote_side = butil::EndPoint();
@@ -240,7 +240,7 @@ void Controller::ResetNonPods() {
240240
void Controller::ResetPods() {
241241
// NOTE: Make the sequence of assignments same with the order that they're
242242
// defined in header. Better for cpu cache and faster for lookup.
243-
_span = NULL;
243+
_span.reset();
244244
_flags = 0;
245245
#ifndef BAIDU_INTERNAL
246246
set_pb_bytes_to_base64(true);
@@ -458,9 +458,9 @@ void Controller::SetFailed(const std::string& reason) {
458458
AppendServerIdentiy();
459459
}
460460
_error_text.append(reason);
461-
if (_span) {
462-
_span->set_error_code(_error_code);
463-
_span->Annotate(reason);
461+
if (auto span = _span.lock()) {
462+
span->set_error_code(_error_code);
463+
span->Annotate(reason);
464464
}
465465
UpdateResponseHeader(this);
466466
}
@@ -487,9 +487,9 @@ void Controller::SetFailed(int error_code, const char* reason_fmt, ...) {
487487
va_start(ap, reason_fmt);
488488
butil::string_vappendf(&_error_text, reason_fmt, ap);
489489
va_end(ap);
490-
if (_span) {
491-
_span->set_error_code(_error_code);
492-
_span->AnnotateCStr(_error_text.c_str() + old_size, 0);
490+
if (auto span = _span.lock()) {
491+
span->set_error_code(_error_code);
492+
span->AnnotateCStr(_error_text.c_str() + old_size, 0);
493493
}
494494
UpdateResponseHeader(this);
495495
}
@@ -515,9 +515,9 @@ void Controller::CloseConnection(const char* reason_fmt, ...) {
515515
va_start(ap, reason_fmt);
516516
butil::string_vappendf(&_error_text, reason_fmt, ap);
517517
va_end(ap);
518-
if (_span) {
519-
_span->set_error_code(_error_code);
520-
_span->AnnotateCStr(_error_text.c_str() + old_size, 0);
518+
if (auto span = _span.lock()) {
519+
span->set_error_code(_error_code);
520+
span->AnnotateCStr(_error_text.c_str() + old_size, 0);
521521
}
522522
UpdateResponseHeader(this);
523523
}
@@ -952,9 +952,9 @@ void Controller::EndRPC(const CompletionInfo& info) {
952952
}
953953
// RPC finished, now it's safe to release `LoadBalancerWithNaming'
954954
_lb.reset();
955-
if (_span) {
956-
_span->set_ending_cid(info.id);
957-
_span->set_async(_done);
955+
if (auto span = _span.lock()) {
956+
span->set_ending_cid(info.id);
957+
span->set_async(_done);
958958
// Submit the span if we're in async RPC. For sync RPC, the span
959959
// is submitted after Join() to get a more accurate resuming timestamp.
960960
if (_done) {
@@ -1028,12 +1028,16 @@ void Controller::DoneInBackupThread() {
10281028

10291029
void Controller::SubmitSpan() {
10301030
const int64_t now = butil::cpuwide_time_us();
1031-
_span->set_start_callback_us(now);
1032-
if (_span->local_parent()) {
1033-
_span->local_parent()->AsParent();
1031+
if (auto span = _span.lock()) {
1032+
span->set_start_callback_us(now);
1033+
if (auto parent_span = span->local_parent().lock()) {
1034+
if (parent_span->is_active()) {
1035+
parent_span->AsParent();
1036+
}
1037+
}
1038+
Span::Submit(span, now);
1039+
_span.reset();
10341040
}
1035-
Span::Submit(_span, now);
1036-
_span = NULL;
10371041
}
10381042

10391043
void Controller::HandleSendFailed() {
@@ -1131,8 +1135,7 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
11311135
CHECK_EQ(_remote_side, tmp_sock->remote_side());
11321136
}
11331137

1134-
Span* span = _span;
1135-
if (span) {
1138+
if (auto span = _span.lock()) {
11361139
if (_current_call.nretry == 0) {
11371140
span->set_remote_side(_remote_side);
11381141
} else {
@@ -1244,15 +1247,15 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
12441247
int rc;
12451248
size_t packet_size = 0;
12461249
if (user_packet_guard) {
1247-
if (span) {
1250+
if (auto span = _span.lock()) {
12481251
packet_size = user_packet_guard->EstimatedByteSize();
12491252
}
12501253
rc = _current_call.sending_sock->Write(user_packet_guard, &wopt);
12511254
} else {
12521255
packet_size = packet.size();
12531256
rc = _current_call.sending_sock->Write(&packet, &wopt);
12541257
}
1255-
if (span) {
1258+
if (auto span = _span.lock()) {
12561259
if (_current_call.nretry == 0) {
12571260
span->set_sent_us(butil::cpuwide_time_us());
12581261
span->set_request_size(packet_size);
@@ -1396,8 +1399,18 @@ const Controller* Controller::sub(int index) const {
13961399
return NULL;
13971400
}
13981401

1399-
uint64_t Controller::trace_id() const { return _span ? _span->trace_id() : 0; }
1400-
uint64_t Controller::span_id() const { return _span ? _span->span_id() : 0; }
1402+
uint64_t Controller::trace_id() const {
1403+
if (auto span = _span.lock()) {
1404+
return span->trace_id();
1405+
}
1406+
return 0;
1407+
}
1408+
uint64_t Controller::span_id() const {
1409+
if (auto span = _span.lock()) {
1410+
return span->span_id();
1411+
}
1412+
return 0;
1413+
}
14011414

14021415
void* Controller::session_local_data() {
14031416
if (_session_local_data) {

src/brpc/controller.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <functional> // std::function
2626
#include <gflags/gflags.h> // Users often need gflags
2727
#include <string>
28+
#include <memory>
2829
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
2930
#include "bthread/errno.h" // Redefine errno
3031
#include "butil/endpoint.h" // butil::EndPoint
@@ -803,7 +804,7 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);
803804
private:
804805
// NOTE: align and group fields to make Controller as compact as possible.
805806

806-
Span* _span;
807+
std::weak_ptr<Span> _span;
807808
uint32_t _flags; // all boolean fields inside Controller
808809
int32_t _error_code;
809810
std::string _error_text;

0 commit comments

Comments
 (0)