diff --git a/.asf.yaml b/.asf.yaml index b2f56d8d29..e5a156cbb0 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -21,3 +21,16 @@ github: issues: true projects: false discussions: true + rulesets: + - name: "Default Branch Protection" + type: branch + branches: + includes: + - "~DEFAULT_BRANCH" + - "release/*" + - "rel/*" + excludes: [] + bypass_teams: + - root + restrict_deletion: true + restrict_force_push: true diff --git a/.bazelignore b/.bazelignore index c19cecab25..1559ee6ef2 100644 --- a/.bazelignore +++ b/.bazelignore @@ -1 +1,8 @@ -./example/build_with_bazel \ No newline at end of file +./example/build_with_bazel + +# `registry/` is brpc's self-maintained Bzlmod registry. Its overlay +# BUILD.bazel files reference sources from the libunwind tarball that is +# only materialized at module resolution time, not in the source tree. +# Without this entry, `bazel build //...` would try to evaluate them as +# regular packages and fail with "missing input file ...". +./registry diff --git a/.bazelrc b/.bazelrc index 2ee10ddac7..c10fb589bc 100644 --- a/.bazelrc +++ b/.bazelrc @@ -13,12 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -# +# Bazel doesn't need more than 200MB of memory for local build based on memory profiling: +# https://docs.bazel.build/versions/master/skylark/performance.html#memory-profiling +# The default JVM max heapsize is 1/4 of physical memory up to 32GB which could be large +# enough to consume all memory constrained by cgroup in large host. +# Limiting JVM heapsize here to let it do GC more when approaching the limit to +# leave room for compiler/linker. +# The number 3G is chosen heuristically to both support large VM and small VM with RBE. +# Startup options cannot be selected via config. +startup --host_jvm_args=-Xmx3g +startup --host_jvm_args="-DBAZEL_TRACK_SOURCE_DIRECTORIES=1" + # Default build options. These are applied first and unconditionally. -# common --registry=https://bcr.bazel.build common --registry=https://baidu.github.io/babylon/registry +common --registry=https://raw.githubusercontent.com/apache/brpc/master/registry +build --verbose_failures +# Keep SHT_SYMTAB in built binaries so google::Symbolize can resolve +# in-binary functions (e.g. TestBody() in test binaries) by name +# instead of falling back to "". Bazel's default +# `--strip=sometimes` strips debug/symbol sections in fastbuild mode, +# which is what `bazel test` uses unless `-c dbg` is given. +build --strip=never build --cxxopt="-std=c++17" build --copt="-fno-omit-frame-pointer" # Use gnu17 for asm keyword. @@ -33,8 +50,18 @@ build --features=per_object_debug_info # We already have absl in the build, define absl=1 to tell googletest to use absl for backtrace. build --define absl=1 -# For brpc. -test --define=BRPC_BUILD_FOR_UNITTEST=true +build:rdma --define BRPC_WITH_RDMA=true + +# For UT. +build:test --define BRPC_BUILD_FOR_UNITTEST=true +# Hide libunwind's `_Unwind_*` symbols so they don't preempt libgcc_s at +# runtime. Without this, pthread_exit / C++ exception unwinding crashes +# when libunwind.so appears earlier in the dynamic link chain. +# See registry/modules/1.8.1/overlay/BUILD.bazel for details. +build:test --define with_bthread_tracer=false +build:test --define libunwind_hide_unwind_symbols=true + +test --config=test test --test_output=streamed # Pass PATH, CC, CXX and LLVM_CONFIG variables from the environment. @@ -42,3 +69,21 @@ build --action_env=CC build --action_env=CXX build --action_env=LLVM_CONFIG build --action_env=PATH + +# Basic ASAN +build:asan --define=with_asan=true +build:asan --copt -fsanitize=address +build:asan --linkopt -fsanitize=address +# ASAN needs -O1 to get reasonable performance. +build:asan --copt -O1 +build:asan --copt -fno-optimize-sibling-calls + +# macOS ASAN +build:macos-asan --config=asan +# Workaround, see https://github.com/bazelbuild/bazel/issues/6932 +build:macos-asan --copt -Wno-macro-redefined +build:macos-asan --copt -D_FORTIFY_SOURCE=0 +# Dynamic link cause issues like: `dyld: malformed mach-o: load commands size (59272) > 32768` +build:macos-asan --dynamic_mode=off + +test:asan --test_env=ASAN_OPTIONS=detect_leaks=0:detect_stack_use_after_return=1 diff --git a/.github/actions/install-all-dependencies/action.yml b/.github/actions/install-all-dependencies/action.yml index 86d2884b97..5c1f673ff7 100644 --- a/.github/actions/install-all-dependencies/action.yml +++ b/.github/actions/install-all-dependencies/action.yml @@ -2,7 +2,7 @@ runs: using: "composite" steps: - uses: ./.github/actions/install-essential-dependencies - - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs1 libibverbs-dev + - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs-dev shell: bash - run: | wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz && tar -xf thrift-0.11.0.tar.gz && cd thrift-0.11.0/ diff --git a/.github/actions/install-essential-dependencies/action.yml b/.github/actions/install-essential-dependencies/action.yml index d6c5da96c1..5c1944d68e 100644 --- a/.github/actions/install-essential-dependencies/action.yml +++ b/.github/actions/install-essential-dependencies/action.yml @@ -3,5 +3,7 @@ runs: steps: - run: ulimit -c unlimited -S && sudo bash -c "echo 'core.%e.%p' > /proc/sys/kernel/core_pattern" shell: bash - - run: sudo apt-get update && sudo apt-get install -y git g++ make libssl-dev libgflags-dev libprotobuf-dev libprotoc-dev protobuf-compiler libleveldb-dev + - run: sudo apt-get update && sudo apt-get install -y git g++ make libssl-dev libgflags-dev libprotobuf-dev libprotoc-dev protobuf-compiler libleveldb-dev redis-server mysql-server libibverbs-dev + shell: bash + - run: redis-server --version && mysqld --version shell: bash diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index f2d6d69287..5c75104816 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -29,7 +29,9 @@ jobs: - name: gcc with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan - name: clang with default options uses: ./.github/actions/compile-with-make @@ -39,7 +41,9 @@ jobs: - name: clang with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan compile-with-cmake: runs-on: ubuntu-22.04 @@ -57,7 +61,9 @@ jobs: run: | export CC=gcc && export CXX=g++ mkdir gcc_build_all && cd gcc_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean - name: clang with default options @@ -70,7 +76,9 @@ jobs: run: | export CC=clang && export CXX=clang++ mkdir clang_build_all && cd clang_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_UBRING=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean gcc-compile-with-make-protobuf: @@ -103,33 +111,32 @@ jobs: protobuf-install-dir: /protobuf-3.21.12 config-brpc-options: --cc=gcc --cxx=g++ --werror - gcc-compile-with-bazel: + gcc-unittest-with-bazel: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - run: bazel build --verbose_failures -- //... -//example/... - - gcc-compile-with-boringssl: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v2 - - run: bazel build --verbose_failures --define with_mesalink=false --define with_glog=true --define with_thrift=true --define BRPC_WITH_BORINGSSL=true -- //... -//example/... + # Install redis-server/mysql-server so the integration tests that fork a + # real server (e.g. brpc_redis_unittest) actually run under bazel instead + # of skipping. Same shared action the make-based unittest jobs use. + - uses: ./.github/actions/install-essential-dependencies + - run: bazel test //test/... gcc-compile-with-bazel-all-options: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - run: | - bazel build --verbose_failures \ - --define with_mesalink=false \ + bazel build --define with_mesalink=false \ --define with_glog=true \ --define with_thrift=true \ + --define BRPC_WITH_BORINGSSL=true \ --define with_debug_bthread_sche_safety=true \ --define with_debug_lock=true \ --define with_asan=true \ --define with_bthread_tracer=true \ --define BRPC_WITH_NO_PTHREAD_MUTEX_HOOK=true \ - -- //... -//example/... + --define with_babylon_counter=true \ + -- //:brpc clang-compile-with-make-protobuf: runs-on: ubuntu-22.04 @@ -161,34 +168,35 @@ jobs: protobuf-install-dir: /protobuf-3.21.12 config-brpc-options: --cc=clang --cxx=clang++ --werror - clang-compile-with-bazel: + clang-unittest-with-bazel: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - - run: bazel build --verbose_failures --action_env=CC=clang -- //... -//example/... - - clang-compile-with-boringssl: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v2 - - run: bazel build --verbose_failures --action_env=CC=clang --define with_mesalink=false --define with_glog=true --define with_thrift=true --define BRPC_WITH_BORINGSSL=true -- //... -//example/... + # Install redis-server/mysql-server so the forked-server integration tests + # actually run under bazel (see gcc-unittest-with-bazel). + - uses: ./.github/actions/install-essential-dependencies + - run: | + bazel test --test_output=streamed \ + --action_env=CC=clang \ + //test/... clang-compile-with-bazel-all-options: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - run: | - bazel build --verbose_failures \ - --action_env=CC=clang \ + bazel build --action_env=CC=clang \ --define with_mesalink=false \ --define with_glog=true \ --define with_thrift=true \ + --define BRPC_WITH_BORINGSSL=true \ --define with_debug_bthread_sche_safety=true \ --define with_debug_lock=true \ --define with_asan=true \ --define with_bthread_tracer=true \ --define BRPC_WITH_NO_PTHREAD_MUTEX_HOOK=true \ - -- //... -//example/... + --define with_babylon_counter=true \ + -- //:brpc clang-unittest: runs-on: ubuntu-22.04 @@ -224,13 +232,33 @@ jobs: - name: run tests run: | cd test - sh ./run_tests.sh + # The redis integration tests (sanity/keys_with_spaces/incr_and_decr/by_components/auth) + # fork a real redis-server and connect after a fixed 50ms wait; under ASan redis starts + # too slowly, so they flake here (connection refused). Skip just those under ASan; the + # redis codec/server tests still run, and the full suite runs in clang-unittest. + GTEST_FILTER='-RedisTest.sanity:RedisTest.keys_with_spaces:RedisTest.incr_and_decr:RedisTest.by_components:RedisTest.auth' sh ./run_tests.sh - bazel-bvar-unittest: + clang-unittest-bazel-with-babylon-and-new-pb: runs-on: ubuntu-22.04 + env: + TEST_PROTOBUF_VERSION: "34.1" + # protobuf >= 34.x uses new ProtoInfo fields (option_deps, + # extension_declarations) introduced in Bazel 8.x. The repo's + # .bazelversion (7.2.1) is too old. bazelisk honors USE_BAZEL_VERSION. + USE_BAZEL_VERSION: "8.3.1" steps: - uses: actions/checkout@v2 - - run: bazel test --verbose_failures //test:bvar_test - - run: bazel test --verbose_failures --define with_babylon_counter=true //test:bvar_test - - run: bazel test --verbose_failures --action_env=CC=clang //test:bvar_test - - run: bazel test --verbose_failures --action_env=CC=clang --define with_babylon_counter=true //test:bvar_test + # Install redis-server/mysql-server so the forked-server integration tests + # actually run under bazel (see gcc-unittest-with-bazel). + - uses: ./.github/actions/install-essential-dependencies + - name: Override protobuf version for testing + run: | + sed -i -E "s/(bazel_dep\(name = ['\"]protobuf['\"], version = ['\"])[^'\"]+/\1${TEST_PROTOBUF_VERSION}/" MODULE.bazel + echo "After override:" + grep -E "bazel_dep\(name = ['\"]protobuf['\"]" MODULE.bazel + grep -qE "bazel_dep\(name = ['\"]protobuf['\"], version = ['\"]${TEST_PROTOBUF_VERSION}['\"]" MODULE.bazel \ + || { echo "ERROR: failed to override protobuf version in MODULE.bazel to ${TEST_PROTOBUF_VERSION}"; exit 1; } + - run: | + bazel test --action_env=CC=clang --config=rdma \ + --define with_babylon_counter=true \ + //test/... --test_arg=--gtest_filter=-RdmaRpcTest.* diff --git a/.github/workflows/ci-macos.yml b/.github/workflows/ci-macos.yml index 61d45ac821..1f64b18997 100644 --- a/.github/workflows/ci-macos.yml +++ b/.github/workflows/ci-macos.yml @@ -34,7 +34,7 @@ jobs: - name: compile with cmake run: | echo "CMAKE_PREFIX_PATH=$(brew --prefix protobuf@21)" - mkdir build && cd build && cmake -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_PREFIX_PATH=$(brew --prefix protobuf@21) .. + mkdir build && cd build && cmake -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DWITH_UBRING=ON -DCMAKE_PREFIX_PATH=$(brew --prefix protobuf@21) .. make -j ${{env.proc_num}} && make clean compile-with-make-cmake-protobuf29: @@ -56,7 +56,7 @@ jobs: - name: compile with cmake run: | - mkdir build && cd build && cmake -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_PREFIX_PATH=$(brew --prefix protobuf@29) .. + mkdir build && cd build && cmake -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DWITH_UBRING=ON -DCMAKE_PREFIX_PATH=$(brew --prefix protobuf@29) .. make -j ${{env.proc_num}} && make clean compile-with-bazel: diff --git a/.licenserc.yaml b/.licenserc.yaml index 525adf7257..7c8bbd54bf 100644 --- a/.licenserc.yaml +++ b/.licenserc.yaml @@ -218,4 +218,7 @@ header: # Fuzzing seed - 'test/fuzzing/fuzz_*_seed_corpus/*' + # bazel-central-registry + - 'registry/**' + comment: on-failure diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..b0fde7668c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,27 @@ + + +# Agent Guide for brpc + +This file is read by automated agents (security scanners, code +analyzers, AI assistants) operating on this repository. + +## Security + +Security model: [SECURITY.md](./SECURITY.md) + +Agents that scan this repository should consult `SECURITY.md` and the +threat model it links before reporting issues. diff --git a/BUILD.bazel b/BUILD.bazel index 138e416b10..b51ee0f6b0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_proto_library", "objc_library") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "objc_library") +load("//bazel/tools:brpc_proto_library.bzl", "brpc_proto_library") licenses(["notice"]) # Apache v2 @@ -22,43 +22,46 @@ exports_files(["LICENSE"]) COPTS = [ "-fno-omit-frame-pointer", - "-DBTHREAD_USE_FAST_PTHREAD_MUTEX", - "-D__const__=__unused__", - "-D_GNU_SOURCE", - "-DUSE_SYMBOLIZE", - "-DNO_TCMALLOC", - "-D__STDC_FORMAT_MACROS", - "-D__STDC_LIMIT_MACROS", - "-D__STDC_CONSTANT_MACROS", ] + select({ - "//bazel/config:brpc_with_glog": ["-DBRPC_WITH_GLOG=1"], - "//conditions:default": ["-DBRPC_WITH_GLOG=0"], -}) + select({ - "//bazel/config:brpc_with_mesalink": ["-DUSE_MESALINK"], - "//conditions:default": [""], -}) + select({ - "//bazel/config:brpc_with_thrift": ["-DENABLE_THRIFT_FRAMED_PROTOCOL=1"], - "//conditions:default": [""], -}) + select({ - "//bazel/config:brpc_with_thrift_legacy_version": [], - "//conditions:default": ["-DTHRIFT_STDCXX=std"], -}) + select({ - "//bazel/config:brpc_with_rdma": ["-DBRPC_WITH_RDMA=1"], - "//conditions:default": [""], -}) + select({ - "//bazel/config:brpc_with_debug_bthread_sche_safety": ["-DBRPC_DEBUG_BTHREAD_SCHE_SAFETY=1"], - "//conditions:default": ["-DBRPC_DEBUG_BTHREAD_SCHE_SAFETY=0"], -}) + select({ - "//bazel/config:brpc_with_debug_lock": ["-DBRPC_DEBUG_LOCK=1"], - "//conditions:default": ["-DBRPC_DEBUG_LOCK=0"], -}) + select({ "//bazel/config:brpc_with_asan": ["-fsanitize=address"], - "//conditions:default": [""], -}) + select({ - "//bazel/config:brpc_with_no_pthread_mutex_hook": ["-DNO_PTHREAD_MUTEX_HOOK"], - "//conditions:default": [""], + "//conditions:default": [], }) +DEFINES = [ + "BTHREAD_USE_FAST_PTHREAD_MUTEX", + "__const__=__unused__", + "_GNU_SOURCE", + "USE_SYMBOLIZE", + "NO_TCMALLOC", + "__STDC_FORMAT_MACROS", + "__STDC_LIMIT_MACROS", + "__STDC_CONSTANT_MACROS", +] + select({ + "//bazel/config:brpc_with_glog": ["BRPC_WITH_GLOG=1"], + "//conditions:default": ["BRPC_WITH_GLOG=0"], + }) + select({ + "//bazel/config:brpc_with_mesalink": ["USE_MESALINK"], + "//conditions:default": [], + }) + select({ + "//bazel/config:brpc_with_thrift": ["ENABLE_THRIFT_FRAMED_PROTOCOL=1"], + "//conditions:default": [], + }) + select({ + "//bazel/config:brpc_with_thrift_legacy_version": [], + "//conditions:default": ["THRIFT_STDCXX=std"], + }) + select({ + "//bazel/config:brpc_with_rdma": ["BRPC_WITH_RDMA=1"], + "//conditions:default": [], + }) + select({ + "//bazel/config:brpc_with_debug_bthread_sche_safety": ["BRPC_DEBUG_BTHREAD_SCHE_SAFETY=1"], + "//conditions:default": ["BRPC_DEBUG_BTHREAD_SCHE_SAFETY=0"], + }) + select({ + "//bazel/config:brpc_with_debug_lock": ["BRPC_DEBUG_LOCK=1"], + "//conditions:default": ["BRPC_DEBUG_LOCK=0"], + }) + select({ + "//bazel/config:brpc_with_no_pthread_mutex_hook": ["NO_PTHREAD_MUTEX_HOOK"], + "//conditions:default": [], + }) + LINKOPTS = [ "-pthread", "-ldl", @@ -93,7 +96,7 @@ LINKOPTS = [ "//conditions:default": [], }) + select({ "//bazel/config:brpc_with_asan": ["-fsanitize=address"], - "//conditions:default": [""], + "//conditions:default": [], }) genrule( @@ -331,13 +334,13 @@ cc_library( "src/butil/third_party/dmg_fp/dtoa.cc", ":config_h", ], - copts = COPTS + select({ + defines = DEFINES + select({ "//bazel/config:brpc_build_for_unittest": [ - "-DBVAR_NOT_LINK_DEFAULT_VARIABLES", - "-DUNIT_TEST", + "UNIT_TEST", ], "//conditions:default": [], }), + copts = COPTS, includes = [ "src/", ], @@ -354,8 +357,14 @@ cc_library( "@bazel_tools//src/conditions:darwin": [":macos_lib"], "//conditions:default": [], }) + select({ - "//bazel/config:brpc_with_boringssl": ["@boringssl//:ssl", "@boringssl//:crypto"], - "//conditions:default": ["@openssl//:ssl", "@openssl//:crypto"], + "//bazel/config:brpc_with_boringssl": [ + "@boringssl//:ssl", + "@boringssl//:crypto" + ], + "//conditions:default": [ + "@openssl//:ssl", + "@openssl//:crypto" + ], }), ) @@ -381,14 +390,13 @@ cc_library( defines = [] + select({ "//bazel/config:with_babylon_counter": ["WITH_BABYLON_COUNTER=1"], "//conditions:default": [], - }), - copts = COPTS + select({ + }) + select({ "//bazel/config:brpc_build_for_unittest": [ - "-DBVAR_NOT_LINK_DEFAULT_VARIABLES", - "-DUNIT_TEST", + "BVAR_NOT_LINK_DEFAULT_VARIABLES", ], "//conditions:default": [], }), + copts = COPTS, includes = [ "src/", ], @@ -475,7 +483,6 @@ cc_library( deps = [ ":brpc_idl_options_cc_proto", ":butil", - "@com_google_protobuf//src/google/protobuf/compiler:code_generator", ], ) @@ -487,20 +494,11 @@ filegroup( visibility = ["//visibility:public"], ) -proto_library( - name = "brpc_idl_options_proto", - srcs = [":brpc_idl_options_proto_srcs"], - strip_import_prefix = "src", - visibility = ["//visibility:public"], - deps = [ - "@com_google_protobuf//:descriptor_proto", - ], -) - -cc_proto_library( +brpc_proto_library( name = "brpc_idl_options_cc_proto", + srcs = [":brpc_idl_options_proto_srcs"], + include = "src", visibility = ["//visibility:public"], - deps = [":brpc_idl_options_proto"], ) filegroup( @@ -508,25 +506,17 @@ filegroup( srcs = glob([ "src/brpc/*.proto", "src/brpc/policy/*.proto", + "src/brpc/rdma/*.proto", ]), visibility = ["//visibility:public"], ) -proto_library( - name = "brpc_internal_proto", - srcs = [":brpc_internal_proto_srcs"], - strip_import_prefix = "src", - visibility = ["//visibility:public"], - deps = [ - ":brpc_idl_options_proto", - "@com_google_protobuf//:descriptor_proto", - ], -) - -cc_proto_library( +brpc_proto_library( name = "brpc_internal_cc_proto", + srcs = [":brpc_internal_proto_srcs"], + include = "src", + deps = [":brpc_idl_options_cc_proto"], visibility = ["//visibility:public"], - deps = [":brpc_internal_proto"], ) cc_library( @@ -587,5 +577,6 @@ cc_binary( deps = [ ":brpc", ":brpc_idl_options_cc_proto", + "@com_google_protobuf//:protoc_lib", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 77703a4661..ff2dc018c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ option(WITH_THRIFT "With thrift framed protocol supported" OFF) option(WITH_BTHREAD_TRACER "With bthread tracer supported" OFF) option(WITH_SNAPPY "With snappy" OFF) option(WITH_RDMA "With RDMA" OFF) +option(WITH_UBRING "With UB" OFF) option(WITH_DEBUG_BTHREAD_SCHE_SAFETY "With debugging bthread sche safety" OFF) option(WITH_DEBUG_LOCK "With debugging lock" OFF) option(WITH_ASAN "With AddressSanitizer" OFF) @@ -40,7 +41,7 @@ if(POLICY CMP0042) cmake_policy(SET CMP0042 NEW) endif() -set(BRPC_VERSION 1.16.0) +set(BRPC_VERSION 1.17.0) SET(CPACK_GENERATOR "DEB") SET(CPACK_DEBIAN_PACKAGE_MAINTAINER "brpc authors") @@ -104,6 +105,11 @@ if(WITH_RDMA) set(WITH_RDMA_VAL "1") endif() +set(WITH_UBRING_VAL "0") +if(WITH_UBRING) + set(WITH_UBRING_VAL "1") +endif() + set(WITH_DEBUG_BTHREAD_SCHE_SAFETY_VAL "0") if(WITH_DEBUG_BTHREAD_SCHE_SAFETY) set(WITH_DEBUG_BTHREAD_SCHE_SAFETY_VAL "1") @@ -136,7 +142,7 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} -Wno-deprecated-declarations -Wno-inconsistent-missing-override") endif() -set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} ${DEFINE_CLOCK_GETTIME} -DBRPC_WITH_GLOG=${WITH_GLOG_VAL} -DBRPC_WITH_RDMA=${WITH_RDMA_VAL} -DBRPC_DEBUG_BTHREAD_SCHE_SAFETY=${WITH_DEBUG_BTHREAD_SCHE_SAFETY_VAL} -DBRPC_DEBUG_LOCK=${WITH_DEBUG_LOCK_VAL}") +set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} ${DEFINE_CLOCK_GETTIME} -DBRPC_WITH_GLOG=${WITH_GLOG_VAL} -DBRPC_WITH_RDMA=${WITH_RDMA_VAL} -DBRPC_WITH_UBRING=${WITH_UBRING_VAL} -DBRPC_DEBUG_BTHREAD_SCHE_SAFETY=${WITH_DEBUG_BTHREAD_SCHE_SAFETY_VAL} -DBRPC_DEBUG_LOCK=${WITH_DEBUG_LOCK_VAL}") if (WITH_ASAN) set(CMAKE_CPP_FLAGS "${CMAKE_CPP_FLAGS} -fsanitize=address") set(CMAKE_C_FLAGS "${CMAKE_CPP_FLAGS} -fsanitize=address") @@ -172,7 +178,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-gcse") elseif((CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")) # RISC-V specific optimizations - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gc") + option(WITH_RISCV_ZBC "Enable RISC-V Zbc carry-less multiplication for CRC32C acceleration" OFF) + if(WITH_RISCV_ZBC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gc_zbc") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gc") + endif() endif() if(NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-aligned-new") @@ -206,11 +217,13 @@ if(Protobuf_VERSION GREATER 4.21) absl::hash absl::layout absl::log_initialize + absl::log_globals absl::log_severity absl::memory absl::node_hash_map absl::node_hash_set - absl::optional + absl::random_distributions + absl::random_random absl::span absl::status absl::statusor @@ -322,6 +335,11 @@ if(WITH_RDMA) list(APPEND DYNAMIC_LIB ${RDMA_LIB}) endif() +if(WITH_UBRING) + message(STATUS "brpc compile with ubring") + list(APPEND DYNAMIC_LIB ${UB_LIB}) +endif() + set(BRPC_PRIVATE_LIBS "-lgflags -lprotobuf -lleveldb -lprotoc -lssl -lcrypto -ldl -lz") if(WITH_GLOG) @@ -548,7 +566,8 @@ set(PROTO_FILES idl_options.proto brpc/policy/mongo.proto brpc/trackme.proto brpc/streaming_rpc_meta.proto - brpc/proto_base.proto) + brpc/proto_base.proto + brpc/rdma/rdma_handshake.proto) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/output/include/brpc) set(PROTOC_FLAGS ${PROTOC_FLAGS} -I${PROTOBUF_INCLUDE_DIR}) compile_proto(PROTO_HDRS PROTO_SRCS ${PROJECT_BINARY_DIR} diff --git a/LICENSE b/LICENSE index 0efd09e42e..392fd8c95a 100644 --- a/LICENSE +++ b/LICENSE @@ -959,3 +959,41 @@ copyright The Chromium Authors and licensed under the 3-clause BSD license: THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- + +registry/modules/libunwind/**: licensed under the following terms: + + Forked from the Bazel Central Registry (BCR): + https://github.com/bazelbuild/bazel-central-registry/tree/main/modules/libunwind + + The upstream files are distributed under the Apache License, Version 2.0, + the same license as brpc itself (the full text of which is reproduced at + the top of this LICENSE file). brpc has modified the forked files; per the + Apache License 2.0 §4(b), the modifications are summarized below: + + - registry/modules/libunwind//MODULE.bazel + - registry/modules/libunwind//overlay/MODULE.bazel + Append the suffix `.brpc-no-unwind` to the `version` field, marking + these as brpc's variant of the libunwind module. + + - registry/modules/libunwind//overlay/BUILD.bazel + Add the `hide_unwind_symbols` config_setting (gated by + `--define=libunwind_hide_unwind_symbols=true`); when the switch is + on, drop `src/unwind/*.c` (the GCC `_Unwind_*` ABI compatibility + layer) from `unwind_srcs` so the resulting libunwind does not export + `_Unwind_*` symbols. See docs/cn/bthread_tracer.md for the + rationale. + + - registry/modules/libunwind//source.json + Update overlay file SHA-256 hashes to match the modified + BUILD.bazel / MODULE.bazel. + + - registry/modules/libunwind/metadata.json + Replace the `versions` array with brpc's renamed versions. + + The above only governs the build/source files brpc redistributes inside + registry/modules/libunwind/. The libunwind source code itself, downloaded + by Bazel from the URL pinned in source.json, remains governed by the + libunwind project's own license; see https://github.com/libunwind/libunwind + for details. diff --git a/MODULE.bazel b/MODULE.bazel index 95f4e6b763..bd5b8e3824 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,6 +1,6 @@ module( name = 'brpc', - version = '1.16.0', + version = '1.17.0', compatibility_level = 1, ) @@ -16,8 +16,17 @@ bazel_dep(name = "apple_support", version = "1.17.1") bazel_dep(name = 'rules_cc', version = '0.0.1') bazel_dep(name = 'rules_proto', version = '4.0.0') bazel_dep(name = 'zlib', version = '1.3.1.bcr.5', repo_name = 'com_github_madler_zlib') -bazel_dep(name = 'libunwind', version = '1.8.1', repo_name = 'com_github_libunwind_libunwind') bazel_dep(name = 'babylon', version = '1.4.4') +# --registry=https://raw.githubusercontent.com/apache/brpc/master/registry +# `registry/modules/libunwind/` (see the `--registry=https://raw.githubusercontent.com/apache/brpc/master/registry` +# entry at the top of `.bazelrc`). The version suffix `.brpc-no-unwind` marks +# this as brpc's self-maintained variant whose `src/unwind/*.c` (the GCC +# `_Unwind_*` ABI compatibility layer) is gated by the +# `--define=libunwind_hide_unwind_symbols=true` switch. Hiding those +# `_Unwind_*` symbols is required for `--define=with_bthread_tracer=true` +# to work without crashing in pthread_exit / exception unwinding. +# See docs/cn/bthread_tracer.md for background. +bazel_dep(name = 'libunwind', version = '1.8.1.brpc-no-unwind', repo_name = 'com_github_libunwind_libunwind') # --registry=https://baidu.github.io/babylon/registry bazel_dep(name = 'leveldb', version = '1.23', repo_name = 'com_github_google_leveldb') @@ -34,6 +43,7 @@ single_version_override( bazel_dep(name = 'thrift', version = '0.21.0', repo_name = 'org_apache_thrift') # test only +bazel_dep(name = "gperftools", version = "2.18.1", dev_dependency = True) bazel_dep(name = 'googletest', version = '1.14.0.bcr.1', repo_name = 'com_google_googletest', dev_dependency = True) bazel_dep(name = 'hedron_compile_commands', dev_dependency = True) git_override( diff --git a/README.md b/README.md index 1c4f78528b..d65366fafb 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ You can use it to: * [FlatMap](docs/en/flatmap.md) * [Coroutine](docs/en/coroutine.md) * [Circuit Breaker](docs/en/circuit_breaker.md) + * [UBRing](docs/en/ubring.md) * [RDMA](docs/en/rdma.md) * [Bazel Support](docs/en/bazel_support.md) * [Wireshark baidu_std dissector plugin](docs/en/wireshark_baidu_std.md) diff --git a/README_cn.md b/README_cn.md index 6413f83fde..2cc686bd85 100644 --- a/README_cn.md +++ b/README_cn.md @@ -87,6 +87,7 @@ * [FlatMap](docs/cn/flatmap.md) * [协程](docs/cn/coroutine.md) * [熔断](docs/cn/circuit_breaker.md) + * [UBRing](docs/cn/ubring.md) * [RDMA](docs/cn/rdma.md) * [Bazel构建支持](docs/cn/bazel_support.md) * [Wireshark baidu_std协议解析插件](docs/cn/wireshark_baidu_std.md) diff --git a/RELEASE_VERSION b/RELEASE_VERSION index 15b989e398..092afa15df 100644 --- a/RELEASE_VERSION +++ b/RELEASE_VERSION @@ -1 +1 @@ -1.16.0 +1.17.0 diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..ac8d16b936 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,13 @@ +# Security Policy + +## Reporting a Vulnerability + +`apache/brpc` follows the [Apache Software Foundation security process](https://www.apache.org/security/). Please report suspected +vulnerabilities privately to `security@apache.org`; do not open public +GitHub issues or pull requests for security reports. + +## Threat Model + +What the project treats as in scope and out of scope, the security +properties it provides and disclaims, the adversary model, and how +findings are triaged are documented in [THREAT_MODEL.md](./THREAT_MODEL.md). diff --git a/THREAT_MODEL.md b/THREAT_MODEL.md new file mode 100644 index 0000000000..e318507634 --- /dev/null +++ b/THREAT_MODEL.md @@ -0,0 +1,627 @@ +# Apache bRPC — Threat Model (v1 draft) + +## §1 Header + +- **Project:** Apache bRPC (`apache/brpc`). +- **Scope:** the `apache/brpc` repository only. The PMC confirmed on 2026-05-20 that `apache/brpc-website` is out of scope for this engagement. This draft is first authored by the ASF Security Team and then reviewed by the PMC. +- **Version binding:** based on `master` around 2026-05-21. Vulnerability reports should be triaged against the model for the corresponding version, not against `HEAD`; re-bind on each release. +- **Authors and status:** drafted by the ASF Security Team (Glasswing pre-scan), revised by PMC member Weibing Wang. **DRAFT v1** +- **Reporting entry point:** at drafting time, the repository had no `SECURITY.md`, and the Apache security project index listed only the generic `security@apache.org` address. Until project-level disclosure documentation is published, report vulnerabilities to `security@apache.org` per ASF policy. +- **Provenance legend:** + - *(documented — source)*: from repository documentation, headers, gflags, source code, or Apache governance artifacts. + - *(inferred — Qn)*: inferred from code structure, general RPC security experience, or absence of a defense, with a corresponding question in §14. + - *(maintainer)*: not yet used in this draft; to be substituted after PMC confirmation. +- **Draft confidence:** most factual descriptions in §1-§13 come from documentation and source code. Intent, default security posture, resource boundaries, the built-in service trust model, and vulnerability triage criteria still require confirmation in §14. +- **bRPC overview:** bRPC is an embeddable C++ RPC framework that supports dispatching multiple protocols on the same port via content sniffing, including `baidu_std`, HTTP/1.x, HTTP/2/gRPC, Thrift, Redis, memcached, RTMP, Mongo, the nshead family, and others. It provides naming services, load balancers, optional TLS, bthread scheduling, and HTTP built-in admin services. bRPC is not a standalone daemon; downstream applications link `libbrpc` and start `brpc::Server` in-process. + +--- + +## §2 Scope and intended use + +bRPC's primary uses are: + +1. Expose one or more `google::protobuf::Service` implementations on a TCP port, with multiple protocols sharing the same port. +2. Act as a client `brpc::Channel` to access a single server or a cluster described by a naming service plus load balancer. +3. Be embedded by a host application; deployment, TLS, network exposure, authentication, routing, and process lifecycle are the responsibility of the application or operator. +4. Use built-in services such as `/status`, `/vars`, `/flags`, `/connections`, `/rpcz`, and `/health` for debugging and monitoring. +5. Use bRPC as an HTTP/h2 server or client, including Restful URL mapping and JSON-to-Protobuf / Protobuf-to-JSON conversion. + +bRPC is not a secure-by-default managed service. Threats land in the runtime code linked by the application and in the built-in admin services the application chooses to expose. + +### Caller / actor roles + +- **Application developer:** defines `.proto` files, implements services, chooses protocols and ports, and calls clients. Trusted within the compile-time and configuration boundary. +- **Operator:** runs the binary and configures ports, `internal_port`, TLS, gflags, naming services, rate limiting, and built-in services. Trusted for that instance. +- **Client peer (RPC client):** sends bytes to a bRPC server. Untrusted by default; protocol sniffers must handle arbitrary input. +- **Server peer:** the remote server that a bRPC client connects to. Response bytes are also untrusted. +- **Naming-service backend:** `bns`, `file`, `consul`, `nacos`, DNS, and similar backends. Configured by the operator and treated by bRPC as trusted infrastructure. +- **Built-in admin-service consumer:** a person or program accessing `/status`, `/flags`, `/rpcz`, and similar endpoints. Its trust level depends entirely on whether the operator places these endpoints on an internal network or disables them. + +### Component-family table + +| Family | Representative entry point | Touches outside the process? | In model? | +| --- | --- | --- | --- | +| **C++ runtime core** | `Server::Start`, `Channel::CallMethod`, `Socket`, `InputMessenger` | sockets, files, threads, TLS, optional RDMA | **In model** | +| **Wire-protocol parsers** | `Parse*Message` | reads untrusted bytes from sockets | **In model**, highest priority | +| **HTTP / h2 (+ gRPC) stack** | `HttpServerHandler`, `H2StreamContext`, `URI` | sockets, TLS, json2pb | **In model** | +| **Built-in admin services** | `/status`, `/vars`, `/flags`, `/rpcz`, `/dir`, `/threads`, profiler, metrics | HTTP, `/proc`, filesystem, profiler | **In model**, depending on exposure | +| **TLS / SSL layer** | `ServerSSLOptions`, `ChannelSSLOptions` | OpenSSL, certificate files | **In model** | +| **Authentication hooks** | `Authenticator::VerifyCredential` | wire token | **In model**, policy is implemented by the application | +| **Naming services** | `file://`, `http(s)://`, `consul://`, `nacos://`, `dns://` | files and network | **In model**, backend treated as trusted | +| **Load balancers** | rr, wrr, random, la, consistent hashing | pure computation | **In model**, not a direct attack surface | +| **Compression layer** | gzip/snappy/zlib/lz4 | CPU/memory | **In model** | +| **RDMA transport** | `rdma://` | verbs API, pinned memory | **In model** only when enabled at build time and runtime | +| **Coroutine bridge** | `usercode_in_coroutine` | scheduling | **In model** in non-default mode | +| **bthread / butil / bvar** | scheduler, IOBuf, counters | threads, futex, files, `/proc` | **In model** | +| **mcpack2pb / json2pb** | mcpack/JSON <-> pb | pure computation | **In model** on the corresponding protocol paths | +| **Java / Python bindings** | JNI / pyBRPC | language bridge | currently unsupported; **Out of model** | +| **tools / test / example / build / community / docs** | CLI, tests, examples, build, governance, docs | non-production runtime surface | **Out of model** | + +Wire-protocol parsers are the most important surface: they directly receive untrusted bytes, are implemented in C++, cover many protocols with different maturity levels, and multi-protocol sniffing can make every registered parser reachable from the same listening port. `enabled_protocols` can restrict the sniffing set, but its implementation details need maintainer confirmation. + +--- + +## §3 Out of scope (explicit non-goals) + +The following scenarios or threats are not protected by the bRPC threat model. Matching reports are closed according to §13. + +### Use cases out of scope + +1. **A secure-by-default RPC endpoint.** bRPC provides components, not a `brpcd` with a strong default security posture. +2. **A sandbox or process-isolation boundary.** Deserialization, business handlers, and the host application run in the same process. +3. **Cryptographic primitives.** TLS security comes from OpenSSL; bRPC only wraps and configures it. +4. **An application field-validation framework.** bRPC validates wire framing, not semantic fields such as email addresses, IDs, or business ranges. +5. **A complete replacement for authenticated transport.** Bare `baidu_std`, HTTP, Thrift, Redis, and similar protocols have no default authentication; `Authenticator` and mTLS require explicit configuration. +6. **A general HTML/JSON rendering security surface.** Escaping in application HTTP handlers is the application's responsibility. +7. **A general file server.** `/dir` is disabled by default and is not modeled as a production file-serving feature. +8. **Exposing built-in services to the public internet.** bRPC built-in services should only be exposed to trusted internal consumers. + +### Threats explicitly out of scope + +1. Attackers who control the calling process, can read/write process memory, or can call arbitrary bRPC APIs. +2. Attackers who control the compiler, dependencies, build environment, or supply chain. +3. Side channels such as timing, cache, power, and RowHammer. +4. Raw socket-layer DoS: SYN flood, slowloris, half-open connection exhaustion, and kernel table exhaustion. +5. Resource exhaustion in `bvar`/`bthread`/`butil` caused purely by in-process call patterns. +6. Performance loss caused by an operator intentionally enabling expensive tracing/profiling. +7. Compromise of the naming-service backend; bRPC trusts the backend configured by the operator. +8. Vulnerabilities in application handlers, such as SQL injection or missing business authorization. + +### Code shipped in the repo but out of scope + +`test/`, `example/`, `tools/`, build and packaging files, `community/`, `docs/`, and `apache/brpc-website` are not modeled as production runtime surfaces. The core boundary is the runtime, parsers, transports, built-in services, and support libraries under `src/`. + +--- + +## §4 Trust boundaries and data flow + +bRPC has four main trust boundaries: three on the network surface and one on the operator configuration and naming-service surface. + +### Boundary 1 — the wire (server side, business RPC) + +Server-side socket bytes are untrusted by default: + +```text +socket bytes + -> Acceptor / Socket + -> InputMessenger::CutInputMessage + -> try registered protocol parsers one by one + -> check whether declared body size exceeds -max_body_size + -> protocol message structure + -> protobuf / json2pb / mcpack2pb / other decoder + -> Service::Method(controller, request, response, done) +``` + +The trust transition occurs when the framework calls the user's `Service::Method`. Crashes or memory corruption reachable from the wire in parsers, decoders, compression, and TLS handshake are in model. How the business handler processes `Request` is the application's responsibility. + +All protocol parsers are enabled by default. `ServerOptions.enabled_protocols` can restrict the protocol sniffing set. + +### Boundary 2 — the wire (server side, built-in admin services) + +HTTP/h2 requests on the same listening port may be routed to built-in services. `internal_port` can move built-in services to a separate port, and `has_builtin_services=false` can disable them entirely. `/dir` and `/threads` are disabled by default; other built-in services are enabled by default when `has_builtin_services=true`. + +Model posture: built-in services are treated as an operator-trusted surface when placed on an internal port or protected by firewall. If exposed to the public internet, they become a major risk surface. Specific triage depends on Q46/Q50 in §14. + +Even when built-in services are only served internally, they still need to avoid severe attacks such as command injection or denial of service. However, modifying gflags, enabling rpcz, profiling, and similar operations through built-in services can affect business logic or program performance; those effects are normal use and are not security vulnerabilities. + +### Boundary 3 — the wire (client side) + +A bRPC client deserializes server responses through paths symmetric to the server side: untrusted bytes enter parsers and decoders before being handed to the application. Parser/decoder vulnerabilities triggered by malicious server responses are in model. + +### Boundary 4 — the operator config + naming-service backend + +`ServerOptions`, `ChannelOptions`, gflags, flagfiles, TLS file paths, and naming-service URIs are supplied by the operator and are trusted. Server lists returned by naming services are also treated as trusted infrastructure. Backend compromise is out of scope, but crashes in reply parsers on malformed input remain in model. + +### Reachability preconditions per family + +- Parser issues are in model only when reachable from Boundary 1/3; opt-in protocols require operator enablement. +- Built-in services are triaged according to the Boundary 2 exposure rules. +- TLS issues are inside the network surface; OpenSSL CVEs are handled as external dependency issues. +- Compression, json2pb, and mcpack2pb are in model when reachable from attacker input. +- RDMA and coroutine paths are modeled only when explicitly enabled or loaded. + +### Data not crossing a boundary + +Purely internal scheduler queues, in-process `bvar` counters, internal `IOBuf` lifetimes, `socket_map` lifetimes, Wireshark dissectors under `tools/`, and similar paths are not trust transitions. + +--- + +## §5 Assumptions about the environment + +bRPC is a C++ library that normally runs on Linux/macOS with POSIX sockets, OpenSSL, and a C++ toolchain. + +### OS / runtime / hardware + +- Linux and macOS are supported; Windows is not included in the security support scope for now. +- A C++11 or newer toolchain is required. +- POSIX sockets and Linux epoll or macOS kqueue are required. +- TLS depends on OpenSSL; risks from old OpenSSL versions are managed by the operator. +- gflags, protobuf, and leveldb are regular dependencies; tcmalloc/gperftools is optional. + +### Concurrency + +- bthread is an M:N scheduler; the number of worker pthreads depends on CPU count or configuration. +- `ServerOptions.num_threads` is a global worker-count hint, not hard isolation for one server. +- bthread-local state may be reused; user destructors must be reentrant. +- `-usercode_in_pthread` moves user code to pthreads and changes concurrency and queuing semantics. + +### Memory model + +bRPC is memory-unsafe C++. All `Parse*Message`, `Read*`, `Decompress*`, JSON/mcpack/protobuf decoder paths are candidate surfaces for OOB access, UAF, double-free, integer overflow, and stack overflow. `-max_body_size` is the key outer cap; the model depends on parsers correctly checking lengths and avoiding signed/unsigned overflow. + +### Time / clock + +bRPC uses `gettimeofday`, `clock_gettime`, `cpuwide_time_us`, and similar clocks. Resistance to timing side channels is not a goal. + +### Filesystem / network / peripherals + +bRPC opens TCP/Unix sockets, pid files, TLS cert/key files, naming-service files, profile output, rpc dumps, and reads `/proc/*` for bvar. It also sets normal socket options such as keepalive, `SO_REUSEPORT`, and `TCP_USER_TIMEOUT`. + +### What bRPC does *not* do to its host (negative-claim inventory) + +By default it does not install signal handlers. Normal operation does not fork child processes unless a profiler endpoint is triggered. It reads limited environment variables and gflags. First SSL use initializes OpenSSL global state. Logs go to stdout/stderr or glog. It does not modify locale or FPU state. `fork without exec` must happen before bRPC initialization. + +--- + +## §5a Build-time and configuration variants + +bRPC's security boundary is affected by build options, gflags, `ServerOptions`, `ChannelOptions`, and SSL options. + +### Build-time options that change the security envelope + +| Knob / flag | Default | Effect on model | +| --- | --- | --- | +| `WITH_THRIFT` | off | makes compiling and enabling the Thrift parser possible | +| `WITH_GLOG` | off | changes the logging backend; not critical | +| `WITH_RDMA` | off | introduces RDMA transport and pinned memory | +| `WITH_MESALINK` | off | OpenSSL replacement implementation | +| `BUILD_UNIT_TESTS` | off | builds tests, out of model | + +*Note: the table above is only an example, not a complete list. Analyze the latest code for details.* + +### Runtime gflags that change the security envelope + +| Flag / option | Default | Effect | Maintainer stance | +| --- | --- | --- | --- | +| `-max_body_size` | 64 MiB | global body cap for all protocols; protobuf's internal cap is lifted to `INT_MAX` | critical | +| `-enable_dir_service` | false | `/dir` reads files | dangerous when enabled | +| `-enable_threads_service` | false | `/threads` dumps stacks | dangerous when enabled | +| `-immutable_flags` | false | `/flags` can modify reloadable gflags | safer when enabled | +| `-usercode_in_pthread` | false | changes the thread model for user code | changes concurrency semantics and may cause deadlocks | +| `ServerOptions.has_builtin_services` | true | global switch for built-in services | critical | +| `ServerOptions.internal_port` | -1 | built-in services are on the main port by default | critical; safer when set | + +*Note: the table above is only an example, not a complete list. Analyze the latest code for details.* + +### Insecure-default summary (the §13 routing question) + +Defaults triaged by the PMC: + +1. `-max_body_size=64 MiB` is the production default. +2. `has_builtin_services=false` or `internal_port!=-1` is operator responsibility. +3. The combination of `-immutable_flags=false` and reachable `/flags` is acceptable. +4. `ChannelSSLOptions.verify.verify_mode=NOT_SET`, causing client TLS not to verify the server certificate, is acceptable. +5. `ChannelSSLOptions.protocols` including TLS 1.0/1.1 is acceptable. +6. `force_ssl=false`, causing a TLS-configured port to still accept plaintext, is acceptable. +7. `/dir` and `/threads` are disabled by default. If an operator enables them and exposes them publicly, treat that as a dangerous non-default configuration. + +*Note: the table above is only an example, not a complete list. Analyze the latest code for details.* + +--- + +## §6 Assumptions about inputs + +bRPC receives three classes of input: socket bytes, in-process API parameters, and operator configuration. Boundary 1/3 wire bytes are the highest risk. + +### Per-parameter trust table + +| Component / function | Parameter | Attacker-controllable? | Caller / operator must enforce | +| --- | --- | --- | --- | +| `InputMessenger::CutInputMessage` | initial connection bytes | yes | sniffer must defend against arbitrary bytes | +| `ParseRpcMessage` | PRPC header, `RpcMeta`, payload | yes | `-max_body_size`, `meta_size <= body_size` | +| `ParseHttpMessage` | method, URI, headers, body | yes | HTTP grammar, body cap | +| `ParseH2Message` | frame, HPACK headers, SETTINGS | yes | frame/header/HPACK/SETTINGS caps | +| Redis / memcache / Thrift / Mongo / nshead parsers, etc. | declared length, frame, body | yes | must be uniformly constrained by `-max_body_size` or protocol caps | +| RTMP / AMF | chunk stream, AMF object | yes | chunk and nesting caps | +| `json2pb::JsonToProtoMessage` | HTTP JSON body | yes | UTF-8, depth, repeated-field expansion | +| `mcpack2pb` | mcpack body | yes | depth and size caps | +| `Decompress` | compressed payload | yes | bRPC currently does not guarantee a decompressed-size cap; decompression may fail | +| TLS handshake | cert, SNI, ALPN | yes | operator configures verify, SNI, ALPN | +| `Authenticator::VerifyCredential` | credential, client_addr | yes; more obvious when header IP is enabled | application implements authentication; operator protects proxy boundary | +| Built-in `/flags` | gflag value | yes if reachable | hide built-ins or set `-immutable_flags=true` | +| Built-in `/dir` | path | yes if enabled | do not enable on public internet | +| Built-in profiler | `seconds` and similar parameters | yes if reachable | hide built-ins and limit profiling cost | +| `ServerOptions` / `ChannelOptions` / gflags | configuration values | no | operator trusted | +| Naming service reply | backend response | trusted by transitivity | parser should still tolerate malformed replies | + +### Size, shape, rate + +- `-max_body_size=64 MiB` is the global message cap; there is no unified per-method or per-protocol cap. +- The cap for unconsumed streaming RPC bytes is `-socket_max_streams_unconsumed_bytes`; default 0 means unlimited. +- Pending write buffers per socket are limited by `-socket_max_unwritten_bytes`. +- `max_concurrency` limits in-flight requests, not connection count; built-in services are not constrained by this option. +- `idle_timeout_sec` is disabled by default; connection count and slowloris are mainly handled by operators / load balancers. +- Protobuf has a recursion limit; JSON, mcpack, and AMF depth limits need confirmation. +- The framework has no general rate limiting. + +--- + +## §7 Adversary model + +### Primary adversary — the wire peer + +The primary attacker is the socket peer, which can send arbitrary bytes, confuse inputs across protocols, declare large lengths, construct compression bombs, trigger parser bugs, or make a client parse malicious server responses. Goals include crashes, memory/CPU exhaustion, OOB read/write, framing bypass, arbitrary command execution, authentication bypass, abuse of reachable built-in services, modification of reloadable gflags, and reading information through `/dir`. + +### Secondary adversary — the authenticated peer + +A peer that passes `Authenticator` is still untrusted at the parser layer. Authentication only narrows the attacker set; it does not change the peer's ability to send bytes. + +### Tertiary adversary — the HTTP-built-in-services consumer + +When `internal_port < 0` and `has_builtin_services=true`, any HTTP client that can reach the listening port may access `/status`, `/vars`, `/flags`, profiler, metrics, and similar endpoints. Final triage depends on Q46/Q50. + +### Out-of-scope adversaries + +In-process attackers, local side-channel attackers, attackers controlling `ServerOptions`/`ChannelOptions`, attackers controlling naming-service backends, build-chain attackers, and attacks that depend on the operator failing to harden network boundaries are not objects that bRPC itself promises to defend against. + +### Adversary capabilities by transport + +- **Plain TCP:** full wire control, no default authentication or encryption. +- **TLS:** provides encryption, but client/server certificate verification and TLS-only behavior require operator configuration. +- **RDMA:** modeled only when enabled; risks come from the verbs API and memory regions. +- **Unix domain socket:** if supported, the trust model is similar to TCP, with socket file permissions managed by the operator. + +--- + +## §8 Security properties the project provides + +### Memory and process-safety properties + +**P1. Under configured caps, valid wire input should not cause memory corruption.** +Any OOB access, UAF, double-free, or heap corruption in a parser, decoder, decompressor, TLS path, or built-in handler reachable from Boundary 1/3 is a high/critical issue. + +**P2. Bounded recursive structures should not cause stack overflow.** +Protobuf, JSON, AMF, mcpack, and similar decoders should reject excessively deep input before exhausting the stack. + +### Wire-format properties + +**P3. Multi-protocol sniffing should be conservative.** +When a parser returns `TRY_OTHERS`, it should not consume bytes; the same frame should not be taken over by the wrong parser. + +**P4. `-max_body_size` should be enforced.** +Parsers with declared lengths should reject and close the connection when the cap is exceeded; they should not allocate oversized memory first. + +### TLS / transport-security properties (when enabled) + +**P5. With SSL correctly configured and OpenSSL secure, TLS provides confidentiality and integrity.** + +**P6. The server rejects SSLv3 by default.** + +**P7. SNI can be used for certificate selection; fallback behavior on no match is controlled by `strict_sni`.** + +**P8. ALPN selection should be limited to `ServerSSLOptions.alpns`.** + +### Resource-bound properties + +**P9. Single-message memory use should be bounded by `-max_body_size`.** +Linear constant-factor overhead is acceptable; superlinear expansion or allocations that bypass the cap should be treated as bugs. See §9 D5 for decompression expansion. + +**P10. Per-socket unwritten buffers should be limited by `-socket_max_unwritten_bytes`.** + +**P11. Server/method concurrency caps should take effect when configured by the operator.** + +### Concurrency properties + +**P12. The framework's own cross-connection state should be thread-safe.** Synchronization of shared state in user handlers is the user's responsibility. + +**P13. `Channel::CallMethod` may be called across threads; `Channel::Init()` is not guaranteed to be thread-safe.** + +### Built-in admin-service properties (conditional on operator exposure) + +**P14. Built-in services should be protected by `has_builtin_services`, `internal_port`, or network boundaries.** + +**P15. `/dir` and `/threads` are disabled by default.** + +--- + +## §9 Security properties the project does *not* provide + +### Disclaimed properties (downstream's job, not bRPC's) + +**D1.** Bare protocols do not have default peer authentication; `Authenticator` authenticates per connection, not per request. +**D2.** No application-layer authorization is provided; RBAC, ACLs, scopes, and rate limits are implemented by the application or `Interceptor`. +**D3.** When `-http_header_of_user_ip` is enabled, bRPC does not defend against attacker-forged headers; trusted proxies and firewalls must enforce the boundary. +**D4.** Without a configured cap, streaming does not defend against unlimited growth of unconsumed stream data. +**D5.** bRPC does not guarantee a decompressed-output size cap. Decompression failure is allowed, but crashes or memory boundary violations are not. +**D6.** bRPC does not guarantee a cap on decoded JSON/repeated-field size. Decode failure is allowed, but crashes or memory boundary violations are not. +**D7.** h2/HPACK bomb, CONTINUATION flood, and PING flood defenses are currently missing and should preferably be added. +**D8.** No constant-time guarantee is provided. +**D9.** OpenSSL weaknesses and default TLS version choices are managed by the operator. +**D10.** No anti-replay is provided; request/sequence IDs are only used for request/response matching. +**D11.** There is no confidentiality without TLS. +**D12.** bRPC does not defend against socket-layer connection floods or slowloris. +**D13.** When built-ins are exposed and `-immutable_flags=false`, bRPC does not defend against `/flags` modifying reloadable gflags. +**D14.** When built-ins are reachable, bRPC does not defend against information disclosure from `/status`, `/vars`, `/connections`, `/rpcz`, and similar endpoints. +**D15.** When built-ins are reachable, bRPC does not defend against CPU/IO cost from profiler endpoints. +**D16.** When `internal_port=-1`, user services and built-in services on the same port are not isolated. +**D17.** Cross-protocol confusion from the multi-protocol sniffer is a parser issue the framework must defend against, but the risk of not restricting the protocol set is the operator's responsibility. +**D18.** Pathological protobuf input under protobuf `TotalBytesLimit=INT_MAX` has no additional framework cap. +**D19.** When `force_ssl=false`, a TLS-configured port still accepts plaintext. + +### False-friend properties + +**F1.** `has_builtin_services=true` is an administrative surface, not just harmless debug pages. +**F2.** `Authenticator::VerifyCredential` runs once per connection, not once per request. +**F3.** `force_ssl=false` means the same port accepts both SSL and non-SSL. +**F4.** `Controller::remote_side()` may be affected by `-http_header_of_user_ip`. +**F5.** `verify_depth=0` means verification is disabled, not strict verification. +**F6.** `Authenticator` does not cover all protocols by default. +**F7.** `strict_sni=false` means an unknown SNI falls back to the default certificate. +**F8.** `-max_body_size` does not limit decompressed size, decoded JSON size, or protobuf's internal cap. +**F9.** The LZ4 enum and documented support status are not fully aligned and need confirmation. + +### Well-known attack classes left to the caller + +Compression bombs, JSON/repeated-field expansion, HPACK bombs, TLS compression oracles, TLS downgrade, authentication brute force, `/flags` weaponization, slow decoder DoS, HTTP request smuggling, cross-protocol confusion, and naming-service spoofing. Except for parser memory safety and explicit claims, most of these are mitigated by the operator or application. + +--- + +## §10 Downstream responsibilities + +### Application developer responsibilities + +1. Pin the bRPC version and evaluate against that version's model. +2. Treat all deserialized fields as untrusted application input and perform business-semantic validation. +3. Implement `Authenticator` when identity is required, remembering that it authenticates per connection. +4. Implement `Interceptor` or handler checks when per-request authorization is required. +5. Do not call `Channel::Init()` concurrently across threads; `CallMethod()` is the thread-safe path. +6. Do not cache gflag values long-term when running with reloadable flags. +7. Do not use `assert` for security checks. + +### Operator responsibilities + +8. Hide built-in services on public ports: set `internal_port`, set `has_builtin_services=false`, or use a firewall. +9. Keep `-enable_dir_service=false` and `-enable_threads_service=false` on public ports. +10. Unless truly needed, set `-immutable_flags=true` or ensure `/flags` is reachable only by trusted operators. +11. Configure TLS: server certificate, client verification, removal of TLS 1.0/1.1, and `force_ssl=true` when needed. +12. Set `max_concurrency`, per-method concurrency, `idle_timeout_sec`, and streaming caps. +13. When enabling `-http_header_of_user_ip`, allow only trusted proxies to access the backend port. +14. Do not run bRPC processes as root. +15. Use escaping when application HTTP handlers output HTML. +16. Put public endpoints behind a load balancer / reverse proxy for connection-rate control, slowloris handling, and L7 hardening. +17. Track OpenSSL patches and rotate TLS material separately. +18. If compression is enabled, limit decompressed size at the application layer or sidecar. +19. Use `enabled_protocols` to expose only required protocols. +20. Choose trusted naming-service backends and protect permissions on flagfiles, cert/key files, and pid files. + +### Both + +21. Re-read and revise this model on each major bRPC upgrade or whenever a §12 condition occurs. + +--- + +## §11 Known misuse patterns + +**M1.** Exposing a server with default `has_builtin_services=true` and `internal_port=-1` to the public internet. +**M2.** Treating `Authenticator` as per-request authentication. +**M3.** Enabling `-http_header_of_user_ip` while allowing clients to bypass the trusted proxy and connect directly. +**M4.** Calling `Channel::Init()` from multiple threads. +**M5.** Using `assert` for security checks. +**M6.** Running bRPC as root. +**M7.** Configuring SSL but forgetting `force_ssl=true`, leaving plaintext available. +**M8.** Using plaintext protocols across trust boundaries. +**M9.** Setting `-max_body_size` very large and assuming it will not affect memory. +**M10.** Assuming compressed output size is limited by the framework. +**M11.** Enabling `/dir` or `/threads` on the public internet. +**M12.** Enabling nshead/nova/public/nshead_mcpack and forgetting the sniffer will try these parsers. +**M13.** Treating sequence/correlation IDs as nonces. +**M14.** Exposing `/flags` while keeping `-immutable_flags=false`. +**M15.** Not revalidating `auth_context()` or downstream identity information. +**M16.** Combining `-usercode_in_pthread` with high concurrency without setting `max_concurrency`. +**M17.** Using an untrusted naming service while client TLS does not verify the server certificate. +**M18.** Exposing profiler endpoints to the public internet. +**M19.** Enabling `/rpcz` on traffic containing sensitive request fields. +**M20.** Enabling RDMA and exposing it to untrusted peers. + +--- + +## §11a Known non-findings (recurring false positives) + +**N1.** A parser reading length and checking it with `FLAGS_max_body_size` is by design; it is an issue only if a parser truly skips the check. +**N2.** ~~Decompression has no decompressed-size cap, per §9 D5.~~ +**N3.** Built-in services expose internal information; operators should hide them. +**N4.** `/flags` modifying reloadable gflags without authentication is pending Q46. +**N5.** `/dir` reads files: disabled by default; enabling it is a dangerous non-default configuration. +**N6.** `/threads` dumps stacks: disabled by default. +**N7.** bvar reading `/proc/loadavg` and `/proc/stat` is by design. +**N8.** Client TLS includes TLS 1.0/1.1 by default; triage pending Q56. +**N9.** Client TLS does not verify server certificates by default; triage pending Q57. +**N10.** `force_ssl=false` allowing plaintext is documented behavior. +**N11.** `Authenticator` running once per connection is documented behavior. +**N12.** Issues in test/example/tools/build/community are out of model. +**N13.** Hard-coded certificates under `test/` are out of model. +**N14.** HTTP Basic credentials in plaintext are a consequence of the operator choosing plaintext transport. +**N15.** MD5/SHA-1 used for load-balancer hash distribution is not cryptographic use. +**N16.** Wire-reachable integer overflow/OOB is a valid vulnerability class and should not be mechanically closed. +**N17.** `session_local_data_factory` lifetime is guaranteed by the user as required by documentation. +**N18.** The bthread keytable pool retaining objects is a reuse design. +**N19.** `Channel.Init()` being non-thread-safe is documented behavior. +**N20.** Runtime certificate add/delete/modify concurrency safety is judged case by case and needs maintainer confirmation. +**N21.** Built-ins being handled on the main port is default behavior; triage pending Q50. +**N22.** The sniffer trying multiple parsers on unknown bytes is by design. +**N23.** A fuzz crash in an opt-in parser can still be `VALID` when the path is enabled. +**N24.** Protobuf `TotalBytesLimit=INT_MAX` is documented design; the outer layer relies on `-max_body_size`. +**N25.** Naming-service backend replies are trusted. +**N26.** Triage for crashes on malformed DNS/consul/nacos replies must distinguish backend compromise from parser hardening. +**N27.** Brute-force authentication using many short connections is rate-limited by the operator. +**N28.** `/rpcz` may expose sampled request information; this is a design risk after enabling it. +**N29.** OpenSSL CVEs are not bRPC CVEs, but they affect deployments. +**N30.** `fork without exec` after bRPC initialization is unsupported. + +--- + +## §12 Conditions that would change this model + +The following changes should trigger model revision: + +1. A new wire protocol parser or sniffer registration. +2. A new built-in admin service, especially one with write capability. +3. Changes to security-relevant defaults in §5a: `-max_body_size`, streaming caps, `-immutable_flags`, built-in defaults, `force_ssl`, TLS verification, TLS protocols, and so on. +4. An opt-in protocol becoming enabled by default. +5. Adding a decompressed-size cap. +6. Adding h2/HPACK bomb defenses. +7. Adding a per-request authentication mode. +8. Adding `SECURITY.md` or a project security page to the repository. +9. The PMC publishing a security policy. +10. A new CVE class that cannot be triaged under §13. + +Vulnerabilities should be triaged against the model in effect for the affected version, not against the latest `HEAD`. + +--- + +## §13 Triage dispositions + +| Disposition | Meaning | Licensed by | +| --- | --- | --- | +| `VALID` | violates a property claimed in §8, with attacker, input, and component all in model | §8, §6, §7 | +| `VALID-HARDENING` | does not violate an existing claim, but the API easily leads to §11 misuse and the project may choose to harden it | §11 | +| `OUT-OF-MODEL: trusted-input` | requires controlling configuration or objects marked trusted by the model | §6 | +| `OUT-OF-MODEL: adversary-not-in-scope` | requires attacker capabilities excluded by the model | §7, §3 | +| `OUT-OF-MODEL: unsupported-component` | located in out-of-scope components such as test/tools/example/build/community/docs | §3 | +| `OUT-OF-MODEL: non-default-build` | appears only under dangerous or non-default configuration | §5a | +| `BY-DESIGN: property-disclaimed` | belongs to a property explicitly disclaimed in §9 | §9 | +| `KNOWN-NON-FINDING` | matches a recurring false-positive pattern in §11a | §11a | +| `MODEL-GAP` | cannot be classified and requires model revision | §12 | + +### Worked routing examples + +- `Parse*Message` OOB write: `VALID`. +- Private key leak under `test/`: `OUT-OF-MODEL: unsupported-component`. +- Client TLS not verifying server certificates by default: `BY-DESIGN`. +- Remote access to `/flags/max_body_size?setvalue=0`: operator responsibility. +- `/dir?path=/etc/passwd`: if it requires `-enable_dir_service=true`, `OUT-OF-MODEL: non-default-build`. +- A small gzip payload decompressing to huge output and causing OOM: `VALID`. +- h2 HPACK bomb: if defenses are missing and reachable, potentially `VALID`. +- Hostile consul backend redirecting a client: `OUT-OF-MODEL: adversary-not-in-scope`; malformed reply crashes can be judged separately. +- OpenSSL CVE: `BY-DESIGN: property-disclaimed`. +- `Channel.Init()` data race: `BY-DESIGN`, because documentation says it is not thread-safe. + +--- + +## §14 Open questions for the maintainers + +Each question includes the draft recommendation and awaits PMC confirmation, correction, or deletion. After confirmation, the corresponding *(inferred)* tags in the body should be updated to *(maintainer)*. + +--- + +### Wave 1 — scope, adversary, and the insecure-default rulings (must answer first) + +**Q1.** Is the server-side wire peer untrusted by default, so the runtime cannot assume input matches any registered protocol? Yes. +**Q2.** Is client-side deserialization of server responses in the same model scope as server-side handling of client requests? Yes. +**Q40.** Is `-max_body_size=64 MiB` a supported production default, or a baseline that operators must lower per deployment? Production default. +**Q50.** `has_builtin_services=true` + `internal_port=-1` puts built-ins on the main port by default; how should public exposure of `/flags`/`/version` be triaged? Operator responsibility. +**Q46.** `-immutable_flags=false` allows reachable `/flags` to modify reloadable gflags; should the default be changed to true? No, operator responsibility. +**Q57.** Client TLS does not verify server certificates by default; should successful MITM be `BY-DESIGN` or `VALID`/`VALID-HARDENING`? `BY-DESIGN`. +**Q56.** Is it acceptable that `ChannelSSLOptions.protocols` includes TLS 1.0/1.1 by default? `BY-DESIGN`. +**Q51.** Is `force_ssl=false`, which makes a TLS port still accept plaintext, only operator responsibility? Operator responsibility. + +### Wave 2 — what the runtime does (and does not do) to its host + +**Q9.** Is bRPC explicitly an embedded library rather than a secure-by-default daemon? Yes. +**Q10.** Does the user handler run in the same process as the deserializer, with no sandbox provided by bRPC? Yes. +**Q11.** Are TLS/SSL cryptographic guarantees entirely inherited from OpenSSL? Yes. +**Q12.** Does bRPC validate only wire framing, not application field semantics? Yes. +**Q24.** Is Windows outside the security support environment? Yes. +**Q28.** Do maintainers agree that parser/read/decompress paths are memory-unsafe C++ attack surfaces? Yes. + +### Wave 3 — wire-protocol parser surface (per-protocol hardening parity) + +**Q8.** Is the list of default-on, opt-in, and client-only protocols accurate? Yes. +**Q37.** Should `enabled_protocols` ensure parsers not listed are never tried? Yes. +**Q22.** Should vulnerabilities in opt-in parsers be triaged as `OUT-OF-MODEL: non-default-build` when not enabled? Parsers are all enabled by default, so triage as `VALID`. +**Q80.** Do all parsers with length framing uniformly check `FLAGS_max_body_size`? Yes. +**Q59.** Does h2/HPACK have defenses against bomb, CONTINUATION flood, PING flood, and SETTINGS flood? Preferably yes. +**Q68.** Does `mcpack2pb` have recursion/depth caps? Preferably yes. +**Q69.** Does `json2pb` have a JSON depth cap? Preferably yes. + +### Wave 4 — built-in admin services trust model + +**Q4.** Is the trust status of built-in admin consumers determined entirely by operator exposure? Yes. +**Q19.** Should public access to default built-ins such as `/version` be triaged as operator responsibility? Yes. +**Q45.** Because `/dir` and `/threads` are disabled by default, should issues after public enablement be triaged as non-default-build? Yes. +**Q71.** Do profiler endpoint `seconds` parameters have reasonable upper bounds? How should DoS after exposure be triaged? Preferably yes. +**Q87.** Is there no isolation between built-in services, so a bug in one handler can affect other services in the same process? This is an issue. + +### Wave 5 — TLS / SSL / authentication + +**Q52.** Is `strict_sni=false` fallback to the default certificate a supported posture? Yes. +**Q54.** Is server-side mTLS being disabled by default a supported posture? Yes. +**Q90.** Which protocols does `Authenticator` actually cover? Does it not cover Thrift/Redis/memcache/Mongo/RTMP/nshead by default? Yes. +**Q77.** Is an authenticated peer still an untrusted adversary at the parser layer? Yes. +**Q47.** Does `-http_header_of_user_ip` have no built-in trusted-proxy verification? Yes; this is not an issue. + +### Wave 6 — resource bounds, compression, streaming + +**Q42.** Should streaming unconsumed bytes being unlimited by default change default behavior or require explicit operator configuration? Yes. +**Q70.** Is it confirmed that bRPC has no decompressed-output size cap? Should `-max_decompressed_size` be added? Preferably yes. +**Q73.** Is it true that there is no per-protocol/per-method body cap and only global `-max_body_size`? There is indeed none, and this is acceptable. +**Q74.** Is there no framework-level connection-count cap? None. +**Q82.** Can the resource policy be stated as: superlinear memory growth over the wire-declared size is a bug when constrained by `-max_body_size`? Yes. +**Q93.** Has the HTTP parser been audited for request smuggling defenses? Yes. + +### Wave 7 — concurrency, properties, and remaining open items + +**Q5.** Does RDMA require explicit build-time and runtime enablement, with related vulnerabilities out of model when not enabled? Yes. +**Q7.** Are `test/`, `example/`, `tools/`, and `community/` out of model? Yes. +**Q18.** Do wire bytes remain untrusted until the user's `Service::Method` is called? Yes. +**Q20.** Is client response parsing modeled equivalently to server request parsing? Yes. +**Q23.** Can crashes on malformed naming-service replies be `VALID`, while malicious backend redirection is out of model? Yes. +**Q72.** Should consul/nacos and similar naming-service parsers defend against malformed input? Yes. +**Q79.** Does bRPC have continuous fuzzing/OSS-Fuzz; if not, does P1 mainly rely on code review and CI? No. +**Q83.** Under the per-connection bthread model, is synchronization of application shared state entirely the user's responsibility? Yes. +**Q84.** Is per-request authorization implemented by `Interceptor` or the handler, not `Authenticator`? Yes. +**Q88.** Should cross-protocol sniffer confusion be handled as a framework `VALID` issue? Yes. + +--- + +## §15 Optional: machine-readable companion + +v1 does not generate a machine-readable sidecar. After the defaults and triage decisions in §14 wave 1 stabilize, generate a derived index for tooling: entry points and parameter trust levels, in-scope/out-of-scope components, security-relevant gflag defaults, §8 properties, §9 disclaimed properties, §11a non-findings, and §13 disposition labels. This document remains the normative specification. + +--- + +## Appendix: SECURITY.md statement -> threat-model § back-map + +At drafting time, the repository had no `SECURITY.md`, and the Apache security project index provided only the generic `security@apache.org` address, with no bRPC project-level security page. + +| Source | Statement | Threat-model section | +| --- | --- | --- | +| `security.apache.org/projects/` | no project-level security content | N/A | +| future `SECURITY.md` | TBD by PMC | to be backfilled in a future version | + +Once the PMC publishes `SECURITY.md` or an official security policy, that artifact should become a higher-authority source. This document should then map back to or link to the official document. + +--- + +*End of v1 draft.* + diff --git a/bazel/config/BUILD.bazel b/bazel/config/BUILD.bazel index d08ea2ec23..eec551da8b 100644 --- a/bazel/config/BUILD.bazel +++ b/bazel/config/BUILD.bazel @@ -136,6 +136,7 @@ config_setting( define_values = { "with_bthread_tracer": "true", }, + visibility = ["//visibility:public"], ) config_setting( diff --git a/bazel/tools/BUILD.bazel b/bazel/tools/BUILD.bazel new file mode 100644 index 0000000000..3b44830568 --- /dev/null +++ b/bazel/tools/BUILD.bazel @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This BUILD file marks bazel/tools/ as a Bazel package so that loads +# of the form `//bazel/tools:xxx.bzl` are valid. The directory only +# hosts .bzl files (brpc_proto_library.bzl / proto_gen.bzl) and does +# not define any build targets. + +package(default_visibility = ["//visibility:public"]) + +exports_files(glob(["*.bzl"])) diff --git a/bazel/tools/brpc_proto_library.bzl b/bazel/tools/brpc_proto_library.bzl new file mode 100644 index 0000000000..22a3c00bee --- /dev/null +++ b/bazel/tools/brpc_proto_library.bzl @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# brpc_proto_library: a proto -> cc_library macro that is robust against +# protobuf and Bazel version churn. +# +# Background: +# - Old protobuf (< 3.21) used to expose a `cc_proto_library` macro +# via `@com_google_protobuf//:protobuf.bzl`; newer protobuf removed +# that macro. +# - Newer protobuf expects callers to use Bazel's native +# `proto_library` / `cc_proto_library` (or the `cc_proto_library` +# re-exported from `@com_google_protobuf//bazel:cc_proto_library.bzl`), +# but older protobuf does not surface a complete `ProtoInfo`, so +# the native rules cannot build it. +# - Bazel 8 additionally removed `cc_proto_library` from +# `@rules_cc//cc:defs.bzl` (the load now resolves to a stub rule +# that fails analysis), forcing every caller to either load it +# from `@com_google_protobuf` or replace it. +# - Bazel's `load()` is static and cannot be conditional on the +# protobuf version, so a single brpc_proto_library.bzl cannot +# directly satisfy both old and new protobuf by switching imports. +# +# Solution (inspired by https://github.com/trpc-group/trpc-cpp/blob/a69d2950f4a9c16b3cb40a2750c69d8940c96820/trpc/trpc.bzl): +# We implement the proto compilation rule from scratch (see +# proto_gen.bzl in this directory). Dependency information is +# propagated through Bazel's native `ProtoInfo` provider, and the +# default well-known-proto source is +# `@com_google_protobuf//:descriptor_proto` -- a `proto_library` +# whose name has been stable across protobuf 3.x / 27.x / 34.x. +load("//bazel/tools:proto_gen.bzl", "brpc_proto_gen") + +# `:descriptor_proto` is the protobuf repo's `proto_library` target; +# its label has been stable since protobuf 3.x. We consume it through +# Bazel's native `ProtoInfo` provider, which is uniform across PB +# versions. If a caller's .proto needs additional well-known protos +# (any.proto, timestamp.proto, ...), they can pass extra entries via +# the `proto_deps` argument. +_DEFAULT_PROTO_DEPS = ["@com_google_protobuf//:descriptor_proto"] + +def brpc_proto_library( + name, + srcs, + deps = [], + include = None, + proto_deps = None, + visibility = None, + testonly = 0): + """Generate a cc_library from a set of .proto files. + + Args: + name: target name. + srcs: list of .proto files, paths relative to the current package. + deps: list of other `brpc_proto_library` targets this target + depends on. The macro internally translates these labels + to the underlying `_genproto` rule so that .proto + include paths and sources propagate. + include: protoc `-I` root AND the resulting cc_library `includes` + root, relative to the current package. + When omitted, "" or None, the include root is the + current package itself (suitable for .proto files + sitting directly under the package root, as in `test/` + and `example/...`). The root `BUILD.bazel` of brpc must + pass `"src"` so that code can reference the protos as + `import "brpc/foo.proto"`. + proto_deps: list of native `proto_library` dependencies + (well-known protos or external .proto libraries). + Defaults to + `["@com_google_protobuf//:descriptor_proto"]`. + Pass `[]` explicitly to disable the default; pass + None (the default) to use it. + visibility: same semantics as cc_library. + testonly: same semantics as cc_library. + """ + + real_include = include if include != None else "" + real_proto_deps = proto_deps if proto_deps != None else _DEFAULT_PROTO_DEPS + + gen_name = name + "_genproto" + brpc_proto_gen( + name = gen_name, + srcs = srcs, + # Rewrite each `brpc_proto_library` dep to its underlying + # `_genproto` rule, which exports BrpcProtoInfo. + deps = [d + "_genproto" for d in deps], + include = real_include, + proto_deps = real_proto_deps, + visibility = visibility, + testonly = testonly, + ) + + # Split the generated files into .pb.h / .pb.cc via OutputGroupInfo + # and feed them to cc_library's hdrs / srcs separately. Bazel + # rejects declaring the same File in both hdrs and srcs, so a + # plain DefaultInfo wouldn't work here. + native.filegroup( + name = gen_name + "_hdrs", + srcs = [":" + gen_name], + output_group = "hdrs", + visibility = visibility, + testonly = testonly, + ) + native.filegroup( + name = gen_name + "_srcs", + srcs = [":" + gen_name], + output_group = "srcs", + visibility = visibility, + testonly = testonly, + ) + + native.cc_library( + name = name, + srcs = [":" + gen_name + "_srcs"], + hdrs = [":" + gen_name + "_hdrs"], + # cc_library `includes` is required, otherwise the .pb.cc + # files inside this cc_library cannot find the .pb.h headers + # they just generated (the headers live under + # bazel-bin///...). When include="" we pass + # "." to mean "the current package itself"; Bazel then exposes + # both `-I ` and `-I bazel-bin/` + # automatically to dependents. + includes = [real_include if real_include else "."], + deps = deps + ["@com_google_protobuf//:protobuf"], + visibility = visibility, + testonly = testonly, + ) diff --git a/bazel/tools/proto_gen.bzl b/bazel/tools/proto_gen.bzl new file mode 100644 index 0000000000..554d24df87 --- /dev/null +++ b/bazel/tools/proto_gen.bzl @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# brpc's self-implemented proto -> .pb.{h,cc} compilation rule. +# +# Cross-protobuf-version compatibility strategy (inspired by trpc.bzl): +# 1) Do NOT load `@com_google_protobuf//:protobuf.bzl` and do NOT +# call `native.cc_proto_library` -- this side-steps the double +# incompatibility of newer protobuf removing the macro and older +# protobuf missing a complete `ProtoInfo` provider. +# 2) Do NOT hardcode any protobuf-repo-internal filegroup names +# (such as `well_known_protos` / `descriptor_proto_srcs`) -- the +# names of those filegroups have shifted between protobuf 3.x +# and 27.x+. +# 3) Propagate dependencies via Bazel's native `ProtoInfo` provider: +# callers pass `proto_library` targets through `proto_deps` +# (defaulting to `@com_google_protobuf//:descriptor_proto`, a +# `proto_library` whose label has been stable across PB versions), +# and this rule walks `transitive_sources` / `transitive_proto_path` +# to collect .proto source files and protoc `-I` paths -- fully +# decoupled from the PB version. +# 4) Among `brpc_proto_gen` targets themselves, dependencies are +# propagated via the custom `BrpcProtoInfo` provider declared +# below. +# +# Note on minimum Bazel version: this rule itself does not call any +# Bazel-version-specific API, but it still consumes `ProtoInfo` from +# `proto_library` targets (e.g. `descriptor_proto`) defined inside +# the protobuf repo. Protobuf >= 34 calls +# `ProtoInfo(option_deps = ..., extension_declarations = ...)` in its +# own `proto_library` implementation, and those fields are only +# accepted by Bazel >= 8. In other words, the minimum Bazel version +# brpc requires is whatever the bundled protobuf version requires -- +# we just don't add any extra requirement on top of that. + +BrpcProtoInfo = provider( + doc = "Carries transitive .proto sources and protoc -I flags between brpc_proto_gen targets.", + fields = { + "transitive_srcs": "depset of all transitive .proto Files.", + "transitive_imports": "depset of strings, all -I flags needed by protoc.", + }, +) + +def _resolve_include_dir(ctx): + """Compute this target's protoc include root (workspace-relative). + + Examples: + ctx.label.package = "" + include = "src" -> "src" + ctx.label.package = "test" + include = "" -> "test" + ctx.label.package = "" + include = "" -> "." + """ + pkg = ctx.label.package + inc = ctx.attr.include.rstrip("/") + if pkg and inc: + return pkg + "/" + inc + if pkg: + return pkg + if inc: + return inc + return "." + +def _proto_gen_impl(ctx): + srcs = ctx.files.srcs + include_dir = _resolve_include_dir(ctx) + bin_root = ctx.bin_dir.path + + # `-I` flags for this target itself: the source-tree root plus + # the corresponding bin-dir root. The bin-dir entry is needed + # when a transitive dep generates .proto files into bazel-bin + # (e.g. via a custom code generator). + own_imports = ["-I" + include_dir] + if include_dir == ".": + own_imports.append("-I" + bin_root) + else: + own_imports.append("-I" + bin_root + "/" + include_dir) + + # Collect transitive info from other `brpc_proto_gen` deps. + dep_srcs_list = [d[BrpcProtoInfo].transitive_srcs for d in ctx.attr.deps] + dep_imports_list = [d[BrpcProtoInfo].transitive_imports for d in ctx.attr.deps] + + # Collect native `proto_library` deps (well-known protos or + # external .proto libraries). Key design: .proto sources and + # `-I` paths are obtained via Bazel's native `ProtoInfo`, + # avoiding any reliance on protobuf-repo-internal filegroup + # names (descriptor_proto_srcs / well_known_protos / ...). + # `transitive_proto_path` already accounts for the + # `_virtual_imports` paths created by `strip_import_prefix`, so + # protoc can consume the paths directly. + # + # Important: `strip_import_prefix` causes .proto files to live at + # virtual paths under bazel-out//bin// + # (see trpc.bzl for prior art). For every `proto_dep` repo we + # therefore expose FOUR candidate `-I` paths to be safe: + # 1) -- source root (plain proto_library) + # 2) / -- virtual root after strip_import_prefix + # 3) /src -- fallback for older PB descriptor_proto without strip + # 4) //src -- same as #3, under bin-dir + # protoc only warns on non-existent `-I` paths, so this is harmless. + proto_dep_src_depsets = [] + proto_dep_imports = [] + extra_pb_root_imports = [] + # Belt-and-suspenders: also feed each `proto_dep`'s default outputs + # (i.e. the .proto files themselves) into `inputs`, in case the + # ProtoInfo's transitive_sources does not cover everything we need. + proto_dep_src_depsets.append(depset(direct = ctx.files.proto_deps)) + for pd in ctx.attr.proto_deps: + pi = pd[ProtoInfo] + proto_dep_src_depsets.append(pi.transitive_sources) + for path in pi.transitive_proto_path.to_list(): + proto_dep_imports.append("-I" + path) + wsroot = pd.label.workspace_root + if wsroot: + extra_pb_root_imports.append("-I" + wsroot) + extra_pb_root_imports.append("-I" + bin_root + "/" + wsroot) + extra_pb_root_imports.append("-I" + wsroot + "/src") + extra_pb_root_imports.append("-I" + bin_root + "/" + wsroot + "/src") + # Deduplicate the workspace-level `-I` entries so the same repo + # is not listed multiple times when several proto_deps share it. + proto_dep_imports.extend(depset(extra_pb_root_imports).to_list()) + + all_imports = depset( + direct = own_imports + proto_dep_imports, + transitive = dep_imports_list, + ) + + # Output files: every .proto produces a .pb.h and a .pb.cc. + # NB: the `path` argument of `ctx.actions.declare_file(path)` is + # *relative to the current package*, while `src.short_path` is + # *relative to the workspace root*. The two differ by the package + # prefix; we must strip that prefix before calling declare_file, + # otherwise the outputs would land at + # bazel-bin///... (one level too deep). + pkg = ctx.label.package + pkg_prefix = pkg + "/" if pkg else "" + outs = [] + for src in srcs: + if not src.short_path.endswith(".proto"): + fail("brpc_proto_gen srcs must be .proto, got: %s" % src.short_path) + + rel = src.short_path + if pkg_prefix and rel.startswith(pkg_prefix): + rel = rel[len(pkg_prefix):] + base = rel[:-len(".proto")] + outs.append(ctx.actions.declare_file(base + ".pb.h")) + outs.append(ctx.actions.declare_file(base + ".pb.cc")) + + # protoc's --cpp_out points at the include root under bin_root. + # After protoc organizes outputs by their import-relative path, + # the .pb.{h,cc} files land exactly where declare_file declared + # them above. + if include_dir == ".": + cpp_out_dir = bin_root + else: + cpp_out_dir = bin_root + "/" + include_dir + + args = ctx.actions.args() + args.add_all(all_imports.to_list()) + args.add("--cpp_out=" + cpp_out_dir) + args.add_all([s.path for s in srcs]) + + ctx.actions.run( + executable = ctx.executable.protoc, + arguments = [args], + inputs = depset( + direct = srcs, + transitive = dep_srcs_list + proto_dep_src_depsets, + ), + outputs = outs, + mnemonic = "BrpcProtoCompile", + progress_message = "BrpcProtoCompile %s" % ctx.label, + ) + + hdr_files = [f for f in outs if f.path.endswith(".pb.h")] + src_files = [f for f in outs if f.path.endswith(".pb.cc")] + + return [ + DefaultInfo(files = depset(outs)), + # Split outputs so the brpc_proto_library macro can route + # them to cc_library.hdrs vs cc_library.srcs separately. + OutputGroupInfo( + hdrs = depset(hdr_files), + srcs = depset(src_files), + ), + BrpcProtoInfo( + transitive_srcs = depset(direct = srcs, transitive = dep_srcs_list), + transitive_imports = all_imports, + ), + ] + +brpc_proto_gen = rule( + implementation = _proto_gen_impl, + attrs = { + "srcs": attr.label_list( + allow_files = [".proto"], + mandatory = True, + ), + # Other `brpc_proto_gen` targets; deps are propagated via the + # custom `BrpcProtoInfo` provider. + "deps": attr.label_list( + providers = [BrpcProtoInfo], + default = [], + ), + # The include root, relative to the current BUILD package + # (e.g. "src" for the brpc root BUILD). + "include": attr.string(default = ""), + "protoc": attr.label( + default = Label("@com_google_protobuf//:protoc"), + executable = True, + # `"host"` is the legacy spelling of `"exec"`. Modern + # Bazel (>= 6) prefers `"exec"`, but still accepts + # `"host"` as an alias through at least Bazel 8.x. We + # keep `"host"` here for the broad range of Bazel + # versions used to build brpc; switch to `"exec"` once + # `"host"` is finally removed (likely a future major + # Bazel release). + cfg = "host", + allow_files = True, + ), + # Native `proto_library` deps (well-known protos or external + # .proto libraries). .proto sources and -I paths are obtained + # automatically through Bazel's native `ProtoInfo`, fully + # decoupled from the protobuf version. + "proto_deps": attr.label_list( + providers = [ProtoInfo], + default = [], + ), + }, +) diff --git a/cmake/CMakeLists.download_gtest.in b/cmake/CMakeLists.download_gtest.in index df020c89fc..78d99bafb7 100644 --- a/cmake/CMakeLists.download_gtest.in +++ b/cmake/CMakeLists.download_gtest.in @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 2.8.10) +cmake_minimum_required(VERSION 2.8.12) project(googletest-download NONE) diff --git a/docs/cn/bthread_tracer.md b/docs/cn/bthread_tracer.md index cc6e3a1eda..bab09ce40c 100644 --- a/docs/cn/bthread_tracer.md +++ b/docs/cn/bthread_tracer.md @@ -58,7 +58,7 @@ jump_stack是bthread挂起或者运行的必经之路,也是STB的拦截点。 # 使用方法 -1. 下载安装libunwind和abseil-cpp。 +1. 下载安装libunwind和abseil-cpp。**注意:libunwind 必须从源码编译,不要使用系统包管理器安装的 `libunwind-dev` / `libunwind-devel`**,否则会触发本文末尾「[已知问题:libunwind 与 libgcc_s 的 `_Unwind_*` 符号冲突](#已知问题libunwind-与-libgcc_s-的-_unwind_-符号冲突)」中描述的崩溃。bazel 构建可以跳过此步,直接使用 brpc 仓库自维护的 libunwind 版本。 2. 给config_brpc.sh增加`--with-bthread-tracer`选项或者给cmake增加`-DWITH_BTHREAD_TRACER=ON`选项或者给bazel(Bzlmod模式)增加`--define with_bthread_tracer=true`选项。 3. 访问服务的内置服务:`http://ip:port/bthreads/?st=1`或者代码里调用`bthread::stack_trace()`函数。 4. 如果希望追踪pthread的调用栈,在对应pthread上调用`bthread::init_for_pthread_stack_trace()`函数获取一个伪bthread_t,然后使用步骤3即可获取pthread调用栈。 @@ -73,6 +73,145 @@ jump_stack是bthread挂起或者运行的必经之路,也是STB的拦截点。 #5 0x00007fdbbfa58dc0 bthread::TaskGroup::task_runner() ``` +# 已知问题 + +## libunwind 与 libgcc_s 的 `_Unwind_*` 符号冲突 + +### 现象 + +启用 bthread tracer 后,可能在 `bthread_exit` / `pthread_exit` 或者 C++ 异常处理路径上偶发段错误,类似如下调用栈: + +```text +#0 0x0000000000000000 in ?? () +#1 0x00007fa2b5d6458a in _ULx86_64_dwarf_find_proc_info () + from /root/.cache/bazel/_bazel_root/743b333b2429a1dbd390ef66b59c771d/execroot/_main/bazel-out/k8-fastbuild/bin/test/../_solib_k8/libexternal_Slibunwind~_Slibunwind.so +#2 0x00007fa2b5d6668d in fetch_proc_info () + from /root/.cache/bazel/_bazel_root/743b333b2429a1dbd390ef66b59c771d/execroot/_main/bazel-out/k8-fastbuild/bin/test/../_solib_k8/libexternal_Slibunwind~_Slibunwind.so +#3 0x00007fa2b5d681a1 in _ULx86_64_dwarf_make_proc_info () + from /root/.cache/bazel/_bazel_root/743b333b2429a1dbd390ef66b59c771d/execroot/_main/bazel-out/k8-fastbuild/bin/test/../_solib_k8/libexternal_Slibunwind~_Slibunwind.so +#4 0x00007fa2b5d70cfd in _ULx86_64_get_proc_info () + from /root/.cache/bazel/_bazel_root/743b333b2429a1dbd390ef66b59c771d/execroot/_main/bazel-out/k8-fastbuild/bin/test/../_solib_k8/libexternal_Slibunwind~_Slibunwind.so +#5 0x00007fa2b5d6c775 in __libunwind_Unwind_GetLanguageSpecificData () + from /root/.cache/bazel/_bazel_root/743b333b2429a1dbd390ef66b59c771d/execroot/_main/bazel-out/k8-fastbuild/bin/test/../_solib_k8/libexternal_Slibunwind~_Slibunwind.so +#6 0x00007fa2b503c6df in __gxx_personality_v0 () from /lib/x86_64-linux-gnu/libstdc++.so.6 +#7 0x00007fa2b5452ce5 in ?? () from /lib/x86_64-linux-gnu/libgcc_s.so.1 +#8 0x00007fa2b54533c0 in _Unwind_ForcedUnwind () from /lib/x86_64-linux-gnu/libgcc_s.so.1 +#9 0x00007fa2b4ca57a4 in __GI___pthread_unwind (buf=) at ./nptl/unwind.c:130 +#10 0x00007fa2b4c9dd22 in __do_cancel () at ../sysdeps/nptl/pthreadP.h:271 +#11 __GI___pthread_exit (value=0x0) at ./nptl/pthread_exit.c:36 +#12 0x0000000000000000 in ?? () +``` + +### 根因 + +libunwind 的 `src/unwind/*.c` 实现了 GCC 的 `_Unwind_*` ABI 兼容层(`_Unwind_GetLanguageSpecificData`、`_Unwind_ForcedUnwind`、`_Unwind_Resume` 等),与 `libgcc_s.so.1` 提供同名的全局符号。当 libunwind 以**动态库**形式被链接,并且在最终二进制的 `DT_NEEDED` 列表中位置比 `libgcc_s.so.1` 靠前时,运行时动态链接器 `ld.so` 会把 `pthread_exit` / 异常处理触发的 `_Unwind_*` 调用解析到 libunwind 中的 DWARF 实现。该实现需要的内部上下文在 `pthread_exit` 路径上未被 bRPC 初始化好,从而触发空指针访问。 + +这是一个 ELF **运行时符号解析顺序**问题,与编译器(GCC / Clang)无关 —— Clang 默认运行时同样使用 `libstdc++ + libgcc_s`,会复现完全一致的崩溃。 + +### 解决方案 + +> **重要:不要使用系统包管理器安装的 libunwind**(例如 `apt install libunwind-dev`、`yum install libunwind-devel`)。多数发行版打包的 `libunwind.so` 仍把 `_Unwind_*` 暴露在动态符号表中,会触发本节描述的崩溃。 +> +> 必须使用**从源码编译的 libunwind**。上游 `./configure` + `make` 默认会通过 `-Wl,--version-script` 把 `_Unwind_*` 标为 local,不导出到动态符号表,从而避免冲突。 + +下表汇总了三种构建方式的推荐方案: + +| 构建方式 | 推荐方案 | +|---|---| +| `config_brpc.sh` + `make` | 从源码编译并安装 libunwind,把头文件与库目录显式传给 `config_brpc.sh` | +| `cmake` | 从源码编译并安装 libunwind,把头文件与库目录显式传给 `cmake` | +| `bazel`(Bzlmod) | 直接使用 brpc 仓库自维护的 libunwind 版本 | + +### make (config_brpc.sh) + +源码编译并安装 libunwind 到一个独立目录(避免污染系统目录),然后让 `config_brpc.sh` 显式从该目录查找 libunwind。 + +```bash +# 1) 源码编译 libunwind(推荐 v1.8.1 或以上版本) +git clone https://github.com/libunwind/libunwind.git +cd libunwind && git checkout tags/v1.8.1 +mkdir -p /opt/libunwind +autoreconf -i +./configure --prefix=/opt/libunwind +make -j$(nproc) && make install +cd .. + +# 2) 让 config_brpc.sh 使用 /libunwind 下的头文件与库(不要让它自动找到系统的 libunwind-dev) +cd brpc +sh config_brpc.sh \ + --with-bthread-tracer \ + --headers="/opt/libunwind/include /usr/include" \ + --libs="/opt/libunwind/lib /usr/lib /usr/lib64" +make -j$(nproc) +``` + +构建完成后可用以下命令确认 libunwind.so 没有导出 `_Unwind_*`: + +```bash +nm -D /libunwind/lib/libunwind.so | grep ' T _Unwind_' \ + && echo "WARN: _Unwind_* exported" \ + || echo "OK: _Unwind_* hidden" +``` + +### cmake + +[`CMakeLists.txt`](../../CMakeLists.txt:90-100) 通过 `find_library(... NAMES unwind unwind-x86_64)` 查找 libunwind。同样需要先源码编译 libunwind 到独立前缀,再用 `CMAKE_PREFIX_PATH` 让 cmake 优先在该前缀下查找: + +```bash +# 1) 源码编译 libunwind(同 make 章节) + +# 2) 让 cmake 在 /libunwind 下优先查找头文件和库 +cd brpc +mkdir build && cd build +cmake -DWITH_BTHREAD_TRACER=ON \ + -DCMAKE_PREFIX_PATH=/opt/libunwind \ + .. +make -j$(nproc) +``` + +> 提示:如果系统已经装了 `libunwind-dev`,`find_library` 仍可能优先匹配到 `/usr/lib`。可在 cmake 命令上额外指定 +> `-DLIBUNWIND_LIB=/libunwind/lib/libunwind.so -DLIBUNWIND_X86_64_LIB=/libunwind/lib/libunwind-x86_64.so -DLIBUNWIND_INCLUDE_PATH=/libunwind/include` +> 强制走自编译版本,避免系统包混入。 + +### bazel (Bzlmod) + +bRPC 仓库已经在 [`registry/modules/libunwind/`](../../registry/modules/libunwind/) 维护了一份 libunwind 的 Bzlmod overlay,并通过 [`.bazelrc`](../../.bazelrc) 中的 `--registry=https://github.com/apache/brpc/registry` 使用 bRPC 维护的 overlay。版本号采用 `.brpc-no-unwind` 后缀(例如 `1.8.3.brpc-no-unwind`),用以区别于 BCR 上的同基版本。该 overlay 增加了一个开关: + +``` +--define libunwind_hide_unwind_symbols=true +``` + +开启后,libunwind 的 `src/unwind/*.c`(即 GCC `_Unwind_*` 兼容层)整体不参与编译,等效于上游 autoconf 默认效果。bRPC 只使用 libunwind 的 `unw_*` 原生 API(`unw_getcontext`、`unw_init_local`、`unw_step` 等),不依赖 `_Unwind_*` 兼容层,因此该开关安全无副作用。 + +`.bazelrc` 已默认在 `build:test` / `test` 配置下打开此开关: + +``` +build:test --define libunwind_hide_unwind_symbols=true +test --define libunwind_hide_unwind_symbols=true +``` + +用户按文档使用方法第 2 步加上 `--define=with_bthread_tracer=true` 即可: + +```bash +# 测试场景,.bazelrc 中 test 配置已自动带上 hide 开关 +bazel test //test:bthread_unittest + +# 非测试构建(生产部署),需要显式同时带上两个 define +bazel build --define=with_bthread_tracer=true \ + --define=libunwind_hide_unwind_symbols=true \ + //... +``` + +> **特别注意**:如果在生产构建中只启用 `--define=with_bthread_tracer=true` 而漏掉 `--define=libunwind_hide_unwind_symbols=true`,binary 在 `pthread_exit` / 异常路径上会概率性崩溃。 + +构建后可用以下命令验证 libunwind 共享库没有导出 `_Unwind_*`: + +```bash +nm -D bazel-bin/external/_solib_*/libexternal*libunwind*.so 2>/dev/null \ + | grep ' T _Unwind_' || echo "OK: no _Unwind_* exported by libunwind.so" +``` + + # 相关flag - `signal_trace_timeout_ms`:信号追踪模式的超时时间,默认为50ms。 \ No newline at end of file diff --git a/docs/cn/ubring.md b/docs/cn/ubring.md new file mode 100644 index 0000000000..6519ae3f9f --- /dev/null +++ b/docs/cn/ubring.md @@ -0,0 +1,199 @@ +# UBRing: 高性能共享内存 RPC + +UBRing 是 brpc 中的高性能 RPC 实现,它利用共享内存进行进程间通信(IPC)。它支持本地共享内存(POSIX IPC)和远端共享内存(ubs-mem)两种模式,提供微秒到纳秒级的进程间通信延迟。 + +## 技术背景 + +传统的 RPC 框架通常使用网络套接字进行通信,由于内核参与、上下文切换和数据拷贝等原因,会引入显著的开销。UBRing 通过使用共享内存作为通信介质来解决这个问题,允许进程之间直接内存访问,最小化内核干预。 + +UBRing 的主要优势: + +- **超低延迟**:微秒级 RPC 延迟 +- **高吞吐量**:每秒数百万次 RPC 调用 +- **减少数据拷贝**:进程间直接内存访问 +- **跨平台支持**:支持 Linux 和 macOS + +## 支持的共享内存后端 + +UBRing 支持两种共享内存后端,通过 `ub_shm_type` 参数控制: + +### 1. POSIX IPC 共享内存 (ub\_shm\_type = 1) + +这是默认模式,使用标准 POSIX 共享内存进行本地 IPC。同一机器上的进程可以通过共享内存区域直接通信。 + +### 2. UBS-Mem 远端共享内存 (ub\_shm\_type = 2) + +此模式使用 ubs-mem(Unified Block Storage Memory),这是来自 openEuler 的开源远端共享内存框架。它支持机架内节点之间的共享内存通信,类似于 RDMA 但部署要求更简单。 + +**UBS-Mem 开源地址**: + +**所需库文件**: +- `libubsm_sdk.so` - UBS-Mem SDK 库(安装路径:`/usr/local/ubs_mem/lib/libubsm_sdk.so`) +- UBS-Mem 通过 `dlopen()` 动态加载该库,并使用 `ubsmem_initialize()`、`ubsmem_create_region()`、`ubsmem_shmem_allocate()`、`ubsmem_shmem_map()` 等函数 + +**UBS-Mem 关键函数**: +- `ubsmem_init_attributes()` - 初始化 UBS-Mem 属性 +- `ubsmem_initialize()` - 初始化 UBS-Mem 库 +- `ubsmem_finalize()` - 释放 UBS-Mem 库 +- `ubsmem_create_region()` - 创建共享内存区域 +- `ubsmem_shmem_allocate()` - 分配共享内存 +- `ubsmem_shmem_map()` - 将共享内存映射到本地地址空间 +- `ubsmem_shmem_unmap()` - 解除共享内存映射 +- `ubsmem_shmem_deallocate()` - 释放共享内存 +- `ubsmem_destroy_region()` - 销毁共享内存区域 + +### 未来扩展 + +该架构设计支持未来扩展 CXL(Compute Express Link)基于的远端共享内存,实现更灵活的分布式内存共享。 + +## 构建配置 + +### 使用 CMake 构建 + +要构建带有 UBRing 支持的 brpc,请使用以下命令: + +```bash +# 构建 brpc 并启用 UBRing 支持 +cd /path/to/brpc +cmake -B build -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DWITH_UBRING:BOOL=ON +cmake --build build -j 8 + +# 构建 ubring_performance 示例 +cd /path/to/brpc/example/ubring_performance +cmake -B build +cmake --build build -j 8 +``` + +### 使用 Bazel 构建 + +使用 Bazel 构建带有 UBRing 支持的 brpc: + +```bash +# 构建 brpc 并启用 UBRing 支持 +cd /path/to/brpc +bazel build //... --define=with_ubring=true + +# 构建 ubring_performance 示例 +bazel build //example/ubring_performance/... +``` + +### 选择共享内存后端 + +共享内存后端通过 `--ub_shm_type` 参数控制: + +```bash +# 使用 POSIX IPC(默认) +./your_program --ub_shm_type=1 + +# 使用 UBS-Mem +./your_program --ub_shm_type=2 +``` + +## 性能测试 + +### 示例: ubring\_performance + +brpc 在 `example/ubring_performance/` 目录提供了性能测试示例。 + +#### 构建示例 + +```bash +cd example/ubring_performance +mkdir -p build && cd build +cmake .. +make +``` + +#### 运行服务端 + +```bash +# 使用 POSIX IPC +./ubring_performance_server --ub_shm_type=1 + +# 使用 UBS-Mem +./ubring_performance_server --ub_shm_type=2 +``` + +#### 运行客户端 + +```bash +# 使用 POSIX IPC +./ubring_performance_client --ub_shm_type=1 --server=127.0.0.1:8000 + +# 使用 UBS-Mem +./ubring_performance_client --ub_shm_type=2 --server=:8000 +``` + +#### 测试选项 + +| 选项 | 描述 | 默认值 | +| --------------- | ------------------------- | -------------- | +| `--ub_shm_type` | 共享内存类型 (1=IPC, 2=UBS-Mem) | 1 | +| `--server` | 服务端地址 | 127.0.0.1:8000 | +| `--thread_num` | 客户端线程数 | 1 | +| `--request_num` | 每线程请求总数 | 1000000 | +| `--timeout_ms` | 请求超时时间(毫秒) | 1000 | + +## 架构概述 + +```mermaid +graph TD + subgraph 客户端进程 + A[Client] + end + + subgraph 服务端进程 + B[Server] + end + + subgraph 共享内存层 + C[SHM Manager] + D[IPC Backend] + E[UBS-Mem Backend] + end + + A -->|直接内存访问| C + B -->|直接内存访问| C + C --> D + C --> E + + style A fill:#636,color:#fff,stroke:#333,stroke-width:2px + style B fill:#369,color:#fff,stroke:#333,stroke-width:2px + style C fill:#396,color:#fff,stroke:#333,stroke-width:2px +``` + +### 架构细节 + +UBRing 架构包含以下组件: + +1. **客户端/服务端进程**: 通过共享内存通信的应用进程 +2. **SHM Manager**: 共享内存操作的中央管理器 (`shm_mgr.cpp`) +3. **IPC Backend**: 用于本地通信的 POSIX 共享内存实现 +4. **UBS-Mem Backend**: 用于跨节点通信的远端共享内存实现 + +## 实现细节 + +### 共享内存管理 + +共享内存管理器 (`shm_mgr.cpp`) 为不同的共享内存后端提供统一接口: + +- **初始化**: `ShmMgrInit()` - 初始化共享内存子系统 +- **本地分配**: `ShmLocalMalloc()` - 分配本地共享内存 +- **远端分配**: `ShmRemoteMalloc()` - 分配远程节点可访问的共享内存 +- **释放**: `ShmFree()` - 释放共享内存资源 + +### 定时器管理 + +UBRing 使用高精度定时器系统 (`timer_mgr.cpp`) 进行连接管理和超时处理,支持 epoll(Linux)和 kqueue(macOS)。 + +## 参考资料 + +- [UBRing 特性提案](https://github.com/apache/brpc/issues/3226) +- [UBRing 技术讨论](https://github.com/apache/brpc/discussions/3217) +- [UBS-Mem 开源项目](https://atomgit.com/openeuler/ubs-mem) + +## 相关文档 + +- [UB Client](ub_client.md) - 访问 UB 服务 +- [RDMA 支持](rdma.md) - 远程直接内存访问 + diff --git a/docs/en/ubring.md b/docs/en/ubring.md new file mode 100644 index 0000000000..f910facb7b --- /dev/null +++ b/docs/en/ubring.md @@ -0,0 +1,197 @@ +# UBRing: High-Performance Shared Memory RPC + +UBRing is a high-performance RPC implementation in brpc that leverages shared memory for inter-process communication (IPC). It supports both local shared memory (POSIX IPC) and remote shared memory (ubs-mem), providing ultra-low latency communication between processes. + +## Technical Background + +Traditional RPC frameworks typically use network sockets for communication, which introduces significant overhead due to kernel involvement, context switches, and data copying. UBRing addresses this by using shared memory as the communication medium, allowing direct memory access between processes with minimal kernel intervention. + +Key advantages of UBRing: +- **Ultra-low latency**: Microsecond-level RPC latency +- **High throughput**: Millions of RPC calls per second +- **Reduced data copying**: Direct memory access between processes +- **Cross-platform support**: Works on Linux and macOS + +## Supported Shared Memory Backends + +UBRing supports two types of shared memory backends, controlled by the `ub_shm_type` flag: + +### 1. POSIX IPC Shared Memory (ub_shm_type = 1) + +This is the default mode, using standard POSIX shared memory for local IPC. Processes on the same machine can communicate directly through shared memory regions. + +### 2. UBS-Mem Remote Shared Memory (ub_shm_type = 2) + +This mode uses ubs-mem (Unified Block Storage Memory), an open-source remote shared memory framework from openEuler. It enables shared memory communication across nodes in a rack, similar to RDMA but with simpler deployment requirements. + +**UBS-Mem Open Source**: https://atomgit.com/openeuler/ubs-mem + +**Required Libraries**: +- `libubsm_sdk.so` - UBS-Mem SDK library (installed at `/usr/local/ubs_mem/lib/libubsm_sdk.so`) +- UBS-Mem dynamically loads the library via `dlopen()` and uses functions like `ubsmem_initialize()`, `ubsmem_create_region()`, `ubsmem_shmem_allocate()`, `ubsmem_shmem_map()`, etc. + +**UBS-Mem Key Functions**: +- `ubsmem_init_attributes()` - Initialize UBS-Mem attributes +- `ubsmem_initialize()` - Initialize UBS-Mem library +- `ubsmem_finalize()` - Finalize UBS-Mem library +- `ubsmem_create_region()` - Create a shared memory region +- `ubsmem_shmem_allocate()` - Allocate shared memory +- `ubsmem_shmem_map()` - Map shared memory to local address space +- `ubsmem_shmem_unmap()` - Unmap shared memory +- `ubsmem_shmem_deallocate()` - Deallocate shared memory +- `ubsmem_destroy_region()` - Destroy a shared memory region + +### Future Expansion + +The architecture is designed to support CXL (Compute Express Link) based remote shared memory in the future, enabling even more flexible distributed memory sharing. + +## Build Configuration + +### Build with CMake + +To build brpc with UBRing support, use the following commands: + +```bash +# Build brpc with UBRing support +cd /path/to/brpc +cmake -B build -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DWITH_UBRING:BOOL=ON +cmake --build build -j 8 + +# Build the ubring_performance example +cd /path/to/brpc/example/ubring_performance +cmake -B build +cmake --build build -j 8 +``` + +### Build with Bazel + +To build brpc with UBRing support using Bazel: + +```bash +# Build brpc with UBRing support +cd /path/to/brpc +bazel build //... --define=with_ubring=true + +# Build the ubring_performance example +bazel build //example/ubring_performance/... +``` + +### Select Shared Memory Backend + +The shared memory backend is controlled by the `--ub_shm_type` flag: + +```bash +# Use POSIX IPC (default) +./your_program --ub_shm_type=1 + +# Use UBS-Mem +./your_program --ub_shm_type=2 +``` + +## Performance Testing + +### Example: ubring_performance + +brpc provides a performance test example at `example/ubring_performance/`. + +#### Build the Example + +```bash +cd example/ubring_performance +mkdir -p build && cd build +cmake .. +make +``` + +#### Run Server + +```bash +# Run with POSIX IPC +./ubring_performance_server --ub_shm_type=1 + +# Run with UBS-Mem +./ubring_performance_server --ub_shm_type=2 +``` + +#### Run Client + +```bash +# Run with POSIX IPC +./ubring_performance_client --ub_shm_type=1 --server=127.0.0.1:8000 + +# Run with UBS-Mem +./ubring_performance_client --ub_shm_type=2 --server=:8000 +``` + +#### Test Options + +| Option | Description | Default | +|--------|-------------|---------| +| `--ub_shm_type` | Shared memory type (1=IPC, 2=UBS-Mem) | 1 | +| `--server` | Server address | 127.0.0.1:8000 | +| `--thread_num` | Number of client threads | 1 | +| `--request_num` | Total requests per thread | 1000000 | +| `--timeout_ms` | Request timeout in milliseconds | 1000 | + +## Architecture Overview + +```mermaid +graph TD + subgraph Client Process + A[Client] + end + + subgraph Server Process + B[Server] + end + + subgraph Shared Memory + C[SHM Manager] + D[IPC Backend] + E[UBS-Mem Backend] + end + + A -->|Direct Memory Access| C + B -->|Direct Memory Access| C + C --> D + C --> E + + style A fill:#636,color:#fff,stroke:#333,stroke-width:2px + style B fill:#369,color:#fff,stroke:#333,stroke-width:2px + style C fill:#396,color:#fff,stroke:#333,stroke-width:2px +``` + +### Architecture Details + +The UBRing architecture consists of: + +1. **Client/Server Processes**: Application processes that communicate via shared memory +2. **SHM Manager**: Central manager for shared memory operations (`shm_mgr.cpp`) +3. **IPC Backend**: POSIX shared memory implementation for local communication +4. **UBS-Mem Backend**: Remote shared memory implementation for cross-node communication + +## Implementation Details + +### Shared Memory Management + +The shared memory manager (`shm_mgr.cpp`) provides a unified interface for different shared memory backends: + +- **Initialization**: `ShmMgrInit()` - Initializes the shared memory subsystem +- **Local Allocation**: `ShmLocalMalloc()` - Allocates shared memory for local use +- **Remote Allocation**: `ShmRemoteMalloc()` - Allocates shared memory accessible by remote nodes +- **Free**: `ShmFree()` - Releases shared memory resources + +### Timer Management + +UBRing uses a high-precision timer system (`timer_mgr.cpp`) for connection management and timeout handling, supporting both epoll (Linux) and kqueue (macOS). + +## References + +- [UBRing Feature Proposal](https://github.com/apache/brpc/issues/3226) +- [UBRing Technical Discussion](https://github.com/apache/brpc/discussions/3217) +- [UBS-Mem Open Source](https://atomgit.com/openeuler/ubs-mem) + +## See Also + +- [UB Client](ub_client.md) - Accessing UB services +- [RDMA Support](rdma.md) - Remote direct memory access \ No newline at end of file diff --git a/example/BUILD.bazel b/example/BUILD.bazel index df2722a4f6..4ee7cb140f 100644 --- a/example/BUILD.bazel +++ b/example/BUILD.bazel @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//bazel/tools:brpc_proto_library.bzl", "brpc_proto_library") COPTS = [ "-D__STDC_FORMAT_MACROS", @@ -36,32 +36,18 @@ COPTS = [ "//conditions:default": [""], }) -proto_library( - name = "echo_c++_proto", - srcs = [ - "echo_c++/echo.proto", - ], -) - -proto_library( - name = "rdma_performance_proto", - srcs = [ - "rdma_performance/test.proto", - ], -) - -cc_proto_library( +brpc_proto_library( name = "cc_echo_c++_proto", - deps = [ - ":echo_c++_proto", - ], + srcs = ["echo_c++/echo.proto"], + include = "echo_c++", + proto_deps = [], ) -cc_proto_library( +brpc_proto_library( name = "cc_rdma_performance_proto", - deps = [ - ":rdma_performance_proto", - ], + srcs = ["rdma_performance/test.proto"], + include = "rdma_performance", + proto_deps = [], ) cc_binary( diff --git a/example/ubring_performance/CMakeLists.txt b/example/ubring_performance/CMakeLists.txt new file mode 100644 index 0000000000..729381ccb8 --- /dev/null +++ b/example/ubring_performance/CMakeLists.txt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 2.8.10) +project(ubring_performance C CXX) + +option(LINK_SO "Whether examples are linked dynamically" OFF) + +execute_process( + COMMAND bash -c "find ${PROJECT_SOURCE_DIR}/../.. -type d -regex '.*output/include$' | head -n1 | xargs dirname | tr -d '\n'" + OUTPUT_VARIABLE OUTPUT_PATH +) + +set(CMAKE_PREFIX_PATH ${OUTPUT_PATH}) + +include(FindThreads) +include(FindProtobuf) +protobuf_generate_cpp(PROTO_SRC PROTO_HEADER test.proto) +# include PROTO_HEADER +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +# Search for libthrift* by best effort. If it is not found and brpc is +# compiled with thrift protocol enabled, a link error would be reported. +find_library(THRIFT_LIB NAMES thrift) +if (NOT THRIFT_LIB) + set(THRIFT_LIB "") +endif() + +find_path(BRPC_INCLUDE_PATH NAMES brpc/server.h) +if(LINK_SO) + find_library(BRPC_LIB NAMES brpc) +else() + find_library(BRPC_LIB NAMES libbrpc.a brpc) +endif() +if((NOT BRPC_INCLUDE_PATH) OR (NOT BRPC_LIB)) + message(FATAL_ERROR "Fail to find brpc") +endif() +include_directories(${BRPC_INCLUDE_PATH}) + +find_path(GFLAGS_INCLUDE_PATH gflags/gflags.h) +find_library(GFLAGS_LIBRARY NAMES gflags libgflags) +if((NOT GFLAGS_INCLUDE_PATH) OR (NOT GFLAGS_LIBRARY)) + message(FATAL_ERROR "Fail to find gflags") +endif() +include_directories(${GFLAGS_INCLUDE_PATH}) + +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + include(CheckFunctionExists) + CHECK_FUNCTION_EXISTS(clock_gettime HAVE_CLOCK_GETTIME) + if(NOT HAVE_CLOCK_GETTIME) + set(DEFINE_CLOCK_GETTIME "-DNO_CLOCK_GETTIME_IN_MAC") + endif() +endif() + +set(CMAKE_CPP_FLAGS "${DEFINE_CLOCK_GETTIME} -DBRPC_WITH_UBRING=1") +set(CMAKE_CXX_FLAGS "${CMAKE_CPP_FLAGS} -DNDEBUG -O2 -D__const__=__unused__ -pipe -W -Wall -Wno-unused-parameter -fPIC -fno-omit-frame-pointer") + +if(CMAKE_VERSION VERSION_LESS "3.1.3") + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + endif() + if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + endif() +else() + set(CMAKE_CXX_STANDARD 11) + set(CMAKE_CXX_STANDARD_REQUIRED ON) +endif() + +find_path(LEVELDB_INCLUDE_PATH NAMES leveldb/db.h) +find_library(LEVELDB_LIB NAMES leveldb) +if ((NOT LEVELDB_INCLUDE_PATH) OR (NOT LEVELDB_LIB)) + message(FATAL_ERROR "Fail to find leveldb") +endif() +include_directories(${LEVELDB_INCLUDE_PATH}) + +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(OPENSSL_ROOT_DIR + "/usr/local/opt/openssl" # Homebrew installed OpenSSL + ) +endif() + +find_package(OpenSSL) +include_directories(${OPENSSL_INCLUDE_DIR}) + +set(DYNAMIC_LIB + ${CMAKE_THREAD_LIBS_INIT} + ${GFLAGS_LIBRARY} + ${PROTOBUF_LIBRARIES} + ${LEVELDB_LIB} + ${OPENSSL_CRYPTO_LIBRARY} + ${OPENSSL_SSL_LIBRARY} + ${THRIFT_LIB} + dl + z + ) + +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + set(DYNAMIC_LIB ${DYNAMIC_LIB} + pthread + "-framework CoreFoundation" + "-framework CoreGraphics" + "-framework CoreData" + "-framework CoreText" + "-framework Security" + "-framework Foundation" + "-Wl,-U,_MallocExtension_ReleaseFreeMemory" + "-Wl,-U,_ProfilerStart" + "-Wl,-U,_ProfilerStop" + "-Wl,-U,__Z13GetStackTracePPvii" + "-Wl,-U,_mallctl" + "-Wl,-U,_malloc_stats_print" + ) +endif() + +add_executable(ubring_performance_client client.cpp ${PROTO_SRC} ${PROTO_HEADER}) +add_executable(ubring_performance_server server.cpp ${PROTO_SRC} ${PROTO_HEADER}) + +target_link_libraries(ubring_performance_client ${BRPC_LIB} ${DYNAMIC_LIB}) +target_link_libraries(ubring_performance_server ${BRPC_LIB} ${DYNAMIC_LIB}) \ No newline at end of file diff --git a/example/ubring_performance/client.cpp b/example/ubring_performance/client.cpp new file mode 100644 index 0000000000..c14268a430 --- /dev/null +++ b/example/ubring_performance/client.cpp @@ -0,0 +1,328 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include "butil/atomicops.h" +#include "butil/fast_rand.h" +#include "butil/logging.h" +#include "brpc/server.h" +#include "brpc/channel.h" +#include "bthread/bthread.h" +#include "bvar/latency_recorder.h" +#include "bvar/variable.h" +#include "test.pb.h" + +#ifdef BRPC_WITH_UBRING + +DEFINE_int32(thread_num, 0, "How many threads are used"); +DEFINE_int32(queue_depth, 1, "How many requests can be pending in the queue"); +DEFINE_int32(expected_qps, 0, "The expected QPS"); +DEFINE_int32(max_thread_num, 16, "The max number of threads are used"); +DEFINE_int32(attachment_size, -1, "Attachment size is used (in Bytes)"); +DEFINE_bool(echo_attachment, false, "Select whether attachment should be echo"); +DEFINE_string(connection_type, "single", "Connection type of the channel"); +DEFINE_string(protocol, "baidu_std", "Protocol type."); +DEFINE_string(servers, "0.0.0.0:8002+0.0.0.0:8002", "IP Address of servers"); +DEFINE_bool(use_ubring, false, "Use UBRING or not"); +DEFINE_int32(rpc_timeout_ms, 5000, "RPC call timeout"); +DEFINE_int32(test_seconds, 20, "Test running time"); +DEFINE_int32(test_iterations, 0, "Test iterations"); +DEFINE_int32(dummy_port, 8001, "Dummy server port number"); + +bvar::LatencyRecorder g_latency_recorder("client"); +bvar::LatencyRecorder g_server_cpu_recorder("server_cpu"); +bvar::LatencyRecorder g_client_cpu_recorder("client_cpu"); +butil::atomic g_last_time(0); +butil::atomic g_total_bytes; +butil::atomic g_total_cnt; +std::vector g_servers; +int rr_index = 0; +volatile bool g_stop = false; + +butil::atomic g_token(10000); + +static void* GenerateToken(void* arg) { + int64_t start_time = butil::monotonic_time_ns(); + int64_t accumulative_token = g_token.load(butil::memory_order_relaxed); + while (!g_stop) { + bthread_usleep(100000); + int64_t now = butil::monotonic_time_ns(); + if (accumulative_token * 1000000000 / (now - start_time) < FLAGS_expected_qps) { + int64_t delta = FLAGS_expected_qps * (now - start_time) / 1000000000 - accumulative_token; + g_token.fetch_add(delta, butil::memory_order_relaxed); + accumulative_token += delta; + } + } + return NULL; +} + +class PerformanceTest { +public: + PerformanceTest(int attachment_size, bool echo_attachment) + : _addr(NULL) + , _channel(NULL) + , _start_time(0) + , _iterations(0) + , _stop(false) + { + if (attachment_size > 0) { + _addr = malloc(attachment_size); + butil::fast_rand_bytes(_addr, attachment_size); + _attachment.append(_addr, attachment_size); + } + _echo_attachment = echo_attachment; + } + + ~PerformanceTest() { + if (_addr) { + free(_addr); + } + delete _channel; + } + + inline bool IsStop() { return _stop; } + + int Init() { + brpc::ChannelOptions options; + options.socket_mode = FLAGS_use_ubring? brpc::SOCKET_MODE_UBRING : brpc::SOCKET_MODE_TCP; + options.protocol = FLAGS_protocol; + options.connection_type = FLAGS_connection_type; + options.timeout_ms = FLAGS_rpc_timeout_ms; + options.max_retry = 0; + // TODO A bug exists when the connection_group parameter is used. + // options.connection_group = std::to_string(reinterpret_cast(this)); + std::string server = g_servers[(rr_index++) % g_servers.size()]; + _channel = new brpc::Channel(); + if (_channel->Init(server.c_str(), &options) != 0) { + LOG(ERROR) << "Fail to initialize channel"; + return -1; + } + + // Add retry mechanism for RPC call + int retry = 3; + while (retry > 0) { + brpc::Controller cntl; + test::PerfTestResponse response; + test::PerfTestRequest request; + request.set_echo_attachment(_echo_attachment); + test::PerfTestService_Stub stub(_channel); + stub.Test(&cntl, &request, &response, NULL); + if (!cntl.Failed()) { + return 0; + } + LOG(WARNING) << "RPC call failed, retrying... (" << retry << " left): " << cntl.ErrorText(); + retry--; + bthread_usleep(1000000); // 100ms delay before retry + } + LOG(ERROR) << "RPC call failed after multiple retries"; + return -1; + } + + struct RespClosure { + brpc::Controller* cntl; + test::PerfTestResponse* resp; + PerformanceTest* test; + }; + + void SendRequest() { + if (FLAGS_expected_qps > 0) { + while (g_token.load(butil::memory_order_relaxed) <= 0) { + bthread_usleep(10); + } + g_token.fetch_sub(1, butil::memory_order_relaxed); + } + RespClosure* closure = new RespClosure; + test::PerfTestRequest request; + closure->resp = new test::PerfTestResponse(); + closure->cntl = new brpc::Controller(); + request.set_echo_attachment(_echo_attachment); + closure->cntl->request_attachment().append(_attachment); + closure->test = this; + google::protobuf::Closure* done = brpc::NewCallback(&HandleResponse, closure); + test::PerfTestService_Stub stub(_channel); + stub.Test(closure->cntl, &request, closure->resp, done); + } + + static void HandleResponse(RespClosure* closure) { + std::unique_ptr cntl_guard(closure->cntl); + std::unique_ptr response_guard(closure->resp); + if (closure->cntl->Failed()) { + LOG(DEBUG) << "RPC call failed: " << closure->cntl->ErrorText(); + // Don't stop the test immediately, just log the error and continue + } else { + g_latency_recorder << closure->cntl->latency_us(); + if (closure->resp->cpu_usage().size() > 0) { + g_server_cpu_recorder << atof(closure->resp->cpu_usage().c_str()) * 100; + } + g_total_bytes.fetch_add(closure->cntl->request_attachment().size(), butil::memory_order_relaxed); + g_total_cnt.fetch_add(1, butil::memory_order_relaxed); + } + + cntl_guard.reset(NULL); + response_guard.reset(NULL); + + if (closure->test->_iterations == 0 && FLAGS_test_iterations > 0) { + closure->test->_stop = true; + return; + } + --closure->test->_iterations; + uint64_t last = g_last_time.load(butil::memory_order_relaxed); + uint64_t now = butil::gettimeofday_us(); + if (now > last && now - last > 100000) { + if (g_last_time.exchange(now, butil::memory_order_relaxed) == last) { + g_client_cpu_recorder << + atof(bvar::Variable::describe_exposed("process_cpu_usage").c_str()) * 100; + } + } + if (now - closure->test->_start_time > FLAGS_test_seconds * 1000000u) { + closure->test->_stop = true; + return; + } + closure->test->SendRequest(); + } + + static void* RunTest(void* arg) { + PerformanceTest* test = (PerformanceTest*)arg; + test->_start_time = butil::gettimeofday_us(); + test->_iterations = FLAGS_test_iterations; + + for (int i = 0; i < FLAGS_queue_depth; ++i) { + test->SendRequest(); + } + + return NULL; + } + +private: + void* _addr; + brpc::Channel* _channel; + uint64_t _start_time; + uint32_t _iterations; + volatile bool _stop; + butil::IOBuf _attachment; + bool _echo_attachment; +}; + +static void* DeleteTest(void* arg) { + PerformanceTest* test = (PerformanceTest*)arg; + delete test; + return NULL; +} + +void Test(int thread_num, int attachment_size) { + std::cout << "[Threads: " << thread_num + << ", Depth: " << FLAGS_queue_depth + << ", Attachment: " << attachment_size << "B" + << ", UBRING: " << (FLAGS_use_ubring ? "yes" : "no") + << ", Echo: " << (FLAGS_echo_attachment ? "yes]" : "no]") + << std::endl; + g_total_bytes.store(0, butil::memory_order_relaxed); + g_total_cnt.store(0, butil::memory_order_relaxed); + std::vector tests; + for (int k = 0; k < thread_num; ++k) { + PerformanceTest* t = new PerformanceTest(attachment_size, FLAGS_echo_attachment); + if (t->Init() < 0) { + exit(1); + } + tests.push_back(t); + } + uint64_t start_time = butil::gettimeofday_us(); + bthread_t tid[thread_num]; + if (FLAGS_expected_qps > 0) { + bthread_t tid; + bthread_start_background(&tid, &BTHREAD_ATTR_NORMAL, GenerateToken, NULL); + } + for (int k = 0; k < thread_num; ++k) { + bthread_start_background(&tid[k], &BTHREAD_ATTR_NORMAL, + PerformanceTest::RunTest, tests[k]); + } + for (int k = 0; k < thread_num; ++k) { + while (!tests[k]->IsStop()) { + bthread_usleep(10000); + } + } + uint64_t end_time = butil::gettimeofday_us(); + double throughput = g_total_bytes / 1.048576 / (end_time - start_time); + if (FLAGS_test_iterations == 0) { + std::cout << "Avg-Latency: " << g_latency_recorder.latency(10) + << ", 90th-Latency: " << g_latency_recorder.latency_percentile(0.9) + << ", 99th-Latency: " << g_latency_recorder.latency_percentile(0.99) + << ", 99.9th-Latency: " << g_latency_recorder.latency_percentile(0.999) + << ", Throughput: " << throughput << "MB/s" + << ", QPS: " << (g_total_cnt.load(butil::memory_order_relaxed) * 1000 / (end_time - start_time)) << "k" + << ", Server CPU-utilization: " << g_server_cpu_recorder.latency(10) << "%" + << ", Client CPU-utilization: " << g_client_cpu_recorder.latency(10) << "%" + << std::endl; + } else { + std::cout << " Throughput: " << throughput << "MB/s" << std::endl; + } + g_stop = true; + for (int k = 0; k < thread_num; ++k) { + bthread_start_background(&tid[k], &BTHREAD_ATTR_NORMAL, DeleteTest, tests[k]); + } + for (int k = 0; k < thread_num; ++k) { + bthread_join(tid[k], NULL); + } +} + +int main(int argc, char* argv[]) { + GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); + + brpc::StartDummyServerAt(FLAGS_dummy_port); + + std::string::size_type pos1 = 0; + std::string::size_type pos2 = FLAGS_servers.find('+'); + while (pos2 != std::string::npos) { + g_servers.push_back(FLAGS_servers.substr(pos1, pos2 - pos1)); + pos1 = pos2 + 1; + pos2 = FLAGS_servers.find('+', pos1); + } + g_servers.push_back(FLAGS_servers.substr(pos1)); + + if (FLAGS_thread_num > 0 && FLAGS_attachment_size >= 0) { + Test(FLAGS_thread_num, FLAGS_attachment_size); + } else if (FLAGS_thread_num <= 0 && FLAGS_attachment_size >= 0) { + for (int i = 1; i <= FLAGS_max_thread_num; i *= 2) { + Test(i, FLAGS_attachment_size); + } + } else if (FLAGS_thread_num > 0 && FLAGS_attachment_size < 0) { + for (int i = 1; i <= 1024; i *= 4) { + Test(FLAGS_thread_num, i); + } + } else { + for (int j = 1; j <= 1024; j *= 4) { + for (int i = 1; i <= FLAGS_max_thread_num; i *= 2) { + Test(i, j); + } + } + } + + return 0; +} + +#else + +int main(int argc, char* argv[]) { + LOG(ERROR) << " brpc is not compiled with ubring. To enable it, please refer to the ubring documentation"; + return 0; +} + +#endif diff --git a/example/ubring_performance/server.cpp b/example/ubring_performance/server.cpp new file mode 100644 index 0000000000..b138c91c8d --- /dev/null +++ b/example/ubring_performance/server.cpp @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + +#include +#include "butil/atomicops.h" +#include "butil/logging.h" +#include "butil/time.h" +#include "brpc/server.h" +#include "bvar/variable.h" +#include "test.pb.h" + +#ifdef BRPC_WITH_UBRING + +DEFINE_int32(port, 8002, "TCP Port of this server"); +DEFINE_bool(use_ubring, false, "Use UBRING or not"); + +butil::atomic g_last_time(0); + +namespace test { +class PerfTestServiceImpl : public PerfTestService { +public: + PerfTestServiceImpl() {} + ~PerfTestServiceImpl() {} + + void Test(google::protobuf::RpcController* cntl_base, + const PerfTestRequest* request, + PerfTestResponse* response, + google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + uint64_t last = g_last_time.load(butil::memory_order_relaxed); + uint64_t now = butil::monotonic_time_us(); + if (now > last && now - last > 100000) { + if (g_last_time.exchange(now, butil::memory_order_relaxed) == last) { + response->set_cpu_usage(bvar::Variable::describe_exposed("process_cpu_usage")); + } else { + response->set_cpu_usage(""); + } + } else { + response->set_cpu_usage(""); + } + if (request->echo_attachment()) { + brpc::Controller* cntl = + static_cast(cntl_base); + cntl->response_attachment().append(cntl->request_attachment()); + } + } +}; +} + +int main(int argc, char* argv[]) { + GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); + + brpc::Server server; + test::PerfTestServiceImpl perf_test_service_impl; + + if (server.AddService(&perf_test_service_impl, + brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "Fail to add service"; + return -1; + } + g_last_time.store(0, butil::memory_order_relaxed); + + brpc::ServerOptions options; + options.socket_mode = FLAGS_use_ubring? brpc::SOCKET_MODE_UBRING : brpc::SOCKET_MODE_TCP; + if (server.Start(FLAGS_port, &options) != 0) { + LOG(ERROR) << "Fail to start EchoServer"; + return -1; + } + + server.RunUntilAskedToQuit(); + return 0; +} + +#else + + +int main(int argc, char* argv[]) { + LOG(ERROR) << " brpc is not compiled with ubring. To enable it, please refer to the ubring documentation"; + return 0; +} + +#endif \ No newline at end of file diff --git a/example/ubring_performance/test.proto b/example/ubring_performance/test.proto new file mode 100644 index 0000000000..22646d113c --- /dev/null +++ b/example/ubring_performance/test.proto @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax="proto2"; +option cc_generic_services = true; + +package test; + +message PerfTestRequest { + required bool echo_attachment = 1; +}; + +message PerfTestResponse { + required string cpu_usage = 1; +}; + +service PerfTestService { + rpc Test(PerfTestRequest) returns (PerfTestResponse); +}; \ No newline at end of file diff --git a/package/rpm/brpc.spec b/package/rpm/brpc.spec index 4099ca2cbd..19c2d83b12 100644 --- a/package/rpm/brpc.spec +++ b/package/rpm/brpc.spec @@ -18,7 +18,7 @@ # Name: brpc -Version: 1.16.0 +Version: 1.17.0 Release: 1%{?dist} Summary: Industrial-grade RPC framework using C++ Language. diff --git a/registry/modules/libunwind/1.8.1.brpc-no-unwind/MODULE.bazel b/registry/modules/libunwind/1.8.1.brpc-no-unwind/MODULE.bazel new file mode 100644 index 0000000000..1540330956 --- /dev/null +++ b/registry/modules/libunwind/1.8.1.brpc-no-unwind/MODULE.bazel @@ -0,0 +1,21 @@ +# Forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.1/MODULE.bazel +# Distributed under the Apache License, Version 2.0. See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file. +# +# brpc modification (Apache License 2.0 §4(b)): +# - `version` suffixed with `.brpc-no-unwind` to distinguish this brpc +# variant from the upstream BCR version. + +module( + name = "libunwind", + version = "1.8.1.brpc-no-unwind", + bazel_compatibility = [">=7.2.1"], # need support for "overlay" directory + compatibility_level = 0, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.11") +bazel_dep(name = "rules_cc", version = "0.1.1") +bazel_dep(name = "xz", version = "5.4.5.bcr.5") +bazel_dep(name = "zlib", version = "1.3.1.bcr.5") diff --git a/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/BUILD.bazel b/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/BUILD.bazel new file mode 100644 index 0000000000..ec5ecd1f34 --- /dev/null +++ b/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/BUILD.bazel @@ -0,0 +1,781 @@ +# This file is forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.1/overlay/BUILD.bazel +# +# The original file is distributed under the Apache License, Version 2.0 +# (https://www.apache.org/licenses/LICENSE-2.0). See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file for the full +# notice and the list of brpc's modifications. +# +# Modifications by brpc maintainers (Apache License 2.0 §4(b)): +# - Add the `hide_unwind_symbols` config_setting (gated by +# `--define=libunwind_hide_unwind_symbols=true`). +# - When the switch is on, drop `src/unwind/*.c` (the GCC `_Unwind_*` ABI +# compatibility layer) from `unwind_srcs` so the resulting libunwind +# does not export `_Unwind_*` symbols. See docs/cn/bthread_tracer.md +# for the rationale. + +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@rules_cc//cc:defs.bzl", "cc_library") + +# Targets in this file are based off configure.ac, Makefile.am, and +# src/Makefile.am with the default configurations. +# +# Only supports aarch64, arm, and x86_64 on Linux for now. + +### Config settings ############################################################ + +# Bazel does not support nested selects, so we need config settings with the +# various combinations of OS and CPU. + +# Switch to drop the GCC `_Unwind_*` ABI compatibility layer (src/unwind/*.c) +# from the build. These sources implement `_Unwind_GetLanguageSpecificData`, +# `_Unwind_ForcedUnwind`, `_Unwind_Resume`, etc. - the very symbols that +# libgcc_s also provides. When Bazel builds libunwind as a dylib (its default +# behavior in fastbuild), these symbols get exported and at runtime the dynamic +# loader resolves the libstdc++ / pthread_exit unwinding paths to libunwind's +# DWARF-based implementations instead of libgcc_s's, causing crashes such as: +# _Unwind_ForcedUnwind -> __gxx_personality_v0 -> +# __libunwind_Unwind_GetLanguageSpecificData -> SEGV in +# _ULx86_64_dwarf_find_proc_info. +# +# brpc only consumes libunwind's native `unw_*` API (provided by src/mi/), so +# omitting src/unwind/*.c is safe and is the same effect autoconf gets via +# `--version-script` in the upstream Makefile (the make-built libunwind.so does +# not export `_Unwind_*` either - see CI's init-ut-make-config action). +# +# Enable with: --define=libunwind_hide_unwind_symbols=true +config_setting( + name = "hide_unwind_symbols", + values = {"define": "libunwind_hide_unwind_symbols=true"}, +) + +selects.config_setting_group( + name = "linux_arm64", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:arm64", + ], +) + +selects.config_setting_group( + name = "linux_arm", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:arm", + ], +) + +selects.config_setting_group( + name = "linux_x86_64", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], +) + +### Common defines ############################################################# + +libunwind_defines = [ + # Defaults based on configure.ac assuming we're on a modern Linux OS. + "_GNU_SOURCE", + "CONFIG_BLOCK_SIGNALS", + "CONFIG_WEAK_BACKTRACE", + "CONSERVATIVE_CHECKS", + "HAVE__BUILTIN___CLEAR_CACHE", + "HAVE__BUILTIN_UNREACHABLE", + "HAVE_ELF_H", + "HAVE_LINK_H", + "HAVE_LZMA", + "HAVE_ZLIB", +] + +### Common source files ######################################################## + +dwarf_srcs = glob(["src/dwarf/*.c"]) + +dwarf_textual_hdrs = glob(["src/dwarf/G*.c"]) + +mi_srcs = glob( + ["src/mi/*.c"], + exclude = [ + # Only included if Linux. + "src/mi/_ReadSLEB.c", + "src/mi/_ReadULEB.c", + # The Makefile does not include this, it's also broken as it can include + # Gdyn-remote.c in certain situations, which uses WSIZE, which will not + # be defined in the situations Gdyn-remote.c is included. + "src/mi/Ldyn-remote.c", + # TODO: for some reason these complain about duplicate definitions if + # included. + "src/mi/Gaddress_validator.c", + "src/mi/Gget_accessors.c", + ], +) + +mi_textual_hdrs = glob(["src/mi/G*.c"]) + +# When `--define=libunwind_hide_unwind_symbols=true`, drop src/unwind/ entirely +# so libunwind does not provide its own `_Unwind_*` implementations. See the +# comment on `:hide_unwind_symbols` config_setting above. +unwind_srcs = select({ + ":hide_unwind_symbols": [], + "//conditions:default": glob([ + "src/unwind/*.c", + "src/unwind/*.h", + ]), +}) + +### Features source files ###################################################### + +coredump_srcs = glob( + [ + "src/coredump/*.c", + "src/coredump/*.h", + ], + exclude = [ + "src/coredump/_UCD_access_reg_freebsd.c", + "src/coredump/_UCD_access_reg_linux.c", + "src/coredump/_UCD_access_reg_qnx.c", + "src/coredump/_UCD_get_mapinfo_generic.c", + "src/coredump/_UCD_get_mapinfo_linux.c", + "src/coredump/_UCD_get_mapinfo_qnx.c", + "src/coredump/_UCD_get_threadinfo_prstatus.c", + ], +) + select({ + "@platforms//os:linux": [ + "src/coredump/_UCD_access_reg_linux.c", + "src/coredump/_UCD_get_mapinfo_linux.c", + "src/coredump/_UCD_get_threadinfo_prstatus.c", + ], + "//conditions:default": [], +}) + +ptrace_srcs = glob([ + "src/ptrace/*.c", + "src/ptrace/*.h", +]) + +setjmp_srcs = glob([ + "src/setjmp/*.c", + "src/setjmp/*.h", +]) + select({ + "@platforms//cpu:aarch64": [ + "src/aarch64/longjmp.S", + "src/aarch64/siglongjmp.S", + ], + "@platforms//cpu:arm": [ + "src/arm/siglongjmp.S", + ], + "@platforms//cpu:x86_64": [ + "src/x86_64/longjmp.S", + "src/x86_64/siglongjmp.S", + ], +}) + +### Arch specific source files ################################################# + +arm64_srcs = select({ + "@platforms//cpu:aarch64": glob( + [ + "src/aarch64/*.c", + "src/aarch64/*.h", + ], + exclude = [ + "src/aarch64/Los-freebsd.c", + "src/aarch64/Los-linux.c", + "src/aarch64/Los-qnx.c", + "src/aarch64/Gos-freebsd.c", + "src/aarch64/Gos-linux.c", + "src/aarch64/Gos-qnx.c", + ], + ) + [ + # The Makefile doesn't include this and only includes it if the OS + # is FreeBSD. This is inconsistent with the other archs, not sure + # whether this is intended. + # "src/aarch64/setcontext.S", + "src/aarch64/getcontext.S", + "src/elf64.h", + ], + "//conditions:default": [], +}) + +arm64_textual_hdrs = select({ + "@platforms//cpu:aarch64": glob( + [ + "src/aarch64/G*.c", + ], + exclude = [ + "src/aarch64/Gos-freebsd.c", + "src/aarch64/Gos-linux.c", + "src/aarch64/Gos-qnx.c", + ], + ) + [ + "src/elf64.c", + ], + "//conditions:default": [], +}) + +arm_srcs = select({ + "@platforms//cpu:arm": glob( + [ + "src/arm/*.c", + "src/arm/*.h", + ], + exclude = [ + "src/arm/Los-freebsd.c", + "src/arm/Los-linux.c", + "src/arm/Los-other.c", + "src/arm/Gos-freebsd.c", + "src/arm/Gos-linux.c", + "src/arm/Gos-other.c", + ], + ) + [ + "src/arm/getcontext.S", + "src/elf32.h", + ], + "//conditions:default": [], +}) + +arm_textual_hdrs = select({ + "@platforms//cpu:arm": glob( + [ + "src/arm/G*.c", + ], + exclude = [ + "src/arm/Gos-freebsd.c", + "src/arm/Gos-linux.c", + "src/arm/Gos-other.c", + ], + ) + [ + "src/elf32.c", + ], + "//conditions:default": [], +}) + +x86_64_srcs = select({ + "@platforms//cpu:x86_64": glob( + [ + "src/x86_64/*.c", + "src/x86_64/*.h", + ], + exclude = [ + "src/x86_64/Los-freebsd.c", + "src/x86_64/Los-linux.c", + "src/x86_64/Los-qnx.c", + "src/x86_64/Los-solaris.c", + "src/x86_64/Gos-freebsd.c", + "src/x86_64/Gos-linux.c", + "src/x86_64/Gos-qnx.c", + "src/x86_64/Gos-solaris.c", + ], + ) + [ + "src/elf64.h", + "src/x86_64/getcontext.S", + "src/x86_64/setcontext.S", + ], + "//conditions:default": [], +}) + +x86_64_textual_hdrs = select({ + "@platforms//cpu:x86_64": glob( + [ + "src/x86_64/G*.c", + ], + exclude = [ + "src/x86_64/Gos-freebsd.c", + "src/x86_64/Gos-linux.c", + "src/x86_64/Gos-qnx.c", + "src/x86_64/Gos-solaris.c", + ], + ) + [ + "src/elf64.c", + ], + "//conditions:default": [], +}) + +### OS specific source files ################################################### + +linux_srcs = select({ + "@platforms//os:linux": [ + "src/dl-iterate-phdr.c", + "src/mi/_ReadSLEB.c", + "src/mi/_ReadULEB.c", + "src/os-linux.c", + "src/os-linux.h", + ], + "//conditions:default": [], +}) + select({ + ":linux_arm64": [ + "src/aarch64/Los-linux.c", + "src/aarch64/Gos-linux.c", + ], + "//conditions:default": [], +}) + select({ + ":linux_arm": [ + "src/arm/Los-linux.c", + "src/arm/Gos-linux.c", + ], + "//conditions:default": [], +}) + select({ + ":linux_x86_64": [ + "src/x86_64/Los-linux.c", + "src/x86_64/Gos-linux.c", + ], + "//conditions:default": [], +}) + +linux_textual_hdrs = select({ + ":linux_arm64": ["src/aarch64/Gos-linux.c"], + "//conditions:default": [], +}) + select({ + ":linux_arm": ["src/arm/Gos-linux.c"], + "//conditions:default": [], +}) + select({ + ":linux_x86_64": ["src/x86_64/Gos-linux.c"], + "//conditions:default": [], +}) + +### libunwind ################################################################## + +libunwind_srcs = [ + "src/elfxx.h", + "src/elfxx.c", +] + dwarf_srcs + mi_srcs + unwind_srcs + arm64_srcs + arm_srcs + x86_64_srcs + linux_srcs + +libunwind_textual_hdrs = [ + "src/elfxx.c", +] + dwarf_textual_hdrs + mi_textual_hdrs + arm64_textual_hdrs + arm_textual_hdrs + x86_64_textual_hdrs + linux_textual_hdrs + +expand_template( + name = "libunwind_common_h", + out = "include/libunwind-common.h", + substitutions = { + "@PKG_MAJOR@": "1", + "@PKG_MINOR@": "8", + "@PKG_EXTRA@": "1", + }, + template = "include/libunwind-common.h.in", +) + +cc_library( + name = "unwind", + srcs = libunwind_srcs, + hdrs = glob(["include/tdep/*.h"]) + [ + "include/compiler.h", + "include/dwarf.h", + "include/dwarf-eh.h", + "include/dwarf_i.h", + "include/libunwind.h", + "include/libunwind-dynamic.h", + "include/libunwind_i.h", + "include/mempool.h", + "include/remote.h", + "include/unwind.h", + ":libunwind_common_h", + ] + select({ + "@platforms//cpu:arm64": glob(["include/tdep-aarch64/*.h"]) + ["include/libunwind-aarch64.h"], + "@platforms//cpu:arm": glob(["include/tdep-arm/*.h"]) + ["include/libunwind-arm.h"], + "@platforms//cpu:x86_64": ["include/libunwind-x86_64.h"] + glob(["include/tdep-x86_64/*.h"]), + }), + includes = [ + "include", + "src", + ] + select({ + "@platforms//cpu:arm64": [ + "include/tdep-aarch64", + ], + "@platforms//cpu:arm": [ + "include/tdep-arm", + ], + "@platforms//cpu:x86_64": [ + "include/tdep-x86_64", + ], + }), + local_defines = libunwind_defines + ["HAVE_DL_ITERATE_PHDR"], + textual_hdrs = libunwind_textual_hdrs, + deps = [ + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_coredump", + srcs = coredump_srcs + ["src/mi/init.c"], + hdrs = ["include/libunwind-coredump.h"], + includes = ["include"], + local_defines = libunwind_defines + [ + "HAVE_STRUCT_ELF_PRSTATUS", + "HAVE_SYS_PROCFS_H", + ] + select({ + "@platforms//cpu:arm64": [ + "CONFIG_DEBUG_FRAME", + "SIZEOF_OFF_T=8", + ], + "@platforms//cpu:arm": [ + "CONFIG_DEBUG_FRAME", + "SIZEOF_OFF_T=4", + ], + "@platforms//cpu:x86_64": [ + "SIZEOF_OFF_T=8", + ], + }), + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_ptrace", + srcs = ptrace_srcs + ["src/mi/init.c"], + hdrs = ["include/libunwind-ptrace.h"], + includes = ["include"], + local_defines = libunwind_defines + ["HAVE_TTRACE"], + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_setjmp", + srcs = setjmp_srcs + ["src/mi/init.c"], + local_defines = libunwind_defines, + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +alias( + name = "libunwind", + actual = ":unwind", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_coredump", + actual = ":unwind_coredump", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_ptrace", + actual = ":unwind_ptrace", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_setjmp", + actual = ":unwind_setjmp", + visibility = ["//visibility:public"], +) + +### Tests ##################################################################### + +# TODO: the following tests currently do not work: +# crasher.c +# forker.c +# test-coredump-unwind.c +# test-init-remote.c +# test-proc-info.c +# test-ptrace-misc.c +# test-ptrace.c +# test-setjmp.c + +# Some tests require these test helper files. +cc_library( + name = "unwind_test_helpers", + testonly = True, + srcs = [ + "tests/flush-cache.S", + "tests/ident.c", + ], + hdrs = glob(["tests/*G*.c"]) + [ + "tests/Gtest-init.cxx", + "tests/flush-cache.h", + "tests/ident.h", + ], + defines = ["_GNU_SOURCE"], +) + +cc_test( + name = "frame_record_test", + srcs = ["tests/aarch64-test-frame-record.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "plt_test", + srcs = ["tests/aarch64-test-plt.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "arm64_sve_signal_test", + srcs = ["tests/Larm64-test-sve-signal.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "perf_simple_test", + srcs = ["tests/Lperf-simple.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "perf_trace_test", + srcs = ["tests/Lperf-trace.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "rs_race_test", + srcs = ["tests/Lrs-race.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "bt_test", + srcs = ["tests/Ltest-bt.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "concurrent_test", + srcs = ["tests/Ltest-concurrent.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "cxx_exceptions_test", + srcs = ["tests/Ltest-cxx-exceptions.cxx"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +# TODO: fails on Debian 11 and Ubuntu 20.04. +# cc_test( +# name = "dyn1_test", +# srcs = ["tests/Ltest-dyn1.c"], +# deps = [ +# ":unwind", +# ":unwind_test_helpers", +# ], +# linkopts = ["-ldl"], +# ) + +cc_test( + name = "exc_test", + srcs = ["tests/Ltest-exc.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "init_local_signal_test", + srcs = [ + "tests/Ltest-init-local-signal.c", + "tests/Ltest-init-local-signal-lib.c", + ], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "init_test", + srcs = ["tests/Ltest-init.cxx"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "mem_validate_test", + srcs = ["tests/Ltest-mem-validate.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "nocalloc_test", + srcs = ["tests/Ltest-nocalloc.c"], + linkopts = ["-ldl"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "nomalloc_test", + srcs = ["tests/Ltest-nomalloc.c"], + linkopts = ["-ldl"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "resume_sig_rt_test", + srcs = ["tests/Ltest-resume-sig-rt.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "resume_sig_test", + srcs = ["tests/Ltest-resume-sig.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "trace_test", + srcs = ["tests/Ltest-trace.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "varargs_test", + srcs = ["tests/Ltest-varargs.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "dwarf_expressions_test", + srcs = [ + "tests/Lx64-test-dwarf-expressions.c", + "tests/x64-test-dwarf-expressions.S", + ], + target_compatible_with = ["@platforms//cpu:x86_64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "async_sig_test", + srcs = ["tests/test-async-sig.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "flush_cache_test", + srcs = ["tests/test-flush-cache.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "mem_test", + srcs = ["tests/test-mem.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "reg_state_test", + srcs = ["tests/test-reg-state.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "static_link_loc_test", + srcs = [ + "tests/test-static-link-gen.c", + "tests/test-static-link-loc.c", + ], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "strerror_test", + srcs = ["tests/test-strerror.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "unwind_badjmp_signal_frame_test", + srcs = ["tests/x64-unwind-badjmp-signal-frame.c"], + target_compatible_with = ["@platforms//cpu:x86_64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) diff --git a/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/MODULE.bazel b/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/MODULE.bazel new file mode 100644 index 0000000000..13035df3f6 --- /dev/null +++ b/registry/modules/libunwind/1.8.1.brpc-no-unwind/overlay/MODULE.bazel @@ -0,0 +1,21 @@ +# Forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.1/overlay/MODULE.bazel +# Distributed under the Apache License, Version 2.0. See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file. +# +# brpc modification (Apache License 2.0 §4(b)): +# - `version` suffixed with `.brpc-no-unwind` to distinguish this brpc +# variant from the upstream BCR version. + +module( + name = "libunwind", + version = "1.8.1.brpc-no-unwind", + bazel_compatibility = [">=7.2.1"], # need support for "overlay" directory + compatibility_level = 0, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.11") +bazel_dep(name = "rules_cc", version = "0.1.1") +bazel_dep(name = "xz", version = "5.4.5.bcr.5") +bazel_dep(name = "zlib", version = "1.3.1.bcr.5") diff --git a/registry/modules/libunwind/1.8.1.brpc-no-unwind/presubmit.yml b/registry/modules/libunwind/1.8.1.brpc-no-unwind/presubmit.yml new file mode 100644 index 0000000000..3a0544118e --- /dev/null +++ b/registry/modules/libunwind/1.8.1.brpc-no-unwind/presubmit.yml @@ -0,0 +1,18 @@ +matrix: + platform: + - debian11 + - ubuntu2004 + - ubuntu2204 + - ubuntu2404 + bazel: [7.x, 8.x, rolling] +tasks: + verify_targets: + platform: ${{ platform }} + bazel: ${{ bazel }} + build_targets: + - "@libunwind//:libunwind" + - "@libunwind//:libunwind_coredump" + - "@libunwind//:libunwind_ptrace" + - "@libunwind//:libunwind_setjmp" + test_targets: + - "@libunwind//..." diff --git a/registry/modules/libunwind/1.8.1.brpc-no-unwind/source.json b/registry/modules/libunwind/1.8.1.brpc-no-unwind/source.json new file mode 100644 index 0000000000..54bbb9f543 --- /dev/null +++ b/registry/modules/libunwind/1.8.1.brpc-no-unwind/source.json @@ -0,0 +1,9 @@ +{ + "url": "https://github.com/libunwind/libunwind/releases/download/v1.8.1/libunwind-1.8.1.tar.gz", + "integrity": "sha256-3fDjLdX6/lKDGY035L+d7Pe6F3C25+AGwz5t955qYVc=", + "strip_prefix": "libunwind-1.8.1", + "overlay": { + "BUILD.bazel": "sha256-X2B6q+gpb5U0BGoD8LJPoBPCpyIvT5QFSwsjFxZ9R+0=", + "MODULE.bazel": "sha256-Yu1CCsCC+A2NXfQ3bgzpbtQcYhM+bPPb2YAPKPojdIY=" + } +} diff --git a/registry/modules/libunwind/1.8.3.brpc-no-unwind/MODULE.bazel b/registry/modules/libunwind/1.8.3.brpc-no-unwind/MODULE.bazel new file mode 100644 index 0000000000..c85228a1d6 --- /dev/null +++ b/registry/modules/libunwind/1.8.3.brpc-no-unwind/MODULE.bazel @@ -0,0 +1,21 @@ +# Forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.3/MODULE.bazel +# Distributed under the Apache License, Version 2.0. See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file. +# +# brpc modification (Apache License 2.0 §4(b)): +# - `version` suffixed with `.brpc-no-unwind` to distinguish this brpc +# variant from the upstream BCR version. + +module( + name = "libunwind", + version = "1.8.3.brpc-no-unwind", + bazel_compatibility = [">=7.2.1"], # need support for "overlay" directory + compatibility_level = 0, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.11") +bazel_dep(name = "rules_cc", version = "0.1.1") +bazel_dep(name = "xz", version = "5.4.5.bcr.8") +bazel_dep(name = "zlib", version = "1.3.1.bcr.8") diff --git a/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/BUILD.bazel b/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/BUILD.bazel new file mode 100644 index 0000000000..f8f96c2206 --- /dev/null +++ b/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/BUILD.bazel @@ -0,0 +1,781 @@ +# This file is forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.3/overlay/BUILD.bazel +# +# The original file is distributed under the Apache License, Version 2.0 +# (https://www.apache.org/licenses/LICENSE-2.0). See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file for the full +# notice and the list of brpc's modifications. +# +# Modifications by brpc maintainers (Apache License 2.0 §4(b)): +# - Add the `hide_unwind_symbols` config_setting (gated by +# `--define=libunwind_hide_unwind_symbols=true`). +# - When the switch is on, drop `src/unwind/*.c` (the GCC `_Unwind_*` ABI +# compatibility layer) from `unwind_srcs` so the resulting libunwind +# does not export `_Unwind_*` symbols. See docs/cn/bthread_tracer.md +# for the rationale. + +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +# Targets in this file are based off configure.ac, Makefile.am, and +# src/Makefile.am with the default configurations. +# +# Only supports aarch64, arm, and x86_64 on Linux for now. + +### Config settings ############################################################ + +# Bazel does not support nested selects, so we need config settings with the +# various combinations of OS and CPU. + +# Switch to drop the GCC `_Unwind_*` ABI compatibility layer (src/unwind/*.c) +# from the build. These sources implement `_Unwind_GetLanguageSpecificData`, +# `_Unwind_ForcedUnwind`, `_Unwind_Resume`, etc. - the very symbols that +# libgcc_s also provides. When Bazel builds libunwind as a dylib (its default +# behavior in fastbuild), these symbols get exported and at runtime the dynamic +# loader resolves the libstdc++ / pthread_exit unwinding paths to libunwind's +# DWARF-based implementations instead of libgcc_s's, causing crashes such as: +# _Unwind_ForcedUnwind -> __gxx_personality_v0 -> +# __libunwind_Unwind_GetLanguageSpecificData -> SEGV in +# _ULx86_64_dwarf_find_proc_info. +# +# brpc only consumes libunwind's native `unw_*` API (provided by src/mi/), so +# omitting src/unwind/*.c is safe and is the same effect autoconf gets via +# `--version-script` in the upstream Makefile (the make-built libunwind.so does +# not export `_Unwind_*` either - see CI's init-ut-make-config action). +# +# Enable with: --define=libunwind_hide_unwind_symbols=true +config_setting( + name = "hide_unwind_symbols", + values = {"define": "libunwind_hide_unwind_symbols=true"}, +) + +selects.config_setting_group( + name = "linux_arm64", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:arm64", + ], +) + +selects.config_setting_group( + name = "linux_arm", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:arm", + ], +) + +selects.config_setting_group( + name = "linux_x86_64", + match_all = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + ], +) + +### Common defines ############################################################# + +libunwind_defines = [ + # Defaults based on configure.ac assuming we're on a modern Linux OS. + "_GNU_SOURCE", + "CONFIG_BLOCK_SIGNALS", + "CONFIG_WEAK_BACKTRACE", + "CONSERVATIVE_CHECKS", + "HAVE__BUILTIN___CLEAR_CACHE", + "HAVE__BUILTIN_UNREACHABLE", + "HAVE_ELF_H", + "HAVE_LINK_H", + "HAVE_LZMA", + "HAVE_ZLIB", +] + +### Common source files ######################################################## + +dwarf_srcs = glob(["src/dwarf/*.c"]) + +dwarf_textual_hdrs = glob(["src/dwarf/G*.c"]) + +mi_srcs = glob( + ["src/mi/*.c"], + exclude = [ + # Only included if Linux. + "src/mi/_ReadSLEB.c", + "src/mi/_ReadULEB.c", + # The Makefile does not include this, it's also broken as it can include + # Gdyn-remote.c in certain situations, which uses WSIZE, which will not + # be defined in the situations Gdyn-remote.c is included. + "src/mi/Ldyn-remote.c", + # TODO: for some reason these complain about duplicate definitions if + # included. + "src/mi/Gaddress_validator.c", + "src/mi/Gget_accessors.c", + ], +) + +mi_textual_hdrs = glob(["src/mi/G*.c"]) + +# When `--define=libunwind_hide_unwind_symbols=true`, drop src/unwind/ entirely +# so libunwind does not provide its own `_Unwind_*` implementations. See the +# comment on `:hide_unwind_symbols` config_setting above. +unwind_srcs = select({ + ":hide_unwind_symbols": [], + "//conditions:default": glob([ + "src/unwind/*.c", + "src/unwind/*.h", + ]), +}) + +### Features source files ###################################################### + +coredump_srcs = glob( + [ + "src/coredump/*.c", + "src/coredump/*.h", + ], + exclude = [ + "src/coredump/_UCD_access_reg_freebsd.c", + "src/coredump/_UCD_access_reg_linux.c", + "src/coredump/_UCD_access_reg_qnx.c", + "src/coredump/_UCD_get_mapinfo_generic.c", + "src/coredump/_UCD_get_mapinfo_linux.c", + "src/coredump/_UCD_get_mapinfo_qnx.c", + "src/coredump/_UCD_get_threadinfo_prstatus.c", + ], +) + select({ + "@platforms//os:linux": [ + "src/coredump/_UCD_access_reg_linux.c", + "src/coredump/_UCD_get_mapinfo_linux.c", + "src/coredump/_UCD_get_threadinfo_prstatus.c", + ], + "//conditions:default": [], +}) + +ptrace_srcs = glob([ + "src/ptrace/*.c", + "src/ptrace/*.h", +]) + +setjmp_srcs = glob([ + "src/setjmp/*.c", + "src/setjmp/*.h", +]) + select({ + "@platforms//cpu:aarch64": [ + "src/aarch64/longjmp.S", + "src/aarch64/siglongjmp.S", + ], + "@platforms//cpu:arm": [ + "src/arm/siglongjmp.S", + ], + "@platforms//cpu:x86_64": [ + "src/x86_64/longjmp.S", + "src/x86_64/siglongjmp.S", + ], +}) + +### Arch specific source files ################################################# + +arm64_srcs = select({ + "@platforms//cpu:aarch64": glob( + [ + "src/aarch64/*.c", + "src/aarch64/*.h", + ], + exclude = [ + "src/aarch64/Los-freebsd.c", + "src/aarch64/Los-linux.c", + "src/aarch64/Los-qnx.c", + "src/aarch64/Gos-freebsd.c", + "src/aarch64/Gos-linux.c", + "src/aarch64/Gos-qnx.c", + ], + ) + [ + # The Makefile doesn't include this and only includes it if the OS + # is FreeBSD. This is inconsistent with the other archs, not sure + # whether this is intended. + # "src/aarch64/setcontext.S", + "src/aarch64/getcontext.S", + "src/elf64.h", + ], + "//conditions:default": [], +}) + +arm64_textual_hdrs = select({ + "@platforms//cpu:aarch64": glob( + [ + "src/aarch64/G*.c", + ], + exclude = [ + "src/aarch64/Gos-freebsd.c", + "src/aarch64/Gos-linux.c", + "src/aarch64/Gos-qnx.c", + ], + ) + [ + "src/elf64.c", + ], + "//conditions:default": [], +}) + +arm_srcs = select({ + "@platforms//cpu:arm": glob( + [ + "src/arm/*.c", + "src/arm/*.h", + ], + exclude = [ + "src/arm/Los-freebsd.c", + "src/arm/Los-linux.c", + "src/arm/Los-other.c", + "src/arm/Gos-freebsd.c", + "src/arm/Gos-linux.c", + "src/arm/Gos-other.c", + ], + ) + [ + "src/arm/getcontext.S", + "src/elf32.h", + ], + "//conditions:default": [], +}) + +arm_textual_hdrs = select({ + "@platforms//cpu:arm": glob( + [ + "src/arm/G*.c", + ], + exclude = [ + "src/arm/Gos-freebsd.c", + "src/arm/Gos-linux.c", + "src/arm/Gos-other.c", + ], + ) + [ + "src/elf32.c", + ], + "//conditions:default": [], +}) + +x86_64_srcs = select({ + "@platforms//cpu:x86_64": glob( + [ + "src/x86_64/*.c", + "src/x86_64/*.h", + ], + exclude = [ + "src/x86_64/Los-freebsd.c", + "src/x86_64/Los-linux.c", + "src/x86_64/Los-qnx.c", + "src/x86_64/Los-solaris.c", + "src/x86_64/Gos-freebsd.c", + "src/x86_64/Gos-linux.c", + "src/x86_64/Gos-qnx.c", + "src/x86_64/Gos-solaris.c", + ], + ) + [ + "src/elf64.h", + "src/x86_64/getcontext.S", + "src/x86_64/setcontext.S", + ], + "//conditions:default": [], +}) + +x86_64_textual_hdrs = select({ + "@platforms//cpu:x86_64": glob( + [ + "src/x86_64/G*.c", + ], + exclude = [ + "src/x86_64/Gos-freebsd.c", + "src/x86_64/Gos-linux.c", + "src/x86_64/Gos-qnx.c", + "src/x86_64/Gos-solaris.c", + ], + ) + [ + "src/elf64.c", + ], + "//conditions:default": [], +}) + +### OS specific source files ################################################### + +linux_srcs = select({ + "@platforms//os:linux": [ + "src/dl-iterate-phdr.c", + "src/mi/_ReadSLEB.c", + "src/mi/_ReadULEB.c", + "src/os-linux.c", + "src/os-linux.h", + ], + "//conditions:default": [], +}) + select({ + ":linux_arm64": [ + "src/aarch64/Los-linux.c", + "src/aarch64/Gos-linux.c", + ], + "//conditions:default": [], +}) + select({ + ":linux_arm": [ + "src/arm/Los-linux.c", + "src/arm/Gos-linux.c", + ], + "//conditions:default": [], +}) + select({ + ":linux_x86_64": [ + "src/x86_64/Los-linux.c", + "src/x86_64/Gos-linux.c", + ], + "//conditions:default": [], +}) + +linux_textual_hdrs = select({ + ":linux_arm64": ["src/aarch64/Gos-linux.c"], + "//conditions:default": [], +}) + select({ + ":linux_arm": ["src/arm/Gos-linux.c"], + "//conditions:default": [], +}) + select({ + ":linux_x86_64": ["src/x86_64/Gos-linux.c"], + "//conditions:default": [], +}) + +### libunwind ################################################################## + +libunwind_srcs = [ + "src/elfxx.h", + "src/elfxx.c", +] + dwarf_srcs + mi_srcs + unwind_srcs + arm64_srcs + arm_srcs + x86_64_srcs + linux_srcs + +libunwind_textual_hdrs = [ + "src/elfxx.c", +] + dwarf_textual_hdrs + mi_textual_hdrs + arm64_textual_hdrs + arm_textual_hdrs + x86_64_textual_hdrs + linux_textual_hdrs + +expand_template( + name = "libunwind_common_h", + out = "include/libunwind-common.h", + substitutions = { + "@PKG_MAJOR@": "1", + "@PKG_MINOR@": "8", + "@PKG_EXTRA@": "1", + }, + template = "include/libunwind-common.h.in", +) + +cc_library( + name = "unwind", + srcs = libunwind_srcs, + hdrs = glob(["include/tdep/*.h"]) + [ + "include/compiler.h", + "include/dwarf.h", + "include/dwarf-eh.h", + "include/dwarf_i.h", + "include/libunwind.h", + "include/libunwind-dynamic.h", + "include/libunwind_i.h", + "include/mempool.h", + "include/remote.h", + "include/unwind.h", + ":libunwind_common_h", + ] + select({ + "@platforms//cpu:arm64": glob(["include/tdep-aarch64/*.h"]) + ["include/libunwind-aarch64.h"], + "@platforms//cpu:arm": glob(["include/tdep-arm/*.h"]) + ["include/libunwind-arm.h"], + "@platforms//cpu:x86_64": ["include/libunwind-x86_64.h"] + glob(["include/tdep-x86_64/*.h"]), + }), + includes = [ + "include", + "src", + ] + select({ + "@platforms//cpu:arm64": [ + "include/tdep-aarch64", + ], + "@platforms//cpu:arm": [ + "include/tdep-arm", + ], + "@platforms//cpu:x86_64": [ + "include/tdep-x86_64", + ], + }), + local_defines = libunwind_defines + ["HAVE_DL_ITERATE_PHDR"], + textual_hdrs = libunwind_textual_hdrs, + deps = [ + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_coredump", + srcs = coredump_srcs + ["src/mi/init.c"], + hdrs = ["include/libunwind-coredump.h"], + includes = ["include"], + local_defines = libunwind_defines + [ + "HAVE_STRUCT_ELF_PRSTATUS", + "HAVE_SYS_PROCFS_H", + ] + select({ + "@platforms//cpu:arm64": [ + "CONFIG_DEBUG_FRAME", + "SIZEOF_OFF_T=8", + ], + "@platforms//cpu:arm": [ + "CONFIG_DEBUG_FRAME", + "SIZEOF_OFF_T=4", + ], + "@platforms//cpu:x86_64": [ + "SIZEOF_OFF_T=8", + ], + }), + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_ptrace", + srcs = ptrace_srcs + ["src/mi/init.c"], + hdrs = ["include/libunwind-ptrace.h"], + includes = ["include"], + local_defines = libunwind_defines + ["HAVE_TTRACE"], + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +cc_library( + name = "unwind_setjmp", + srcs = setjmp_srcs + ["src/mi/init.c"], + local_defines = libunwind_defines, + deps = [ + ":unwind", + "@xz//:lzma", + "@zlib", + ], +) + +alias( + name = "libunwind", + actual = ":unwind", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_coredump", + actual = ":unwind_coredump", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_ptrace", + actual = ":unwind_ptrace", + visibility = ["//visibility:public"], +) + +alias( + name = "libunwind_setjmp", + actual = ":unwind_setjmp", + visibility = ["//visibility:public"], +) + +### Tests ##################################################################### + +# TODO: the following tests currently do not work: +# crasher.c +# forker.c +# test-coredump-unwind.c +# test-init-remote.c +# test-proc-info.c +# test-ptrace-misc.c +# test-ptrace.c +# test-setjmp.c + +# Some tests require these test helper files. +cc_library( + name = "unwind_test_helpers", + testonly = True, + srcs = [ + "tests/flush-cache.S", + "tests/ident.c", + ], + hdrs = glob(["tests/*G*.c"]) + [ + "tests/Gtest-init.cxx", + "tests/flush-cache.h", + "tests/ident.h", + ], + defines = ["_GNU_SOURCE"], +) + +cc_test( + name = "frame_record_test", + srcs = ["tests/aarch64-test-frame-record.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "plt_test", + srcs = ["tests/aarch64-test-plt.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "arm64_sve_signal_test", + srcs = ["tests/Larm64-test-sve-signal.c"], + target_compatible_with = ["@platforms//cpu:arm64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "perf_simple_test", + srcs = ["tests/Lperf-simple.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "perf_trace_test", + srcs = ["tests/Lperf-trace.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "rs_race_test", + srcs = ["tests/Lrs-race.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "bt_test", + srcs = ["tests/Ltest-bt.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "concurrent_test", + srcs = ["tests/Ltest-concurrent.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "cxx_exceptions_test", + srcs = ["tests/Ltest-cxx-exceptions.cxx"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +# TODO: fails on Debian 11 and Ubuntu 20.04. +# cc_test( +# name = "dyn1_test", +# srcs = ["tests/Ltest-dyn1.c"], +# deps = [ +# ":unwind", +# ":unwind_test_helpers", +# ], +# linkopts = ["-ldl"], +# ) + +cc_test( + name = "exc_test", + srcs = ["tests/Ltest-exc.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "init_local_signal_test", + srcs = [ + "tests/Ltest-init-local-signal.c", + "tests/Ltest-init-local-signal-lib.c", + ], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "init_test", + srcs = ["tests/Ltest-init.cxx"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "mem_validate_test", + srcs = ["tests/Ltest-mem-validate.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "nocalloc_test", + srcs = ["tests/Ltest-nocalloc.c"], + linkopts = ["-ldl"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "nomalloc_test", + srcs = ["tests/Ltest-nomalloc.c"], + linkopts = ["-ldl"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "resume_sig_rt_test", + srcs = ["tests/Ltest-resume-sig-rt.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "resume_sig_test", + srcs = ["tests/Ltest-resume-sig.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "trace_test", + srcs = ["tests/Ltest-trace.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "varargs_test", + srcs = ["tests/Ltest-varargs.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "dwarf_expressions_test", + srcs = [ + "tests/Lx64-test-dwarf-expressions.c", + "tests/x64-test-dwarf-expressions.S", + ], + target_compatible_with = ["@platforms//cpu:x86_64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "async_sig_test", + srcs = ["tests/test-async-sig.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "flush_cache_test", + srcs = ["tests/test-flush-cache.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "mem_test", + srcs = ["tests/test-mem.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "reg_state_test", + srcs = ["tests/test-reg-state.c"], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "static_link_loc_test", + srcs = [ + "tests/test-static-link-gen.c", + "tests/test-static-link-loc.c", + ], + local_defines = ["UNW_LOCAL_ONLY"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "strerror_test", + srcs = ["tests/test-strerror.c"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) + +cc_test( + name = "unwind_badjmp_signal_frame_test", + srcs = ["tests/x64-unwind-badjmp-signal-frame.c"], + target_compatible_with = ["@platforms//cpu:x86_64"], + deps = [ + ":unwind", + ":unwind_test_helpers", + ], +) diff --git a/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/MODULE.bazel b/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/MODULE.bazel new file mode 100644 index 0000000000..948094ed2b --- /dev/null +++ b/registry/modules/libunwind/1.8.3.brpc-no-unwind/overlay/MODULE.bazel @@ -0,0 +1,21 @@ +# Forked from the Bazel Central Registry: +# https://github.com/bazelbuild/bazel-central-registry/blob/main/modules/libunwind/1.8.3/overlay/MODULE.bazel +# Distributed under the Apache License, Version 2.0. See the `registry/modules/ +# libunwind/**` entry near the bottom of brpc's LICENSE file. +# +# brpc modification (Apache License 2.0 §4(b)): +# - `version` suffixed with `.brpc-no-unwind` to distinguish this brpc +# variant from the upstream BCR version. + +module( + name = "libunwind", + version = "1.8.3.brpc-no-unwind", + bazel_compatibility = [">=7.2.1"], # need support for "overlay" directory + compatibility_level = 0, +) + +bazel_dep(name = "bazel_skylib", version = "1.7.1") +bazel_dep(name = "platforms", version = "0.0.11") +bazel_dep(name = "rules_cc", version = "0.1.1") +bazel_dep(name = "xz", version = "5.4.5.bcr.8") +bazel_dep(name = "zlib", version = "1.3.1.bcr.8") diff --git a/registry/modules/libunwind/1.8.3.brpc-no-unwind/presubmit.yml b/registry/modules/libunwind/1.8.3.brpc-no-unwind/presubmit.yml new file mode 100644 index 0000000000..3a0544118e --- /dev/null +++ b/registry/modules/libunwind/1.8.3.brpc-no-unwind/presubmit.yml @@ -0,0 +1,18 @@ +matrix: + platform: + - debian11 + - ubuntu2004 + - ubuntu2204 + - ubuntu2404 + bazel: [7.x, 8.x, rolling] +tasks: + verify_targets: + platform: ${{ platform }} + bazel: ${{ bazel }} + build_targets: + - "@libunwind//:libunwind" + - "@libunwind//:libunwind_coredump" + - "@libunwind//:libunwind_ptrace" + - "@libunwind//:libunwind_setjmp" + test_targets: + - "@libunwind//..." diff --git a/registry/modules/libunwind/1.8.3.brpc-no-unwind/source.json b/registry/modules/libunwind/1.8.3.brpc-no-unwind/source.json new file mode 100644 index 0000000000..aa6ba90503 --- /dev/null +++ b/registry/modules/libunwind/1.8.3.brpc-no-unwind/source.json @@ -0,0 +1,9 @@ +{ + "url": "https://github.com/libunwind/libunwind/releases/download/v1.8.3/libunwind-1.8.3.tar.gz", + "strip_prefix": "libunwind-1.8.3", + "overlay": { + "BUILD.bazel": "sha256-ozAABv0I3+tacve4z2m+lAh/lhJWh7t7L3unh8FSfp8=", + "MODULE.bazel": "sha256-qimDkl8soxNM/JuqndzGOOcn5ZwGvo2J6CSsNS2/IQs=" + }, + "integrity": "sha256-vjDZEOZ/WNgudTIx8TV/MmoaCIrPEmsh/3fmCqsZuQs=" +} diff --git a/registry/modules/libunwind/metadata.json b/registry/modules/libunwind/metadata.json new file mode 100644 index 0000000000..a25cccf6d4 --- /dev/null +++ b/registry/modules/libunwind/metadata.json @@ -0,0 +1,19 @@ +{ + "homepage": "https://www.nongnu.org/libunwind/", + "maintainers": [ + { + "email": "vtsao@openai.com", + "github": "vtsao-openai", + "github_user_id": 176426301, + "name": "Vincent Tsao" + } + ], + "repository": [ + "github:libunwind/libunwind" + ], + "versions": [ + "1.8.1.brpc-no-unwind", + "1.8.3.brpc-no-unwind" + ], + "yanked_versions": {} +} diff --git a/src/brpc/adaptive_max_concurrency.cpp b/src/brpc/adaptive_max_concurrency.cpp index d1be273032..6a5977bcf4 100644 --- a/src/brpc/adaptive_max_concurrency.cpp +++ b/src/brpc/adaptive_max_concurrency.cpp @@ -25,18 +25,15 @@ namespace brpc { -const std::string AdaptiveMaxConcurrency::UNLIMITED = "unlimited"; -const std::string AdaptiveMaxConcurrency::CONSTANT = "constant"; - AdaptiveMaxConcurrency::AdaptiveMaxConcurrency() - : _value(UNLIMITED) + : _value(UNLIMITED()) , _max_concurrency(0) { } AdaptiveMaxConcurrency::AdaptiveMaxConcurrency(int max_concurrency) : _max_concurrency(0) { if (max_concurrency <= 0) { - _value = UNLIMITED; + _value = UNLIMITED(); _max_concurrency = 0; } else { _value = butil::string_printf("%d", max_concurrency); @@ -83,7 +80,7 @@ void AdaptiveMaxConcurrency::operator=(const butil::StringPiece& value) { void AdaptiveMaxConcurrency::operator=(int max_concurrency) { if (max_concurrency <= 0) { - _value = UNLIMITED; + _value = UNLIMITED(); _max_concurrency = 0; } else { _value = butil::string_printf("%d", max_concurrency); @@ -105,9 +102,9 @@ void AdaptiveMaxConcurrency::operator=(const TimeoutConcurrencyConf& value) { const std::string& AdaptiveMaxConcurrency::type() const { if (_max_concurrency > 0) { - return CONSTANT; + return CONSTANT(); } else if (_max_concurrency == 0) { - return UNLIMITED; + return UNLIMITED(); } else { return _value; } diff --git a/src/brpc/adaptive_max_concurrency.h b/src/brpc/adaptive_max_concurrency.h index c90def6106..2bdcd90723 100644 --- a/src/brpc/adaptive_max_concurrency.h +++ b/src/brpc/adaptive_max_concurrency.h @@ -65,9 +65,17 @@ class AdaptiveMaxConcurrency{ // "unlimited", "constant" or "user-defined" const std::string& type() const; - // Get strings filled with "unlimited" and "constant" - static const std::string UNLIMITED;// = "unlimited"; - static const std::string CONSTANT;// = "constant"; + // Use Meyers' Singleton to avoid static initialization order fiasco: + // global AdaptiveMaxConcurrency objects in other translation units may + // depend on these strings during their construction. + static const std::string& UNLIMITED() { + static const std::string s_unlimited = "unlimited"; + return s_unlimited; + } + static const std::string& CONSTANT() { + static const std::string s_constant = "constant"; + return s_constant; + } void SetConcurrencyLimiter(ConcurrencyLimiter* cl) { _cl = cl; } diff --git a/src/brpc/input_messenger.cpp b/src/brpc/input_messenger.cpp index c249cca22c..fa05423640 100644 --- a/src/brpc/input_messenger.cpp +++ b/src/brpc/input_messenger.cpp @@ -312,7 +312,7 @@ int InputMessenger::ProcessNewMessage( // not in the bthread where the polling bthread is located, because the // method for processing messages may call synchronization primitives, // causing the polling bthread to be scheduled out. - if (m->_socket_mode == SOCKET_MODE_RDMA) { + if (m->_socket_mode == SOCKET_MODE_RDMA || m->_socket_mode == SOCKET_MODE_UBRING) { m->_transport->QueueMessage(last_msg, &num_bthread_created, true); } if (num_bthread_created) { diff --git a/src/brpc/input_messenger.h b/src/brpc/input_messenger.h index 8482c3f3fc..5203c02505 100644 --- a/src/brpc/input_messenger.h +++ b/src/brpc/input_messenger.h @@ -29,6 +29,9 @@ namespace brpc { namespace rdma { class RdmaEndpoint; } +namespace ubring { +class UBShmEndpoint; +} class TcpTransport; struct InputMessageHandler { // The callback to cut a message from `source'. @@ -93,6 +96,7 @@ class InputMessenger : public SocketUser { friend class Socket; friend class TcpTransport; friend class rdma::RdmaEndpoint; +friend class ubring::UBShmEndpoint; public: explicit InputMessenger(size_t capacity = 128); ~InputMessenger(); diff --git a/src/brpc/memcache.cpp b/src/brpc/memcache.cpp index 489d84db16..c7d6f836b1 100644 --- a/src/brpc/memcache.cpp +++ b/src/brpc/memcache.cpp @@ -17,6 +17,7 @@ #include "brpc/memcache.h" +#include #include "brpc/policy/memcache_binary_header.h" #include "brpc/proto_base.pb.h" #include "butil/logging.h" diff --git a/src/brpc/nonreflectable_message.h b/src/brpc/nonreflectable_message.h index 7f2acd78a3..089b23957d 100644 --- a/src/brpc/nonreflectable_message.h +++ b/src/brpc/nonreflectable_message.h @@ -129,7 +129,7 @@ class NonreflectableMessage : public ::google::protobuf::Message { void DiscardUnknownFields() override {} #endif -#if GOOGLE_PROTOBUF_VERSION < 5026000 +#if GOOGLE_PROTOBUF_VERSION >= 3004000 && GOOGLE_PROTOBUF_VERSION < 5026000 // Unsupported by default. size_t SpaceUsedLong() const override { return 0; @@ -163,9 +163,19 @@ class NonreflectableMessage : public ::google::protobuf::Message { #endif // Size of bytes after serialization. +#if GOOGLE_PROTOBUF_VERSION < 3004000 + virtual size_t ByteSizeLong() const { + return 0; + } + + int ByteSize() const override { + return static_cast(ByteSizeLong()); + } +#else size_t ByteSizeLong() const override { return 0; } +#endif #if GOOGLE_PROTOBUF_VERSION >= 3007000 && GOOGLE_PROTOBUF_VERSION < 3010000 void SerializeWithCachedSizes(::google::protobuf::io::CodedOutputStream*) const override {} diff --git a/src/brpc/policy/constant_concurrency_limiter.cpp b/src/brpc/policy/constant_concurrency_limiter.cpp index 425c5f34c3..7d73e2ec8c 100644 --- a/src/brpc/policy/constant_concurrency_limiter.cpp +++ b/src/brpc/policy/constant_concurrency_limiter.cpp @@ -43,7 +43,7 @@ int ConstantConcurrencyLimiter::ResetMaxConcurrency( ConstantConcurrencyLimiter* ConstantConcurrencyLimiter::New(const AdaptiveMaxConcurrency& amc) const { - CHECK_EQ(amc.type(), AdaptiveMaxConcurrency::CONSTANT); + CHECK_EQ(amc.type(), AdaptiveMaxConcurrency::CONSTANT()); return new ConstantConcurrencyLimiter(static_cast(amc)); } diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index b03a961b52..e5e22d924a 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -1387,10 +1387,12 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { http_request->header().uri().path(), server, NULL); if (mp != NULL && mp->is_builtin_service && mp->service->GetDescriptor() != BadMethodService::descriptor()) { - // BuiltinService doesn't need authentication - // TODO: Fix backdoor that sends BuiltinService at first - // and then sends other requests without authentication - return true; + // Builtin services on internal_port doesn't need authentication + // Builtin services on the public listener must pass authentication + if (server->options().internal_port >= 0 && + socket->local_side().port == server->options().internal_port) { + return true; + } } const std::string *authorization diff --git a/src/brpc/rdma/block_pool.cpp b/src/brpc/rdma/block_pool.cpp index 24907a194a..36c763ec13 100644 --- a/src/brpc/rdma/block_pool.cpp +++ b/src/brpc/rdma/block_pool.cpp @@ -41,6 +41,8 @@ DEFINE_int32(rdma_memory_pool_tls_cache_num, 128, "Number of cached block in tls DEFINE_bool(rdma_memory_pool_user_specified_memory, false, "If true, the user must call UserExtendBlockPool() to extend " "memory. bRPC will not handle memory extension."); +DEFINE_string(rdma_recv_block_type, "default", "Default size type for recv WR: " + "default(8KB - 32B)/large(64KB - 32B)/huge(2MB - 32B)"); static RegisterCallback g_cb = NULL; @@ -48,8 +50,8 @@ static RegisterCallback g_cb = NULL; static const size_t BYTES_IN_MB = 1048576; static const int BLOCK_DEFAULT = 0; // 8KB -// static const int BLOCK_LARGE = 1; // 64KB -// static const int BLOCK_HUGE = 2; // 2MB +static const int BLOCK_LARGE = 1; // 64KB +static const int BLOCK_HUGE = 2; // 2MB static const int BLOCK_SIZE_COUNT = 3; static size_t g_block_size[BLOCK_SIZE_COUNT] = { 8192, 65536, 2 * BYTES_IN_MB }; @@ -183,7 +185,7 @@ static void* ExtendBlockPoolImpl(void* region_base, size_t region_size, int bloc // Extend the block pool with a new region (with different region ID) static void* ExtendBlockPool(size_t region_size, int block_type) { - if (region_size < 1) { + if (region_size < 1 || block_type < 0) { errno = EINVAL; return NULL; } @@ -325,7 +327,7 @@ bool InitBlockPool(RegisterCallback cb) { } if (ExtendBlockPool(FLAGS_rdma_memory_pool_initial_size_mb, - BLOCK_DEFAULT) != NULL) { + GetRdmaBlockType()) != NULL) { return true; } return false; @@ -541,6 +543,34 @@ size_t GetBlockSize(int type) { return g_block_size[type]; } +size_t GetRdmaBlockSize() { + if (FLAGS_rdma_recv_block_type == "default") { + return GetBlockSize(0); + } else if (FLAGS_rdma_recv_block_type == "large") { + return GetBlockSize(1); + } else if (FLAGS_rdma_recv_block_type == "huge") { + return GetBlockSize(2); + } else { + LOG(ERROR) << "rdma_recv_block_type incorrect " + << "(valid value: default/large/huge)"; + return 0; + } +} + +int GetRdmaBlockType() { + if (FLAGS_rdma_recv_block_type == "default") { + return BLOCK_DEFAULT; + } else if (FLAGS_rdma_recv_block_type == "large") { + return BLOCK_LARGE; + } else if (FLAGS_rdma_recv_block_type == "huge") { + return BLOCK_HUGE; + } else { + LOG(ERROR) << "rdma_recv_block_type incorrect " + << "(valid value: default/large/huge)"; + return -1; + } +} + void DumpMemoryPoolInfo(std::ostream& os) { if (!g_dump_mutex) { return; diff --git a/src/brpc/rdma/block_pool.h b/src/brpc/rdma/block_pool.h index f9018e5ecc..c9589fb035 100644 --- a/src/brpc/rdma/block_pool.h +++ b/src/brpc/rdma/block_pool.h @@ -95,11 +95,15 @@ int DeallocBlock(void* buf); uint32_t GetRegionId(const void* buf); // Return the block size of given block type -// type=1: BLOCK_DEFAULT(8KB) -// type=2: BLOCK_LARGE(64KB) -// type=3: BLOCK_HUGE(2MB) +// type=0: BLOCK_DEFAULT(8KB) +// type=1: BLOCK_LARGE(64KB) +// type=2: BLOCK_HUGE(2MB) size_t GetBlockSize(int type); +size_t GetRdmaBlockSize(); + +int GetRdmaBlockType(); + // Dump memory pool information void DumpMemoryPoolInfo(std::ostream& os); diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp index c69bf8ec07..658c7a2fcc 100644 --- a/src/brpc/rdma/rdma_endpoint.cpp +++ b/src/brpc/rdma/rdma_endpoint.cpp @@ -31,6 +31,7 @@ #include "brpc/rdma/rdma_helper.h" #include "brpc/rdma/rdma_endpoint.h" #include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.h" DECLARE_int32(task_group_ntags); @@ -55,8 +56,6 @@ DEFINE_int32(rdma_sq_size, 128, "SQ size for RDMA"); DEFINE_int32(rdma_rq_size, 128, "RQ size for RDMA"); DEFINE_bool(rdma_recv_zerocopy, true, "Enable zerocopy for receive side"); DEFINE_int32(rdma_zerocopy_min_size, 512, "The minimal size for receive zerocopy"); -DEFINE_string(rdma_recv_block_type, "default", "Default size type for recv WR: " - "default(8KB - 32B)/large(64KB - 32B)/huge(2MB - 32B)"); DEFINE_int32(rdma_cqe_poll_once, 32, "The maximum of cqe number polled once."); DEFINE_int32(rdma_prepared_qp_size, 128, "SQ and RQ size for prepared QP."); DEFINE_int32(rdma_prepared_qp_cnt, 1024, "Initial count of prepared QP."); @@ -72,84 +71,30 @@ static const size_t IOBUF_BLOCK_HEADER_LEN = 32; // implementation-dependent // DO NOT change this value unless you know the safe value!!! // This is the number of reserved WRs in SQ/RQ for pure ACK. -static const size_t RESERVED_WR_NUM = 3; - -// magic string RDMA (4B) -// message length (2B) -// hello version (2B) -// impl version (2B): 0 means should use tcp -// block size (4B) -// sq size (2B) -// rq size (2B) -// GID (16B) -// QP number (4B) -static const char* MAGIC_STR = "RDMA"; -static const size_t MAGIC_STR_LEN = 4; -static const size_t HELLO_MSG_LEN_MIN = 40; -// static const size_t HELLO_MSG_LEN_MAX = 4096; -static const size_t ACK_MSG_LEN = 4; -static uint16_t g_rdma_hello_msg_len = 40; // In Byte -static uint16_t g_rdma_hello_version = 2; -static uint16_t g_rdma_impl_version = 1; -static uint32_t g_rdma_recv_block_size = 0; +extern const size_t RESERVED_WR_NUM = 3; + +// The local recv block size, set during GlobalInitialize. +uint32_t g_rdma_recv_block_size = 0; // static const uint32_t MAX_INLINE_DATA = 64; static const uint8_t MAX_HOP_LIMIT = 16; static const uint8_t TIMEOUT = 14; static const uint8_t RETRY_CNT = 7; -static const uint16_t MIN_QP_SIZE = 16; +extern const uint16_t MIN_QP_SIZE = 16; static const uint16_t MAX_QP_SIZE = 4096; -static const uint16_t MIN_BLOCK_SIZE = 1024; -static const uint32_t ACK_MSG_RDMA_OK = 0x1; +extern const uint16_t MIN_BLOCK_SIZE = 1024; + +// ACK message wire format (shared by all protocol versions): a single +// 4B big-endian flags word; bit 0 (HELLO_ACK_RDMA_OK) indicates the +// sender wants to use RDMA. The state machines in +// ProcessHandshakeAt{Client,Server} inline the corresponding 4B +// send/recv directly using ReadFromFd / WriteToFd. +static const size_t HELLO_ACK_LEN = 4; +static const uint32_t HELLO_ACK_RDMA_OK = 0x1; static butil::Mutex* g_rdma_resource_mutex = NULL; static RdmaResource* g_rdma_resource_list = NULL; -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; - -void HelloMessage::Serialize(void* data) const { - uint16_t* current_pos = (uint16_t*)data; - *(current_pos++) = butil::HostToNet16(msg_len); - *(current_pos++) = butil::HostToNet16(hello_ver); - *(current_pos++) = butil::HostToNet16(impl_ver); - uint32_t* block_size_pos = (uint32_t*)current_pos; - *block_size_pos = butil::HostToNet32(block_size); - current_pos += 2; // move forward 4 Bytes - *(current_pos++) = butil::HostToNet16(sq_size); - *(current_pos++) = butil::HostToNet16(rq_size); - *(current_pos++) = butil::HostToNet16(lid); - memcpy(current_pos, gid.raw, 16); - uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); - *qp_num_pos = butil::HostToNet32(qp_num); -} - -void HelloMessage::Deserialize(void* data) { - uint16_t* current_pos = (uint16_t*)data; - msg_len = butil::NetToHost16(*current_pos++); - hello_ver = butil::NetToHost16(*current_pos++); - impl_ver = butil::NetToHost16(*current_pos++); - block_size = butil::NetToHost32(*(uint32_t*)current_pos); - current_pos += 2; // move forward 4 Bytes - sq_size = butil::NetToHost16(*current_pos++); - rq_size = butil::NetToHost16(*current_pos++); - lid = butil::NetToHost16(*current_pos++); - memcpy(gid.raw, current_pos, 16); - qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); -} - RdmaResource::~RdmaResource() { if (NULL != qp) { IbvDestroyQp(qp); @@ -171,6 +116,7 @@ RdmaResource::~RdmaResource() { RdmaEndpoint::RdmaEndpoint(Socket* s) : _socket(s) , _state(UNINIT) + , _handshake_version(0) , _resource(NULL) , _send_cq_events(0) , _recv_cq_events(0) @@ -350,31 +296,34 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { } } -bool HelloNegotiationValid(HelloMessage& msg) { - if (msg.hello_ver == g_rdma_hello_version && - msg.impl_ver == g_rdma_impl_version && - msg.block_size >= MIN_BLOCK_SIZE && - msg.sq_size >= MIN_QP_SIZE && - msg.rq_size >= MIN_QP_SIZE) { - // This can be modified for future compatibility - return true; - } - return false; -} - static const int WAIT_TIMEOUT_MS = 50; -int RdmaEndpoint::ReadFromFd(void* data, size_t len) { - CHECK(data != NULL); - int nr = 0; +// Drive an EAGAIN-aware read loop to completion (exactly `len` bytes). +// `read_once(offset, remaining)` performs ONE underlying read attempt: +// - returns > 0 : number of bytes consumed (added to running total); +// - returns = 0 : end-of-stream (the loop fails with EEOF); +// - returns < 0 : errno set; EAGAIN is handled here via butex_wait, +// any other errno bubbles up. +// `offset` is bytes already received in THIS call (initially 0); the +// callable uses it to choose the next write target (e.g. `(char*)buf +// + offset`). Callables that don't need offset (e.g. IOPortal append) +// can ignore it. +// +// Centralizes the EAGAIN/butex/EOF loop so the two ReadFromFd +// overloads below stay one-liners; any future read source (memory- +// mapped, scatter-vector, etc.) can plug in by passing its own +// `read_once`. +template +static int ReadFromFdLoop(butil::atomic* read_butex, + size_t len, ReadOnce&& read_once) { size_t received = 0; - do { - const int expected_val = _read_butex->load(butil::memory_order_acquire); + while (received < len) { + const int expected_val = read_butex->load(butil::memory_order_acquire); const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nr = read(_socket->fd(), (uint8_t*)data + received, len - received); + ssize_t nr = read_once(received, len - received); if (nr < 0) { if (errno == EAGAIN) { - if (bthread::butex_wait(_read_butex, expected_val, &duetime) < 0) { + if (bthread::butex_wait(read_butex, expected_val, &duetime) < 0) { if (errno != EWOULDBLOCK && errno != ETIMEDOUT) { return -1; } @@ -388,34 +337,89 @@ int RdmaEndpoint::ReadFromFd(void* data, size_t len) { } else { received += nr; } - } while (received < len); + } return 0; } -int RdmaEndpoint::WriteToFd(void* data, size_t len) { +int RdmaEndpoint::ReadFromFd(void* data, size_t len) { + CHECK(data != NULL); + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t offset, size_t remaining) { + return read(fd, (uint8_t*)data + offset, remaining); + }); +} + +int RdmaEndpoint::ReadFromFd(butil::IOPortal* data, size_t len) { CHECK(data != NULL); - int nw = 0; + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t /*offset*/, size_t remaining) { + return data->append_from_file_descriptor(fd, remaining); + }); +} + +// Drive an EAGAIN-aware write loop to completion (exactly `len` bytes). +// +// `write_once(offset, remaining)` performs ONE underlying write attempt: +// - returns >= 0 : number of bytes consumed (added to running total); +// - returns < 0 : errno set; EAGAIN triggers `wait_writable(duetime)`, +// any other errno bubbles up. +// `offset` is bytes already written in THIS call (initially 0); the +// callable uses it to choose the next read source (e.g. `(char*)buf +// + offset`). Callables that drain a self-tracking sink (e.g. +// IOBuf::cut_into_file_descriptor) can ignore both args. +// +// `wait_writable(duetime)` is invoked on EAGAIN to park until the fd +// becomes writable again. It returns 0 on wake-up (or ETIMEDOUT), +// non-zero on hard failure. +template +static int WriteToFdLoop(size_t len, WriteOnce&& write_once, WaitWritable&& wait_writable) { size_t written = 0; - do { + while (written < len) { const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nw = write(_socket->fd(), (uint8_t*)data + written, len - written); - if (nw < 0) { - if (errno == EAGAIN) { - if (_socket->WaitEpollOut(_socket->fd(), true, &duetime) < 0) { - if (errno != ETIMEDOUT) { - return -1; - } - } - } else { - return -1; - } - } else { + ssize_t nw = write_once(written, len - written); + if (nw >= 0) { written += nw; + continue; } - } while (written < len); + + if (errno != EAGAIN) { + return -1; + } + if (!wait_writable(&duetime)) { + return -1; + } + } return 0; } +int RdmaEndpoint::WriteToFd(void* data, size_t len) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(len, + [data, fd](size_t offset, size_t remaining) { + return write(fd, (uint8_t*)data + offset, remaining); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + +int RdmaEndpoint::WriteToFd(butil::IOBuf* data) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(data->size(), + [data, fd](size_t /*offset*/, size_t /*remaining*/) { + return data->cut_into_file_descriptor(fd); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + inline void RdmaEndpoint::TryReadOnTcp() { if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { if (_state == FALLBACK_TCP) { @@ -426,19 +430,52 @@ inline void RdmaEndpoint::TryReadOnTcp() { } } +void RdmaEndpoint::ApplyRemoteHello(const ParsedHello& remote) { + _remote_recv_block_size = remote.block_size; + _local_window_capacity = + std::min(_sq_size, remote.rq_size) - RESERVED_WR_NUM; + _remote_window_capacity = + std::min(_rq_size, remote.sq_size) - RESERVED_WR_NUM; + _sq_imm_window_size = RESERVED_WR_NUM; + _remote_rq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); + _sq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); +} + +// Client-side handshake entry: the state machine. +// +// C_ALLOC_QPCQ +// | +// v +// C_HELLO_SEND (hs->SendLocalHello) +// | +// v +// C_HELLO_WAIT (hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + C_BRINGUP_QP] +// | +// v +// C_ACK_SEND +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); RdmaConnect::RunGuard rg((RdmaConnect*)s->_app_connect.get()); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Start handshake on " << s->_local_side; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; + std::unique_ptr handshake = CreateClientHandshake(ep); + CHECK(handshake != NULL); + ep->_handshake_version = handshake->ProtocolVersion(); - // First initialize CQ and QP resources + // First initialize CQ and QP resources. ep->_state = C_ALLOC_QPCQ; - auto* rdma_transport = static_cast(s->_transport.get()); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -448,94 +485,40 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send hello message to server ep->_state = C_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Check magic str + // Receive and parse remote hello. ep->_state = C_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get hello message from server:" << s->description(); - s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG(WARNING) << "Read unexpected data during handshake:" << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - - // Read hello message from server - if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get Hello Message from server:" << s->description(); + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from server:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized data - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = C_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { - LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { + LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" + << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; @@ -544,28 +527,26 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send ACK message to server ep->_state = C_ACK_SEND; - uint32_t flags = 0; - if (rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF) { - flags |= ACK_MSG_RDMA_OK; - } - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - *tmp = butil::HostToNet32(flags); - if (ep->WriteToFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Ack Message to server:" << s->description(); + uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? HELLO_ACK_RDMA_OK : 0; + uint32_t flags_be = butil::HostToNet32(flags); + if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Ack Message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) { ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Client handshake ends (use rdma) on " << s->description(); + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Client handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use tcp) on " << s->description(); } @@ -574,77 +555,75 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { return NULL; } +// Server-side handshake entry: the state machine. +// +// S_HELLO_WAIT (read magic + dispatch + hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + S_ALLOC_QPCQ + S_BRINGUP_QP] +// | +// v +// S_HELLO_SEND (hs->SendLocalHello) +// | +// v +// S_ACK_WAIT +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; - ep->_state = S_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description() << " " << s->_remote_side; + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" + << s->description() << " " << s->_remote_side; s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - auto* rdma_transport = static_cast(s->_transport.get()); - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "It seems that the " - << "client does not use RDMA, fallback to TCP:" + + // Dispatch on magic, or fall back to TCP + std::unique_ptr handshake = CreateServerHandshakeByMagic(ep, magic); + if (!handshake) { + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "It seems that the client does not use RDMA, fallback to TCP:" << s->description(); - // we need to copy data read back to _socket->_read_buf - s->_read_buf.append(data, MAGIC_STR_LEN); + // We need to copy data read back to _socket->_read_buf. + s->_read_buf.append(magic, MAGIC_STR_LEN); ep->_state = FALLBACK_TCP; rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->TryReadOnTcp(); return NULL; } + ep->_handshake_version = handshake->ProtocolVersion(); - if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description(); + // Magic was already consumed above; the subclass MUST NOT re-read it. + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized header - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = S_ALLOC_QPCQ; if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" @@ -652,7 +631,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_state = S_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -660,73 +639,55 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { } } - // Send hello message to client ep->_state = S_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - local_msg.impl_ver = 0; - local_msg.hello_ver = 0; - } else { - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Hello Message to client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Recv ACK Message ep->_state = S_ACK_WAIT; - if (ep->ReadFromFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read ack message from client:" << s->description(); + uint32_t flags_be = 0; + if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read ack message from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } + uint32_t flags = butil::NetToHost32(flags_be); + bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0; - // Check RDMA enable flag - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - uint32_t flags = butil::NetToHost32(*tmp); - if (flags & ACK_MSG_RDMA_OK) { + if (client_ack_ok) { if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); + // Client asked for RDMA but we are falling back: protocol + // breakdown, abort the connection so the client sees a + // clean error rather than a half-up RDMA channel. + LOG(WARNING) << "Client wants RDMA in ACK but server is in " + << "RDMA_OFF state: " << s->description(); s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(EPROTO)); ep->_state = FAILED; return NULL; - } else { - rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; - ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Server handshake ends (use rdma) on " << s->description(); } + rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Server handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use tcp) on " << s->description(); } + ep->TryReadOnTcp(); return NULL; @@ -1078,7 +1039,7 @@ int RdmaEndpoint::PostRecv(uint32_t num, bool zerocopy) { PLOG(WARNING) << "Fail to allocate rbuf"; return -1; } else { - CHECK(static_cast(size) == g_rdma_recv_block_size) << size; + CHECK_EQ(static_cast(size), g_rdma_recv_block_size); } } if (DoPostRecv(_rbuf_data[_rq_received], g_rdma_recv_block_size) < 0) { @@ -1334,6 +1295,21 @@ static void DeallocateCq(ibv_cq* cq) { LOG_IF(WARNING, 0 != err) << "Fail to destroy CQ: " << berror(err); } +static int DrainCq(ibv_cq* cq) { + if (NULL == cq) { + return 0; + } + + ibv_wc wc; + int ret; + do { + ret = ibv_poll_cq(cq, 1, &wc); + } while (ret > 0); + + LOG_IF(ERROR, ret < 0) << "drain CQ failed: " << ret; + return ret; +} + void RdmaEndpoint::DeallocateResources() { if (!_resource) { return; @@ -1360,6 +1336,7 @@ void RdmaEndpoint::DeallocateResources() { } bool remove_consumer = true; +_reclaim: if (!move_to_rdma_resource_list) { if (NULL != _resource->qp) { int err = IbvDestroyQp(_resource->qp); @@ -1403,6 +1380,24 @@ void RdmaEndpoint::DeallocateResources() { } if (move_to_rdma_resource_list) { + // When a QP is moved to the RESET state, all associated send and + // receive queues are flushed, meaning any outstanding WRs are effectively + // abandoned by the hardware. + // + // However, the CQ associated with that QP is *not* cleared automatically, + // meaning that it will still contain entries for WRs that completed before + // the reset. + // + // The application should finish polling the CQ to remove these obsolete + // entries before reusing the QP. + int ret = DrainCq(_resource->polling_cq); + ret += DrainCq(_resource->send_cq); + ret += DrainCq(_resource->recv_cq); + if (ret < 0) { + move_to_rdma_resource_list = false; + goto _reclaim; + } + BAIDU_SCOPED_LOCK(*g_rdma_resource_mutex); _resource->next = g_rdma_resource_list; g_rdma_resource_list = _resource; @@ -1613,6 +1608,7 @@ std::string RdmaEndpoint::GetStateStr() const { void RdmaEndpoint::DebugInfo(std::ostream& os, butil::StringPiece connector) const { os << "rdma_state=ON" << connector << "handshake_state=" << GetStateStr() + << connector << "handshake_version=" << static_cast(_handshake_version) << connector << "rdma_sq_imm_window_size=" << _sq_imm_window_size << connector << "rdma_remote_rq_window_size=" << _remote_rq_window_size.load(butil::memory_order_relaxed) << connector << "rdma_sq_window_size=" << _sq_window_size.load(butil::memory_order_relaxed) @@ -1628,13 +1624,8 @@ void RdmaEndpoint::DebugInfo(std::ostream& os, butil::StringPiece connector) con } int RdmaEndpoint::GlobalInitialize() { - if (FLAGS_rdma_recv_block_type == "default") { - g_rdma_recv_block_size = GetBlockSize(0) - IOBUF_BLOCK_HEADER_LEN; - } else if (FLAGS_rdma_recv_block_type == "large") { - g_rdma_recv_block_size = GetBlockSize(1) - IOBUF_BLOCK_HEADER_LEN; - } else if (FLAGS_rdma_recv_block_type == "huge") { - g_rdma_recv_block_size = GetBlockSize(2) - IOBUF_BLOCK_HEADER_LEN; - } else { + g_rdma_recv_block_size = GetRdmaBlockSize() - IOBUF_BLOCK_HEADER_LEN; + if (g_rdma_recv_block_size <= 0) { LOG(ERROR) << "rdma_recv_block_type incorrect " << "(valid value: default/large/huge)"; errno = EINVAL; diff --git a/src/brpc/rdma/rdma_endpoint.h b/src/brpc/rdma/rdma_endpoint.h index 54a008f1f7..7b6652bc86 100644 --- a/src/brpc/rdma/rdma_endpoint.h +++ b/src/brpc/rdma/rdma_endpoint.h @@ -40,6 +40,24 @@ DECLARE_bool(rdma_use_polling); DECLARE_int32(rdma_poller_num); DECLARE_bool(rdma_disable_bthread); +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; +struct ParsedHello; +class RdmaHello; +class RdmaEndpoint; +namespace v2_wire { + int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated); + int DrainBytes(RdmaEndpoint* ep, size_t n); +} // namespace v2_wire + +namespace v3_wire { + void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg); + int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out); + int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg); +} // namespace v3_wire + class RdmaConnect : public AppConnect { public: void StartConnect(const Socket* socket, @@ -74,6 +92,15 @@ struct RdmaResource { class BAIDU_CACHELINE_ALIGNMENT RdmaEndpoint : public SocketUser { friend class RdmaConnect; friend class Socket; +friend class RdmaHandshakeClientV2; +friend class RdmaHandshakeServerV2; +friend class RdmaHandshakeClientV3; +friend class RdmaHandshakeServerV3; +friend int v2_wire::ReadBodyAndNegotiate(RdmaEndpoint*, ParsedHello*, bool*); +friend int v2_wire::DrainBytes(RdmaEndpoint*, size_t); +friend void v3_wire::FillLocalRdmaHello(const RdmaEndpoint*, RdmaHello*); +friend int v3_wire::ReadAndParseV3Hello(RdmaEndpoint*, RdmaHello*); +friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&); public: explicit RdmaEndpoint(Socket* s); ~RdmaEndpoint() override; @@ -181,6 +208,7 @@ friend class Socket; // wait for _read_butex if encounter EAGAIN // return -1 if encounter other errno (including EOF) int ReadFromFd(void* data, size_t len); + int ReadFromFd(butil::IOPortal* data, size_t len); // Write at most len bytes from data to fd in _socket @@ -188,6 +216,17 @@ friend class Socket; // return -1 if encounter other errno int WriteToFd(void* data, size_t len); + // Write data to fd in _socket. + // wait for _epollout_butex if encounter EAGAIN. + // return -1 if encounter other errno. + int WriteToFd(butil::IOBuf* data); + + // Copy negotiated remote parameters into the endpoint and compute + // the SQ/RQ window capacities. Called by both + // ProcessHandshakeAtClient and ProcessHandshakeAtServer after the + // peer's hello has been validated. + void ApplyRemoteHello(const ParsedHello& remote); + // Bringup the QP from RESET state to RTS state // Arguments: // lid: remote LID @@ -225,6 +264,13 @@ friend class Socket; // State of Handshake State _state; + // Wire-level handshake protocol version (set by dispatch in + // ProcessHandshakeAtClient/Server). Aligned with the protocol code: + // 0 = unnegotiated + // 2 = v2 "RDMA" + // 3 = v3 "RDM3" + int _handshake_version; + // rdma resource RdmaResource* _resource; diff --git a/src/brpc/rdma/rdma_handshake.cpp b/src/brpc/rdma/rdma_handshake.cpp new file mode 100644 index 0000000000..9bd2312ec4 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.cpp @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_RDMA + +#include "brpc/rdma/rdma_handshake.h" + +#include +#include // std::min +#include +#include +#include +#include "butil/iobuf.h" // IOBuf, IOPortal, IOBufAsZeroCopy*Stream +#include "butil/sys_byteorder.h" +#include "brpc/socket.h" +#include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_helper.h" +#include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.pb.h" + +namespace brpc { +namespace rdma { + +DEFINE_int32(rdma_client_handshake_version, 2, + "RDMA handshake protocol version used by client. " + "2 = legacy 'RDMA' magic (default, compatible with all servers); " + "3 = new 'RDM3' protobuf-based handshake " + "(MUST only be enabled after target servers support v3)."); + +extern const uint16_t MIN_QP_SIZE; +extern const uint16_t MIN_BLOCK_SIZE; +extern uint32_t g_rdma_recv_block_size; +extern bool g_skip_rdma_init; + +// Wire-level constants for the v2 handshake. +static const char* MAGIC_STR = "RDMA"; +static constexpr uint16_t RDMA_HELLO_V2_MSG_LEN = 40; // In Byte +extern const uint16_t RDMA_HELLO_V2_VERSION = 2; +extern const uint16_t RDMA_IMPL_V2_VERSION = 1; + +// Wire-level constants for the v3 handshake. +static const char* MAGIC_STR_V3 = "RDM3"; +static const size_t RDMA_HELLO_V3_PB_SIZE_LEN = 4; +static const size_t RDMA_HELLO_V3_MAX_PB_SIZE = 4096; + +namespace v2_wire { + +void HelloMessage::Serialize(void* data) const { + uint16_t* current_pos = (uint16_t*)data; + *(current_pos++) = butil::HostToNet16(msg_len); + *(current_pos++) = butil::HostToNet16(hello_ver); + *(current_pos++) = butil::HostToNet16(impl_ver); + uint32_t* block_size_pos = (uint32_t*)current_pos; + *block_size_pos = butil::HostToNet32(block_size); + current_pos += 2; // move forward 4 Bytes + *(current_pos++) = butil::HostToNet16(sq_size); + *(current_pos++) = butil::HostToNet16(rq_size); + *(current_pos++) = butil::HostToNet16(lid); + fast_memcpy(current_pos, gid.raw, 16); + uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); + *qp_num_pos = butil::HostToNet32(qp_num); +} + +void HelloMessage::Deserialize(void* data) { + uint16_t* current_pos = (uint16_t*)data; + msg_len = butil::NetToHost16(*current_pos++); + hello_ver = butil::NetToHost16(*current_pos++); + impl_ver = butil::NetToHost16(*current_pos++); + block_size = butil::NetToHost32(*(uint32_t*)current_pos); + current_pos += 2; // move forward 4 Bytes + sq_size = butil::NetToHost16(*current_pos++); + rq_size = butil::NetToHost16(*current_pos++); + lid = butil::NetToHost16(*current_pos++); + fast_memcpy(gid.raw, current_pos, 16); + qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); +} + +static bool ValidHelloMessage(const HelloMessage& msg) { + return msg.hello_ver == RDMA_HELLO_V2_VERSION && + msg.impl_ver == RDMA_IMPL_V2_VERSION && + msg.block_size >= MIN_BLOCK_SIZE && + msg.sq_size >= MIN_QP_SIZE && + msg.rq_size >= MIN_QP_SIZE; +} + +static void TranslateV2Hello(const HelloMessage& msg, ParsedHello* out) { + out->block_size = msg.block_size; + out->sq_size = msg.sq_size; + out->rq_size = msg.rq_size; + out->lid = msg.lid; + out->gid = msg.gid; + out->qp_num = msg.qp_num; +} + +int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated) { + uint8_t data[HELLO_MSG_LEN_MIN]; + if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { + return -1; + } + HelloMessage remote_msg{}; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN || + remote_msg.msg_len > HELLO_MSG_LEN_MAX) { + errno = EPROTO; + return -1; + } + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // Drain unknown trailing bytes so they don't pollute subsequent + // reads (e.g. the upcoming ACK message). v2 base fields already + // carry enough information for negotiation; unknown trailing + // bytes are treated as optional hints that v2 safely ignores. + size_t ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN; + if (DrainBytes(ep, ext_len) < 0) { + return -1; + } + } + if (!ValidHelloMessage(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + TranslateV2Hello(remote_msg, remote); + return 0; +} + +int DrainBytes(RdmaEndpoint* ep, size_t n) { + uint8_t scratch[64]; + while (n > 0) { + size_t chunk = std::min(n, sizeof(scratch)); + if (ep->ReadFromFd(scratch, chunk) < 0) { + return -1; + } + n -= chunk; + } + return 0; +} + +} // namespace v2_wire + +int RdmaHandshakeClientV2::SendLocalHello() { + RdmaEndpoint* ep = _ep; + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = ep->_sq_size; + local_msg.rq_size = ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(ep->_resource)) { + local_msg.qp_num = ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +int RdmaHandshakeClientV2::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + RdmaEndpoint* ep = _ep; + + // Read and verify magic (the endpoint did NOT pre-read magic on the client side). + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + return v2_wire::ReadBodyAndNegotiate(ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + return v2_wire::ReadBodyAndNegotiate(_ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::SendLocalHello() { + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + auto rdma_transport = static_cast(_ep->_socket->_transport.get()); + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { + local_msg.hello_ver = 0; + local_msg.impl_ver = 0; + local_msg.block_size = 0; + local_msg.sq_size = 0; + local_msg.rq_size = 0; + local_msg.lid = 0; + memset(local_msg.gid.raw, 0, sizeof(local_msg.gid.raw)); + local_msg.qp_num = 0; + } else { + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = _ep->_sq_size; + local_msg.rq_size = _ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(_ep->_resource)) { + local_msg.qp_num = _ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return _ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +namespace v3_wire { + +bool ValidRdmaHello(const RdmaHello& msg) { + if (msg.gid().size() != sizeof(ibv_gid)) { + return false; + } + // ParsedHello stores these as uint16_t; reject values that would truncate. + constexpr uint16_t MAX_UINT16 = std::numeric_limits::max(); + if (msg.sq_size() > MAX_UINT16 || msg.rq_size() > MAX_UINT16 || msg.lid() > MAX_UINT16) { + return false; + } + if (msg.block_size() < MIN_BLOCK_SIZE) { + return false; + } + if (msg.sq_size() < MIN_QP_SIZE) { + return false; + } + if (msg.rq_size() < MIN_QP_SIZE) { + return false; + } + // qp_num == 0 only happens in UT (no real QP allocated). + if (msg.qp_num() == 0 && !g_skip_rdma_init) { + return false; + } + return true; +} + +void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg) { + msg->set_block_size(g_rdma_recv_block_size); + msg->set_sq_size(ep->_sq_size); + msg->set_rq_size(ep->_rq_size); + msg->set_lid(GetRdmaLid()); + ibv_gid gid = GetRdmaGid(); + msg->set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + if (BAIDU_LIKELY(ep->_resource)) { + msg->set_qp_num(ep->_resource->qp->qp_num); + } else { + // Only happens in UT + msg->set_qp_num(0); + } +} + +int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out) { + uint8_t size_buf[RDMA_HELLO_V3_PB_SIZE_LEN]; + if (ep->ReadFromFd(size_buf, RDMA_HELLO_V3_PB_SIZE_LEN) < 0) { + return -1; + } + uint32_t pb_size = butil::NetToHost32( + *reinterpret_cast(size_buf)); + if (pb_size == 0 || pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + butil::IOPortal body; + if (ep->ReadFromFd(&body, pb_size) < 0) { + return -1; + } + + butil::IOBufAsZeroCopyInputStream input(body); + if (!out->ParseFromZeroCopyStream(&input)) { + LOG(ERROR) << "Failed to parse RdmaHello"; + errno = EPROTO; + return -1; + } + return 0; +} + +int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg) { + uint32_t pb_size = static_cast(msg.ByteSizeLong()); + if (pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + + // [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] + butil::IOBuf packet; + packet.append(MAGIC_STR_V3, MAGIC_STR_LEN); + uint32_t pb_size_be = butil::HostToNet32(pb_size); + packet.append(&pb_size_be, RDMA_HELLO_V3_PB_SIZE_LEN); + butil::IOBufAsZeroCopyOutputStream output(&packet); + if (!msg.SerializeToZeroCopyStream(&output)) { + LOG(ERROR) << "Failed to serialize RdmaHello"; + errno = EPROTO; + return -1; + } + return ep->WriteToFd(&packet); +} + +void TranslateHello(const RdmaHello& msg, ParsedHello* out) { + out->block_size = msg.block_size(); + out->sq_size = static_cast(msg.sq_size()); + out->rq_size = static_cast(msg.rq_size()); + out->lid = static_cast(msg.lid()); + fast_memcpy(out->gid.raw, msg.gid().data(), sizeof(out->gid.raw)); + out->qp_num = msg.qp_num(); +} + +} // namespace v3_wire + +int RdmaHandshakeClientV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +int RdmaHandshakeClientV3::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + uint8_t magic[MAGIC_STR_LEN]; + if (_ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep) { + switch (FLAGS_rdma_client_handshake_version) { + case 3: + return std::unique_ptr(new RdmaHandshakeClientV3(ep)); + case 2: + default: + return std::unique_ptr(new RdmaHandshakeClientV2(ep)); + } +} + +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]) { + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV2(ep)); + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV3(ep)); + } + return nullptr; +} + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA diff --git a/src/brpc/rdma/rdma_handshake.h b/src/brpc/rdma/rdma_handshake.h new file mode 100644 index 0000000000..5f36a9e6e2 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.h @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_RDMA_HANDSHAKE_H +#define BRPC_RDMA_HANDSHAKE_H + +#if BRPC_WITH_RDMA + +#include +#include +#include +#include +#include "butil/macros.h" + +namespace brpc { +namespace rdma { + +class RdmaEndpoint; + +// Length of the RDMA handshake magic string (e.g. "RDMA", "RDM3"). +static const size_t MAGIC_STR_LEN = 4; + +// Wire-format-agnostic representation of a peer's hello message. +// Each protocol version (v2 binary, v3 protobuf) translates its own +// wire format into this struct so the state-machine driver in +// RdmaEndpoint::ProcessHandshakeAt{Client,Server} stays free of any +// wire-format details. +struct ParsedHello { + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +namespace v2_wire { + +// Wire constants for the v2 hello. +// +// HELLO_MSG_LEN_MIN: total length of the base v2 hello (4B magic + +// 36B HelloMessage). Anything shorter than this is malformed. +// HELLO_MSG_LEN_MAX: upper bound for the entire v2 hello message +// length declared by HelloMessage::msg_len. Anything beyond this is +// treated as a protocol error and the connection is closed without +// attempting to drain. +static constexpr size_t HELLO_MSG_LEN_MIN = 40; +static constexpr size_t HELLO_MSG_LEN_MAX = 4096; + +// v2 binary HelloMessage. +struct HelloMessage { + void Serialize(void* data) const; + void Deserialize(void* data); + + uint16_t msg_len; + uint16_t hello_ver; + uint16_t impl_ver; + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +} // namespace v2_wire + +// Abstract base class of an RDMA handshake. +// +// Acts as the protocol-version dispatch point for the state machine +// driven by RdmaEndpoint::ProcessHandshakeAt{Client,Server}. +class RdmaHandshake { +public: + explicit RdmaHandshake(RdmaEndpoint* ep) : _ep(ep) {} + virtual ~RdmaHandshake() = default; + + DISALLOW_COPY_AND_ASSIGN(RdmaHandshake); + + // Wire-level protocol version (2 for "RDMA", 3 for "RDM3"). + virtual int ProtocolVersion() const = 0; + + // Build and send the local hello (including the protocol magic). + // Returns 0 on success, -1 on IO error (errno set). + // + // For a server in fallback state, implementations MUST still + // produce a sendable message; each version uses its own wire + // convention to signal "I am falling back" to the peer: + // - v2: zero hello_ver/impl_ver so the peer's HelloNegotiationValid + // rejects it; + // - v3: qp_num==0 so the peer's ValidRdmaHello rejects it. + virtual int SendLocalHello() = 0; + + // Read the peer's hello, validate it, and translate into ParsedHello. + // + // Role-specific semantics: + // - Client subclasses: read & verify the 4B magic first, then the + // body. (The endpoint did NOT pre-read the magic on the client + // side.) + // - Server subclasses: read ONLY the body. The 4B magic was + // already consumed by ProcessHandshakeAtServer and was used to + // pick `this` from CreateServerHandshakeByMagic; re-reading + // would deadlock. + // + // Outputs: + // *negotiated -- true if the remote hello is structurally valid + // AND passes per-protocol negotiation checks; + // false means the peer asked for fallback or sent + // something we can't honor. + // Returns: + // 0 -- IO/parsing layer OK; check *negotiated and *remote. + // -1 -- IO error or unrecoverable protocol error (errno set). + virtual int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) = 0; + +protected: + RdmaEndpoint* _ep; +}; + +// v2 handshake (legacy "RDMA" magic, 36B binary HelloMessage). +class RdmaHandshakeClientV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// v3 handshake (new "RDM3" magic, protobuf RdmaHello). +// [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] +class RdmaHandshakeClientV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// Factory methods +// +// Pick the client-side handshake based on +// FLAGS_rdma_client_handshake_version: +// 2 (default) -> RdmaHandshakeClientV2 +// 3 -> RdmaHandshakeClientV3 +// Other values fall back to V2. +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep); + +// Pick the server-side handshake based on the 4B magic already read. +// Returns NULL if `magic` is not a recognized RDMA magic +// (the caller should then fallback to TCP). +// "RDMA" -> RdmaHandshakeServerV2 +// "RDM3" -> RdmaHandshakeServerV3 +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]); + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA +#endif // BRPC_RDMA_HANDSHAKE_H diff --git a/src/brpc/rdma/rdma_handshake.proto b/src/brpc/rdma/rdma_handshake.proto new file mode 100644 index 0000000000..c180b58b96 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.proto @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto2"; + +package brpc.rdma; + +option cc_generic_services = false; + +// RDMA handshake v3 message. +// Carried in the body of every "RDM3" handshake packet: +// +// [ "RDM3" 4B ][ pb_size 4B ][ RdmaHello protobuf bytes ] +message RdmaHello { + // ---- v2-parity base fields (required) ---- + // Listed first and in the same logical order as the v2 binary + // HelloMessage (minus hello_ver / impl_ver, which are subsumed by + // the wrapper magic "RDM3"). Keeping the same ordering simplifies + // side-by-side reasoning when debugging mixed v2/v3 traffic. + // + // Marked `required` because the handshake cannot proceed without + // any of these; ParseFromArray() will reject a missing field at + // the protobuf layer, so we don't need an extra has_xxx() check + // in RdmaHelloValid() for presence. + required uint32 block_size = 1; + required uint32 sq_size = 2; + required uint32 rq_size = 3; + required uint32 lid = 4; + // Must be exactly 16 bytes (sizeof(ibv_gid)). + required bytes gid = 5; + required uint32 qp_num = 6; +} diff --git a/src/brpc/rdma/rdma_helper.cpp b/src/brpc/rdma/rdma_helper.cpp index 768bf615e2..96348902f8 100644 --- a/src/brpc/rdma/rdma_helper.cpp +++ b/src/brpc/rdma/rdma_helper.cpp @@ -65,7 +65,7 @@ int (*IbvQueryQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask, ibv_qp_init_attr*) = int (*IbvDestroyQp)(ibv_qp*) = NULL; ibv_comp_channel* (*IbvCreateCompChannel)(ibv_context*) = NULL; int (*IbvDestroyCompChannel)(ibv_comp_channel*) = NULL; -ibv_mr* (*IbvRegMr)(ibv_pd*, void*, size_t, ibv_access_flags) = NULL; +ibv_mr* (*IbvRegMr)(ibv_pd*, void*, size_t, int) = NULL; int (*IbvDeregMr)(ibv_mr*) = NULL; int (*IbvGetCqEvent)(ibv_comp_channel*, ibv_cq**, void**) = NULL; void (*IbvAckCqEvents)(ibv_cq*, unsigned int) = NULL; @@ -178,10 +178,14 @@ void* UserExtendBlockPool(void* region_base, size_t region_size, uint32_t RdmaRegisterMemory(void* buf, size_t size) { // Register the memory as callback in block_pool // The thread-safety should be guaranteed by the caller - ibv_mr* mr = IbvRegMr(g_pd, buf, size, IBV_ACCESS_LOCAL_WRITE); + ibv_mr* mr = IbvRegMr(g_pd, buf, size, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_RELAXED_ORDERING); if (!mr) { - PLOG(ERROR) << "Fail to register memory"; - return 0; + PLOG(WARNING) << "Do not support IBV_ACCESS_RELAXED_ORDERING for RDMA!!!"; + mr = IbvRegMr(g_pd, buf, size, IBV_ACCESS_LOCAL_WRITE); + if (!mr) { + PLOG(ERROR) << "Fail to register memory"; + return 0; + } } g_mrs->push_back(mr); return mr->lkey; @@ -550,6 +554,7 @@ static void GlobalRdmaInitializeOrDieImpl() { } // Initialize RDMA memory pool (block_pool) + butil::SetDefaultBlockSize(GetRdmaBlockSize()); if (!InitBlockPool(RdmaRegisterMemory)) { PLOG(ERROR) << "Fail to initialize RDMA memory pool"; ExitWithError(); @@ -594,10 +599,14 @@ void GlobalRdmaInitializeOrDie() { } uint32_t RegisterMemoryForRdma(void* buf, size_t len) { - ibv_mr* mr = IbvRegMr(g_pd, buf, len, IBV_ACCESS_LOCAL_WRITE); + ibv_mr* mr = IbvRegMr(g_pd, buf, len, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_RELAXED_ORDERING); if (!mr) { - PLOG(ERROR) << "Fail to register memory"; - return 0; + PLOG(WARNING) << "Do not support IBV_ACCESS_RELAXED_ORDERING for RDMA!!!"; + mr = IbvRegMr(g_pd, buf, len, IBV_ACCESS_LOCAL_WRITE); + if (!mr) { + PLOG(ERROR) << "Fail to register memory"; + return 0; + } } { BAIDU_SCOPED_LOCK(*g_user_mrs_lock); diff --git a/src/brpc/rdma_transport.h b/src/brpc/rdma_transport.h index 65ae88f7a6..d8520b1a6d 100644 --- a/src/brpc/rdma_transport.h +++ b/src/brpc/rdma_transport.h @@ -25,9 +25,10 @@ namespace brpc { class RdmaTransport : public Transport { - friend class TransportFactory; - friend class rdma::RdmaEndpoint; - friend class rdma::RdmaConnect; +friend class TransportFactory; +friend class rdma::RdmaEndpoint; +friend class rdma::RdmaConnect; +friend class rdma::RdmaHandshakeServerV2; public: void Init(Socket* socket, const SocketOptions& options) override; void Release() override; @@ -47,7 +48,7 @@ class RdmaTransport : public Transport { private: static bool OptionsAvailableForRdma(const ChannelOptions* opt); static bool OptionsAvailableOverRdma(const ServerOptions* opt); -private: + // The on/off state of RDMA enum RdmaState { RDMA_ON, diff --git a/src/brpc/selective_channel.cpp b/src/brpc/selective_channel.cpp index a59580e321..8dee422598 100644 --- a/src/brpc/selective_channel.cpp +++ b/src/brpc/selective_channel.cpp @@ -419,9 +419,13 @@ void Sender::Clear() { if (_main_cntl == NULL) { return; } - delete _alloc_resources[1].response; - delete _alloc_resources[1].sub_done; - _alloc_resources[1] = Resource(); + for (int i = 0; i < _nalloc; ++i) { + delete _alloc_resources[i].response; + if (_alloc_resources[i].sub_done != &_sub_done0) { + delete _alloc_resources[i].sub_done; + } + _alloc_resources[i] = Resource(); + } const CallId cid = _main_cntl->call_id(); _main_cntl = NULL; if (_user_done) { @@ -434,7 +438,7 @@ inline Resource Sender::PopFree() { if (_nfree == 0) { if (_nalloc == 0) { Resource r; - r.response = _response; + r.response = _response->New(); r.sub_done = &_sub_done0; _alloc_resources[_nalloc++] = r; return r; diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 3a5da7b771..57f665da91 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -754,7 +754,7 @@ static int get_port_from_fd(int fd) { bool Server::CreateConcurrencyLimiter(const AdaptiveMaxConcurrency& amc, ConcurrencyLimiter** out) { - if (amc.type() == AdaptiveMaxConcurrency::UNLIMITED) { + if (amc.type() == AdaptiveMaxConcurrency::UNLIMITED()) { *out = NULL; return true; } @@ -855,26 +855,37 @@ int Server::StartInternal(const butil::EndPoint& endpoint, return -1; } - copy_and_fill_server_options(_options, opt ? *opt : ServerOptions()); + // Validate the user-provided ServerOptions BEFORE + // copy_and_fill_server_options below. This is important: + // copy_and_fill_server_options unconditionally transfers ownership of + // user-provided pointers (nshead_service, thrift_service, ...) into + // _options. If we instead validated against _options after the copy, + // a failed Start() would leave fake/invalid pointers behind in + // _options, and the NEXT Start() would attempt to `delete` them via + // FREE_PTR_IF_NOT_REUSED, crashing (see RdmaTest.server_option_invalid). + const ServerOptions default_opt; + const ServerOptions& real_opt = opt ? *opt : default_opt; - if (!_options.h2_settings.IsValid(true/*log_error*/)) { + if (!real_opt.h2_settings.IsValid(true/*log_error*/)) { LOG(ERROR) << "Invalid h2_settings"; return -1; } - if (_options.bthread_tag < BTHREAD_TAG_DEFAULT || - _options.bthread_tag >= FLAGS_task_group_ntags) { - LOG(ERROR) << "Fail to set tag " << _options.bthread_tag + if (real_opt.bthread_tag < BTHREAD_TAG_DEFAULT || + real_opt.bthread_tag >= FLAGS_task_group_ntags) { + LOG(ERROR) << "Fail to set tag " << real_opt.bthread_tag << ", tag range is [" << BTHREAD_TAG_DEFAULT << ":" << FLAGS_task_group_ntags << ")"; return -1; } - int ret = TransportFactory::ContextInitOrDie(_options.socket_mode, true, &_options); + int ret = TransportFactory::ContextInitOrDie(real_opt.socket_mode, true, &real_opt); if (ret != 0) { LOG(ERROR) << "Fail to initialize transport context for server, ret=" << ret; return -1; } + copy_and_fill_server_options(_options, real_opt); + if (_options.http_master_service) { // Check requirements for http_master_service: // has "default_method" & request/response have no fields @@ -1075,7 +1086,7 @@ int Server::StartInternal(const butil::EndPoint& endpoint, it->second.status->SetConcurrencyLimiter(NULL); } else { const AdaptiveMaxConcurrency* amc = &it->second.max_concurrency; - if (amc->type() == AdaptiveMaxConcurrency::UNLIMITED) { + if (amc->type() == AdaptiveMaxConcurrency::UNLIMITED()) { amc = &_options.method_max_concurrency; } ConcurrencyLimiter* cl = NULL; diff --git a/src/brpc/server.h b/src/brpc/server.h index 9f69a83458..4fbe304fde 100644 --- a/src/brpc/server.h +++ b/src/brpc/server.h @@ -717,7 +717,7 @@ friend class Controller; int SetServiceMaxConcurrency(T* service) { if (NULL != service && NULL != service->_status) { const AdaptiveMaxConcurrency* amc = &service->_max_concurrency; - if (amc->type() == AdaptiveMaxConcurrency::UNLIMITED) { + if (amc->type() == AdaptiveMaxConcurrency::UNLIMITED()) { amc = &_options.method_max_concurrency; } ConcurrencyLimiter* cl = NULL; diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 005873e9b0..7228a0edf0 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -81,6 +81,13 @@ DEFINE_int32(socket_send_buffer_size, -1, DEFINE_int32(ssl_bio_buffer_size, 16*1024, "Set buffer size for SSL read/write"); +DEFINE_int32(ssl_handshake_timeout_ms, 5000, + "Max duration of one SSL handshake on a socket. Zero or negative " + "disables the limit and falls back to waiting forever, which can " + "leak ESTABLISHED sockets if the peer never finishes the TLS " + "handshake (e.g. server not actually listening with SSL)."); +BRPC_VALIDATE_GFLAG(ssl_handshake_timeout_ms, PassValidate); + DEFINE_int64(socket_max_unwritten_bytes, 64 * 1024 * 1024, "Max unwritten bytes in each socket, if the limit is reached," " Socket.Write fails with EOVERCROWDED"); @@ -1265,12 +1272,18 @@ int Socket::Connect(const timespec* abstime, // We need to do async connect (to manage the timeout by ourselves). CHECK_EQ(0, butil::make_non_blocking(sockfd)); if (!_device_name.empty()) { +#ifdef SO_BINDTODEVICE if (setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, _device_name.c_str(), _device_name.size()) < 0) { PLOG(ERROR) << "Fail to set SO_BINDTODEVICE of fd=" << sockfd << " to device_name=" << _device_name; return -1; } +#else + LOG(ERROR) << "SO_BINDTODEVICE (device_name=" << _device_name + << ") is not supported on this platform"; + return -1; +#endif } if (local_side().ip != butil::IP_ANY) { struct sockaddr_storage cli_addr; @@ -1547,8 +1560,7 @@ void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) { g_vars->channel_conn << 1; } if (s->_app_connect) { - s->_app_connect->StartConnect(req->get_socket(), - AfterAppConnected, req); + s->_app_connect->StartConnect(req->get_socket(), AfterAppConnected, req); } else { // Successfully created a connection AfterAppConnected(0, req); @@ -1956,9 +1968,23 @@ int Socket::SSLHandshake(int fd, bool server_mode) { _ssl_state = SSL_CONNECTING; + // Bound the handshake by a deadline; without it, a peer that completes + // the TCP handshake but never returns a TLS Hello (e.g. server not + // configured for SSL) would park this bthread on bthread_fd_wait + // forever. That bthread holds a Socket reference via WriteRequest, so + // the underlying fd would never be recycled and the connection would + // remain ESTABLISHED indefinitely. + const int handshake_timeout_ms = FLAGS_ssl_handshake_timeout_ms; + timespec abstime_storage; + const timespec* abstime = NULL; + if (handshake_timeout_ms > 0) { + abstime_storage = butil::milliseconds_from_now(handshake_timeout_ms); + abstime = &abstime_storage; + } + // Loop until SSL handshake has completed. For SSL_ERROR_WANT_READ/WRITE, - // we use bthread_fd_wait as polling mechanism instead of EventDispatcher - // as it may confuse the origin event processing code. + // we use bthread_fd_timedwait as polling mechanism instead of + // EventDispatcher as it may confuse the origin event processing code. while (true) { ERR_clear_error(); int rc = SSL_do_handshake(_ssl_session); @@ -2004,20 +2030,32 @@ int Socket::SSLHandshake(int fd, bool server_mode) { switch (ssl_error) { case SSL_ERROR_WANT_READ: #if defined(OS_LINUX) - if (bthread_fd_wait(fd, EPOLLIN) != 0) { + if (bthread_fd_timedwait(fd, EPOLLIN, abstime) != 0) { #elif defined(OS_MACOSX) - if (bthread_fd_wait(fd, EVFILT_READ) != 0) { + if (bthread_fd_timedwait(fd, EVFILT_READ, abstime) != 0) { #endif + if (errno == ETIMEDOUT) { + LOG(WARNING) << "SSL handshake timed out after " + << handshake_timeout_ms + << "ms while waiting for peer data on fd=" + << fd << " remote_side=" << _remote_side; + } return -1; } break; case SSL_ERROR_WANT_WRITE: #if defined(OS_LINUX) - if (bthread_fd_wait(fd, EPOLLOUT) != 0) { + if (bthread_fd_timedwait(fd, EPOLLOUT, abstime) != 0) { #elif defined(OS_MACOSX) - if (bthread_fd_wait(fd, EVFILT_WRITE) != 0) { + if (bthread_fd_timedwait(fd, EVFILT_WRITE, abstime) != 0) { #endif + if (errno == ETIMEDOUT) { + LOG(WARNING) << "SSL handshake timed out after " + << handshake_timeout_ms + << "ms while waiting to send on fd=" << fd + << " remote_side=" << _remote_side; + } return -1; } break; @@ -2100,7 +2138,14 @@ ssize_t Socket::DoRead(size_t size_hint) { default: { const unsigned long e = ERR_get_error(); if (nr == 0) { - // Socket EOF or SSL session EOF + if (ssl_error != SSL_ERROR_ZERO_RETURN) { + // Unexpected EOF without proper SSL shutdown (close_notify) + LOG(WARNING) << "Fail to read from ssl_fd=" << fd() + << ": unexpected ssl_error=" << ssl_error; + errno = ESSL; + return -1; + } + // Clean SSL shutdown (close_notify received) } else if (e != 0) { LOG(WARNING) << "Fail to read from ssl_fd=" << fd() << ": " << SSLError(e); diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 816fccdf27..74498d039a 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -56,8 +56,15 @@ class ChannelBalancer; namespace rdma { class RdmaEndpoint; class RdmaConnect; +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; +} +namespace ubring { + class UBShmEndpoint; + class UBConnect; } - class Socket; class AuthContext; class EventDispatcher; @@ -317,6 +324,13 @@ friend class policy::RtmpContext; friend class schan::ChannelBalancer; friend class rdma::RdmaEndpoint; friend class rdma::RdmaConnect; +friend class ubring::UBShmEndpoint; +friend class ubring::UBConnect; +friend class UBShmTransport; +friend class rdma::RdmaHandshakeClientV2; +friend class rdma::RdmaHandshakeServerV2; +friend class rdma::RdmaHandshakeClientV3; +friend class rdma::RdmaHandshakeServerV3; friend class HealthCheckTask; friend class OnAppHealthCheckDone; friend class HealthCheckManager; diff --git a/src/brpc/socket_mode.h b/src/brpc/socket_mode.h index b5d42be4aa..b4ac7dfbca 100644 --- a/src/brpc/socket_mode.h +++ b/src/brpc/socket_mode.h @@ -20,7 +20,8 @@ namespace brpc { enum SocketMode { SOCKET_MODE_TCP = 0, - SOCKET_MODE_RDMA = 1 + SOCKET_MODE_RDMA = 1, + SOCKET_MODE_UBRING = 2 }; } // namespace brpc #endif //BRPC_SOCKET_MODE_H \ No newline at end of file diff --git a/src/brpc/transport_factory.cpp b/src/brpc/transport_factory.cpp index b689e2edd2..36fdaaed05 100644 --- a/src/brpc/transport_factory.cpp +++ b/src/brpc/transport_factory.cpp @@ -18,6 +18,7 @@ #include "brpc/transport_factory.h" #include "brpc/tcp_transport.h" #include "brpc/rdma_transport.h" +#include "brpc/ubshm_transport.h" namespace brpc { int TransportFactory::ContextInitOrDie(SocketMode mode, bool serverOrNot, const void* _options) { @@ -28,6 +29,11 @@ int TransportFactory::ContextInitOrDie(SocketMode mode, bool serverOrNot, const else if (mode == SOCKET_MODE_RDMA) { return RdmaTransport::ContextInitOrDie(serverOrNot, _options); } +#endif +#if BRPC_WITH_UBRING + else if (mode == SOCKET_MODE_UBRING) { + return UBShmTransport::ContextInitOrDie(serverOrNot, _options); + } #endif else { LOG(ERROR) << "unknown transport type " << mode; @@ -43,6 +49,11 @@ std::unique_ptr TransportFactory::CreateTransport(SocketMode mode) { else if (mode == SOCKET_MODE_RDMA) { return std::unique_ptr(new RdmaTransport()); } +#endif +#if BRPC_WITH_UBRING + else if (mode == SOCKET_MODE_UBRING) { + return std::unique_ptr(new UBShmTransport()); + } #endif else { LOG(ERROR) << "socket_mode set error"; diff --git a/src/brpc/ubshm/common/common.h b/src/brpc/ubshm/common/common.h new file mode 100644 index 0000000000..80e7ad83c8 --- /dev/null +++ b/src/brpc/ubshm/common/common.h @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_COMMON_H +#define BRPC_COMMON_H +#include +#include +#include +#include +#include "butil/logging.h" + +#define LIKELY(x) __builtin_expect(!!(x), 1) +#define UNLIKELY(x) __builtin_expect(!!(x), 0) + +#ifndef UNREFERENCE_PARAM +#define UNREFERENCE_PARAM(x) ((void)(x)) +#endif + +#ifdef UT +#define STATIC +#define INLINE +#define UBRING_STATISTICS_PATH ROOT_PATH "/ubring/run" +#else +#define STATIC static +#define INLINE inline +#define UBRING_STATISTICS_PATH "/opt/ubring/run" +#endif + +#ifdef __cplusplus +#include +using AtomicInt = std::atomic; +using AtomicBool = std::atomic; +using AtomicUintFast64 = std::atomic; +using AtomicUintFast8 = std::atomic; +#define ATOMIC_INIT(var, value) var.store(value) +#define ATOMIC_STORE(var, value) var.store(value) +#define ATOMIC_LOAD(var) var.load() +#define ATOMIC_ADD(var, value) var.fetch_add(value) +#define ATOMIC_SUB(var, value) var.fetch_sub(value) +#define ATOMIC_COMPARE_EXCHANGE_STRONG(var, expected, desired) \ + var.compare_exchange_strong((expected), (desired)) +#else +#include +typedef atomic_int AtomicInt; +typedef atomic_bool AtomicBool; +typedef atomic_uint_fast64_t AtomicUintFast64; +typedef atomic_uint_fast8_t AtomicUintFast8; +#define ATOMIC_INIT(var, value) atomic_init(&(var), value) +#define ATOMIC_STORE(var, value) atomic_store(&(var), value) +#define ATOMIC_LOAD(var) atomic_load(&(var)) +#define ATOMIC_ADD(var, value) atomic_fetch_add(&(var), value) +#define ATOMIC_SUB(var, value) atomic_fetch_sub(&(var), value) +#define ATOMIC_COMPARE_EXCHANGE_STRONG(var, expected, desired) \ + atomic_compare_exchange_strong(&(var), &(expected), (desired)) +#endif + +#define ISB() __asm__ __volatile__("isb" ::: "memory") +#define DSB() __asm__ __volatile__("dsb sy" ::: "memory") + +#ifndef errno_t +typedef int errno_t; +#endif +#ifndef EOK +#define EOK 0 +#endif + +#define MAX_NODE_NUM 8 +#define IPV4_FIRST_BYTE_OFFSET 24 +#define COPY_ALIGNED_DATA_BYTES 64 + +#if defined(OS_MACOSX) +#define EPOLLIN 0x001 +#define EPOLLOUT 0x004 +#define EPOLLET 0x80000000 +#endif + +static inline int Copy64Byte(int8_t *dst, int8_t *src) { +#ifdef LS64 + asm volatile ( + "mov x12, %0\n" + "mov x13, %1\n" + "ldr x4, [x12]\n" + "ldr x5, [x12, #8]\n" + "ldr x6, [x12, #16]\n" + "ldr x7, [x12, #24]\n" + "ldr x8, [x12, #32]\n" + "ldr x9, [x12, #40]\n" + "ldr x10, [x12, #48]\n" + "ldr x11, [x12, #56]\n" + "ST64B x4, [x13]\n" + : + : "r" (src), "r" (dst) + : "memory", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13" + ); + return EOK; +#else + memcpy(dst, src, COPY_ALIGNED_DATA_BYTES); + return EOK; +#endif +} + +#define SEC_TO_NSEC 1000000000 +#define MSEC_TO_NSEC 1000000 +#define USEC_TO_NSEC 1000 +#define MSEC_TO_SEC 1000 +#define MAX_IP_PORT_STR_LEN 23 +#define DECIMAL_BASE 10 + +static inline uint64_t GetCurNanoSeconds(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + uint64_t timestamp = (uint64_t)ts.tv_sec * SEC_TO_NSEC + (uint64_t)ts.tv_nsec; + return timestamp; +} + +#define FREE_PTR(ptr) \ + do { \ + if ((ptr) != NULL) { \ + free(ptr); \ + (ptr) = NULL; \ + } \ + } while (0) + +typedef enum { + UBRING_OK = 0, + UBRING_ERR = -1, + UBRING_RETRY = -2, + UBRING_REENTRY = -3, + UBRING_ERR_TIMEOUT = -4, + SHM_ERR = -100, + SHM_ERR_INPUT_INVALID = -101, + SHM_ERR_EXIST = -102, + SHM_ERR_RESOURCE_ATTACHED = -103, + SHM_ERR_NOT_FOUND = -104, + SHM_ERR_UBSM_NET_ERR = -105, + MPA_UDP_ERR = -200, + MPA_UDP_NO_TRX = -201, + MPA_UDP_STATUS_NOT_JOINED = -202, + MPA_MUXER_NOT_READY = -203, + MPA_PORT_FULL = -204, + MPA_PORT_OUTRANGE = -205, + MPA_PORT_TAKEN = -206, + MPA_UDP_STATUS_NOT_CONNECTED = -207, + MPA_UDP_STATUS_ALREADY_CONNECTED = -208, + MPA_UDP_OLD_RDLIST = -209, + MPA_UDP_RDLIST_FULL = -210, + UBR_NOT_CONNECTED = -300, + UBR_ERR_ADDR_IN_USE = -301, +} RETURN_CODE; + +#define ALIGN_BYTES 0x40 +#define CHECKED_ALIGN_BITS (ALIGN_BYTES - 1) + +static inline size_t Aligned64Offset(uint8_t *addr) { + return ((ALIGN_BYTES - (((size_t)(addr)) & CHECKED_ALIGN_BITS)) & CHECKED_ALIGN_BITS); +} + +static inline RETURN_CODE HasTimedOut(const uint64_t startTime, const uint32_t timeout) { + uint64_t endTime = startTime + (uint64_t)timeout * SEC_TO_NSEC; + if (GetCurNanoSeconds() > endTime) { + LOG(ERROR) << "task time out " << timeout << " seconds."; + return UBRING_ERR; + } + return UBRING_OK; +} + +#endif // BRPC_COMMON_H \ No newline at end of file diff --git a/src/brpc/ubshm/common/thread_lock.h b/src/brpc/ubshm/common/thread_lock.h new file mode 100644 index 0000000000..8c07ce360d --- /dev/null +++ b/src/brpc/ubshm/common/thread_lock.h @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_THREAD_LOCK_H +#define BRPC_THREAD_LOCK_H +#include +#include +#include +#include +#include +#include "brpc/ubshm/common/common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline void UnlockMutex(pthread_mutex_t **mtx) +{ + if (LIKELY(mtx != NULL && *mtx != NULL)) { + pthread_mutex_unlock(*mtx); + } else { + LOG(ERROR) << "Invalid input for mtx."; + } +} + +#define LOCK_GUARD(mtxPtr) \ + pthread_mutex_t *__attribute__((cleanup(UnlockMutex))) _mtxPtr = ({ \ + pthread_mutex_lock(&(mtxPtr)); \ + &(mtxPtr); \ + }) + +static inline void UnlockSpinLock(pthread_spinlock_t **spinLock) +{ + if (LIKELY(spinLock != NULL && *spinLock != NULL)) { + pthread_spin_unlock(*spinLock); + } else { + LOG(ERROR) << "Invalid input for spinLock."; + } +} + +#define SPIN_LOCK_GUARD(spinLockPtr) \ + pthread_spinlock_t *__attribute__((cleanup(UnlockSpinLock))) _spinLockPtr = ({ \ + pthread_spin_lock(&(spinLockPtr)); \ + &(spinLockPtr); \ + }) + +static inline void UnlockRWLock(pthread_rwlock_t **rwLock) +{ + if (LIKELY(rwLock != NULL && *rwLock != NULL)) { + pthread_rwlock_unlock(*rwLock); + } else { + LOG(ERROR) << "Invalid input for rwLock."; + } +} + +#define R_LOCK_GUARD(readLockPtr) \ + pthread_rwlock_t *__attribute__((cleanup(UnlockRWLock))) _readLockPtr = ({ \ + pthread_rwlock_rdlock(&(readLockPtr)); \ + &(readLockPtr); \ + }) + +#define W_LOCK_GUARD(writeLockPtr) \ + pthread_rwlock_t *__attribute__((cleanup(UnlockRWLock))) _writeLockPtr = ({ \ + pthread_rwlock_wrlock(&(writeLockPtr)); \ + &(writeLockPtr); \ + }) + +static inline void PostSemWithClose(sem_t **sem) +{ + if (LIKELY(sem != NULL && *sem != NULL)) { + sem_post(*sem); + sem_close(*sem); + *sem = NULL; + sem = NULL; + } else { + LOG(ERROR) << "Invalid input for semaphore."; + } +} + +static inline void PostSem(sem_t **sem) +{ + if (LIKELY(sem != NULL && *sem != NULL)) { + sem_post(*sem); + } else { + LOG(ERROR) << "Invalid input for semaphore."; + } +} + +#define SEMAPHORE_WAIT_GUARD_WITH_CLOSE(semPtr) \ + sem_t *__attribute__((cleanup(PostSemWithClose))) _semPtr = ({ \ + sem_wait(semPtr); \ + semPtr; \ + }) + +#define SEMAPHORE_WAIT_GUARD(semPtr) \ + sem_t *__attribute__((cleanup(PostSem))) _semPtr = ({ \ + sem_wait(semPtr); \ + semPtr; \ + }) + +#ifdef __cplusplus +} +#endif +#endif //BRPC_THREAD_LOCK_H \ No newline at end of file diff --git a/src/brpc/ubshm/shm/shm_def.h b/src/brpc/ubshm/shm/shm_def.h new file mode 100644 index 0000000000..0c28084b96 --- /dev/null +++ b/src/brpc/ubshm/shm/shm_def.h @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_SHM_DEF_H +#define BRPC_SHM_DEF_H +#include +#include +#include + +#define PROT_READ 0x1 /* Page can be read. */ +#define PROT_WRITE 0x2 /* Page can be written. */ +#define PROT_EXEC 0x4 /* Page can be executed. */ +#define PROT_NONE 0x0 /* Page can not be accessed. */ +#define PROT_GROWSDOWN 0x01000000 /* Extend change to start of growsdown vma (mprotect only). */ +#define PROT_GROWSUP 0x02000000 /* Extend change to start of growsup vma (mprotect only). */ +/* Sharing types (must choose one and only one of these). */ +#define MAP_SHARED 0x01 /* Share changes. */ +#define MAP_PRIVATE 0x02 /* Changes are private. */ +#define SHM_MAX_NAME_BUFF_LEN 48 // byte, buffer size, ubsm_sdk need name to be below 48byte +#define SHM_MAX_NAME_LEN (SHM_MAX_NAME_BUFF_LEN - 1) // byte, string length +#define SHM_ALLOC_UNIT_SIZE (4 * 1024 * 1024) // 4MB + +namespace brpc { +namespace ubring { +typedef enum { SHM_TYPE_UB, SHM_TYPE_IPC, SHM_TYPE_UBS, SHM_TYPE_UNSUPPORT } SHM_TYPE; + +typedef struct { + uint8_t *addr; + size_t len; + uint64_t memid; + char name[SHM_MAX_NAME_BUFF_LEN]; + uint32_t fd; +} SHM; + +typedef struct ShmListNode { + SHM shm; + struct ShmListNode *next; + struct ShmListNode *prev; +} ShmListNode; + +typedef struct { + ShmListNode* head; + ShmListNode* tail; + size_t size; + pthread_mutex_t shmLock; +} ShmList; +} +} +#endif //BRPC_SHM_DEF_H \ No newline at end of file diff --git a/src/brpc/ubshm/shm/shm_ipc.cpp b/src/brpc/ubshm/shm/shm_ipc.cpp new file mode 100644 index 0000000000..a63e9cdd7c --- /dev/null +++ b/src/brpc/ubshm/shm/shm_ipc.cpp @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/shm/shm_def.h" +#include "brpc/ubshm/shm/shm_ipc.h" + +namespace brpc { +namespace ubring { +namespace { + +RETURN_CODE ReserveIpcShm(int fd, const SHM *shm) +{ +#if defined(__linux__) + const int rc = posix_fallocate(fd, 0, (off_t)shm->len); + if (rc != 0) { + LOG(ERROR) << "IPC reserve shm=" << shm->name << " length=" << shm->len + << " failed, ret(" << rc << ")."; + return SHM_ERR; + } +#else + UNREFERENCE_PARAM(fd); + UNREFERENCE_PARAM(shm); +#endif + return UBRING_OK; +} + +RETURN_CODE CheckIpcShmSize(int fd, const SHM *shm) +{ + struct stat st; + if (fstat(fd, &st) != 0) { + LOG(ERROR) << "IPC stat shm=" << shm->name << " failed, ret(" << errno << ")."; + return SHM_ERR; + } + if ((uint64_t)st.st_size < (uint64_t)shm->len) { + LOG(ERROR) << "IPC shm=" << shm->name << " actual length=" << st.st_size + << " is shorter than requested length=" << shm->len << "."; + return SHM_ERR; + } + return UBRING_OK; +} + +} // namespace + +RETURN_CODE IpcShmLocalMalloc(SHM *shm) +{ + int fd = shm_open(shm->name, O_CREAT | O_EXCL | O_RDWR, SHM_IPC_MODE); + if (fd < 0) { + if (errno == EEXIST) { + LOG(ERROR) << "IPC Create shm=" << shm->name << " failed, shm exists."; + return SHM_ERR_EXIST; + } + + LOG(ERROR) << "IPC Open shm=" << shm->name << " failed, ret(" << errno << ")."; + return SHM_ERR; + } + + int ret = ftruncate(fd, (off_t)shm->len); + if (ret < 0) { + LOG(ERROR) << "IPC Set shm=" << shm->name << " length=" << shm->len << " failed, ret(" << errno << ")."; + close(fd); + shm_unlink(shm->name); + return SHM_ERR; + } + + if (ReserveIpcShm(fd, shm) != UBRING_OK) { + close(fd); + shm_unlink(shm->name); + return SHM_ERR; + } + + shm->addr = (uint8_t*)mmap(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (shm->addr == (uint8_t*)MAP_FAILED) { + LOG(ERROR) << "IPC map shm=" << shm->name << " length=" << shm->len << " failed, ret(" << errno << ")."; + shm->addr = NULL; + close(fd); + shm_unlink(shm->name); + return SHM_ERR; + } + + close(fd); + return UBRING_OK; +} + +RETURN_CODE IpcShmMunmap(SHM *shm) +{ + if (shm->addr == NULL) { + LOG(INFO) << "IPC unmap shm=" << shm->name << " already unmapped."; + return UBRING_OK; + } + + int ret = munmap(shm->addr, shm->len); + if (ret != UBRING_OK) { + LOG(ERROR) << "IPC unmap shm=" << shm->name << " failed, errno=" << errno; + return SHM_ERR; + } + + shm->addr = NULL; + LOG(INFO) << "IPC unmap shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE IpcShmFree(SHM *shm) +{ + // free + int ret = shm_unlink(shm->name); + if (ret != UBRING_OK) { + if (errno == EBUSY) { + LOG_EVERY_SECOND(ERROR) << "IPC free shm=" << shm->name << " failed, errno=" << errno; + return SHM_ERR_RESOURCE_ATTACHED; + } + if (errno == ENOENT) { + LOG(INFO) << "IPC free shm=" << shm->name << " already deleted."; + shm->addr = NULL; + return SHM_ERR_NOT_FOUND; + } + LOG_EVERY_SECOND(ERROR) << "IPC free shm=" << shm->name << " failed, errno=" << errno; + return SHM_ERR; + } + return UBRING_OK; +} + +RETURN_CODE IpcShmLocalFree(SHM *shm) +{ + if (shm->addr == NULL) { + LOG(INFO) << "IPC free local shm=" << shm->name << " already freed."; + return SHM_ERR_NOT_FOUND; + } + + int ret = munmap(shm->addr, shm->len); + if (ret != UBRING_OK) { + LOG(WARNING) << "IPC unmap shm=" << shm->name << " failed, ret=" << ret; + } else { + shm->addr = NULL; + } + + ret = shm_unlink(shm->name); + if (ret != UBRING_OK) { + if (errno == EBUSY) { + LOG_EVERY_SECOND(ERROR) << "IPC delete shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR_RESOURCE_ATTACHED; + } + if (errno == ENOENT) { + LOG(INFO) << "IPC delete shm=" << shm->name << " already deleted by peer."; + shm->addr = NULL; + return SHM_ERR_NOT_FOUND; + } + LOG_EVERY_SECOND(ERROR) << "IPC delete shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + shm->addr = NULL; + LOG(INFO) << "IPC free local shm=" << shm->name << " success."; + return UBRING_OK; +} + +RETURN_CODE IpcShmRemoteMalloc(SHM *shm) +{ + int fd = shm_open(shm->name, O_RDWR, SHM_IPC_MODE); + if (fd < 0) { + LOG(ERROR) << "IPC open shm=" << shm->name << " failed, ret=" << errno; + return SHM_ERR; + } + + if (CheckIpcShmSize(fd, shm) != UBRING_OK) { + close(fd); + return SHM_ERR; + } + + shm->addr = (uint8_t*)mmap(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (shm->addr == (uint8_t*)MAP_FAILED) { + LOG(ERROR) << "IPC map shm=" << shm->name << " failed, ret=" << errno; + shm->addr = NULL; + close(fd); + return SHM_ERR; + } + + close(fd); + return UBRING_OK; +} + +RETURN_CODE IpcShmLocalMmap(SHM *shm, int prot) +{ + int fd = shm_open(shm->name, O_RDWR, SHM_IPC_MODE); + if (fd < 0) { + LOG(ERROR) << "IPC open shm=" << shm->name << " failed, ret=" << errno; + return SHM_ERR; + } + + if (CheckIpcShmSize(fd, shm) != UBRING_OK) { + close(fd); + return SHM_ERR; + } + + shm->addr = (uint8_t*)mmap(NULL, shm->len, prot, MAP_SHARED, fd, 0); + if (shm->addr == (uint8_t*)MAP_FAILED) { + LOG(ERROR) << "IPC map shm=" << shm->name << " failed, ret=" << errno; + shm->addr = NULL; + close(fd); + return SHM_ERR; + } + + close(fd); + LOG(INFO) << "IPC mmap remote shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE IpcShmRemoteFree(SHM *shm) +{ + if (shm->addr == NULL) { + LOG(INFO) << "IPC free remote shm=" << shm->name << " already freed."; + return UBRING_OK; + } + + int ret = munmap(shm->addr, shm->len); + if (ret != UBRING_OK) { + LOG(ERROR) << "IPC unmap shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + + shm->addr = NULL; + LOG(INFO) << "IPC free remote shm=" << shm->name << " success."; + return UBRING_OK; +} +} +} diff --git a/src/brpc/ubshm/shm/shm_ipc.h b/src/brpc/ubshm/shm/shm_ipc.h new file mode 100644 index 0000000000..34e8307bb8 --- /dev/null +++ b/src/brpc/ubshm/shm/shm_ipc.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_SHM_IPC_H +#define BRPC_SHM_IPC_H + +#include "shm_def.h" + +#define SHM_IPC_MODE 0666 + +namespace brpc { +namespace ubring { + RETURN_CODE IpcShmLocalMalloc(SHM *shm); + RETURN_CODE IpcShmMunmap(SHM *shm); + RETURN_CODE IpcShmFree(SHM *shm); + RETURN_CODE IpcShmLocalFree(SHM *shm); + RETURN_CODE IpcShmRemoteMalloc(SHM *shm); + RETURN_CODE IpcShmRemoteFree(SHM *shm); + RETURN_CODE IpcShmLocalMmap(SHM *shm, int prot); +} +} + +#endif //BRPC_SHM_IPC_H \ No newline at end of file diff --git a/src/brpc/ubshm/shm/shm_mgr.cpp b/src/brpc/ubshm/shm/shm_mgr.cpp new file mode 100644 index 0000000000..39a54dd1df --- /dev/null +++ b/src/brpc/ubshm/shm/shm_mgr.cpp @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/shm/shm_ipc.h" +#include "brpc/ubshm/shm/shm_ubs.h" +#include "brpc/ubshm/shm/shm_mgr.h" + +namespace brpc { +namespace ubring { +DEFINE_int32(ub_shm_type, 1, "shm type: 1-ipc; 2-ub_ring"); +static SHM_TYPE g_shmType; + +static bool CheckInputShmParam(SHM *shm) { + if (shm == NULL) { + LOG(ERROR) << "Input Param shm is NULL."; + return false; + } + + size_t nameLen = strlen(shm->name); + if (nameLen <= 0 || nameLen > SHM_MAX_NAME_LEN) { + LOG(ERROR) << "Shm name=" << shm->name << ", length=" << shm->len + << ", which is not between 1 and " << SHM_MAX_NAME_LEN; + return false; + } + + if (shm->len <= 0) { + LOG(ERROR) << "Shm length=" << shm->len << " is invalid."; + return false; + } + + if (shm->len < SHM_ALLOC_UNIT_SIZE || (shm->len & (SHM_ALLOC_UNIT_SIZE - 1)) != 0) { + LOG(ERROR) << "Shm length=" << shm->len << " need to be (1..n) * 4MB."; + return false; + } + + return true; +} + +RETURN_CODE ShmMgrInit(void) { + if (UNLIKELY(FLAGS_ub_shm_type >= (int32_t)SHM_TYPE_UNSUPPORT || FLAGS_ub_shm_type <= (int32_t)SHM_TYPE_UB)) { + LOG(ERROR) << "Shm type config=" << FLAGS_ub_shm_type << " is not supported."; + return UBRING_ERR; + } + + g_shmType = (SHM_TYPE)FLAGS_ub_shm_type; + if (g_shmType == SHM_TYPE_UBS) { + if (UbsShmInit() != UBRING_OK) { + LOG(ERROR) << "Init beiming ubs shm failed."; + return UBRING_ERR; + } + } + LOG(INFO) << "shm mgr init success, shm type=" << g_shmType; + return UBRING_OK; +} + +void ShmMgrFini(void) { + if (g_shmType == SHM_TYPE_UBS) { + if (UbsShmFini() != UBRING_OK) { + LOG(ERROR) << "Fini beiming ubs shm failed."; + return; + } + } + LOG(INFO) << "shm mgr fini success, shm type=" << g_shmType; +} + +void SetShmType(SHM_TYPE type) { + g_shmType = type; +} + +RETURN_CODE ShmLocalMalloc(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmLocalMalloc(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmLocalMalloc(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmLocalCalloc(SHM *shm) { + RETURN_CODE rc = ShmLocalMalloc(shm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Failed to alloc local shm."; + return rc; + } + if (UNLIKELY(shm->addr == NULL)) { + LOG(ERROR) << "Local shm=" << shm->name << " allocated with NULL address."; + ShmFree(shm); + return SHM_ERR; + } + memset(shm->addr, 0, shm->len); + return UBRING_OK; +} + +RETURN_CODE ShmLocalFree(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmLocalFree(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmLocalFree(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmRemoteMalloc(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmRemoteMalloc(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmRemoteMalloc(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmRemoteFree(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmRemoteFree(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmRemoteFree(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmLocalMmap(SHM *shm, int prot) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmLocalMmap(shm, prot); + break; + case SHM_TYPE_UBS: + rc = UbsShmLocalMmap(shm, prot); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmMunmap(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmMunmap(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmMunmap(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} + +RETURN_CODE ShmFree(SHM *shm) { + if (UNLIKELY(!CheckInputShmParam(shm))) { + LOG(ERROR) << "Input param shm is invalid."; + return SHM_ERR_INPUT_INVALID; + } + + RETURN_CODE rc = UBRING_OK; + switch (g_shmType) { + case SHM_TYPE_IPC: + rc = IpcShmFree(shm); + break; + case SHM_TYPE_UBS: + rc = UbsShmFree(shm); + break; + default: + rc = SHM_ERR; + LOG(ERROR) << "Unsupported shm type."; + } + return rc; +} +} +} diff --git a/src/brpc/ubshm/shm/shm_mgr.h b/src/brpc/ubshm/shm/shm_mgr.h new file mode 100644 index 0000000000..597f5e4ba5 --- /dev/null +++ b/src/brpc/ubshm/shm/shm_mgr.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_SHM_MGR_H +#define BRPC_SHM_MGR_H + +#include +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/shm/shm_def.h" + +namespace brpc { +namespace ubring { +void SetShmType(SHM_TYPE type); + +RETURN_CODE ShmMgrInit(void); + +void ShmMgrFini(void); + +RETURN_CODE ShmLocalMalloc(SHM *shm); + +RETURN_CODE ShmLocalCalloc(SHM *shm); + +RETURN_CODE ShmLocalFree(SHM *shm); + +RETURN_CODE ShmRemoteMalloc(SHM *shm); + +RETURN_CODE ShmRemoteFree(SHM *shm); + +RETURN_CODE ShmLocalMmap(SHM *shm, int prot); + +RETURN_CODE ShmMunmap(SHM *shm); + +RETURN_CODE ShmFree(SHM *shm); +} +} + +#endif //BRPC_SHM_MGR_H \ No newline at end of file diff --git a/src/brpc/ubshm/shm/shm_ubs.cpp b/src/brpc/ubshm/shm/shm_ubs.cpp new file mode 100644 index 0000000000..537d8e91aa --- /dev/null +++ b/src/brpc/ubshm/shm/shm_ubs.cpp @@ -0,0 +1,565 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include "brpc/ubshm/timer/timer_mgr.h" +#include "brpc/ubshm/common/thread_lock.h" +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/shm/shm_def.h" +#include "brpc/ubshm/ub_ring_manager.h" +#include "brpc/ubshm/ubs_mem/ubs_mem.h" +#include "brpc/ubshm/ubs_mem/ubs_mem_def.h" +#ifdef UT +#include "ubs_mem.h" +#endif +#include "shm_ubs.h" + +namespace brpc { +namespace ubring { +#define UBRING_MK_UBSM(ret, fn, args) ret (*fn) args = NULL +#include "brpc/ubshm/ubs_mem/declare_shm_ubs.h" +#define SHM_RIGHT_MODE 0666 +#define UBRING_REGION_NAME_PREFIX "UbrONE2ALLRegion" +DEFINE_uint32(node_location, 1, "Location of the ub machine."); +DEFINE_bool(shm_wr_delay_comp, true, "Indicates whether to enable the write relay." + "0: relay; 1: non-relay."); +DEFINE_int32(ub_flying_io_timeout, 5, "Waiting time for stopping data" + "sending and receiving when the link is disconnected."); +char g_regionName[MAX_REGION_NAME_DESC_LENGTH] = {0}; +int g_shmTimerFd = 0; +ShmList *g_shmList = NULL; +static RETURN_CODE UbsShmInterfacesLoad(void); +char hostname[MAX_HOST_NAME_DESC_LENGTH]; + +RETURN_CODE UbsShmInterfacesLoad(void) +{ +#ifndef UT + const char *ubsmSdkLocation = "/usr/local/ubs_mem/lib/libubsm_sdk.so"; +#if defined(OS_LINUX) + void* dlhandler = dlmopen(LM_ID_NEWLM, ubsmSdkLocation, RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE | RTLD_DEEPBIND); +#elif defined(OS_MACOSX) + void* dlhandler = dlopen(ubsmSdkLocation, RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE); +#endif + if (dlhandler == NULL) { + LOG(ERROR) << "Dlopen libubsm_sdk.so in " << ubsmSdkLocation << " failed, error:" << dlerror(); + return UBRING_ERR; + } + +#define UBRING_MK_UBSM_OPTIONAL(ret, fn, args) \ + do { \ + fn = (decltype(fn))dlsym(dlhandler, #fn); \ + } while (0) + +#define UBRING_MK_UBSM(ret, fn, args) \ + do { \ + if ((fn) != NULL) { \ + break; \ + } \ + UBRING_MK_UBSM_OPTIONAL(ret, fn, args); \ + if ((fn) == NULL) { \ + LOG(ERROR) << "Fail load ubs_mem func " << #fn <<" error:" << dlerror(); \ + return UBRING_ERR; \ + } \ + } while (0) +#include "brpc/ubshm/ubs_mem/declare_shm_ubs.h" + + dlclose(dlhandler); + dlhandler = NULL; +#endif + return UBRING_OK; +} + +static RETURN_CODE CreateUbsShmRegion(const char *regionName) +{ + int ret = snprintf(g_regionName, MAX_REGION_NAME_DESC_LENGTH, "%s_%u", + UBRING_REGION_NAME_PREFIX, FLAGS_node_location); + if (ret < 0) { + LOG(ERROR) << "Snprintf_s region name failed, ret=" << ret; + return UBRING_ERR; + } + + ubsmem_regions_t regions = {0}; // 16 * (48 + 1) bytes, 约0.8k + ret = ubsmem_lookup_regions(®ions); + if (ret != UBSM_OK || regions.region[0].host_num <= 0) { + LOG(ERROR) << "Ubs lookup share region failed, ret=" << ret << ", region.num=" << regions.region[0].host_num; + return UBRING_ERR; + } + ubsmem_region_attributes_t regionAttr = {0}; + regionAttr.host_num = regions.region[0].host_num; + for (int i = 0; i < regionAttr.host_num; i++) { + strcpy(regionAttr.hosts[i].host_name, regions.region[0].hosts[i].host_name); + regionAttr.hosts[i].affinity = (strcmp(regionAttr.hosts[i].host_name, hostname) == 0) ? + true : false; + } + + ret = ubsmem_create_region(regionName, 0, ®ionAttr); + if (ret == UBSM_ERR_ALREADY_EXIST) { + LOG(WARNING) << "Ubs region exists, region_name=" << regionName; + return UBRING_OK; + } else if (ret != UBSM_OK) { + LOG(ERROR) << "Ubsmem create region failed, ret=" << ret; + return UBRING_ERR; + } + + return UBRING_OK; +} + +static uint64_t AquireFlagIfWrDelayComp(const uint64_t flag) +{ + if (FLAGS_shm_wr_delay_comp == 0) { + return flag; + } + return flag | UBSM_FLAG_WR_DELAY_COMP; +} + +RETURN_CODE UbsShmLocalMalloc(SHM *shm) +{ + int ret = ubsmem_shmem_allocate(g_regionName, shm->name, shm->len, SHM_RIGHT_MODE, + AquireFlagIfWrDelayComp(UBSM_FLAG_ONLY_IMPORT_NONCACHE | UBSM_FLAG_MEM_ANONYMOUS)); +do { + if (ret == UBSM_ERR_ALREADY_EXIST) { + if (ubsmem_shmem_deallocate(shm->name) != UBSM_OK) { + LOG(ERROR) << "Ubs create shm name=" << shm->name << " failed, shm exists, ret=" << ret; + return SHM_ERR_EXIST; + } + LOG(INFO) << "Ubs delete shm name=" << shm->name << " success, try to recreate."; + ret = ubsmem_shmem_allocate(g_regionName, shm->name, shm->len, SHM_RIGHT_MODE, + AquireFlagIfWrDelayComp(UBSM_FLAG_ONLY_IMPORT_NONCACHE | UBSM_FLAG_MEM_ANONYMOUS)); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs recreate shm name=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + } else if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs create shm name=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } +} while (0); + + ret = ubsmem_shmem_map(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, shm->name, 0, (void**)&(shm->addr)); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs map shm=" << shm->name << " failed, ret=" << ret; + if (ret == UBSM_ERR_NOT_FOUND) { + return SHM_ERR_NOT_FOUND; + } + ubsmem_shmem_deallocate(shm->name); + return SHM_ERR; + } + + // 通过MXE获取memid + shm->memid = 1; // 暂时打桩 + LOG(INFO) << "Ubs malloc local shm=" << shm->name << " length=" << shm->len << " memid=" << shm->memid << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmMunmap(SHM *shm) +{ + // unmap + if (shm->addr == NULL) { + LOG(ERROR) << "Ubs input shm param is invalid, addr is NULL."; + return SHM_ERR_INPUT_INVALID; + } + + int ret = ubsmem_shmem_unmap(shm->addr, shm->len); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_NET) { + LOG(ERROR) << "Ubs unmap shm=" << shm->name << " failed, ubsm net err=" << ret; + AddShmToList(g_shmList, shm); + return SHM_ERR_UBSM_NET_ERR; + } + LOG(ERROR) << "Ubs unmap shm=" << shm->name << " length=" << shm->len << " failed, ret=" << ret; + return SHM_ERR; + } + + LOG(INFO) << "Ubs unmap shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmFree(SHM *shm) +{ + if (shm->addr == NULL) { + LOG(ERROR) << "Ubs input shm param is invalid, addr is NULL."; + return SHM_ERR_INPUT_INVALID; + } + + // free + int ret = ubsmem_shmem_deallocate(shm->name); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_IN_USING) { + LOG(INFO) << "Ubs free shm=" << shm->name << " failed, resource attached=" << ret; + return SHM_ERR_RESOURCE_ATTACHED; + } else if (ret == UBSM_ERR_NOT_FOUND) { + LOG(INFO) << "Ubs free shm=" << shm->name << " failed, resource not found=" << ret; + return SHM_ERR_NOT_FOUND; + } + LOG(ERROR) << "Ubs free shm="<< shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + shm->addr = NULL; + LOG(INFO) << "Ubs free shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmLocalFree(SHM *shm) +{ + // unmap + if (shm->addr == NULL) { + LOG(ERROR) << "Ubs input shm param is invalid, addr is NULL."; + return SHM_ERR_INPUT_INVALID; + } + + int ret = ubsmem_shmem_unmap(shm->addr, shm->len); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_NET) { + LOG(ERROR) << "Ubs unmap shm=" << shm->name << " failed, ubsm net err=" << ret; + AddShmToList(g_shmList, shm); + return SHM_ERR_UBSM_NET_ERR; + } + LOG(WARNING) << "Ubs unmap shm=" << shm->name << " length=" << shm->len << " failed, ret=" << ret; + } + + // free + ret = ubsmem_shmem_deallocate(shm->name); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_IN_USING) { + LOG_EVERY_SECOND(INFO) << "Ubs delete shm=" << shm->name << " failed, resource attached=" << ret; + return SHM_ERR_RESOURCE_ATTACHED; + } + LOG(ERROR) << "Ubs delete shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + shm->addr = NULL; + LOG(INFO) << "Ubs free local shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmRemoteMalloc(SHM *shm) +{ + int ret = ubsmem_shmem_map(NULL, shm->len, PROT_READ | PROT_WRITE, MAP_SHARED, shm->name, 0, (void**)&(shm->addr)); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs map Shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + + LOG(INFO) << "Ubs malloc remote shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmLocalMmap(SHM *shm, int prot) +{ + int ret = ubsmem_shmem_map(NULL, shm->len, prot, MAP_SHARED, shm->name, 0, (void**)&(shm->addr)); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs map Shm=" << shm->name << " failed, ret=" << ret; + return SHM_ERR; + } + + LOG(INFO) << "Ubs mmap remote shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmRemoteFree(SHM *shm) +{ + // unmap + if (shm->addr == NULL) { + LOG(ERROR) << "Ubs input shm param is invalid, addr is NULL."; + return SHM_ERR_INPUT_INVALID; + } + + int ret = ubsmem_shmem_unmap(shm->addr, shm->len); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_NET) { + LOG(ERROR) << "Ubs unmap shm=" << shm->name << " failed, ubsm net err=" << ret; + AddShmToList(g_shmList, shm); + return SHM_ERR_UBSM_NET_ERR; + } + LOG(ERROR) << "Ubs unmap shm=" << shm->name << " length=" << shm->len << " failed, ret=" << ret; + return SHM_ERR; + } + + LOG(INFO) << "Ubs free Remote shm=" << shm->name << " length=" << shm->len << " success."; + return UBRING_OK; +} + +void UbsMemLoggerPrint(int level, const char *msg) +{ + if (level == UBSM_LOG_ERROR_LEVEL) { + LOG(ERROR) << msg; + } else if (level == UBSM_LOG_WARN_LEVEL) { + LOG(WARNING) << msg; + } else { + LOG(INFO) << msg; + } + return; +} + +RETURN_CODE UbsShmInit(void) +{ + // 加载libubsm_sdk.so函数指针 + RETURN_CODE retCode = UbsShmInterfacesLoad(); + if (retCode != UBRING_OK) { + LOG(ERROR) << "Load ubs shm functions failed, ret=" << retCode; + return UBRING_ERR; + } + + if (gethostname(hostname, MAX_HOST_NAME_DESC_LENGTH) != 0) { + LOG(ERROR) << "ubring config gethostname failed, errno=" << errno; + return UBRING_ERR; + } + + int ret = ubsmem_set_extern_logger(UbsMemLoggerPrint); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs set logger failed, ret=" << ret; + return UBRING_ERR; + } + + ret = ubsmem_set_logger_level(UBSM_LOG_INFO_LEVEL); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs set logger level failed, ret=" << ret; + return UBRING_ERR; + } + + ubsmem_options_t options = {}; + ret = ubsmem_init_attributes(&options); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs shm init attributes failed, ret=" << ret; + return UBRING_ERR; + } + + ret = ubsmem_initialize(&options); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs shm initialize failed, ret=" << ret; + return UBRING_ERR; + } + + if (UNLIKELY(ubsmem_local_nid_query(&FLAGS_node_location) != UBSM_OK)) { + LOG(ERROR) << "Get local nid failed."; + return UBRING_ERR; + } + + if (UNLIKELY(ubsmem_shmem_faults_register(brpc::ubring::UBRingManager::UbEventCallback) != UBSM_OK)) { + LOG(ERROR) << "Failed to register the ub event callback function."; + return UBRING_ERR; + } + + if (CreateUbsShmRegion(g_regionName) != UBRING_OK) { + LOG(ERROR) << "Create Ubs region failed."; + return UBRING_ERR; + } + + if (InitShmTimer(&g_shmList) != UBRING_OK) { + LOG(ERROR) << "Ubs shm list init failed."; + return UBRING_ERR; + } + + LOG(INFO) << "Ubs shm init success."; + return UBRING_OK; +} + +RETURN_CODE UbsShmFini(void) +{ + int ret = ubsmem_finalize(); + if (ret != UBSM_OK) { + LOG(ERROR) << "Ubs shm finalize fail, ret=" << ret; + return UBRING_ERR; + } + + if (UNLIKELY(DestroyShmTimer(g_shmList) != UBRING_OK)) { + LOG(ERROR) << "Ubs shm list finalize failed."; + return UBRING_ERR; + } + + LOG(INFO) << "Ubs shm finalize success."; + return UBRING_OK; +} + +static void DeleteShmToList(ShmList* shmList) +{ + if (shmList == NULL || shmList->head == NULL) { + return; + } + + ShmListNode *curNode = shmList->head; + shmList->head = curNode->next; + if (shmList->head != NULL) { + shmList->head->prev = NULL; + } else { + shmList->tail = NULL; + } + LOG(INFO) << "Delete shm to list, name=" << curNode->shm.name << " size=" << shmList->size; + FREE_PTR(curNode); + shmList->size--; +} + +void *UbsShmCallback(void* args) +{ + ShmList *shmList = (ShmList*)args; + if (UNLIKELY(shmList == NULL)) { + LOG(ERROR) << "Shm list is null."; + return NULL; + } + + LOCK_GUARD(shmList->shmLock); + while (shmList->head != NULL) { + SHM shm = shmList->head->shm; + if (shm.addr == NULL) { + LOG(ERROR) << "Ubs input shm param is invalid, addr is NULL."; + return NULL; + } + + int ret = ubsmem_shmem_unmap(shm.addr, shm.len); + if (ret != UBSM_OK) { + if (ret == UBSM_ERR_NET) { + return NULL; + } + LOG(ERROR) << "Ubs unmap shm=" << shm.name << " length=" << shm.len << " failed, ret=" << ret; + return NULL; + } + LOG(INFO) << "Ubs unmap shm=" << shm.name << " length=" << shm.len << " success."; + + ret = ubsmem_shmem_deallocate(shm.name); + if (ret != UBSM_OK) { + DeleteShmToList(shmList); + LOG(ERROR) << "Ubs delete shm=" << shm.name << " failed, ret=" << ret; + return NULL; + } + DeleteShmToList(shmList); + LOG(INFO) << "Ubs free local shm=" << shm.name << " length=" << shm.len << " success."; + } + + return NULL; +} + +RETURN_CODE UbsShmAddTimer(ShmList *shmList) +{ + uint32_t timerInterval = FLAGS_ub_flying_io_timeout; + itimerspec timeSpec = { + .it_interval = {.tv_sec = timerInterval, .tv_nsec = 0}, + .it_value = {.tv_sec = 0, .tv_nsec = 1} + }; + int timerFd = TimerStart(&timeSpec, UbsShmCallback, (void*)shmList); + if (UNLIKELY(timerFd == -1)) { + LOG(ERROR) << "Start shm timer failed."; + return UBRING_ERR; + } + g_shmTimerFd = timerFd; + + return UBRING_OK; +} + +RETURN_CODE InitShmTimer(ShmList **shmList) +{ + *shmList = (ShmList *)malloc(sizeof(ShmList)); + if (*shmList == NULL) { + LOG(ERROR) << "Malloc shm list failed."; + return UBRING_ERR; + } + (*shmList)->head = NULL; + (*shmList)->tail = NULL; + (*shmList)->size = 0; + + if (pthread_mutex_init(&(*shmList)->shmLock, NULL) != 0) { + LOG(ERROR) << "Init shm list mutex failed."; + FREE_PTR(*shmList); + return UBRING_ERR; + } + + if (UbsShmAddTimer(*shmList) == UBRING_ERR) { + LOG(ERROR) << "Ubs add timer failed."; + FREE_PTR(*shmList); + return UBRING_ERR; + } + return UBRING_OK; +} + +RETURN_CODE DestroyShmTimer(ShmList *shmList) +{ + DeleteTimerSafe((uint32_t)g_shmTimerFd); + if (shmList == NULL) { + LOG(WARNING) << "Shm list is null."; + return UBRING_ERR; + } + ShmListNode* current = shmList->head; + ShmListNode* next; + + while (current != NULL) { + next = current->next; + free(current); + current = next; + } + pthread_mutex_destroy(&shmList->shmLock); + FREE_PTR(shmList); + return UBRING_OK; +} + +RETURN_CODE IsExistInShmList(ShmList *shmList, const SHM *shm) +{ + if (UNLIKELY(shmList == NULL || shm == NULL)) { + LOG(ERROR) << "Shm list or shm is null."; + return UBRING_ERR; + } + LOCK_GUARD(shmList->shmLock); + + ShmListNode *curNode = shmList->head; + while (curNode != NULL) { + if (strcmp(curNode->shm.name, shm->name) == 0 && curNode->shm.len == shm->len) { + return UBRING_OK; + } + curNode = curNode->next; + } + return UBRING_ERR; +} + +RETURN_CODE AddShmToList(ShmList *shmList, SHM *shm) +{ + if (shmList == NULL || shm == NULL) { + LOG(ERROR) << "Shm list or shm is null."; + return UBRING_ERR; + } + + if (IsExistInShmList(shmList, shm) == UBRING_OK) { + LOG(ERROR) << "Shm name=" << shm->name << " is exist in shm list."; + return UBRING_ERR; + } + + ShmListNode *newShmNode = (ShmListNode *)malloc(sizeof(ShmListNode)); + if (newShmNode == NULL) { + LOG(ERROR) << "Malloc shm node failed."; + return UBRING_ERR; + } + + memcpy(&newShmNode->shm, shm, sizeof(SHM)); + LOCK_GUARD(shmList->shmLock); + newShmNode->next = NULL; + newShmNode->prev = shmList->tail; + if (shmList->tail) { + shmList->tail->next = newShmNode; + shmList->tail = newShmNode; + } else { + shmList->head = newShmNode; + shmList->tail = newShmNode; + } + shmList->size++; + LOG(INFO) << "Add shm to list success, shm name=" << shm->name << " size=" << shmList->size; + return UBRING_OK; +} +} +} \ No newline at end of file diff --git a/src/brpc/ubshm/shm/shm_ubs.h b/src/brpc/ubshm/shm/shm_ubs.h new file mode 100644 index 0000000000..14b5916503 --- /dev/null +++ b/src/brpc/ubshm/shm/shm_ubs.h @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_SHM_UBS_H +#define BRPC_SHM_UBS_H +namespace brpc { +namespace ubring { +DECLARE_int32(ub_flying_io_timeout); + +typedef enum TagUbsLogLevel { + UBSM_LOG_DEBUG_LEVEL = 0, + UBSM_LOG_INFO_LEVEL = 1, + UBSM_LOG_WARN_LEVEL = 2, + UBSM_LOG_ERROR_LEVEL = 3, + UBSM_LOG_CLOSED_LEVEL = 4 +} UbsLogLevel; + +RETURN_CODE UbsShmLocalMalloc(SHM *shm); +RETURN_CODE UbsShmMunmap(SHM *shm); +RETURN_CODE UbsShmFree(SHM *shm); +RETURN_CODE UbsShmLocalFree(SHM *shm); +RETURN_CODE UbsShmRemoteMalloc(SHM *shm); +RETURN_CODE UbsShmRemoteFree(SHM *shm); +RETURN_CODE UbsShmInit(void); +RETURN_CODE UbsShmFini(void); +RETURN_CODE UbsShmLocalMmap(SHM *shm, int prot); +void UbsMemLoggerPrint(int level, const char *msg); + +void *UbsShmCallback(void* args); +RETURN_CODE UbsShmAddTimer(ShmList *shmList); +RETURN_CODE InitShmTimer(ShmList **shmList); +RETURN_CODE DestroyShmTimer(ShmList *shmList); +RETURN_CODE AddShmToList(ShmList *shmList, SHM *shm); +RETURN_CODE IsExistInShmList(ShmList *shmList, const SHM *shm); +} +} +#endif //BRPC_SHM_UBS_H \ No newline at end of file diff --git a/src/brpc/ubshm/timer/timer_mgr.cpp b/src/brpc/ubshm/timer/timer_mgr.cpp new file mode 100644 index 0000000000..e53833f95e --- /dev/null +++ b/src/brpc/ubshm/timer/timer_mgr.cpp @@ -0,0 +1,468 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/ubshm/timer/timer_mgr.h" + +namespace brpc { +namespace ubring { + +int32_t g_epollFd = -1; +std::atomic g_totalTimerNum; +TimerFdCtx *g_timerFdCtxMap = NULL; +uint32_t maxSystemFd; +static pthread_t g_epollExecuteThread; +static int32_t g_timerModuleInitialized; + +#if defined(OS_MACOSX) +static int timerfd_create_macosx(int clockid, int flags); +static int timerfd_settime_macosx(int fd, int flags, + const itimerspec *new_value, + itimerspec *old_value); +#endif + +static RETURN_CODE DeleteTimerInner(uint32_t fd) { + if (g_timerFdCtxMap == NULL) { + return UBRING_OK; + } + + if (pthread_spin_lock(&g_timerFdCtxMap[fd].spinLock) != 0) { + return UBRING_ERR; + } + + if (g_timerFdCtxMap[fd].status == TIMER_CONTEXT_NOT_USING) { + pthread_spin_unlock(&g_timerFdCtxMap[fd].spinLock); + return UBRING_OK; + } + + g_timerFdCtxMap[fd].status = TIMER_CONTEXT_NOT_USING; + g_timerFdCtxMap[fd].cb = NULL; + g_timerFdCtxMap[fd].args = NULL; + g_timerFdCtxMap[fd].periodical = 0; + g_timerFdCtxMap[fd].fd = 0; + + pthread_spin_unlock(&g_timerFdCtxMap[fd].spinLock); + +#if defined(OS_LINUX) + epoll_ctl(g_epollFd, EPOLL_CTL_DEL, (int)fd, NULL); +#elif defined(OS_MACOSX) + struct kevent evt; + EV_SET(&evt, fd, EVFILT_TIMER, EV_DELETE, 0, 0, NULL); + kevent(g_epollFd, &evt, 1, NULL, 0, NULL); +#endif + + uint64_t exp = 0; + read((int)fd, &exp, sizeof(exp)); + + close((int)fd); + atomic_fetch_sub(&g_totalTimerNum, 1); + return UBRING_OK; +} + +static RETURN_CODE StartTimeEpoll(void) { +#if defined(OS_LINUX) + g_epollFd = epoll_create1(0); +#elif defined(OS_MACOSX) + g_epollFd = kqueue(); +#endif + if (UNLIKELY(g_epollFd == -1)) { + LOG(ERROR) << "Failed to create epoll/kqueue. errno=" << errno; + return UBRING_ERR; + } + + int ret = pthread_create(&g_epollExecuteThread, NULL, TimerEpoll, NULL); + if (UNLIKELY(ret != 0)) { + LOG(ERROR) << "Failed to create thread err=" << ret; + return UBRING_ERR; + } + return UBRING_OK; +} + +static RETURN_CODE TimerSpinLocksInit(void) { + if (g_timerFdCtxMap == NULL) { + LOG(ERROR) << "Timer module is not fully initialized."; + return UBRING_ERR; + } + + for (uint32_t fd = 0; fd < maxSystemFd; fd++) { + int ret = pthread_spin_init(&g_timerFdCtxMap[fd].spinLock, + PTHREAD_PROCESS_PRIVATE); + if (ret != EOK) { + LOG(ERROR) << "Failed to initialize spin lock for fd=" << fd; + for (uint32_t cleanupFd = 0; cleanupFd < fd; cleanupFd++) { + pthread_spin_destroy(&g_timerFdCtxMap[cleanupFd].spinLock); + } + return UBRING_ERR; + } + } + return UBRING_OK; +} + +static RETURN_CODE ExecuteCallback(int32_t timerFd) { + UnifiedCallback((void *)(&g_timerFdCtxMap[timerFd])); + return UBRING_OK; +} + +static RETURN_CODE TimerCtxMapCompletion(void) { + memset(g_timerFdCtxMap, 0, sizeof(TimerFdCtx) * maxSystemFd); + + RETURN_CODE ret = TimerSpinLocksInit(); + if (ret != UBRING_OK) { + LOG(ERROR) << "Failed to init spin locks for timer module."; + return UBRING_ERR; + } + return UBRING_OK; +} + +RETURN_CODE TimerInit(void) { + if (g_timerModuleInitialized > 0) { + return UBRING_OK; + } + + g_totalTimerNum.store(0); + + struct rlimit rlim; + if (getrlimit(RLIMIT_NOFILE, &rlim) != UBRING_OK) { + LOG(ERROR) << "Failed to get fd"; + return UBRING_ERR; + } + maxSystemFd = (uint32_t)rlim.rlim_cur; + + if (g_timerFdCtxMap == NULL) { + g_timerFdCtxMap = (TimerFdCtx *)malloc(sizeof(TimerFdCtx) * maxSystemFd); + if (UNLIKELY(!g_timerFdCtxMap)) { + LOG(ERROR) << "Fail to malloc space for timer modules. errno=%d", errno; + return UBRING_ERR; + } + + RETURN_CODE ret = TimerCtxMapCompletion(); + if (ret != UBRING_OK) { + LOG(ERROR) << "Failed to init main data structure of Time Module. ret=" << ret; + free(g_timerFdCtxMap); + g_timerFdCtxMap = NULL; + return UBRING_ERR; + } + } + + RETURN_CODE ret = StartTimeEpoll(); + if (ret != UBRING_OK) { + LOG(ERROR) << "Failed to start Timer Epoll. ret=" << ret; + if (LIKELY(g_timerFdCtxMap != NULL)) { + FREE_PTR(g_timerFdCtxMap); + } + return UBRING_ERR; + } + g_timerModuleInitialized = 1; + return UBRING_OK; +} + +void *UnifiedCallback(void *args) { + TimerFdCtx *ctx = (TimerFdCtx *)args; + if (pthread_spin_lock(&ctx->spinLock) != 0) { + return NULL; + } + + if (ctx->status == TIMER_CONTEXT_NOT_USING) { + pthread_spin_unlock(&ctx->spinLock); + return NULL; + } + + void *(*cb)(void *) = ctx->cb; + void *cbArgs = ctx->args; + uint32_t fd = ctx->fd; + int isPeriodical = ctx->periodical; + ctx->status = TIMER_CONTEXT_CALLBACK_ONGOING; + + pthread_spin_unlock(&ctx->spinLock); + + cb(cbArgs); + + if (!isPeriodical) { + DeleteTimerInner(fd); + } + return NULL; +} + +void *TimerEpoll(void *args) { + UNREFERENCE_PARAM(args); +#if defined(OS_LINUX) + struct epoll_event readyEvents[MAX_TIMER]; +#elif defined(OS_MACOSX) + struct kevent readyEvents[MAX_TIMER]; +#endif + + while (1) { + if (g_timerModuleInitialized <= 0) { + LOG(ERROR) << "The Timer module is not initialized."; + break; + } + +#if defined(OS_LINUX) + int32_t readyNum = epoll_wait(g_epollFd, readyEvents, MAX_TIMER, + TIMER_EPOLL_WAIT_TIMEOUT); +#elif defined(OS_MACOSX) + struct timespec timeout = {0, TIMER_EPOLL_WAIT_TIMEOUT * 1000000}; + int32_t readyNum = kevent(g_epollFd, NULL, 0, readyEvents, MAX_TIMER, &timeout); +#endif + + if (UNLIKELY(readyNum == -1)) { + errno_t err = errno; + if (err == EINTR) { + LOG_EVERY_SECOND(WARNING) << "Epoll/Kqueue wait was interrupted. errno=" << err; + continue; + } else if (err == EBADF) { + LOG(WARNING) << "The Timer module is destroyed."; + break; + } + LOG(ERROR) << "Epoll/Kqueue wait internal error. errno=" << err; + break; + } + + for (int32_t i = 0; i < readyNum; i++) { +#if defined(OS_LINUX) + struct epoll_event *event = &readyEvents[i]; + int32_t timerFd = event->data.fd; +#elif defined(OS_MACOSX) + struct kevent *event = &readyEvents[i]; + int32_t timerFd = event->ident; +#endif + + uint64_t exp = 0; + if (read(timerFd, &exp, sizeof(exp)) < 0) { + if (errno != EBADF) { + LOG(ERROR) << "Failed to read timerfd=" << timerFd << " errno=" << errno; + } + continue; + } + if (TimerFdCtxValidate((uint32_t)timerFd) != UBRING_OK) { + continue; + } + + RETURN_CODE ret = ExecuteCallback(timerFd); + if (ret != UBRING_OK) { + LOG(ERROR) << "Failed execute callback ret=" << ret; + DeleteTimerInner((uint32_t)timerFd); + continue; + } + } + } + return NULL; +} + +void DeleteTimerSafe(uint32_t fd) { + if (g_timerFdCtxMap == NULL) { + return; + } + + if (pthread_spin_lock(&g_timerFdCtxMap[fd].spinLock) != 0) { + return; + } + + if (g_timerFdCtxMap[fd].status == TIMER_CONTEXT_NOT_USING) { + pthread_spin_unlock(&g_timerFdCtxMap[fd].spinLock); + return; + } + + g_timerFdCtxMap[fd].status = TIMER_CONTEXT_NOT_USING; + g_timerFdCtxMap[fd].cb = NULL; + g_timerFdCtxMap[fd].args = NULL; + g_timerFdCtxMap[fd].periodical = 0; + g_timerFdCtxMap[fd].fd = 0; + + pthread_spin_unlock(&g_timerFdCtxMap[fd].spinLock); + +#if defined(OS_LINUX) + epoll_ctl(g_epollFd, EPOLL_CTL_DEL, (int)fd, NULL); +#elif defined(OS_MACOSX) + struct kevent evt; + EV_SET(&evt, fd, EVFILT_TIMER, EV_DELETE, 0, 0, NULL); + kevent(g_epollFd, &evt, 1, NULL, 0, NULL); +#endif + + uint64_t exp = 0; + read((int)fd, &exp, sizeof(exp)); + + close((int)fd); + atomic_fetch_sub(&g_totalTimerNum, 1); +} + +void DeleteTimer(uint32_t fd) { + if (g_timerFdCtxMap == NULL) { + LOG(WARNING) << "The timer is not initialized."; + return; + } + + g_timerFdCtxMap[fd].periodical = 0; +} + +int32_t TimerStart(const itimerspec *time, void *(*cb)(void *), void *args) { + if (g_epollFd == -1) { + LOG(ERROR) << "Timer epoll/kqueue encountered internal error."; + return -1; + } + +#if defined(OS_LINUX) + int timerFd = timerfd_create(CLOCK_MONOTONIC, 0); +#elif defined(OS_MACOSX) + int timerFd = timerfd_create_macosx(CLOCK_MONOTONIC, 0); +#endif + + if (UNLIKELY(timerFd >= (int)maxSystemFd || timerFd == -1)) { + LOG(ERROR) << "Failed to create timerfd=" << timerFd << " errno=" << errno; + return -1; + } + + g_timerFdCtxMap[timerFd].status = TIMER_CONTEXT_EPOLL_WAITING; + g_timerFdCtxMap[timerFd].cb = cb; + g_timerFdCtxMap[timerFd].args = args; + g_timerFdCtxMap[timerFd].fd = (uint32_t)timerFd; + + if (LIKELY(time->it_interval.tv_sec > 0 || time->it_interval.tv_nsec > 0)) { + g_timerFdCtxMap[timerFd].periodical = 1; + } + +#if defined(OS_LINUX) + struct epoll_event event = { + .events = EPOLLIN, + .data = {.fd = timerFd} + }; + + int32_t ret = epoll_ctl(g_epollFd, EPOLL_CTL_ADD, timerFd, &event); +#elif defined(OS_MACOSX) + struct kevent event; + uint64_t timeout_nsec = time->it_value.tv_sec * 1000000000ULL + time->it_value.tv_nsec; + uint64_t interval_nsec = time->it_interval.tv_sec * 1000000000ULL + time->it_interval.tv_nsec; + EV_SET(&event, timerFd, EVFILT_TIMER, EV_ADD | EV_ENABLE, 0, + timeout_nsec / 1000000, NULL); + int32_t ret = kevent(g_epollFd, &event, 1, NULL, 0, NULL); +#endif + + if (UNLIKELY(ret != 0)) { + CloseTimerFd((uint32_t)timerFd); + LOG(ERROR) << "Failed to add event to epoll/kqueue. errno=" << errno; + return -1; + } + + atomic_fetch_add(&g_totalTimerNum, 1); + +#if defined(OS_LINUX) + ret = timerfd_settime(timerFd, 0, time, NULL); +#elif defined(OS_MACOSX) + ret = timerfd_settime_macosx(timerFd, 0, time, NULL); +#endif + + if (UNLIKELY(ret != 0)) { +#if defined(OS_LINUX) + if (epoll_ctl(g_epollFd, EPOLL_CTL_DEL, timerFd, NULL) != 0) { +#elif defined(OS_MACOSX) + struct kevent evt; + EV_SET(&evt, timerFd, EVFILT_TIMER, EV_DELETE, 0, 0, NULL); + if (kevent(g_epollFd, &evt, 1, NULL, 0, NULL) != 0) { +#endif + LOG(ERROR) << "Failed to delete the timer fd=" << timerFd << " with errno=" << errno; + } + CloseTimerFd((uint32_t)timerFd); + atomic_fetch_sub(&g_totalTimerNum, 1); + LOG(ERROR) << "Failed to set timer"; + return -1; + } + + return timerFd; +} + +uint32_t GetActiveTimerNum(void) { + return atomic_load(&g_totalTimerNum); +} + +void CloseTimerFd(uint32_t fd) { + g_timerFdCtxMap[fd].cb = NULL; + g_timerFdCtxMap[fd].args = NULL; + g_timerFdCtxMap[fd].status = TIMER_CONTEXT_NOT_USING; + g_timerFdCtxMap[fd].fd = 0; + g_timerFdCtxMap[fd].periodical = 0; + if (close((int)fd) != 0) { + LOG(ERROR) << "Failed to close timer fd=" << fd << " errno=" << errno; + return; + } +} + +void TimerModuleDestroy(void) { + uint32_t maxFd = maxSystemFd; + if (g_timerFdCtxMap) { + for (uint32_t fd = 0; fd < maxFd; fd++) { + if (g_timerFdCtxMap[fd].status != TIMER_CONTEXT_NOT_USING) { + DeleteTimerSafe(fd); + } + } + } + close(g_epollFd); + g_epollFd = -1; + g_totalTimerNum = 0; + g_timerModuleInitialized = 0; + int32_t ret = pthread_join(g_epollExecuteThread, NULL); + if (ret != EOK) { + LOG(ERROR) << "Failed to join pthread, during destroying timer module. ret=" << ret; + return; + } +} + +RETURN_CODE TimerFdCtxValidate(uint32_t fd) { + if (fd >= maxSystemFd) { + LOG(ERROR) << "TimerFd=" << fd << " is out of range=" << maxSystemFd; + return UBRING_ERR; + } + if (g_timerFdCtxMap[fd].status == TIMER_CONTEXT_NOT_USING) { + LOG(ERROR) << "TimerFd=" << fd << " has wrong status=" << g_timerFdCtxMap[fd].status; + return UBRING_ERR; + } + if (g_timerFdCtxMap[fd].cb == NULL) { + LOG(ERROR) << "The callback is not set."; + return UBRING_ERR; + } + + return UBRING_OK; +} + +#if defined(OS_MACOSX) +static int timerfd_create_macosx(int clockid, int flags) { + int pipefd[2]; + if (pipe(pipefd) == -1) { + return -1; + } + return pipefd[0]; +} + +static int timerfd_settime_macosx(int fd, int flags, + const itimerspec *new_value, + itimerspec *old_value) { + if (old_value != NULL) { + memset(old_value, 0, sizeof(itimerspec)); + } + return 0; +} +#endif + +} // namespace ubring +} // namespace brpc \ No newline at end of file diff --git a/src/brpc/ubshm/timer/timer_mgr.h b/src/brpc/ubshm/timer/timer_mgr.h new file mode 100644 index 0000000000..9630430a2c --- /dev/null +++ b/src/brpc/ubshm/timer/timer_mgr.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_TIMER_MGR_H +#define BRPC_TIMER_MGR_H +#include +#include +#include "brpc/ubshm/common/common.h" + +#if defined(OS_LINUX) +#include +#include +#elif defined(OS_MACOSX) +#include +#include +#include +#endif + +#define MAX_TIMER 1024 +#define TIMER_EPOLL_WAIT_TIMEOUT 1000 + +#if defined(OS_MACOSX) +struct itimerspec +{ + struct timespec it_interval; + struct timespec it_value; +}; +#endif +namespace brpc { +namespace ubring { +typedef enum { + TIMER_CONTEXT_NOT_USING, + TIMER_CONTEXT_EPOLL_WAITING, + TIMER_CONTEXT_CALLBACK_ONGOING +} TimerFdCtxStatus; + +typedef struct { + void *(*cb)(void*); + void *args; + uint32_t fd; + TimerFdCtxStatus status; + uint32_t periodical; + pthread_spinlock_t spinLock; +} TimerFdCtx; + +RETURN_CODE TimerInit(void); +void TimerModuleDestroy(void); +void *UnifiedCallback(void *args); +void *TimerEpoll(void *args); +int32_t TimerStart(const itimerspec *time, void *(*cb)(void *), void *args); +uint32_t GetActiveTimerNum(void); +void CloseTimerFd(uint32_t fd); + +void DeleteTimerSafe(uint32_t fd); +void DeleteTimer(uint32_t fd); +RETURN_CODE TimerFdCtxValidate(uint32_t fd); +} +} +#endif //BRPC_TIMER_MGR_H \ No newline at end of file diff --git a/src/brpc/ubshm/ub_endpoint.cpp b/src/brpc/ubshm/ub_endpoint.cpp new file mode 100644 index 0000000000..7b4868209a --- /dev/null +++ b/src/brpc/ubshm/ub_endpoint.cpp @@ -0,0 +1,943 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_UBRING + +#include + +#include +#include +#include "butil/fd_utility.h" +#include "butil/logging.h" // CHECK, LOG +#include "butil/sys_byteorder.h" // HostToNet,NetToHost +#include "bthread/bthread.h" +#include "brpc/errno.pb.h" +#include "brpc/event_dispatcher.h" +#include "brpc/input_messenger.h" +#include "brpc/socket.h" +#include "brpc/reloadable_flags.h" +#include "brpc/ubshm/ub_helper.h" +#include "brpc/ubshm/ub_endpoint.h" +#include "brpc/ubshm/shm/shm_def.h" +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm_transport.h" +#include "brpc/ubshm/ubr_trx.h" + +DECLARE_int32(task_group_ntags); + +namespace brpc { +DECLARE_bool(log_connection_close); +namespace ubring { + +extern bool g_skip_ub_init; +DEFINE_int32(data_queue_size, 4, "data queue size for UB"); +DEFINE_bool(ub_trace_verbose, false, "Print log message verbosely"); +BRPC_VALIDATE_GFLAG(ub_trace_verbose, brpc::PassValidate); +DEFINE_int32(ub_poller_num, 1, "Poller number in ub polling mode."); +DEFINE_bool(ub_poller_yield, false, "Yield thread in RDMA polling mode."); +DEFINE_bool(ub_edisp_unsched, false, "Disable event dispatcher schedule"); +DEFINE_bool(ub_disable_bthread, false, "Disable bthread in RDMA"); + +static const size_t MIN_ONCE_READ = 4096; +static const size_t MAX_ONCE_READ = 524288; +static const size_t IOBUF_IOV_MAX = 256; + +static const char* MAGIC_STR = "UB"; +static const size_t MAGIC_STR_LEN = 2; +static const size_t HELLO_MSG_LEN_MIN = 64; +static const size_t ACK_MSG_LEN = 4; +static uint16_t g_ub_hello_msg_len = 64; +static uint16_t g_ub_hello_version = 2; +static uint16_t g_ub_impl_version = 1; + +static const uint32_t ACK_MSG_UB_OK = 0x1; + +static butil::Mutex* g_ubring_resource_mutex = NULL; + +struct HelloMessage { + void Serialize(void* data) const; + void Deserialize(void* data); + std::string toString() const; + + uint16_t msg_len; + uint16_t hello_ver; + uint16_t impl_ver; + uint64_t len; + char shm_name[SHM_MAX_NAME_BUFF_LEN]; +}; + +void HelloMessage::Serialize(void* data) const { + char* current_pos = static_cast(data); + const uint16_t net_msg_len = butil::HostToNet16(msg_len); + memcpy(current_pos, &net_msg_len, sizeof(net_msg_len)); + current_pos += sizeof(net_msg_len); + const uint16_t net_hello_ver = butil::HostToNet16(hello_ver); + memcpy(current_pos, &net_hello_ver, sizeof(net_hello_ver)); + current_pos += sizeof(net_hello_ver); + const uint16_t net_impl_ver = butil::HostToNet16(impl_ver); + memcpy(current_pos, &net_impl_ver, sizeof(net_impl_ver)); + current_pos += sizeof(net_impl_ver); + const uint64_t net_len = butil::HostToNet64(len); + memcpy(current_pos, &net_len, sizeof(net_len)); + current_pos += sizeof(net_len); + memcpy(current_pos, shm_name, SHM_MAX_NAME_BUFF_LEN); +} + +void HelloMessage::Deserialize(void* data) { + char* current_pos = static_cast(data); + uint16_t net_msg_len; + memcpy(&net_msg_len, current_pos, sizeof(net_msg_len)); + msg_len = butil::NetToHost16(net_msg_len); + current_pos += sizeof(net_msg_len); + uint16_t net_hello_ver; + memcpy(&net_hello_ver, current_pos, sizeof(net_hello_ver)); + hello_ver = butil::NetToHost16(net_hello_ver); + current_pos += sizeof(net_hello_ver); + uint16_t net_impl_ver; + memcpy(&net_impl_ver, current_pos, sizeof(net_impl_ver)); + impl_ver = butil::NetToHost16(net_impl_ver); + current_pos += sizeof(net_impl_ver); + uint64_t net_len; + memcpy(&net_len, current_pos, sizeof(net_len)); + len = butil::NetToHost64(net_len); + current_pos += sizeof(net_len); + memcpy(shm_name, current_pos, SHM_MAX_NAME_BUFF_LEN); +} + +std::string HelloMessage::toString() const { + constexpr size_t MAX_LEN = 16 + 6 + 16 + 6 + 16 + 6 + 20 + 6 + SHM_MAX_NAME_BUFF_LEN + 32; + std::array buf; + int n = snprintf(buf.data(), buf.size(), + "msg_len=%u, hello_ver=%u, impl_ver=%u, len=%lu, shm_name=%.*s", + msg_len, + hello_ver, + impl_ver, + static_cast(len), // 兼容32/64位 + static_cast(SHM_MAX_NAME_BUFF_LEN), // 限制最大输出长度 + shm_name + ); + return std::string(buf.data(), static_cast(n)); +} + +UBShmEndpoint::UBShmEndpoint(Socket* s) + : _socket(s) + , _state(UNINIT) + , _ub_ring(nullptr) + , _cq_sid(INVALID_SOCKET_ID) +{ + _read_butex = bthread::butex_create_checked>(); +} + +UBShmEndpoint::~UBShmEndpoint() { + Reset(); + bthread::butex_destroy(_read_butex); +} + +void UBShmEndpoint::Reset() { + DeallocateResources(); + + delete _ub_ring; + _ub_ring = nullptr; + _cq_sid = INVALID_SOCKET_ID; + _state = UNINIT; +} + +void UBConnect::StartConnect(const Socket* socket, + void (*done)(int err, void* data), + void* data) { + auto* ub_transport = static_cast(socket->_transport.get()); + CHECK(ub_transport->_ub_ep != NULL); + SocketUniquePtr s; + if (Socket::Address(socket->id(), &s) != 0) { + return; + } + if (!IsUBAvailable()) { + ub_transport->_ub_ep->_state = UBShmEndpoint::FALLBACK_TCP; + ub_transport->_ub_state = UBShmTransport::UB_OFF; + done(0, data); + return; + } + _done = done; + _data = data; + bthread_t tid; + bthread_attr_t attr = BTHREAD_ATTR_NORMAL; + bthread_attr_set_name(&attr, "UBProcessHandshakeAtClient"); + if (bthread_start_background(&tid, &attr, + UBShmEndpoint::ProcessHandshakeAtClient, ub_transport->_ub_ep) < 0) { + LOG(FATAL) << "Fail to start handshake bthread"; + Run(); + } else { + s.release(); + } +} + +void UBConnect::StopConnect(Socket* socket) { } + +void UBConnect::Run() { + _done(errno, _data); +} + +static void TryReadOnTcpDuringRdmaEst(Socket* s) { + int progress = Socket::PROGRESS_INIT; + while (true) { + uint8_t tmp; + ssize_t nr = read(s->fd(), &tmp, 1); + if (nr < 0) { + if (errno != EAGAIN) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read from " << s; + s->SetFailed(saved_errno, "Fail to read from %s: %s", + s->description().c_str(), berror(saved_errno)); + return; + } + if (!s->MoreReadEvents(&progress)) { + break; + } + } else if (nr == 0) { + s->SetEOF(); + return; + } else { + LOG(WARNING) << "Read unexpected data from " << s; + s->SetFailed(EPROTO, "Read unexpected data from %s", + s->description().c_str()); + return; + } + } +} + +void UBShmEndpoint::OnNewDataFromTcp(Socket* m) { + auto* ub_transport = static_cast(m->_transport.get()); + UBShmEndpoint* ep = ub_transport->GetUBShmEp(); + CHECK(ep != NULL); + + int progress = Socket::PROGRESS_INIT; + while (true) { + if (ep->_state == UNINIT) { + if (!m->CreatedByConnect()) { + if (!IsUBAvailable()) { + ep->_state = FALLBACK_TCP; + ub_transport->_ub_state = UBShmTransport::UB_OFF; + continue; + } + bthread_t tid; + ep->_state = S_HELLO_WAIT; + SocketUniquePtr s; + m->ReAddress(&s); + bthread_attr_t attr = BTHREAD_ATTR_NORMAL; + bthread_attr_set_name(&attr, "UBProcessHandshakeAtServer"); + if (bthread_start_background(&tid, &attr, + ProcessHandshakeAtServer, ep) < 0) { + ep->_state = UNINIT; + LOG(FATAL) << "Fail to start handshake bthread"; + } else { + s.release(); + } + } else { + // The connection may be closed or reset before the client + // starts handshake. This will be handled by client handshake. + // Ignore the exception here. + } + } else if (ep->_state < ESTABLISHED) { // during handshake + ep->_read_butex->fetch_add(1, butil::memory_order_release); + bthread::butex_wake(ep->_read_butex); + } else if (ep->_state == FALLBACK_TCP){ // handshake finishes + InputMessenger::OnNewMessages(m); + return; + } else if (ep->_state == ESTABLISHED) { + TryReadOnTcpDuringRdmaEst(ep->_socket); + return; + } + if (!m->MoreReadEvents(&progress)) { + break; + } + } +} +bool HelloNegotiationValid(HelloMessage& msg) { + if (msg.hello_ver == g_ub_hello_version && + msg.impl_ver == g_ub_impl_version) { + // This can be modified for future compatibility + return true; + } + return false; +} + +static const int WAIT_TIMEOUT_MS = 50; + +int UBShmEndpoint::ReadFromFd(void* data, size_t len) { + CHECK(data != NULL); + int nr = 0; + size_t received = 0; + do { + const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); + nr = read(_socket->fd(), (uint8_t*)data + received, len - received); + if (nr < 0) { + if (errno == EAGAIN) { + const int expected_val = _read_butex->load(butil::memory_order_acquire); + if (bthread::butex_wait(_read_butex, expected_val, &duetime) < 0) { + if (errno != EWOULDBLOCK && errno != ETIMEDOUT) { + return -1; + } + } + } else { + return -1; + } + } else if (nr == 0) { + errno = EEOF; + return -1; + } else { + received += nr; + } + } while (received < len); + return 0; +} + +int UBShmEndpoint::WriteToFd(void* data, size_t len) { + CHECK(data != NULL); + int nw = 0; + size_t written = 0; + do { + const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); + nw = write(_socket->fd(), (uint8_t*)data + written, len - written); + if (nw < 0) { + if (errno == EAGAIN) { + if (_socket->WaitEpollOut(_socket->fd(), true, &duetime) < 0) { + if (errno != ETIMEDOUT) { + return -1; + } + } + } else { + return -1; + } + } else { + written += nw; + } + } while (written < len); + return 0; +} + +inline void UBShmEndpoint::TryReadOnTcp() { + if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { + if (_state == FALLBACK_TCP) { + InputMessenger::OnNewMessages(_socket); + } else if (_state == ESTABLISHED) { + TryReadOnTcpDuringRdmaEst(_socket); + } + } +} + +void* UBShmEndpoint::ProcessHandshakeAtClient(void* arg) { + UBShmEndpoint* ep = static_cast(arg); + SocketUniquePtr s(ep->_socket); + UBConnect::RunGuard rg((UBConnect*)s->_app_connect.get()); + + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Start handshake on " << s->_local_side; + + uint8_t data[g_ub_hello_msg_len]; + + ep->_state = C_ALLOC_SHM; + auto* ub_transport = static_cast(s->_transport.get()); + size_t local_shm_len = (size_t)(FLAGS_data_queue_size) * MB_TO_BYTE; + SHM local_trx_shm = {NULL, local_shm_len, 0, {0}, (uint32_t)s->fd()}; + auto shm_name_str = butil::endpoint2str(s->local_side()); + const char* shm_name = shm_name_str.c_str(); + if (ep->AllocateClientResources(&local_trx_shm, shm_name) < 0) { + LOG(WARNING) << "Fallback to tcp:" << s->description(); + ub_transport->_ub_state = UBShmTransport::UB_OFF; + ep->_state = FALLBACK_TCP; + return NULL; + } + + ep->_state = C_HELLO_SEND; + HelloMessage local_msg; + local_msg.msg_len = g_ub_hello_msg_len; + local_msg.hello_ver = g_ub_hello_version; + local_msg.impl_ver = g_ub_impl_version; + local_msg.len = local_shm_len; + memcpy(local_msg.shm_name, local_trx_shm.name, SHM_MAX_NAME_BUFF_LEN); + memcpy(data, MAGIC_STR, MAGIC_STR_LEN); + local_msg.Serialize((char*)data + MAGIC_STR_LEN); + if (ep->WriteToFd(data, g_ub_hello_msg_len) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + LOG_IF(INFO, FLAGS_ub_trace_verbose) << "client handshake message : " << local_msg.toString(); + + ep->_state = C_HELLO_WAIT; + if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to get hello message from server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { + LOG(WARNING) << "Read unexpected data during handshake:" << s->description(); + s->SetFailed(EPROTO, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + + if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to get Hello Message from server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + HelloMessage remote_msg; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { + LOG(WARNING) << "Fail to parse Hello Message length from server:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // TODO: Read Hello Message customized data + // Just for future use, should not happen now + } + + if (!HelloNegotiationValid(remote_msg)) { + LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" + << s->description(); + ub_transport->_ub_state = UBShmTransport::UB_OFF; + } else { + ep->_state = C_MAP_REMOTE_SHM; + if (ep->_ub_ring->UbrMapRemoteShm(&local_trx_shm, shm_name) < 0) { + LOG(WARNING) << "Fail to map the remote shm, fallback to tcp:" << s->description(); + ub_transport->_ub_state = UBShmTransport::UB_OFF; + } else { + ub_transport->_ub_state = UBShmTransport::UB_ON; + } + } + + ep->_state = C_ACK_SEND; + uint32_t flags = 0; + if (ub_transport->_ub_state != UBShmTransport::UB_OFF) { + flags |= ACK_MSG_UB_OK; + } + uint32_t* tmp = (uint32_t*)data; + *tmp = butil::HostToNet32(flags); + if (ep->WriteToFd(data, ACK_MSG_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send Ack Message to server:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + if (ub_transport->_ub_state == UBShmTransport::UB_ON) { + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Client handshake ends (use ubring) on " << s->description(); + } else { + ep->_state = FALLBACK_TCP; + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Client handshake ends (use tcp) on " << s->description(); + } + + errno = 0; + + return NULL; +} + +void* UBShmEndpoint::ProcessHandshakeAtServer(void* arg) { + UBShmEndpoint* ep = static_cast(arg); + SocketUniquePtr s(ep->_socket); + + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Start handshake on " << s->description(); + + uint8_t data[g_ub_hello_msg_len]; + + ep->_state = S_HELLO_WAIT; + if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description() << " " << s->_remote_side; + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + auto* ub_transport = static_cast(s->_transport.get()); + if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { + LOG_IF(INFO, FLAGS_ub_trace_verbose) << "It seems that the " + << "client does not use RDMA, fallback to TCP:" + << s->description(); + s->_read_buf.append(data, MAGIC_STR_LEN); + ep->_state = FALLBACK_TCP; + ub_transport->_ub_state = UBShmTransport::UB_OFF; + ep->TryReadOnTcp(); + return NULL; + } + + if (ep->ReadFromFd(data, g_ub_hello_msg_len - MAGIC_STR_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + HelloMessage remote_msg; + remote_msg.Deserialize(data); + LOG_IF(INFO, FLAGS_ub_trace_verbose) << "server receive handshake message : " << remote_msg.toString(); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { + LOG(WARNING) << "Fail to parse Hello Message length from client:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // TODO: Read Hello Message customized header + // Just for future use, should not happen now + } + + if (!HelloNegotiationValid(remote_msg)) { + LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" + << s->description(); + ub_transport->_ub_state = UBShmTransport::UB_OFF; + } else { + ep->_state = S_ALLOC_SHM; + ubring::SHM remote_trx_shm = {NULL, remote_msg.len, 0, {0}, (uint32_t)ep->_socket->fd()}; + strncpy(remote_trx_shm.name, remote_msg.shm_name, SHM_MAX_NAME_BUFF_LEN); + + size_t local_shm_len = (size_t)(FLAGS_data_queue_size) * MB_TO_BYTE; + // server端共享内存名称 + ubring::SHM local_trx_shm = {NULL, local_shm_len, 0, {0}, (uint32_t)ep->_socket->fd()}; + char clientName[SHM_MAX_NAME_BUFF_LEN]; + strncpy(clientName, remote_msg.shm_name, SHM_MAX_NAME_BUFF_LEN); + + char *clientIpPort = strrchr(clientName, '_'); + if (clientIpPort != NULL) { + *clientIpPort = '\0'; + } + int result = snprintf(local_trx_shm.name, SHM_MAX_NAME_BUFF_LEN, "%s_%s", + clientName, SERVER_SHM_NAME_SUFFIX); + if (UNLIKELY(result < 0)) { + LOG(WARNING) << "Copy client shared memory name failed, ret=" << result; + ub_transport->_ub_state = UBShmTransport::UB_OFF; + } + if (result >= 0 && ep->AllocateServerResources(&remote_trx_shm, &local_trx_shm) < 0) { + LOG(WARNING) << "Fail to allocate ub resources, fallback to tcp:" + << s->description(); + ub_transport->_ub_state = UBShmTransport::UB_OFF; + } + } + + ep->_state = S_HELLO_SEND; + HelloMessage local_msg; + local_msg.msg_len = g_ub_hello_msg_len; + if (ub_transport->_ub_state == UBShmTransport::UB_OFF) { + local_msg.impl_ver = 0; + local_msg.hello_ver = 0; + } else { + local_msg.hello_ver = g_ub_hello_version; + local_msg.impl_ver = g_ub_impl_version; + local_msg.len = (FLAGS_data_queue_size) * MB_TO_BYTE; + memcpy(local_msg.shm_name, remote_msg.shm_name, SHM_MAX_NAME_BUFF_LEN); + } + memcpy(data, MAGIC_STR, MAGIC_STR_LEN); + local_msg.Serialize((char*)data + MAGIC_STR_LEN); + if (ep->WriteToFd(data, g_ub_hello_msg_len) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ub handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + ep->_state = S_ACK_WAIT; + if (ep->ReadFromFd(data, ACK_MSG_LEN) < 0) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read ack message from client:" << s->description(); + s->SetFailed(saved_errno, "Fail to complete ubring handshake from %s: %s", + s->description().c_str(), berror(saved_errno)); + ep->_state = FAILED; + return NULL; + } + + uint32_t* tmp = (uint32_t*)data; + uint32_t flags = butil::NetToHost32(*tmp); + if (flags & ACK_MSG_UB_OK) { + if (ub_transport->_ub_state == UBShmTransport::UB_OFF) { + LOG(WARNING) << "Fail to parse Hello Message length from client:" + << s->description(); + s->SetFailed(EPROTO, "Fail to complete ub handshake from %s: %s", + s->description().c_str(), berror(EPROTO)); + ep->_state = FAILED; + return NULL; + } else { + ub_transport->_ub_state = UBShmTransport::UB_ON; + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Server handshake ends (use ubring) on " << s->description(); + } + } else { + ub_transport->_ub_state = UBShmTransport::UB_OFF; + ep->_state = FALLBACK_TCP; + LOG_IF(INFO, FLAGS_ub_trace_verbose) + << "Server handshake ends (use tcp) on " << s->description(); + } + ep->TryReadOnTcp(); + + return NULL; +} + +bool UBShmEndpoint::IsWritable() const { + if (BAIDU_UNLIKELY(g_skip_ub_init)) { + // Just for UT + return false; + } + auto ret = _ub_ring->IsUbrTrxWriteable(EPOLLET); + if (ret == 0) { + return true; + } + return false; +} + +ssize_t UBShmEndpoint::CutFromIOBufList(butil::IOBuf** from, size_t ndata) { + if (BAIDU_UNLIKELY(g_skip_ub_init)) { + // Just for UT + errno = EAGAIN; + return -1; + } + if (BAIDU_UNLIKELY(ndata == 0)) { + return 0; + } + struct iovec vec[IOBUF_IOV_MAX]; + size_t nvec = 0; + for (size_t i = 0; i < ndata; ++i) { + const butil::IOBuf* p = from[i]; + const size_t nref = p->backing_block_num(); + for (size_t j = 0; j < nref && nvec < IOBUF_IOV_MAX; ++j, ++nvec) { + butil::StringPiece sp = p->backing_block(j); + vec[nvec].iov_base = const_cast(sp.data()); + vec[nvec].iov_len = sp.size(); + } + } + + ssize_t nw = 0; + errno = 0; + nw = _ub_ring->UbrTrxWritev(vec, nvec); + if (UNLIKELY(nw == -1)) { + if (errno == EMSGSIZE) { + LOG(ERROR) << "Non-blocking send msg failed, message is larger than ubring capacity."; + } else { + LOG(ERROR) << "Non-blocking send msg in failed, connection has been closed."; + errno = EPIPE; + } + } else if (UNLIKELY(nw == UBRING_RETRY)) { + errno = EAGAIN; + nw = -1; + } + if (nw <= 0) { + return nw; + } + size_t npop_all = nw; + for (size_t i = 0; i < ndata; ++i) { + npop_all -= from[i]->pop_front(npop_all); + if (npop_all == 0) { + break; + } + } + return nw; +} + +int UBShmEndpoint::AllocateClientResources(ubring::SHM* local_trx_shm, const char* shm_name) { + if (BAIDU_UNLIKELY(g_skip_ub_init)) { + // For UT + return 0; + } + + CHECK(_ub_ring == NULL); + // TODO: Pooling management + _ub_ring = new UBRing(); + + SocketOptions options; + options.user = this; + options.keytable_pool = _socket->_keytable_pool; + if (Socket::Create(options, &_cq_sid) < 0) { + PLOG(WARNING) << "Fail to create socket for cq"; + return -1; + } + int ret = _ub_ring->UbrAllocateLocalShm(local_trx_shm, shm_name); + if (ret != 0) { + return ret; + } + PollerRegisterEvent(CqSidOp::ADD, EPOLLIN); + return 0; +} + +int UBShmEndpoint::AllocateServerResources(ubring::SHM* remote_trx_shm, ubring::SHM* local_trx_shm) { + if (BAIDU_UNLIKELY(g_skip_ub_init)) { + // For UT + return 0; + } + + CHECK(_ub_ring == NULL); + // TODO: Pooling management + _ub_ring = new UBRing(); + + SocketOptions options; + options.user = this; + options.keytable_pool = _socket->_keytable_pool; + if (Socket::Create(options, &_cq_sid) < 0) { + PLOG(WARNING) << "Fail to create socket for cq"; + return -1; + } + int ret = _ub_ring->UbrAllocateServerShm(remote_trx_shm, local_trx_shm); + if (ret != 0) { + return ret; + } + // TODO mwj 是否应该在连接之后再进行轮询? + PollerRegisterEvent(CqSidOp::ADD, EPOLLIN); + return ret; +} + +void UBShmEndpoint::DeallocateResources() { + if (!_ub_ring) { + return; + } + PollerRegisterEvent(CqSidOp::REMOVE); + _ub_ring->UbrTrxClose(); + if (INVALID_SOCKET_ID != _cq_sid) { + SocketUniquePtr s; + if (Socket::Address(_cq_sid, &s) == 0) { + s->_user = NULL; + s->_fd = -1; + s->SetFailed(); + } + } +} + +void UBShmEndpoint::PollIn(UBShmEndpoint* ep, uint32_t epEvent) { + SocketUniquePtr s; + if (Socket::Address(ep->_socket->id(), &s) < 0) { + return; + } + auto* ub_transport = static_cast(s->_transport.get()); + CHECK(ep == ub_transport->_ub_ep); + + InputMessageClosure last_msg; + while (true) { + int ret = ep->_ub_ring->IsUbrTrxReadable(epEvent); + if (ret < 0) { + return; + } + + bool read_eof = false; + while (!read_eof) { + const int64_t received_us = butil::cpuwide_time_us(); + const int64_t base_realtime = butil::gettimeofday_us() - received_us; + + size_t once_read = s->_avg_msg_size * 16; + if (once_read < MIN_ONCE_READ) { + once_read = MIN_ONCE_READ; + } else if (once_read > MAX_ONCE_READ) { + once_read = MAX_ONCE_READ; + } + + const ssize_t nr = s->_read_buf.append_from_reader(ep->_ub_ring, once_read); + if (nr <= 0) { + if (0 == nr) { + // Set `read_eof' flag and proceed to feed EOF into `Protocol' + // (implied by m->_read_buf.empty), which may produce a new + // `InputMessageBase' under some protocols such as HTTP + LOG_IF(WARNING, FLAGS_log_connection_close) << *s << " was closed by remote side"; + read_eof = true; + } else if (errno != EAGAIN) { + if (errno == EINTR) { + continue; + } + const int saved_errno = errno; + PLOG(WARNING) << "Fail to read from " << *s; + s->SetFailed(saved_errno, "Fail to read from %s: %s", + s->description().c_str(), berror(saved_errno)); + return; + } else { + return; + } + } + + InputMessenger* messenger = static_cast(s->user()); + if (messenger->ProcessNewMessage(s.get(), nr, read_eof, received_us, + base_realtime, last_msg) < 0) { + return; + } + } + + if (read_eof) { + s->SetEOF(); + } + } +} + +void UBShmEndpoint::PollOut(UBShmEndpoint* ep, uint32_t epEvent) { + SocketUniquePtr s; + if (Socket::Address(ep->_socket->id(), &s) < 0) { + return; + } + auto* ub_transport = static_cast(s->_transport.get()); + CHECK(ep == ub_transport->_ub_ep); + if (ep->IsWritable()) { + ep->_socket->WakeAsEpollOut(); + } + +} + +int UBShmEndpoint::GlobalInitialize() { + g_ubring_resource_mutex = new butil::Mutex; + _poller_groups = std::vector(FLAGS_task_group_ntags); + return 0; +} + +void UBShmEndpoint::GlobalRelease() { + for (int i = 0; i < FLAGS_task_group_ntags; ++i) { + PollingModeRelease(i); + } +} + +std::vector UBShmEndpoint::_poller_groups; + +int UBShmEndpoint::PollingModeInitialize(bthread_tag_t tag, + std::function callback, + std::function init_fn, + std::function release_fn) { + auto& group = _poller_groups[tag]; + auto& pollers = group.pollers; + auto& running = group.running; + bool expected = false; + if (!running.compare_exchange_strong(expected, true)) { + return 0; + } + struct FnArgs { + Poller* poller; + std::atomic* running; + }; + auto fn = [](void* p) -> void* { + std::unique_ptr args(static_cast(p)); + auto poller = args->poller; + auto running = args->running; + std::unordered_set cq_sids; + CqSidOp op; + + if (poller->init_fn) { + poller->init_fn(); + } + while (running->load(std::memory_order_relaxed)) { + while (poller->op_queue.Dequeue(op)) { + if (op.type == CqSidOp::ADD) { + cq_sids.emplace(op); + } else if (op.type == CqSidOp::REMOVE) { + cq_sids.erase(op); + + } else if (op.type == CqSidOp::MOD) { + cq_sids.erase(op); + cq_sids.emplace(op); + } + } + for (auto cq : cq_sids) { + SocketUniquePtr s; + if (Socket::Address(cq.sid, &s) < 0) { + continue; + } + UBShmEndpoint* ep = static_cast(s->user()); + if (!ep) { + continue; + } + + if (cq.event & EPOLLIN) { + PollIn(ep, cq.event); + } + + if (cq.event & EPOLLOUT) { + PollOut(ep, cq.event); + } + } + if (poller->callback) { + poller->callback(); + } + if (FLAGS_ub_poller_yield) { + bthread_yield(); + } + } + + if (poller->release_fn) { + poller->release_fn(); + } + + return nullptr; + }; + for (int i = 0; i < FLAGS_ub_poller_num; ++i) { + auto args = new FnArgs{&pollers[i], &running}; + auto attr = FLAGS_ub_disable_bthread ? BTHREAD_ATTR_PTHREAD + : BTHREAD_ATTR_NORMAL; + attr.tag = tag; + bthread_attr_set_name(&attr, "UBPolling"); + pollers[i].callback = callback; + pollers[i].init_fn = init_fn; + pollers[i].release_fn = release_fn; + auto rc = bthread_start_background(&pollers[i].tid, &attr, fn, args); + if (rc != 0) { + LOG(ERROR) << "Fail to start ubring polling bthread"; + return -1; + } + } + return 0; +} + +void UBShmEndpoint::PollingModeRelease(bthread_tag_t tag) { + auto& group = _poller_groups[tag]; + auto& pollers = group.pollers; + auto& running = group.running; + running.store(false, std::memory_order_relaxed); + for (int i = 0; i < FLAGS_ub_poller_num; ++i) { + bthread_join(pollers[i].tid, NULL); + } +} + +void UBShmEndpoint::PollerRegisterEvent(CqSidOp::OpType op, uint32_t events) { + auto index = butil::fmix32(_cq_sid) % FLAGS_ub_poller_num; + auto& group = _poller_groups[bthread_self_tag()]; + auto& pollers = group.pollers; + auto& poller = pollers[index]; + if (INVALID_SOCKET_ID != _cq_sid) { + poller.op_queue.Enqueue(CqSidOp{_cq_sid, events, op}); + } +} + +} // namespace ubring +} // namespace brpc + +#endif // if BRPC_WITH_UBRING diff --git a/src/brpc/ubshm/ub_endpoint.h b/src/brpc/ubshm/ub_endpoint.h new file mode 100644 index 0000000000..d199f5881a --- /dev/null +++ b/src/brpc/ubshm/ub_endpoint.h @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UB_ENDPOINT_H +#define BRPC_UB_ENDPOINT_H + +#if BRPC_WITH_UBRING + +#include +#include +#include +#include +#include +#include "butil/atomicops.h" +#include "butil/iobuf.h" +#include "butil/macros.h" +#include "butil/containers/mpsc_queue.h" +#include "brpc/socket.h" +#include "brpc/ubshm/ub_helper.h" +#include "brpc/ubshm/ub_ring.h" +#include "brpc/ubshm/shm/shm_def.h" + + +namespace brpc { +class Socket; +namespace ubring { + +DECLARE_int32(ub_poller_num); +DECLARE_bool(ub_edisp_unsched); +DECLARE_bool(ub_disable_bthread); + +class UBConnect : public AppConnect { +public: + void StartConnect(const Socket* socket, + void (*done)(int err, void* data), void* data) override; + void StopConnect(Socket*) override; + struct RunGuard { + RunGuard(UBConnect* rc) { this_rc = rc; } + ~RunGuard() { if (this_rc) this_rc->Run(); } + UBConnect* this_rc; + }; + +private: + void Run(); + void (*_done)(int, void*){NULL}; + void* _data{NULL}; +}; + +class BAIDU_CACHELINE_ALIGNMENT UBShmEndpoint : public SocketUser { +friend class UBConnect; +friend class Socket; +public: + explicit UBShmEndpoint(Socket* s); + ~UBShmEndpoint() override; + + // Global initialization + // Return 0 if success, -1 if failed and errno set + static int GlobalInitialize(); + + static void GlobalRelease(); + + // Reset the endpoint (for next use) + void Reset(); + + // Cut data from the given IOBuf list and use UBRING to send + // Return bytes cut if success, -1 if failed and errno set + ssize_t CutFromIOBufList(butil::IOBuf** data, size_t ndata); + + // Whether the endpoint can send more data + bool IsWritable() const; + + void PollerRegisterEpollOut(bool pollin) { + uint32_t events = EPOLLOUT | EPOLLET; + if (pollin) { + PollerRegisterEvent(CqSidOp::MOD, events | EPOLLIN); + return; + } + PollerRegisterEvent(CqSidOp::ADD, events); + } + + void PollerUnRegisterEpollOut(bool pollin) { + uint32_t events = EPOLLIN | EPOLLET; + if (pollin) { + PollerRegisterEvent(CqSidOp::MOD, events); + return; + } + PollerRegisterEvent(CqSidOp::REMOVE); + } + + // Callback when there is new epollin event on TCP fd + static void OnNewDataFromTcp(Socket* m); + + // Initialize polling mode + static int PollingModeInitialize(bthread_tag_t tag, + std::function callback, + std::function init_fn, + std::function release_fn); + + static void PollingModeRelease(bthread_tag_t tag); + +private: + enum State { + UNINIT = 0x0, + C_ALLOC_SHM = 0x1, + C_HELLO_SEND = 0x2, + C_HELLO_WAIT = 0x3, + C_MAP_REMOTE_SHM = 0x4, + C_ACK_SEND = 0x5, + S_HELLO_WAIT = 0x11, + S_ALLOC_SHM = 0x12, + S_HELLO_SEND = 0x13, + S_ACK_WAIT = 0x14, + ESTABLISHED = 0x100, + FALLBACK_TCP = 0x200, + FAILED = 0x300 + }; + + // Process handshake at the client + static void* ProcessHandshakeAtClient(void* arg); + + // Process handshake at the server + static void* ProcessHandshakeAtServer(void* arg); + + // Allocate resources + // Return 0 if success, -1 if failed and errno set + int AllocateClientResources(SHM* local_trx_shm, const char* shm_name); + + int AllocateServerResources(SHM* remote_trx_shm, SHM* local_trx_shm); + + // Release resources + void DeallocateResources(); + + // Read at most len bytes from fd in _socket to data + // wait for _read_butex if encounter EAGAIN + // return -1 if encounter other errno (including EOF) + int ReadFromFd(void* data, size_t len); + + + // Write at most len bytes from data to fd in _socket + // wait for _epollout_butex if encounter EAGAIN + // return -1 if encounter other errno + int WriteToFd(void* data, size_t len); + + // Poll CQ and get the work completion + static void PollIn(UBShmEndpoint* ep, uint32_t epEvent); + + static void PollOut(UBShmEndpoint* ep, uint32_t epEvent); + + // Try to read data on TCP fd in _socket + inline void TryReadOnTcp(); + + // Not owner + Socket* _socket; + + State _state; + + // ub resource + ubring::UBRing* _ub_ring{nullptr}; + + SocketId _cq_sid; + + // butex for inform read events on TCP fd during handshake + butil::atomic *_read_butex; + + DISALLOW_COPY_AND_ASSIGN(UBShmEndpoint); + + struct CqSidOp { + enum OpType { + ADD, + REMOVE, + MOD + }; + SocketId sid; + uint32_t event; + OpType type; + }; + + struct CqSidOpHash { + std::size_t operator()(const CqSidOp& op) const { + return op.sid; + } + }; + + struct CqSidOpEqual { + bool operator()(const CqSidOp& lhs, const CqSidOp& rhs) const { + return lhs.sid == rhs.sid; + } + }; + + // Poller instance + struct BAIDU_CACHELINE_ALIGNMENT Poller { + bthread_t tid{INVALID_BTHREAD}; + butil::MPSCQueue> op_queue; + // Callback used for io_uring/spdk etc + std::function callback; + // Init and Destroy function + std::function init_fn; + std::function release_fn; + }; + // Poller group + struct BAIDU_CACHELINE_ALIGNMENT PollerGroup { + PollerGroup() : pollers(FLAGS_ub_poller_num), running(false) {} + std::vector pollers; + std::atomic running; + }; + static std::vector _poller_groups; + + void PollerRegisterEvent(CqSidOp::OpType op, uint32_t events = EPOLLET); +}; + +} // namespace ubring +} // namespace brpc + +#else // if BRPC_WITH_UBRING + +class UBShmEndpoint { }; + +#endif + +#endif //BRPC_UB_ENDPOINT_H diff --git a/src/brpc/ubshm/ub_helper.cpp b/src/brpc/ubshm/ub_helper.cpp new file mode 100644 index 0000000000..6c4c7a5fde --- /dev/null +++ b/src/brpc/ubshm/ub_helper.cpp @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_UBRING + +#include // dlopen +#include +#include +#include +#include +#include "butil/logging.h" +#include "brpc/socket.h" +#include "brpc/ubshm/ub_endpoint.h" +#include "brpc/ubshm/ub_helper.h" +#include "brpc/ubshm/ub_ring_manager.h" + +namespace brpc { +namespace ubring { + +void* g_handle_ub = NULL; +bool g_skip_ub_init = false; + +butil::atomic g_ub_available(false); + +void GlobalRelease() { + g_ub_available.store(false, butil::memory_order_release); + UBShmEndpoint::GlobalRelease(); + UBRingManager::UbrMgrFini(); + ShmMgrFini(); +} + +static inline void ExitWithError() { + GlobalRelease(); + exit(1); +} + +static void GlobalUBInitializeOrDieImpl() { + if (BAIDU_UNLIKELY(g_skip_ub_init)) { + // Just for UT + return; + } + + if (UBRingManager::UbrMgrInit()) { + PLOG(ERROR) << "Fail to UbrMgrInit"; + ExitWithError(); + } + + if (TimerInit()) { + PLOG(ERROR) << "Fail to TimerInit"; + ExitWithError(); + } + + if (ShmMgrInit()) { + PLOG(ERROR) << "Fail to ShmMgrInit"; + ExitWithError(); + } + + if (UBShmEndpoint::GlobalInitialize() < 0) { + LOG(ERROR) << "ubring_recv_block_type incorrect " + << "(valid value: default/large/huge)"; + ExitWithError(); + } + + g_ub_available.store(true, butil::memory_order_relaxed); +} + +static pthread_once_t initialize_UB_once = PTHREAD_ONCE_INIT; + +void GlobalUBInitializeOrDie() { + if (pthread_once(&initialize_UB_once, + GlobalUBInitializeOrDieImpl) != 0) { + LOG(FATAL) << "Fail to pthread_once GlobalUBInitializeOrDie"; + exit(1); + } +} + +bool IsUBAvailable() { + return g_ub_available.load(butil::memory_order_acquire); +} + +void GlobalDisableUb() { + if (g_ub_available.exchange(false, butil::memory_order_acquire)) { + LOG(FATAL) << "ub is disabled due to some unrecoverable problem"; + } +} + +bool SupportedByUB(std::string protocol) { + if (protocol.compare("baidu_std") == 0) { + return true; + } + return false; +} + +bool InitPollingModeWithTag(bthread_tag_t tag, + std::function callback, + std::function init_fn, + std::function release_fn) { + if (UBShmEndpoint::PollingModeInitialize(tag, callback, init_fn, + release_fn) == 0) { + return true; + } + return false; +} + +} // namespace ubring +} // namespace brpc + +#else + +#include +#include "butil/logging.h" + +namespace brpc { +namespace ubring { +void GlobalUBInitializeOrDie() { + LOG(ERROR) << "brpc is not compiled with ubring. To enable it, please refer to " + << "https://github.com/apache/brpc/blob/master/docs/en/ubring.md"; + exit(1); +} +} +} + +#endif // if BRPC_WITH_UBRING \ No newline at end of file diff --git a/src/brpc/ubshm/ub_helper.h b/src/brpc/ubshm/ub_helper.h new file mode 100644 index 0000000000..6ad9ebe3eb --- /dev/null +++ b/src/brpc/ubshm/ub_helper.h @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UB_HELPER_H +#define BRPC_UB_HELPER_H + +#if BRPC_WITH_UBRING + +#include +#include +#include "bthread/types.h" + +namespace brpc { +namespace ubring { + +void GlobalRelease(); + +void GlobalUBInitializeOrDie(); + +bool InitPollingModeWithTag(bthread_tag_t tag, + std::function callback = nullptr, + std::function init_fn = nullptr, + std::function release_fn = nullptr); + +bool IsUBAvailable(); + +void GlobalDisableUb(); + +bool SupportedByUB(std::string protocol); + +} // namespace ubring +} // namespace brpc + +#else + +namespace brpc { +namespace ubring { + +void GlobalRelease(); + +void GlobalUBInitializeOrDie(); + +} // namespace ubring +} // namespace brpc + +#endif // if BRPC_WITH_UBRING + +#endif // BRPC_UB_HELPER_H \ No newline at end of file diff --git a/src/brpc/ubshm/ub_ring.cpp b/src/brpc/ubshm/ub_ring.cpp new file mode 100644 index 0000000000..11a5d9b311 --- /dev/null +++ b/src/brpc/ubshm/ub_ring.cpp @@ -0,0 +1,1091 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include "bthread/bthread.h" +#include "butil/logging.h" +#include "brpc/ubshm/ub_ring.h" +#include "brpc/ubshm/ub_ring_manager.h" +#include "brpc/ubshm/shm/shm_ipc.h" + +namespace brpc { +namespace ubring { +uint32_t g_sleepTime[UBR_TASK_STEP_NUM] = {0}; +#define TIME_COVERSION 1000 +DEFINE_int32(ub_disconnect_timeout, 5, "Ubshm disconnection timeout."); +DEFINE_int32(ub_connect_timeout, 1, "Ubshm connection timeout."); +DEFINE_int32(ub_hb_timer_interval, 5, "Heartbeat timer interval."); +DEFINE_int32(ub_hb_retry_cnt, 10, "Heartbeat retry times."); +DEFINE_int32(ub_event_queue_timer_interval, 100, "Interval of the disconnection timer."); + +UBRing::UBRing() +{} +UBRing::~UBRing() +{} + +RETURN_CODE UBRing::UbrTrxMapShm(SHM *localShm, SHM *remoteShm) +{ + RETURN_CODE rc = UbrTrxMapLocalShm(localShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx map local shared memory failed."; + return rc; + } + rc = UbrTrxMapRemoteShm(remoteShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx map remote shared memory failed."; + return rc; + } + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrTrxClose() { + RETURN_CODE closeCheckRc = UbrTrxCloseCheck(_trx); + if (UNLIKELY(closeCheckRc != UBRING_OK)) { + if (closeCheckRc == UBRING_REENTRY) { + LOG(INFO) << "Trx close skipped, already closing, local name=" << _trx->localShm.name; + return UBRING_OK; + } + return UBRING_ERR; + } + if (_trx->ubrRx.remoteTxEventQ.addr != nullptr) { + ((UbrEventQMsg *)_trx->ubrRx.remoteTxEventQ.addr)->flag = UBR_STATE_CLOSING; + } + + uint32_t disconnectTimeout = FLAGS_ub_disconnect_timeout; + uint64_t startTime = GetCurNanoSeconds(); + + if (_trx->ubrTx.localTxEventQ.addr != nullptr && ((UbrEventQMsg *)_trx->ubrTx.localTxEventQ.addr)->flag == UBR_STATE_CONNECTED) { + ((UbrEventQMsg *)_trx->ubrTx.localTxEventQ.addr)->flag = UBR_STATE_CLOSED; + _trx->ubrTx.trxState = UBR_STATE_CLOSED; + } + + if (_trx->ubrTx.remoteRxEventQ.addr != nullptr) { + ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->flag = UBR_STATE_CLOSED; + } + while (_trx->ubrRx.localRxEventQ.addr != nullptr && ((UbrEventQMsg *)_trx->ubrRx.localRxEventQ.addr)->flag != UBR_STATE_CLOSED) { + UbrSetSleepTask(UBR_TASK_CLOSE); + if (HasTimedOut(startTime, disconnectTimeout) != UBRING_OK) { + LOG(WARNING) << "Local shm " << _trx->localShm.name + << " wait for the peer to close timed out, force cleanup."; + _trx->ubrRx.trxState = UBR_STATE_CLOSED; + // Force synchronous cleanup instead of relying on async timer + DeleteTimerSafe((uint32_t)_trx->timerFd); + DeleteTimerSafe((uint32_t)_trx->hbTimerFd); + if (_trx->ubrTx.remoteRxEventQ.addr != nullptr) { + ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->flag = UBR_STATE_CLOSED; + } + if (UNLIKELY(UbrTrxFreeShm(_trx) != UBRING_OK)) { + LOG(WARNING) << "Force close, local shm " << _trx->localShm.name << " free failed."; + } + if (UNLIKELY(UBRingManager::ReleaseUbrTrxFromMgr(_trx) != UBRING_OK)) { + LOG(WARNING) << "Force close, release trx " << _trx->localShm.name << " failed."; + } + return UBRING_ERR_TIMEOUT; + } + bthread_usleep(1000); // 1ms, yield to other bthreads + } + _trx->ubrRx.trxState = UBR_STATE_CLOSED; + RETURN_CODE rc; + if (UNLIKELY((rc = ClearTrxResource(_trx, startTime, UBR_SEND_CLOSE)) != UBRING_OK)) { + if (rc == UBRING_REENTRY) { + LOG(INFO) << "Trx close, peer is closing, trx local name=" << _trx->localShm.name; + return UBRING_OK; + } + LOG(ERROR) << "Trx close, clear trx resource failed, trx local name=" << _trx->localShm.name; + return UBRING_ERR; + } + // Unlink local shm name immediately so process exit does not leave visible leftovers. + RETURN_CODE unlinkRc = ShmFree(&_trx->localShm); + if (unlinkRc != UBRING_OK && unlinkRc != SHM_ERR_NOT_FOUND && unlinkRc != SHM_ERR_RESOURCE_ATTACHED) { + LOG(WARNING) << "Trx close, unlink local shm failed, trx local name=" << _trx->localShm.name + << ", rc=" << unlinkRc; + } + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrAddCloseTimer() { + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "Trx add close timer failed, trx is null."; + return UBRING_ERR; + } + + uint32_t eventQTimerInterval = FLAGS_ub_event_queue_timer_interval * TIME_COVERSION; + itimerspec timeSpec = { + .it_interval = {.tv_sec = 0, .tv_nsec = eventQTimerInterval}, + .it_value = {.tv_sec = 0, .tv_nsec = 1} + }; + int timerFd = TimerStart(&timeSpec, UbrTrxCloseCallback, (void*)_trx); + if (UNLIKELY(timerFd == -1)) { + LOG(ERROR) << "Start ubr close timer failed, trx local name=" << _trx->localShm.name; + return UBRING_ERR; + } + _trx->timerFd = timerFd; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrAddTimer() { + if (UNLIKELY(UbrAddCloseTimer() != UBRING_OK)) { + LOG(ERROR) << "Ubr " << _trx->localShm.name << " add closed timer failed."; + return UBRING_ERR; + } + + if (UNLIKELY(UbrAddHBTimer() != UBRING_OK)) { + DeleteTimerSafe((uint32_t)_trx->timerFd); + LOG(ERROR) << "Ubr " << _trx->localShm.name << " add heartbeat timer failed."; + return UBRING_ERR; + } + return UBRING_OK; +} + +void* UBRing::UbrTrxCloseCallback(void* args) { + auto* trx = (UbrTrx*) args; + if (UNLIKELY(UBRing::UbrTrxCallbackCheck(trx) != UBRING_OK)) { + return nullptr; + } + + auto* localRxEventQ = (UbrEventQMsg *)trx->ubrRx.localRxEventQ.addr; + auto* localTxEventQ = (UbrEventQMsg *)trx->ubrTx.localTxEventQ.addr; + if (localRxEventQ->flag != UBR_STATE_CLOSED || localTxEventQ->flag == UBR_STATE_CLOSED) { + return nullptr; + } + trx->ubrRx.trxState = UBR_STATE_CLOSED; + int fd = (int)trx->localShm.fd; + do { + if (ATOMIC_LOAD(trx->closeCnt) == 0) { + break; + } + ATOMIC_SUB(trx->closeCnt, 1); + + uint64_t startTime = GetCurNanoSeconds(); + + if (localTxEventQ->flag == UBR_STATE_CONNECTED || ATOMIC_LOAD(trx->closeCnt) == 1) { + localTxEventQ->flag = UBR_STATE_CLOSED; + trx->ubrTx.trxState = UBR_STATE_CLOSED; + } + UbrEventQMsg* remoteRxEventQ = (UbrEventQMsg *)trx->ubrTx.remoteRxEventQ.addr; + if (remoteRxEventQ == nullptr) { + LOG(ERROR) << "Trx close callback failed, " << trx->localShm.name << " remoteRxEventQ is NULL."; + break; + } + remoteRxEventQ->flag = UBR_STATE_CLOSED; + RETURN_CODE clearRc = ClearTrxResource(trx, startTime, UBR_CALL_BACK_CLOSE, 1); + if (UNLIKELY(clearRc != UBRING_OK && clearRc != UBRING_REENTRY)) { + LOG(ERROR) << "Trx close callback failed, " << trx->localShm.name << " clear trx resource failed."; + break; + } + } while (0); + return nullptr; +} + +RETURN_CODE UBRing::UbrAddHBTimer() { + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "Trx add heartbeat timer failed, trx is null."; + return UBRING_ERR; + } + + itimerspec timeSpec = { + .it_interval = {.tv_sec = FLAGS_ub_hb_timer_interval, .tv_nsec = 0}, + .it_value = {.tv_sec = 0, .tv_nsec = 1} + }; + int timerFd = TimerStart(&timeSpec, UbrTrxHBCallback, (void*)_trx); + if (UNLIKELY(timerFd == -1)) { + LOG(ERROR) << "Start ubr heartbeat timer failed."; + return UBRING_ERR; + } + _trx->hbTimerFd = timerFd; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrPassiveClearTrx(UbrTrx *trx, int fd, PASSIVE_DISC_TYPE type) { + RETURN_CODE passiveCloseCheckRc = UbrTrxCloseCheck(trx); + if (UNLIKELY(passiveCloseCheckRc != UBRING_OK)) { + if (passiveCloseCheckRc == UBRING_REENTRY) { + LOG(INFO) << "Passive close skipped, active close in progress, name=" << trx->localShm.name; + uint64_t startTime = GetCurNanoSeconds(); + return ClearTrxResource(trx, startTime, UBR_CALL_BACK_CLOSE); + } + return UBRING_ERR; + } + trx->ubrTx.trxState = UBR_STATE_CLOSED; + trx->ubrRx.trxState = UBR_STATE_CLOSED; + DeleteTimerSafe((uint32_t)trx->timerFd); + const char *typeName = NULL; + if (type == UBR_HEARTBEAT) { + DeleteTimer((uint32_t)trx->hbTimerFd); + typeName = "Trx heartbeat"; + } else if (type == UBR_UB_EVENT) { + DeleteTimerSafe((uint32_t)trx->hbTimerFd); + typeName = "Ub event callback"; + } + bthread_usleep(FLAGS_ub_flying_io_timeout * 1000000LL); // yield-friendly sleep + + int rc = ShmLocalFree(&trx->remoteShm); + if (rc != UBRING_OK) { + LOG(ERROR) << typeName << ", delete remote shm failed. ret=" << rc; + } + rc = ShmLocalFree(&trx->localShm); + if (rc != UBRING_OK) { + LOG(ERROR) << typeName << ", delete local shm failed. ret=" << rc; + } + + UBRingManager::ReleaseUbrTrxFromMgr(trx); + return UBRING_OK; +} + +void* UBRing::UbrTrxHBCallback(void* args) { + auto* trx = (UbrTrx*) args; + if (UNLIKELY(UbrTrxCallbackCheck(trx) != UBRING_OK)) { + return NULL; + } + + auto* localDataStatus = (UbrDataStatusQMsg *)trx->ubrTx.localDataStatusQ.addr; + auto* remoteDataStatus = (UbrDataStatusQMsg *)trx->ubrRx.remoteDataStatusQ.addr; + if (UNLIKELY(localDataStatus == NULL || remoteDataStatus == NULL)) { + LOG(ERROR) << "Heartbeat error, datastatus is NULL."; + return NULL; + } + + if (trx->ubrTx.trxState != UBR_STATE_CONNECTED || trx->ubrRx.trxState != UBR_STATE_CONNECTED) { + LOG_EVERY_SECOND(INFO) << "Heartbeat cannot be started, wait connected state."; + return NULL; + } + + remoteDataStatus->heartBeat = 1; + if (localDataStatus->heartBeat == 1) { + localDataStatus->heartBeat = 0; + trx->ubrTx.hbRetryCnt = 0; + return NULL; + } + + ++trx->ubrTx.hbRetryCnt; + if (trx->ubrTx.hbRetryCnt <= FLAGS_ub_hb_retry_cnt) { + return NULL; + } + + int fd = (int)trx->localShm.fd; + LOG(INFO) << "Hlc heartbeat, start to clear trx resource. hbTimerFd=" << fd << ", shmName=" << trx->localShm.name; + UbrPassiveClearTrx(trx, fd, UBR_HEARTBEAT); + LOG(INFO) << "Hlc heartbeat clear trx resource finish."; + return NULL; +} + +RETURN_CODE UBRing::UbrAddAsynClearTimer(UbrTrx *trx) { + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Trx add close timer failed, trx is null."; + return UBRING_ERR; + } + + if (trx->clearTimerFd > 0) { + return UBRING_OK; + } + + itimerspec timeSpec = { + .it_interval = {.tv_sec = 0, .tv_nsec = 0}, + .it_value = {.tv_sec = FLAGS_ub_flying_io_timeout, .tv_nsec = 0} + }; + + int timerFd = TimerStart(&timeSpec, UbrAsynClearCallback, (void*)trx); + if (UNLIKELY(timerFd == -1)) { + LOG(ERROR) << "Start ubr close timer failed, trx name=%s.", trx->localShm.name; + return UBRING_ERR; + } + trx->clearTimerFd = timerFd; + return UBRING_OK; +} + +void *UBRing::UbrAsynClearCallback(void *args) +{ + auto* trx = (UbrTrx*) args; + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Trx close, trx is null."; + return NULL; + } + + if (UNLIKELY(UbrTrxFreeShm(trx) != UBRING_OK)) { + LOG(ERROR) << "Trx close, wait for local shm " << trx->localShm.name << " free fail."; + } + + if (UNLIKELY(UBRingManager::ReleaseUbrTrxFromMgr(trx) != UBRING_OK)) { + LOG(ERROR) << "Trx close, release shm " << trx->localShm.name << " trx failed."; + } + return NULL; +} + +int UBRing::UbrTrxSend(const void *buf, uint32_t bufLen) +{ + if (UNLIKELY(CheckTrxSendPreCheck(_trx) != UBRING_OK)) { + return UBRING_ERR; + } + // 1.2 计算空间 + auto *dataStatusMsg = (UbrDataStatusQMsg *)_trx->ubrTx.localDataStatusQ.addr; + auto *dataMsg = (UbrMsgFormat *)_trx->ubrTx.remoteDataQ.addr; + uint32_t cap = _trx->ubrTx.capacity; + uint32_t tail = dataStatusMsg->tail; + uint32_t remainChunkNum = + (_trx->ubrTx.writePos > tail) ? (tail + cap - _trx->ubrTx.writePos) : (tail - _trx->ubrTx.writePos); + uint32_t needMsgChunkNum = CalcUbrMsgChunkCnt(bufLen); + if (needMsgChunkNum >= cap) { + LOG(ERROR) << "Ubr send failed, payload length=" << bufLen + << " needs " << needMsgChunkNum << " chunks, capacity=" << cap << "."; + errno = EMSGSIZE; + return UBRING_ERR; + } + if (remainChunkNum < needMsgChunkNum) { + return UBRING_RETRY; + } + UbrMsgFormat *msg = &(_trx->ubrTx.localMsgSpace); + uint32_t totalSendLen = 0; + uint32_t remainBufLen = bufLen; + uint8_t isLastPkt = 0; + _trx->ubrTx.outIoId++; + ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->ioId = _trx->ubrTx.outIoId; + while (remainBufLen > 0) { + isLastPkt = (uint8_t)(remainBufLen <= UBR_MSG_PAYLOAD_LEN); + msg->header[UBR_MSG_FLAG_INDEX] = isLastPkt ? UBR_MSG_CHUNK_EOF : UBR_MSG_CHUNK_EXIST; + msg->header[UBR_MSG_LEN_INDEX] = isLastPkt ? (uint8_t)remainBufLen : UBR_MSG_PAYLOAD_LEN; + msg->header[UBR_MSG_CUR_INDEX] = 0; + memcpy(msg->payload.inner, (const uint8_t *)buf + totalSendLen, msg->header[UBR_MSG_LEN_INDEX]); + Copy64Byte((int8_t *)&dataMsg[_trx->ubrTx.writePos], (int8_t *)msg); + _trx->ubrTx.writePos = (_trx->ubrTx.writePos + 1) % cap; + totalSendLen += msg->header[UBR_MSG_LEN_INDEX]; + remainBufLen -= msg->header[UBR_MSG_LEN_INDEX]; + } + return (int)totalSendLen; +} + +int UBRing::UbrTrxRecv(void *buf, uint32_t bufLen) +{ + RETURN_CODE rc = UBRING_OK; + if (UNLIKELY((rc = CheckTrxRecvParam(_trx, buf, bufLen)) != UBRING_OK)) { + return (rc == UBR_NOT_CONNECTED) ? 0 : rc; + } + UbrMsgFormat *dataMsg = (UbrMsgFormat *)_trx->ubrRx.localDataQ.addr; + uint32_t readPosEnd = _trx->ubrRx.readPos; + uint8_t flag = dataMsg[readPosEnd].header[UBR_MSG_FLAG_INDEX]; + if (flag == UBR_MSG_CHUNK_NONE) { + return UBRING_RETRY; + } + return UbrTrxRecvBlockMode(static_cast(buf), bufLen); +} + +int UBRing::UbrTrxRecvBlockMode(uint8_t *dest, uint32_t bufLen) +{ + RETURN_CODE rc = UBRING_OK; + if (UNLIKELY((rc = CheckTrxRecvParam(_trx, dest, bufLen)) != UBRING_OK)) { + return (rc == UBR_NOT_CONNECTED) ? 0 : rc; + } + + int32_t totalCopied = 0; + int32_t remainingLen = (int32_t)bufLen; + bool notEofEncountered = true; + + UbrRx *ubrRx = &_trx->ubrRx; + UbrMsgFormat *dataMsg = (UbrMsgFormat *)ubrRx->localDataQ.addr; + bool needUpdateEpollEofPos = ubrRx->readPos == ubrRx->epEofPos; + + while (notEofEncountered && remainingLen > 0) { + if (UNLIKELY(CheckTrxRecvPreCheck(_trx) != UBRING_OK)) { + return UBRING_ERR; + } + UbrMsgFormat *currentChunk = &dataMsg[ubrRx->readPos]; + uint8_t flag = currentChunk->header[UBR_MSG_FLAG_INDEX]; + if (flag == UBR_MSG_CHUNK_NONE) { + if (totalCopied > 0) { + break; + } + errno = EAGAIN; + return -1; + } + if (flag == UBR_MSG_CHUNK_EOF) { + notEofEncountered = false; + } + uint8_t chunkMsgLen = currentChunk->header[UBR_MSG_LEN_INDEX]; + uint8_t curIndex = currentChunk->header[UBR_MSG_CUR_INDEX]; + uint8_t availableData = chunkMsgLen - curIndex; + + int32_t copyLen = (remainingLen < availableData) ? remainingLen : availableData; + memcpy(dest + totalCopied, dataMsg[ubrRx->readPos].payload.inner + curIndex, (size_t)copyLen); + totalCopied += copyLen; + remainingLen -= copyLen; + currentChunk->header[UBR_MSG_CUR_INDEX] += (uint8_t)copyLen; + if (LIKELY(currentChunk->header[UBR_MSG_CUR_INDEX] == chunkMsgLen)) { + currentChunk->header[UBR_MSG_FLAG_INDEX] = UBR_MSG_CHUNK_NONE; + UpdateDataQTail(_trx); + ubrRx->readPos = (ubrRx->readPos + 1) % ubrRx->capacity; + } + } + if (needUpdateEpollEofPos) { + ubrRx->epEofPos = ubrRx->readPos; + } + return (int)totalCopied; +} + +ssize_t UBRing::UbrTrxWritev(const struct iovec *iov, int iovcnt) +{ + if (UNLIKELY(CheckTrxSendPreCheck(_trx) != UBRING_OK)) { + return UBRING_ERR; + } + + size_t bufLen = 0; + for (int i = 0; i < iovcnt; i++) { + bufLen += iov[i].iov_len; + } + RETURN_CODE rc = WritevHasEnoughSpace(bufLen); + if (rc != UBRING_OK) { + return rc; + } + + UbrMsgFormat *dataMsg = (UbrMsgFormat *)_trx->ubrTx.remoteDataQ.addr; + UbrMsgFormat *msg = &(_trx->ubrTx.localMsgSpace); + int curIov = 0; + size_t curIovPos = 0; + ssize_t totalSendLen = 0; + size_t pktRemainN = 0; + size_t iovRemain = 0; + size_t fulled = 0; + uint8_t isLastPkt = 0; + uint8_t curPktLen = 0; + _trx->ubrTx.outIoId++; + ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->ioId = _trx->ubrTx.outIoId; + while (bufLen > 0) { + isLastPkt = (uint8_t)(bufLen <= UBR_MSG_PAYLOAD_LEN); + curPktLen = isLastPkt ? (uint8_t)bufLen : UBR_MSG_PAYLOAD_LEN; + msg->header[UBR_MSG_FLAG_INDEX] = isLastPkt ? UBR_MSG_CHUNK_EOF : UBR_MSG_CHUNK_EXIST; + msg->header[UBR_MSG_LEN_INDEX] = curPktLen; + msg->header[UBR_MSG_CUR_INDEX] = 0; + pktRemainN = curPktLen; + while (curIov < iovcnt && pktRemainN > 0) { + iovRemain = (iov[curIov].iov_len - curIovPos); + fulled = iovRemain > pktRemainN ? pktRemainN : iovRemain; + memcpy((msg->payload.inner + (curPktLen - (uint8_t)pktRemainN)), + (uint8_t *)(iov[curIov].iov_base) + curIovPos, + fulled); + pktRemainN -= fulled; + curIovPos += fulled; + if (curIovPos == iov[curIov].iov_len) { + curIov++; + curIovPos = 0; + } + } + + Copy64Byte((int8_t *)&dataMsg[_trx->ubrTx.writePos], (int8_t *)msg); + _trx->ubrTx.writePos = (_trx->ubrTx.writePos + 1) % _trx->ubrTx.capacity; + totalSendLen += (ssize_t)curPktLen; + bufLen -= (int)curPktLen; + } + return totalSendLen; +} + +ssize_t UBRing::UbrTrxReadv(const struct iovec *iov, int iovcnt) +{ + RETURN_CODE rc = UBRING_OK; + if (UNLIKELY((rc = CheckTrxRecvParam(_trx, iov, (uint32_t)iovcnt)) != UBRING_OK)) { + return (rc == UBR_NOT_CONNECTED) ? 0 : rc; + } + UbrMsgFormat *dataMsg = (UbrMsgFormat *)_trx->ubrRx.localDataQ.addr; + uint32_t readPosEnd = _trx->ubrRx.readPos; + uint8_t flag = dataMsg[readPosEnd].header[UBR_MSG_FLAG_INDEX]; + if (flag == UBR_MSG_CHUNK_NONE) { + errno = EAGAIN; + return -1; + } + ssize_t nr = UbrTrxReadvBlockMode(iov, iovcnt); + if (UNLIKELY(nr == -1)) { + LOG(ERROR) << "Non-blocking readv msg in failed, connection has been closed."; + errno = EPIPE; + return -1; + } + return nr; +} + +ssize_t UBRing::UbrTrxReadvBlockMode(const struct iovec *iov, int iovcnt) +{ + RETURN_CODE rc = UBRING_OK; + if (UNLIKELY((rc = CheckTrxRecvParam(_trx, iov, (uint32_t)iovcnt)) != UBRING_OK)) { + return (rc == UBR_NOT_CONNECTED) ? 0 : rc; + } + + size_t remainBufLen = 0; + for (int i = 0; i < iovcnt; i++) { + remainBufLen += iov[i].iov_len; + } + + bool needUpdateEpollEofPos = _trx->ubrRx.readPos == _trx->ubrRx.epEofPos; + ssize_t totalRecvLen = StartReadv(_trx, iov, iovcnt, remainBufLen); + + if (needUpdateEpollEofPos) { + _trx->ubrRx.epEofPos = _trx->ubrRx.readPos; + } + return totalRecvLen; +} + +RETURN_CODE UBRing::IsUbrTrxReadable(uint32_t epEvent) +{ + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "The trx to be checked is NULL."; + return UBRING_ERR; + } + if (UNLIKELY(_trx->localShm.addr == NULL)) { + LOG(ERROR) << "The trx localShm to be checked is NULL."; + return UBRING_ERR; + } + if (UNLIKELY(_trx->ubrTx.trxState != UBR_STATE_CONNECTED)) { + // TODO mwj 这几块的日志是否需要删除 + // LOG(ERROR) << "The trx is not connected state."; + return UBRING_ERR; + } + + uint64_t ioId = ((UbrEventQMsg *)_trx->ubrRx.localRxEventQ.addr)->ioId; + if ((epEvent & EPOLLET) && ioId == _trx->ubrRx.inIoId) { + return MPA_MUXER_NOT_READY; + } + + uint32_t readPosEnd = _trx->ubrRx.readPos; + if (epEvent & EPOLLET) { + readPosEnd = _trx->ubrRx.epEofPos; + } + + UbrMsgFormat *dataMsg = (UbrMsgFormat *)_trx->ubrRx.localDataQ.addr; + uint8_t flag = dataMsg[readPosEnd].header[UBR_MSG_FLAG_INDEX]; + if (flag == UBR_MSG_CHUNK_NONE) { + return MPA_MUXER_NOT_READY; + } + if (epEvent & EPOLLET) { + _trx->ubrRx.inIoId = ioId; + } + return UBRING_OK; +} + +RETURN_CODE UBRing::IsUbrTrxWriteable(uint32_t epEvent) +{ + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "The trx to be checked is NULL."; + return UBRING_ERR; + } + if (UNLIKELY(_trx->localShm.addr == NULL)) { + LOG(ERROR) << "The trx localShm to be checked is NULL."; + return UBRING_ERR; + } + if (UNLIKELY((UbrEventQMsg *)_trx->ubrTx.localTxEventQ.addr == NULL)) { + LOG(ERROR) << "The trx localTxEventQ addr is NULL."; + return UBRING_ERR; + } + if (UNLIKELY((UbrEventQMsg *)_trx->ubrTx.localDataStatusQ.addr == NULL)) { + LOG(ERROR) << "The trx localDataStatusQ addr is NULL."; + return UBRING_ERR; + } + + if (UNLIKELY(_trx->ubrTx.trxState != UBR_STATE_CONNECTED)) { + LOG(ERROR) << "The trx is not connected state."; + return UBRING_ERR; + } + + UbrDataStatusQMsg *dataStatusMsg = (UbrDataStatusQMsg *)_trx->ubrTx.localDataStatusQ.addr; + uint32_t cap = _trx->ubrTx.capacity; + uint32_t tail = dataStatusMsg->tail; + uint32_t remainChunkNum = + (_trx->ubrTx.writePos > tail) ? (tail + cap - _trx->ubrTx.writePos) : (tail - _trx->ubrTx.writePos); + if (remainChunkNum == 0) { + _trx->ubrTx.epLastCap = remainChunkNum; + return MPA_MUXER_NOT_READY; + } + + if ((epEvent & EPOLLET) && (_trx->ubrTx.epLastCap >= remainChunkNum)) { + _trx->ubrTx.epLastCap = remainChunkNum; + return MPA_MUXER_NOT_READY; + } + _trx->ubrTx.epLastCap = remainChunkNum; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrSetTimeout(UbrTaskStep taskType, int timeout) +{ + if (taskType >= UBR_TASK_STEP_NUM || timeout < 0) { + LOG(ERROR) << "Set timeout failed, invalid task type."; + return UBRING_ERR; + } + + g_sleepTime[taskType] = (uint32_t)timeout; + LOG(INFO) << "Set timeout success, taskType=" << taskType << ", timeout=" << timeout; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrTrxFreeShm(UbrTrx *trx) +{ + if (trx == NULL) { + LOG(ERROR) << "Trx is NULL."; + return UBRING_ERR; + } + + RETURN_CODE rc = UBRING_OK; + rc = ShmMunmap(&trx->localShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx close, local unmap " << trx->localShm.name << " shm fail."; + return UBRING_ERR; + } + + rc = ShmFree(&trx->localShm); + if (UNLIKELY(rc != UBRING_OK)) { + if (rc != SHM_ERR_RESOURCE_ATTACHED && rc != SHM_ERR_NOT_FOUND) { + LOG(ERROR) << "Wait for " << trx->localShm.name << " local shm free fail."; + return UBRING_ERR; + } + LOG(INFO) << "Local shm " << trx->localShm.name << " already freed, continue to free remote shm."; + } + + RETURN_CODE remoteRc = UBRING_OK; + if (trx->remoteShm.addr != NULL) { + remoteRc = ShmRemoteFree(&trx->remoteShm); + } + if (remoteRc != UBRING_OK) { + LOG(WARNING) << "Free remote shm " << trx->remoteShm.name << " failed, rc=" << remoteRc; + } + + return UBRING_OK; +} + +void UBRing::PreWriteAddr(uint8_t *addr, size_t len) +{ + if (addr == NULL) { + return; + } + + size_t i = 0; + while (i < len) { + if (i + sizeof(uint64_t) <= len) { + *(uint64_t *)(addr + i) = (uint64_t)0; + i += sizeof(uint64_t); + } else if (i + sizeof(uint32_t) < len) { + *(uint32_t *)(addr + i) = (uint32_t)0; + i += sizeof(uint32_t); + } else if (i + sizeof(uint16_t) < len) { + *(uint16_t *)(addr + i) = (uint16_t)0; + i += sizeof(uint16_t); + } else { + *(addr + i) = (uint8_t)0; + i += sizeof(uint8_t); + } + } +} + +void UBRing::PrewriteUbrTx(UbrTx *tx) +{ + if (tx == NULL) { + return; + } + PreWriteAddr(tx->remoteDataQ.addr, tx->capacity * sizeof(UbrMsgFormat)); +} + +void UBRing::PrewriteUbrRx(UbrRx *rx) +{ + if (rx == NULL) { + return; + } + PreWriteAddr(rx->localDataQ.addr, rx->capacity * sizeof(UbrMsgFormat)); +} + +RETURN_CODE UBRing::UbrTrxMapLocalShm(SHM *localShm) +{ + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "Trx map Shared memory failed, trx is null."; + return UBRING_ERR; + } + if (UNLIKELY(localShm == NULL || localShm->addr == NULL)) { + LOG(ERROR) << "Trx map Shared memory failed, localShm is null or addr is NULL."; + return UBRING_ERR; + } + _trx->localShm = *localShm; + _trx->ubrTx.localTxEventQ.addr = localShm->addr + TX_EVENTQ_ADDR_OFFSET; + _trx->ubrTx.localTxEventQ.len = UBR_EVENTQ_LEN; + _trx->ubrRx.localRxEventQ.addr = localShm->addr + RX_EVENTQ_ADDR_OFFSET; + _trx->ubrRx.localRxEventQ.len = UBR_EVENTQ_LEN; + _trx->ubrTx.localDataStatusQ.addr = localShm->addr + DATASTATUSQ_ADDR_OFFSET; + _trx->ubrTx.localDataStatusQ.len = UBR_DATASTATUSQ_LEN; + size_t addrAlignedOffset = Aligned64Offset(localShm->addr + DATAQ_ADDR_OFFSET); + _trx->ubrRx.localDataQ.addr = localShm->addr + DATAQ_ADDR_OFFSET + addrAlignedOffset; + _trx->ubrRx.localDataQ.len = localShm->len - DATAQ_ADDR_OFFSET - addrAlignedOffset; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrTrxMapRemoteShm(SHM *remoteShm) +{ + if (UNLIKELY(_trx == NULL)) { + LOG(ERROR) << "Trx map Shared memory failed, trx is null."; + return UBRING_ERR; + } + if (UNLIKELY(remoteShm == NULL || remoteShm->addr == NULL)) { + LOG(ERROR) << "Trx map Shared memory failed, remoteShm is null or addr is NULL."; + return UBRING_ERR; + } + _trx->remoteShm = *remoteShm; + _trx->ubrRx.remoteTxEventQ.addr = remoteShm->addr + TX_EVENTQ_ADDR_OFFSET; + _trx->ubrRx.remoteTxEventQ.len = UBR_EVENTQ_LEN; + _trx->ubrTx.remoteRxEventQ.addr = remoteShm->addr + RX_EVENTQ_ADDR_OFFSET; + _trx->ubrTx.remoteRxEventQ.len = UBR_EVENTQ_LEN; + _trx->ubrRx.remoteDataStatusQ.addr = remoteShm->addr + DATASTATUSQ_ADDR_OFFSET; + _trx->ubrRx.remoteDataStatusQ.len = UBR_DATASTATUSQ_LEN; + size_t addrAlignedOffset = Aligned64Offset(remoteShm->addr + DATAQ_ADDR_OFFSET); + _trx->ubrTx.remoteDataQ.addr = remoteShm->addr + DATAQ_ADDR_OFFSET + addrAlignedOffset; + _trx->ubrTx.remoteDataQ.len = remoteShm->len - DATAQ_ADDR_OFFSET - addrAlignedOffset; + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrServerTrxInit(SHM *localShm, SHM *remoteShm) +{ + RETURN_CODE rc = UbrTrxMapShm(localShm, remoteShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) <<"Trx map shared memory failed."; + return rc; + } + + uint32_t localDataMsgCap = (uint32_t)(_trx->ubrRx.localDataQ.len / UBR_MSG_LEN); + uint32_t remoteDataMsgCap = (uint32_t)(_trx->ubrTx.remoteDataQ.len / UBR_MSG_LEN); + _trx->ubrRx.capacity = localDataMsgCap; + _trx->ubrTx.capacity = remoteDataMsgCap; + rc = UBRingManager::GetUbrDealMsgMaxCnt(_trx->ubrRx.capacity, &_trx->ubrRx.dealMsgMaxCnt); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Get ubring deal msg max cnt."; + return rc; + } + PrewriteUbrRx(&_trx->ubrRx); + PrewriteUbrTx(&_trx->ubrTx); + + ((UbrDataStatusQMsg *)(_trx->ubrTx.localDataStatusQ.addr))->tail = remoteDataMsgCap - 1; + ((UbrDataStatusQMsg *)(_trx->ubrRx.remoteDataStatusQ.addr))->tail = localDataMsgCap - 1; + + if (UNLIKELY(UbrAddTimer() != UBRING_OK)) { + LOG(ERROR) << "Ubr add timer failed, localName=" << localShm->name; + return UBRING_ERR; + } + + ((UbrDataStatusQMsg *)(_trx->ubrTx.localDataStatusQ.addr))->timeout = FLAGS_ub_connect_timeout; + ((UbrDataStatusQMsg *)(_trx->ubrRx.remoteDataStatusQ.addr))->timeout = FLAGS_ub_connect_timeout; + + ((UbrEventQMsg *)_trx->ubrTx.remoteRxEventQ.addr)->flag = UBR_STATE_CONNECTED; + ((UbrEventQMsg *)_trx->ubrRx.localRxEventQ.addr)->flag = UBR_STATE_CONNECTED; + _trx->ubrTx.trxState = UBR_STATE_CONNECTED; + _trx->ubrRx.trxState = UBR_STATE_CONNECTED; + return UBRING_OK; +} + +int UBRing::UbrAllocateServerShm(SHM* remote_trx_shm, SHM* local_trx_shm) { + UbrSetSleepTask(UBR_TASK_ACCEPT_MAP_FRONT); + if (UNLIKELY((ShmRemoteMalloc(remote_trx_shm)) != UBRING_OK)) { + LOG(ERROR) << "Trx apply remote shared memory failed."; + return -1; + } + + if (UNLIKELY((ShmLocalCalloc(local_trx_shm)) != UBRING_OK)) { + LOG(ERROR) << "Trx apply local shared memory failed."; + ShmRemoteFree(remote_trx_shm); + return -1; + } + + UbrTrx **ubrTrxPtr = &_trx; + if (UNLIKELY((UBRingManager::AcquireUbrTrxFromMgr(ubrTrxPtr)) != UBRING_OK)) { + LOG(ERROR) << "Acquire ubrtrx failed."; + ShmRemoteFree(remote_trx_shm); + ShmLocalFree(local_trx_shm); + return -1; + } + _trx->type = TCP_TRX; + if (UNLIKELY((UbrServerTrxInit(local_trx_shm, remote_trx_shm)) != UBRING_OK)) { + LOG(ERROR) << "Server trx init failed."; + UbrTrxFreeShm(_trx); + UBRingManager::ReleaseUbrTrxFromMgr(_trx); + _trx = nullptr; + return -1; + } + return 0; +} + +int UBRing::UbrAllocateLocalShm(SHM *local_trx_shm, const char *shm_name) +{ + if (UNLIKELY((UBRingManager::AcquireUbrTrxFromMgr(&(_trx))) != UBRING_OK)) { + LOG(ERROR) << "Acquire ubrtrx failed, localName=" << shm_name; + return -1; + } + + _trx->type = TCP_TRX; + if (UNLIKELY((ApplyAndMapLocalShm(local_trx_shm, shm_name)) != UBRING_OK)) { + LOG(ERROR) << "Trx apply or map local shared memory failed, localName=" << shm_name; + _trx = nullptr; + return -1; + } + return 0; +} + +int UBRing::UbrMapRemoteShm(SHM *local_trx_shm, const char *local_name) +{ + RETURN_CODE rc = UbrMapRemoteShmAddTimer(local_trx_shm, local_name); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Connect Trx failed, local shm name=" << local_trx_shm->name; + return -1; + } + PrewriteUbrRx(&_trx->ubrRx); + PrewriteUbrTx(&_trx->ubrTx); + ((UbrEventQMsg *)_trx->ubrRx.remoteTxEventQ.addr)->flag = UBR_STATE_CONNECTED; + ((UbrEventQMsg *)_trx->ubrRx.localRxEventQ.addr)->flag = UBR_STATE_CONNECTED; + _trx->ubrTx.trxState = UBR_STATE_CONNECTED; + _trx->ubrRx.trxState = UBR_STATE_CONNECTED; + return 0; +} + +RETURN_CODE UBRing::UbrMapRemoteShmAddTimer(SHM *localTrxShm, const char *localName) +{ + uint64_t startTime = GetCurNanoSeconds(); + + size_t remoteServerLen = UBR_MSG_LEN * (((UbrDataStatusQMsg *)(_trx->ubrTx.localDataStatusQ.addr))->tail + 1) + + UBR_MSG_LEN * ((DATAQ_ADDR_OFFSET / UBR_MSG_LEN) + 1); + SHM remoteTrxShm = {NULL, remoteServerLen, 0, {0}, localTrxShm->fd}; + int result = snprintf(remoteTrxShm.name, + SHM_MAX_NAME_BUFF_LEN, + "%s_%s_%s", + SHM_NAME_PREFIX, + localName, + SERVER_SHM_NAME_SUFFIX); + if (UNLIKELY(result < 0)) { + LOG(ERROR) << "Copy server shared memory name failed, localName=%s, ret=%d.", localName, result; + return UBRING_ERR; + } + UbrSetSleepTask(UBR_TASK_CONNECT_MAP_FRONT); + RETURN_CODE rc = ApplyAndMapRemoteShm(&remoteTrxShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Connect Trx map shared memory failed, remote shm=" << remoteTrxShm.name; + return rc; + } + + if (UNLIKELY(UbrAddTimer() != UBRING_OK)) { + LOG(ERROR) << "Ubr add timer failed, localName=" << localName; + ShmRemoteFree(&_trx->remoteShm); + return UBRING_ERR; + } + + UbrSetSleepTask(UBR_TASK_CONNECT_MAP_AFTER); + + uint32_t timeout = ((UbrDataStatusQMsg *)(_trx->ubrTx.localDataStatusQ.addr))->timeout; + if (HasTimedOut(startTime, timeout) != UBRING_OK) { + LOG(ERROR) << "Local shm " << localTrxShm->name << " wait for connect remote map timeout."; + DeleteTimerSafe((uint32_t)_trx->hbTimerFd); + DeleteTimerSafe((uint32_t)_trx->timerFd); + ShmRemoteFree(&_trx->remoteShm); + return UBRING_ERR_TIMEOUT; + } + + return UBRING_OK; +} + +RETURN_CODE UBRing::ApplyAndMapLocalShm(SHM *localTrxShm, const char *localName) +{ + if (UNLIKELY(_trx == NULL || localTrxShm == NULL)) { + LOG(ERROR) << "Trx map Shared memory failed, trx is null, localName=" << localName; + return UBRING_ERR; + } + int result = snprintf(localTrxShm->name, + SHM_MAX_NAME_BUFF_LEN, + "%s_%s_%s", + SHM_NAME_PREFIX, + localName, + CLIENT_SHM_NAME_SUFFIX); + if (UNLIKELY(result < 0)) { + LOG(ERROR) << "Copy client localTrx shared memory name failed, localName=" << localName << ", ret=" << result; + return UBRING_ERR; + } + + RETURN_CODE rc = ShmLocalCalloc(localTrxShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx apply local shared memory failed, local shm name=" << localTrxShm->name << ", rc=" << rc; + if (rc == SHM_ERR_EXIST || rc == SHM_ERR_NOT_FOUND) { + rc = UBR_ERR_ADDR_IN_USE; + } + UBRingManager::ReleaseUbrTrxFromMgr(_trx); + return rc; + } + rc = UbrTrxMapLocalShm(localTrxShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx map local shared memory failed, local shm name=" << localTrxShm->name; + ShmLocalFree(localTrxShm); + UBRingManager::ReleaseUbrTrxFromMgr(_trx); + return rc; + } + ((UbrDataStatusQMsg *)_trx->ubrTx.localDataStatusQ.addr)->timeout = FLAGS_ub_connect_timeout; + _trx->ubrRx.capacity = (uint32_t)(_trx->ubrRx.localDataQ.len / UBR_MSG_LEN); + rc = UBRingManager::GetUbrDealMsgMaxCnt(_trx->ubrRx.capacity, &_trx->ubrRx.dealMsgMaxCnt); + if (rc != UBRING_OK) { + LOG(ERROR) << "Get ubring deal msg max cnt, local shm name=" << localTrxShm->name; + ShmLocalFree(localTrxShm); + UBRingManager::ReleaseUbrTrxFromMgr(_trx); + return rc; + } + return UBRING_OK; +} + +RETURN_CODE UBRing::ApplyAndMapRemoteShm(SHM *remoteTrxShm) +{ + RETURN_CODE rc = ShmRemoteMalloc(remoteTrxShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx apply remote shared memory failed."; + return rc; + } + rc = UbrTrxMapRemoteShm(remoteTrxShm); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Trx map shared memory failed."; + ShmRemoteFree(remoteTrxShm); + return rc; + } + _trx->ubrTx.capacity = (uint32_t)(_trx->ubrTx.remoteDataQ.len / UBR_MSG_LEN); + return UBRING_OK; +} + +RETURN_CODE UBRing::WritevHasEnoughSpace(size_t bufLen) +{ + UbrDataStatusQMsg *dataStatusMsg = (UbrDataStatusQMsg *)_trx->ubrTx.localDataStatusQ.addr; + uint32_t cap = _trx->ubrTx.capacity; + uint32_t tail = dataStatusMsg->tail; + uint32_t remainChunkNum = + (_trx->ubrTx.writePos > tail) ? (tail + cap - _trx->ubrTx.writePos) : (tail - _trx->ubrTx.writePos); + uint32_t needMsgChunkNum = CalcUbrMsgChunkCnt((uint32_t)bufLen); + if (needMsgChunkNum >= cap) { + LOG(ERROR) << "Ubr write failed, payload length=" << bufLen + << " needs " << needMsgChunkNum << " chunks, capacity=" << cap << "."; + errno = EMSGSIZE; + return UBRING_ERR; + } + if (remainChunkNum < needMsgChunkNum) { + return UBRING_RETRY; + } + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrClearResourceCheck(UbrTrx *trx, uint64_t startTime, UbrCloseType closeType) +{ + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Trx close failed, trx is null."; + return UBRING_ERR; + } + + UbrEventQMsg* localTxEventQ = (UbrEventQMsg *)trx->ubrTx.localTxEventQ.addr; + if (localTxEventQ->flag == UBR_STATE_CONNECTED) { + localTxEventQ->flag = UBR_STATE_CLOSING; + } + + if (closeType == UBR_SEND_CLOSE) { + DeleteTimerSafe((uint32_t)trx->timerFd); + } else { + DeleteTimer((uint32_t)trx->timerFd); + } + DeleteTimerSafe((uint32_t)trx->hbTimerFd); + + if (localTxEventQ->flag == UBR_STATE_CLOSING) { + localTxEventQ->flag = UBR_STATE_CLOSED; + trx->ubrTx.trxState = UBR_STATE_CLOSED; + } + + return UBRING_OK; +} + +RETURN_CODE UBRing::ClearTrxResource(UbrTrx *trx, uint64_t startTime, UbrCloseType closeType, int op) +{ + RETURN_CODE rc = UbrClearResourceCheck(trx, startTime, closeType); + if (rc != UBRING_OK) { + return rc; + } + + rc = UbrAddAsynClearTimer(trx); + if (rc != UBRING_OK) { + LOG(ERROR) << "Trx close, add " << trx->localShm.name << " close clear timer failed."; + return UBRING_ERR; + } + + return UBRING_OK; +} + +RETURN_CODE UBRing::UbrTrxCloseCheck(UbrTrx *trx) +{ + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Trx close failed, client trx is null."; + return UBRING_ERR; + } + int expected = MAX_CLOSE_COUNT; + if (!ATOMIC_COMPARE_EXCHANGE_STRONG(trx->closeCnt, expected, MAX_CLOSE_COUNT - 1)) { + LOG(INFO) << "Trx close skipped, already closing, trx local name=" << trx->localShm.name; + return UBRING_REENTRY; + } + + if (UNLIKELY(trx->ubrTx.localTxEventQ.addr == nullptr)) { + LOG(ERROR) << "Trx close failed, localTxEventQ addr is NULL, trx local name=" << trx->localShm.name; + return UBRING_ERR; + } + return UBRING_OK; +} + +ssize_t UBRing::StartReadv(UbrTrx *trx, const struct iovec *iov, int iovcnt, size_t remainBufLen) +{ + ssize_t totalRecvLen = 0; + int iovIndex = 0; + size_t iovPos = 0; + UbrMsgFormat *dataMsg = (UbrMsgFormat *)trx->ubrRx.localDataQ.addr; + bool notEofEncountered = true; + while (notEofEncountered && remainBufLen > 0) { + if (UNLIKELY(CheckTrxRecvPreCheck(trx) != UBRING_OK)) { + return UBRING_ERR; + } + UbrMsgFormat *currentChunk = &dataMsg[trx->ubrRx.readPos]; + uint8_t flag = currentChunk->header[UBR_MSG_FLAG_INDEX]; + if (flag == UBR_MSG_CHUNK_NONE) { + if (totalRecvLen > 0) { + break; + } + errno = EAGAIN; + return -1; + } + if (flag == UBR_MSG_CHUNK_EOF) { + notEofEncountered = false; + } + uint8_t chunkMsgLen = currentChunk->header[UBR_MSG_LEN_INDEX]; + uint8_t curIndex = currentChunk->header[UBR_MSG_CUR_INDEX]; + uint8_t recvLen = + remainBufLen > (size_t)(chunkMsgLen - curIndex) ? (chunkMsgLen - curIndex) : (uint8_t)remainBufLen; + while (iovIndex < iovcnt && recvLen > 0) { + size_t copyLen = + recvLen > (iov[iovIndex].iov_len - iovPos) ? iov[iovIndex].iov_len - iovPos : (size_t)recvLen; + memcpy((uint8_t *)iov[iovIndex].iov_base + iovPos, currentChunk->payload.inner + curIndex, copyLen); + recvLen -= (uint8_t)copyLen; + iovPos += copyLen; + curIndex += (uint8_t)copyLen; + if (iovPos == iov[iovIndex].iov_len) { + iovIndex++; + iovPos = 0; + } + remainBufLen -= copyLen; + totalRecvLen += (ssize_t)copyLen; + } + currentChunk->header[UBR_MSG_CUR_INDEX] = curIndex; + if (currentChunk->header[UBR_MSG_CUR_INDEX] == chunkMsgLen) { + currentChunk->header[UBR_MSG_FLAG_INDEX] = UBR_MSG_CHUNK_NONE; + UpdateDataQTail(trx); + trx->ubrRx.readPos = (trx->ubrRx.readPos + 1) % trx->ubrRx.capacity; + } + } + return totalRecvLen; +} +} // namespace ubring +} // namespace brpc diff --git a/src/brpc/ubshm/ub_ring.h b/src/brpc/ubshm/ub_ring.h new file mode 100644 index 0000000000..09a97d1dcb --- /dev/null +++ b/src/brpc/ubshm/ub_ring.h @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UB_RING_H +#define BRPC_UB_RING_H + +#include +#include +#include "butil/macros.h" +#include "butil/reader_writer.h" +#include "brpc/ubshm/ubr_trx.h" +#include "brpc/ubshm/shm/shm_mgr.h" +#include "brpc/ubshm/timer/timer_mgr.h" + +namespace brpc { +namespace ubring { +DECLARE_int32(ub_flying_io_timeout); +extern uint32_t g_sleepTime[UBR_TASK_STEP_NUM]; + +class UBRing : public butil::IReader { +public: + UBRing(); + ~UBRing(); + DISALLOW_COPY_AND_ASSIGN(UBRing); + + ssize_t ReadV(const iovec* iov, int iovcnt) override { + return UbrTrxReadv(iov, iovcnt); + } + + RETURN_CODE UbrTrxMapShm(SHM *localShm, SHM *remoteShm); + + RETURN_CODE UbrTrxClose(); + + RETURN_CODE UbrAddCloseTimer(); + + RETURN_CODE UbrAddTimer(); + + static void *UbrTrxCloseCallback(void *args); + + RETURN_CODE UbrAddHBTimer(); + + static void *UbrTrxHBCallback(void *args); + + static RETURN_CODE UbrPassiveClearTrx(UbrTrx *trx, int fd, PASSIVE_DISC_TYPE type); + + static RETURN_CODE UbrAddAsynClearTimer(UbrTrx *trx); + + static void *UbrAsynClearCallback(void *args); + + int UbrTrxSend(const void *buf, uint32_t bufLen); + + int UbrTrxRecv(void *buf, uint32_t bufLen); + + int UbrTrxRecvBlockMode(uint8_t *dest, uint32_t bufLen); + + ssize_t UbrTrxWritev(const struct iovec *iov, int iovcnt); + ssize_t UbrTrxReadv(const struct iovec *iov, int iovcnt); + ssize_t UbrTrxReadvBlockMode(const struct iovec *iov, int iovcnt); + + RETURN_CODE IsUbrTrxReadable(uint32_t epEvent); + + RETURN_CODE IsUbrTrxWriteable(uint32_t epEvent); + + RETURN_CODE UbrSetTimeout(UbrTaskStep taskType, int timeout); + + static RETURN_CODE UbrTrxFreeShm(UbrTrx *trx); + + void PrewriteUbrTx(UbrTx *tx); + void PrewriteUbrRx(UbrRx *rx); + + static inline void UbrSetSleepTask(UbrTaskStep taskType) + { + if (taskType >= UBR_TASK_STEP_NUM || taskType < 0) { + return; + } + uint32_t type = (uint32_t)taskType; + sleep(g_sleepTime[type]); + return; + } + + static inline RETURN_CODE CheckTrxConnectParam(const char *listenerName, const char *localName) + { + if (UNLIKELY(listenerName == NULL)) { + LOG(ERROR) << "The request listener name is null."; + return UBRING_ERR; + } + if (UNLIKELY(localName == NULL)) { + LOG(ERROR) << "The request trx shared memory name is null."; + return UBRING_ERR; + } + return UBRING_OK; + } + + int UbrAllocateServerShm(SHM* remote_trx_shm, SHM* local_trx_shm); + + int UbrMapRemoteShm(SHM *local_trx_shm, const char *local_name); + + int UbrAllocateLocalShm(SHM *local_trx_shm, const char *shm_name); + + RETURN_CODE UbrMapRemoteShmAddTimer(SHM *localTrxShm, const char *localName); + + static inline RETURN_CODE CheckTrxSendPreCheck(UbrTrx *trx) + { + if (UNLIKELY(trx->ubrTx.trxState != UBR_STATE_CONNECTED)) { + LOG(ERROR) << "Trx send failed, trx is not connected state."; + return UBRING_ERR; + } + + return UBRING_OK; + } + static RETURN_CODE CheckTrxRecvParam(UbrTrx *trx, const void *buf, uint32_t bufLen) + { + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Trx recv failed, trx is null."; + return UBRING_ERR; + } + + if (UNLIKELY((UbrEventQMsg *)trx->ubrRx.localRxEventQ.addr == NULL)) { + LOG(ERROR) << "Trx send failed, localTxEventQ addr is NULL."; + return UBRING_ERR; + } + + if (UNLIKELY(trx->ubrRx.trxState != UBR_STATE_CONNECTED)) { + LOG(ERROR) << "Trx recv failed, trx is not connected statep=" << trx->ubrRx.trxState; + return UBR_NOT_CONNECTED; + } + if (UNLIKELY(buf == NULL)) { + LOG(ERROR) << "Trx recv failed, buf is null."; + return UBRING_ERR; + } + if (UNLIKELY(bufLen == 0)) { + LOG(ERROR) << "Trx recv failed, bufLen is 0."; + return UBRING_ERR; + } + return UBRING_OK; + } + + static inline RETURN_CODE CheckTrxRecvPreCheck(UbrTrx *trx) + { + if (UNLIKELY(trx->ubrRx.trxState != UBR_STATE_CONNECTED)) { + LOG(ERROR) << "Trx recv failed, trx is not connected state."; + return UBRING_ERR; + } + return UBRING_OK; + } + + static inline void UpdateDataQTail(UbrTrx *trx) + { + ((UbrDataStatusQMsg *)trx->ubrRx.remoteDataStatusQ.addr)->tail = trx->ubrRx.readPos; + } + + static RETURN_CODE UbrTrxCallbackCheck(UbrTrx *trx) + { + if (trx == NULL) { + LOG(ERROR) << "Trx close callback failed, trx is null."; + return UBRING_ERR; + } + if (UNLIKELY(trx->localShm.addr == NULL)) { + LOG(ERROR) << "Trx close failed, localShm addr is NULL."; + return UBRING_ERR; + } + if (UNLIKELY(trx->ubrRx.localRxEventQ.addr == NULL)) { + LOG(ERROR) << "Trx close failed, localRxEventQ addr is NULL."; + return UBRING_ERR; + } + if (UNLIKELY(trx->ubrTx.localTxEventQ.addr == NULL)) { + LOG(ERROR) << "Trx close failed, localTxEventQ addr is NULL."; + return UBRING_ERR; + } + return UBRING_OK; + } + +private: + RETURN_CODE UbrTrxMapLocalShm(SHM *localShm); + RETURN_CODE UbrTrxMapRemoteShm(SHM *remoteShm); + RETURN_CODE ApplyAndMapLocalShm(SHM *localTrxShm, const char *localName); + RETURN_CODE ApplyAndMapRemoteShm(SHM *remoteTrxShm); + static RETURN_CODE UbrTrxCloseCheck(UbrTrx *trx); + void ReleaseFileLock(int lockFd); + ssize_t StartReadv(UbrTrx *trx, const struct iovec *iov, int iovcnt, size_t remainBufLen); + void PreWriteAddr(uint8_t *addr, size_t len); + RETURN_CODE WritevHasEnoughSpace(size_t bufLen); + RETURN_CODE UbrServerTrxInit(SHM *localShm, SHM *remoteShm); + static RETURN_CODE UbrClearResourceCheck(UbrTrx *trx, uint64_t startTime, UbrCloseType closeType); + static RETURN_CODE ClearTrxResource(UbrTrx *trx, uint64_t startTime, UbrCloseType closeType, int op=0); + + UbrTrx* _trx{nullptr}; +}; +} +} + +#endif //BRPC_UB_RING_H \ No newline at end of file diff --git a/src/brpc/ubshm/ub_ring_manager.cpp b/src/brpc/ubshm/ub_ring_manager.cpp new file mode 100644 index 0000000000..13df631f9e --- /dev/null +++ b/src/brpc/ubshm/ub_ring_manager.cpp @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include "brpc/ubshm/ub_ring.h" +#include "brpc/ubshm/ub_ring_manager.h" +#include "butil/logging.h" + +namespace brpc { +namespace ubring { +DEFINE_int32(ubr_max_managed_num, 1024, "maximum number of managed ubring"); +DEFINE_int32(tail_update_after_read, 8, "Position of the tail update after the read"); + +UbrMgr UBRingManager::g_ubrMgr; +UbrLinkInfoMgr UBRingManager::g_linkInfoMgr; +pthread_mutex_t UBRingManager::g_ubrTrxMgrMtx = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t UBRingManager::g_ubrListenerMgrMtx = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t UBRingManager::g_linkInfoMgrMtx = PTHREAD_MUTEX_INITIALIZER; + +uint64_t g_ubrTrxNum = 0; +uint64_t g_ubEventCnt = 0; +uint64_t g_ubrListenerNum = 0; + +RETURN_CODE UBRingManager::GetUbrDealMsgMaxCnt(const uint32_t capacity, uint32_t *dealMsgMaxCnt) { + if (UNLIKELY(dealMsgMaxCnt == NULL)) { + LOG(ERROR) << "Get update factor failed, dealMsgMaxCnt is null."; + return UBRING_ERR; + } + if (UNLIKELY(FLAGS_tail_update_after_read == 0)) { + LOG(ERROR) << "Get update factor failed, factor is 0."; + return UBRING_ERR; + } + *dealMsgMaxCnt = capacity / FLAGS_tail_update_after_read; + return UBRING_OK; +} + +RETURN_CODE UBRingManager::UbrMgrDefault() +{ + g_ubrMgr.trxNum = 0; + g_ubrMgr.trxCap = FLAGS_ubr_max_managed_num; + g_ubrMgr.trxMgrUnitStatus = NULL; + g_ubrMgr.trxMgr = NULL; + return UBRING_OK; +} + +RETURN_CODE UBRingManager::UbrMgrInit() { + RETURN_CODE rc = UbrMgrDefault(); + if (UNLIKELY(rc != UBRING_OK)) { + LOG(ERROR) << "Ubr manager set default values failed."; + return rc; + } + + size_t trxMgrSize = g_ubrMgr.trxCap * sizeof(UbrTrx); + g_ubrMgr.trxMgr = (UbrTrx *)malloc(trxMgrSize); + size_t trxMgrStatusSize = g_ubrMgr.trxCap * sizeof(UbrMgrUnitStatus); + g_ubrMgr.trxMgrUnitStatus = (UbrMgrUnitStatus *)malloc(trxMgrStatusSize); + if (UNLIKELY(g_ubrMgr.trxMgr == NULL || + g_ubrMgr.trxMgrUnitStatus == NULL)) { + LOG(ERROR) << "Ubr manager memory allocation failed."; + UbrMgrFini(); + return UBRING_ERR; + } + + memset(g_ubrMgr.trxMgr, 0, trxMgrSize); + memset(g_ubrMgr.trxMgrUnitStatus, UBR_MGR_UNIT_FREE, trxMgrStatusSize); + LinkInfoInit(); + return UBRING_OK; +} + +void UBRingManager::UbrMgrFini() { + { + LOCK_GUARD(g_ubrTrxMgrMtx); + FREE_PTR(g_ubrMgr.trxMgr); + FREE_PTR(g_ubrMgr.trxMgrUnitStatus); + } + { + LOCK_GUARD(g_ubrListenerMgrMtx); + } + g_ubrMgr.trxNum = 0; + g_ubrMgr.trxCap = 0; + LinkInfoFini(); +} + +RETURN_CODE UBRingManager::AcquireUbrTrxFromMgr(UbrTrx **trx) { + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Acquire trx failed, trx is null."; + return UBRING_ERR; + } + + if (UNLIKELY(g_ubrMgr.trxMgr == NULL)) { + LOG(ERROR) << "Acquire trx failed, trxMgr is null."; + return UBRING_ERR; + } + + LOCK_GUARD(g_ubrTrxMgrMtx); + if (g_ubrMgr.trxNum >= g_ubrMgr.trxCap) { + LOG(ERROR) << "Acquire trx failed, trx number is full."; + return UBRING_ERR; + } + + for (uint32_t i = 0; i < g_ubrMgr.trxCap; ++i) { + if (g_ubrMgr.trxMgrUnitStatus[i] == UBR_MGR_UNIT_FREE) { + memset(&g_ubrMgr.trxMgr[i], 0, sizeof(UbrTrx)); + g_ubrMgr.trxMgrUnitStatus[i] = UBR_MGR_UNIT_USED; + *trx = &g_ubrMgr.trxMgr[i]; + (*trx)->trxMgrIndex = i; + (*trx)->ubrId = g_ubrTrxNum; + (*trx)->closeState = UBR_CLOSE_FIRST; + (*trx)->closeCnt = MAX_CLOSE_COUNT; + ++g_ubrMgr.trxNum; + ++g_ubrTrxNum; + return UBRING_OK; + } + } + LOG(ERROR) << "Acquire trx failed, no available space."; + return UBRING_ERR; +} + +RETURN_CODE UBRingManager::ReleaseUbrTrxFromMgr(UbrTrx *trx) { + if (UNLIKELY(trx == NULL)) { + LOG(ERROR) << "Release trx failed, trx is null."; + return UBRING_ERR; + } + + trx->localShm.addr = NULL; + trx->ubrTx.localTxEventQ.addr = NULL; + trx->ubrTx.localDataStatusQ.addr = NULL; + trx->ubrRx.localRxEventQ.addr = NULL; + trx->ubrRx.remoteDataStatusQ.addr = NULL; + if (UNLIKELY(g_ubrMgr.trxMgr == NULL)) { + LOG(ERROR) << "Release trx failed, trxMgr is null."; + return UBRING_ERR; + } + + LOCK_GUARD(g_ubrTrxMgrMtx); + uint32_t idx = trx->trxMgrIndex; + if (g_ubrMgr.trxMgrUnitStatus[idx] == UBR_MGR_UNIT_FREE) { + LOG(INFO) << "Release trx already freed, name=" << trx->localShm.name; + return UBRING_OK; + } + + if (g_ubrMgr.trxNum == 0) { + LOG(ERROR) << "Release trx failed, trx number is 0."; + return UBRING_ERR; + } + + g_ubrMgr.trxMgrUnitStatus[idx] = UBR_MGR_UNIT_FREE; + --g_ubrMgr.trxNum; + return UBRING_OK; +} + +void UBRingManager::LinkInfoInit(void) { + + size_t linkInfoMgrSize = FLAGS_ubr_max_managed_num * sizeof(UbrLinkInfo); + g_linkInfoMgr.allLinkInfo = (UbrLinkInfo*) malloc(linkInfoMgrSize); + if (g_linkInfoMgr.allLinkInfo == NULL) { + LOG(ERROR) << "allLinkInfo is NULL"; + LinkInfoFini(); + return; + } + + g_linkInfoMgr.linkMgrUnitStatus = (UbrMgrUnitStatus*) malloc(linkInfoMgrSize); + if (g_linkInfoMgr.linkMgrUnitStatus == NULL) { + LinkInfoFini(); + return; + } + + memset(g_linkInfoMgr.allLinkInfo, 0, linkInfoMgrSize); + memset(g_linkInfoMgr.linkMgrUnitStatus, 0, linkInfoMgrSize); +} + +void UBRingManager::LinkInfoFini(void) { + if (g_linkInfoMgr.linkMgrUnitStatus == NULL || g_linkInfoMgr.allLinkInfo == NULL) { + LOG(ERROR) << "LinkInfo is NULL"; + return; + } + { + LOCK_GUARD(g_linkInfoMgrMtx); + FREE_PTR(g_linkInfoMgr.allLinkInfo); + FREE_PTR(g_linkInfoMgr.linkMgrUnitStatus); + } + + g_linkInfoMgr.linkNum = 0; +} + +void UBRingManager::AcquireLinkInfoToMgr(const char *listenerName, UbrTrx *trx) { + if (listenerName == NULL || trx == NULL) { + LOG(ERROR) << "LinkInfo acquire fail."; + return; + } + + if (g_linkInfoMgr.linkMgrUnitStatus == NULL || g_linkInfoMgr.allLinkInfo == NULL) { + LOG(ERROR) << "LinkInfo is NULL."; + return; + } + uint32_t ubrIndex = trx->trxMgrIndex; + char* connectName = trx->localShm.name; + if (g_linkInfoMgr.linkMgrUnitStatus[ubrIndex] == UBR_MGR_UNIT_FREE) { + strncpy(g_linkInfoMgr.allLinkInfo[ubrIndex].connectName, + connectName, SHM_MAX_NAME_BUFF_LEN); + strncpy(g_linkInfoMgr.allLinkInfo[ubrIndex].listenerName, + listenerName, SHM_MAX_NAME_BUFF_LEN); + g_linkInfoMgr.linkMgrUnitStatus[ubrIndex] = UBR_MGR_UNIT_USED; + g_linkInfoMgr.linkNum++; + } +} + +void UBRingManager::ReleaseLinkInfoFromMgr(UbrTrx *trx) { + if (trx == NULL || g_linkInfoMgr.linkMgrUnitStatus == NULL) { + LOG(ERROR) << "LinkInfo release fail."; + return; + } + + if (g_linkInfoMgr.linkMgrUnitStatus[trx->trxMgrIndex] == UBR_MGR_UNIT_FREE) { + LOG(ERROR) << "Release linkInfo failed, trx is not in manager."; + return; + } + g_linkInfoMgr.linkMgrUnitStatus[trx->trxMgrIndex] = UBR_MGR_UNIT_FREE; + g_linkInfoMgr.linkNum--; +} + +int32_t UBRingManager::UbEventCallback(const char *shmName) +{ + if (UNLIKELY(shmName == NULL)) { + LOG(ERROR) << "Ub event callback failed, shm name is null."; + return UBRING_ERR; + } + if (UNLIKELY(g_ubrMgr.trxMgr == NULL)) { + LOG(ERROR) << "Ub event callback failed, trx mgr is null."; + return UBRING_ERR; + } + LOG(INFO) << "Ub event callback is processing. shm_name=" << shmName; + + for (uint32_t i = 0; i < g_ubrMgr.trxCap; ++i) { + if (g_ubrMgr.trxMgrUnitStatus[i] == UBR_MGR_UNIT_FREE) { + continue; + } + + if (strcmp(g_ubrMgr.trxMgr[i].localShm.name, shmName) == 0 || // 故障链路为该trx的本端shm + strcmp(g_ubrMgr.trxMgr[i].remoteShm.name, shmName) == 0) { // 故障链路为该trx的对端shm + ++g_ubEventCnt; + int fd = (int)g_ubrMgr.trxMgr[i].localShm.fd; + LOG(WARNING) << "Ub event callback, the fd of the faulty link is " << fd; + return UBRing::UbrPassiveClearTrx(&g_ubrMgr.trxMgr[i], fd, UBR_UB_EVENT); + } + } + return UBRING_ERR; +} +} +} diff --git a/src/brpc/ubshm/ub_ring_manager.h b/src/brpc/ubshm/ub_ring_manager.h new file mode 100644 index 0000000000..c901791565 --- /dev/null +++ b/src/brpc/ubshm/ub_ring_manager.h @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UB_RING_MANAGER_H +#define BRPC_UB_RING_MANAGER_H + +#include "brpc/ubshm/ubr_trx.h" +#include "brpc/ubshm/shm/shm_def.h" +#include "brpc/ubshm/common/common.h" + +namespace brpc { +namespace ubring { +typedef enum { + UBR_MGR_UNIT_FREE = 0, + UBR_MGR_UNIT_USED = 1 +} UbrMgrUnitStatus; + +typedef struct TagUbrMgr { + uint32_t trxNum; + uint32_t trxCap; + UbrTrx *trxMgr; + UbrMgrUnitStatus *trxMgrUnitStatus; +} UbrMgr; + +typedef struct TagUbrLinkInfo { + char connectName[SHM_MAX_NAME_BUFF_LEN]; + char listenerName[SHM_MAX_NAME_BUFF_LEN]; +} UbrLinkInfo; + +typedef struct TagUbrLinkInfoMgr { + uint32_t linkNum; + UbrLinkInfo* allLinkInfo; + UbrMgrUnitStatus *linkMgrUnitStatus; +} UbrLinkInfoMgr; + +class UBRingManager { +public: + ~UBRingManager(){ + UbrMgrFini(); + } + + static RETURN_CODE GetUbrDealMsgMaxCnt(const uint32_t capacity, uint32_t *dealMsgMaxCnt); + + static RETURN_CODE UbrMgrDefault(); + + static RETURN_CODE UbrMgrInit(); + + static void UbrMgrFini(); + + static RETURN_CODE AcquireUbrTrxFromMgr(UbrTrx **trx); + + static RETURN_CODE ReleaseUbrTrxFromMgr(UbrTrx *trx); + + static void LinkInfoInit(void); + static void LinkInfoFini(void); + static void AcquireLinkInfoToMgr(const char* listenerName, UbrTrx *trx); + static void ReleaseLinkInfoFromMgr(UbrTrx* trx); + static int32_t UbEventCallback(const char *shmName); + +private: + UBRingManager() { + } + + static UbrMgr g_ubrMgr; + static UbrLinkInfoMgr g_linkInfoMgr; + static pthread_mutex_t g_ubrTrxMgrMtx; + static pthread_mutex_t g_ubrListenerMgrMtx; + static pthread_mutex_t g_linkInfoMgrMtx; +}; +} +} + +#endif //BRPC_UB_RING_MANAGER_H \ No newline at end of file diff --git a/src/brpc/ubshm/ubr_msg.h b/src/brpc/ubshm/ubr_msg.h new file mode 100644 index 0000000000..8a19b6f6bc --- /dev/null +++ b/src/brpc/ubshm/ubr_msg.h @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UBR_MSG_H +#define BRPC_UBR_MSG_H +#define UBR_MSG_HEADER_LEN 4 +#define UBR_MSG_PAYLOAD_LEN 60 +#define UBR_MSG_LEN (UBR_MSG_HEADER_LEN + UBR_MSG_PAYLOAD_LEN) + +#define UBR_MSG_FLAG_INDEX 0 +#define UBR_MSG_LEN_INDEX 1 +#define UBR_MSG_CUR_INDEX 2 + +namespace brpc { +namespace ubring { +typedef enum { + UBR_MSG_CHUNK_NONE = 0, + UBR_MSG_CHUNK_EXIST = 1, + UBR_MSG_CHUNK_EOF = 2 +} UbrMsgHdrFlag; + +typedef struct TagUbrMsgPayload { + uint8_t inner[UBR_MSG_PAYLOAD_LEN]; +} UbrMsgPayload; + +typedef struct __attribute__((aligned(64))) TagUbrMsgFormat { + UbrMsgPayload payload; + + uint8_t header[UBR_MSG_HEADER_LEN]; +} UbrMsgFormat; + +static inline uint32_t CalcUbrMsgChunkCnt(uint32_t bufLen) +{ + uint32_t msgChunkNum = (bufLen + UBR_MSG_PAYLOAD_LEN - 1) / UBR_MSG_PAYLOAD_LEN; + return msgChunkNum; +} +} +} +#endif //BRPC_UBR_MSG_H \ No newline at end of file diff --git a/src/brpc/ubshm/ubr_trx.h b/src/brpc/ubshm/ubr_trx.h new file mode 100644 index 0000000000..af9c52ade7 --- /dev/null +++ b/src/brpc/ubshm/ubr_trx.h @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UBR_TRX_H +#define BRPC_UBR_TRX_H +#include +#include +#include +#include "brpc/ubshm/shm/shm_def.h" +#include "brpc/ubshm/common/common.h" +#include "brpc/ubshm/common/thread_lock.h" +#include "brpc/ubshm/ubr_msg.h" + +/* +----------------------------------------------------------------------------+ + │ UbrTrx shm │ + +-------------+-------------+-------------+---------------+------------------+ + │ TxEventQ │ RxEventQ │ DataStatusQ │ zero(44Bytes) | DataQ │ + +-------------+-------------+-------------+---------------+------------------+ */ + +#define UBR_EVENTQ_LEN sizeof(UbrEventQMsg) +#define UBR_DATASTATUSQ_LEN sizeof(UbrDataStatusQMsg) + +#define TX_EVENTQ_ADDR_OFFSET 0 +#define RX_EVENTQ_ADDR_OFFSET UBR_EVENTQ_LEN +#define DATASTATUSQ_ADDR_OFFSET ((UBR_EVENTQ_LEN) << 1) +#define DATAQ_ADDR_OFFSET (DATASTATUSQ_ADDR_OFFSET + UBR_DATASTATUSQ_LEN) +#define MB_TO_BYTE (1024 * 1024) +#define MAX_CLOSE_COUNT 2 + +#define SHM_NAME_PREFIX "UBRING" +#define SERVER_SHM_NAME_SUFFIX "S" +#define CLIENT_SHM_NAME_SUFFIX "C" + +namespace brpc { +namespace ubring { +extern RETURN_CODE(*g_BeforeTcpClose)(int); +extern RETURN_CODE(*g_AfterTcpClose)(int); + +typedef enum { + UBR_STATE_NONE, + UBR_STATE_CONNECTED, + UBR_STATE_CLOSING, + UBR_STATE_CLOSED +} EventQState; + +typedef enum { + UBR_SEND_CLOSE, + UBR_CALL_BACK_CLOSE +} UbrCloseType; + +typedef enum { + UBR_CLOSE_FIRST, + UBR_CLOSE_SECOND, + UBR_CLOSE_END +} UbrCloseCount; + +typedef enum { + UDP_TRX, + TCP_TRX, + UBR_TRX +} UbrTrxType; + +typedef enum { + UBR_TASK_CONNECT_MAP_FRONT, + UBR_TASK_CONNECT_MAP_AFTER, + UBR_TASK_ACCEPT_MAP_FRONT, + UBR_TASK_ACCEPT_MAP_AFTER, + UBR_TASK_CLOSE, + UBR_TASK_STEP_NUM +} UbrTaskStep; + +typedef struct TagUbrDataStatusQMsg { + uint32_t tail; + uint32_t timeout; + uint8_t heartBeat; +} UbrDataStatusQMsg; + +typedef struct TagUbrEventQMsg { + uint64_t ioId; + EventQState flag; +} UbrEventQMsg; + +typedef struct TagUbrAddrInfo { + uint8_t *addr; + size_t len; +} UbrAddrInfo; + +typedef struct TagUbrTx { + UbrAddrInfo remoteDataQ; + UbrAddrInfo remoteRxEventQ; + UbrAddrInfo localDataStatusQ; + UbrAddrInfo localTxEventQ; + uint64_t outIoId; + uint32_t writePos; + uint32_t capacity; + UbrMsgFormat localMsgSpace; + uint32_t hbRetryCnt; + uint32_t epLastCap; + volatile EventQState trxState; +} UbrTx; + +typedef struct TagUbrRx { + UbrAddrInfo localDataQ; + UbrAddrInfo localRxEventQ; + UbrAddrInfo remoteDataStatusQ; + UbrAddrInfo remoteTxEventQ; + uint64_t inIoId; + uint32_t readPos; + uint32_t capacity; + uint32_t dealMsgNum; + uint32_t dealMsgMaxCnt; + uint32_t epEofPos; + volatile EventQState trxState; +} UbrRx; + +typedef struct TagUbrTrx { + UbrTx ubrTx; + UbrRx ubrRx; + uint64_t ubrId; + uint32_t trxMgrIndex; + UbrTrxType type; + SHM localShm; + SHM remoteShm; + int timerFd; + int hbTimerFd; + int clearTimerFd; + AtomicInt closeCnt; + AtomicInt closeState; +} UbrTrx; + +typedef struct TagFileLock { + int lockFd; + char* lockPath; +} FileLock; + +typedef struct TagUbrLinkLock { + int fileLockNum; + FileLock* fileLock; +} UbrLinkLock; + +typedef enum { + UBR_UB_EVENT, + UBR_HEARTBEAT, +}PASSIVE_DISC_TYPE; + +} +} +#endif //BRPC_UBR_TRX_H \ No newline at end of file diff --git a/src/brpc/ubshm/ubs_mem/declare_shm_ubs.h b/src/brpc/ubshm/ubs_mem/declare_shm_ubs.h new file mode 100644 index 0000000000..b09b2bf943 --- /dev/null +++ b/src/brpc/ubshm/ubs_mem/declare_shm_ubs.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef UBRING_MK_UBSM +#error Do not include this file unless you know what you are doing. +#endif + +#ifndef UBRING_MK_UBSM_OPTIONAL +#define UBRING_MK_UBSM_OPTIONAL UBRING_MK_UBSM +#endif + +UBRING_MK_UBSM(int, ubsmem_init_attributes, (ubsmem_options_t *ubsm_shmem_opts)); + +UBRING_MK_UBSM(int, ubsmem_initialize, (const ubsmem_options_t *ubsm_shmem_opts)); + +UBRING_MK_UBSM(int, ubsmem_finalize, (void)); + +UBRING_MK_UBSM(int, ubsmem_set_logger_level, (int level)); + +UBRING_MK_UBSM(int, ubsmem_set_extern_logger, (void (*func)(int level, const char *msg))); + +UBRING_MK_UBSM(int, ubsmem_lookup_regions, (ubsmem_regions_t* regions)); + +UBRING_MK_UBSM(int, ubsmem_create_region, (const char *region_name, size_t size, const ubsmem_region_attributes_t *reg_attr)); + +UBRING_MK_UBSM(int, ubsmem_destroy_region, (const char *region_name)); + +UBRING_MK_UBSM(int, ubsmem_shmem_allocate,(const char *region_name, const char *name, size_t size, mode_t mode, + uint64_t flags)); + +UBRING_MK_UBSM(int, ubsmem_shmem_deallocate, (const char *name)); + +UBRING_MK_UBSM(int, ubsmem_shmem_map, (void *addr, size_t length, int prot, int flags, const char *name, off_t offset, + void **local_ptr)); + +UBRING_MK_UBSM(int, ubsmem_shmem_unmap, (void *local_ptr, size_t length)); + +UBRING_MK_UBSM(int, ubsmem_shmem_faults_register, (shmem_faults_func registerFunc)); + +UBRING_MK_UBSM(int, ubsmem_local_nid_query, (uint32_t *nid)); + +#undef UBRING_MK_UBSM_OPTIONAL +#undef UBRING_MK_UBSM \ No newline at end of file diff --git a/src/brpc/ubshm/ubs_mem/ubs_mem.h b/src/brpc/ubshm/ubs_mem/ubs_mem.h new file mode 100644 index 0000000000..66069c6e9c --- /dev/null +++ b/src/brpc/ubshm/ubs_mem/ubs_mem.h @@ -0,0 +1,210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UBS_MEM_H +#define BRPC_UBS_MEM_H +#include "ubs_mem_def.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Initialize the UBSMSHMEM attributes + * + * @param ubsm_shmem_opts - [out] shmem attributes + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_init_attributes(ubsmem_options_t *ubsm_shmem_opts); + +/** + * Initialize the UBSMSHMEM library. + * Required to be the first called when a process uses the UBSMSHMEM library. + * @param ubsm_shmem_opts - options structure containing initialization choices + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_initialize(const ubsmem_options_t *ubsm_shmem_opts); + +/** + * Finalize the UBSMSHMEM library. + * Once finalized, the process can continue work,but it is disconnected from the UBSMSHMEM library functions. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_finalize(void); + +/** + * @brief Set log level + * @return - 0 on success and other on failure + * @param level - level to be set, debug(0), info(1), warning(2), error(3), closed(4) + */ +SHMEM_API int ubsmem_set_logger_level(int level); + +/** + * @brief Set external log function, user can set customized logger function, + * in the customized logger function, user can use unified logger utility, + * then the log message can be written into the same log file as caller's, + * if it is not set, log message will be printed to stdout. + * @param func - [in] external logger function + * @return 0 on success and other on failure + */ +SHMEM_API int ubsmem_set_extern_logger(void (*func)(int level, const char *msg)); + +/** + * Look up regions in UBSMSHMEM associated with the local node. + * @param regions - [out] The descriptor to the regions. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_lookup_regions(ubsmem_regions_t* regions); + +/** + * Create a large region of UBSMSHMEM. + * Regions are primarily used as large containers within which additional memory may be allocated and managed by + * the program. + * @param region_name - name of the region + * @param size - size (in bytes) requested for the region, 930 no use, default 0. + * Note that implementations may round up the size to implementation-dependent sizes, + * and may impose system-wide (or user-dependent) limits on individual and total size allocated to a given user. + * @param reg_attr - details of UBSMSHMEM region attributes + * @param region_desc - [out] Region_Descriptor for the created region + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_create_region(const char *region_name, size_t size, const ubsmem_region_attributes_t *reg_attr); + +/** + * Look up a region in UBSMSHMEM by name in the name service. + * @param region_name - name of the region. + * @param region_desc - [out] The descriptor to the region. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_lookup_region(const char *region_name, ubsmem_region_desc_t *region_desc); + +/** + * Destroy a region, and all contents within the region. Note that this + * method call will trigger a delayed free operation to permit other + * instances currently using the region to finish. + * @param region_name - name of the region. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_destroy_region(const char *region_name); + +/** + * Allocate some named space within a region. Allocates an area of UBSMSHMEM within a region + * @param region_name - name of the region. + * @param name - name of the share memory object + * @param size - size of the space to allocate in bytes. + * @param mode - mode associated with this space. + * @param flags - Special marking for this object, MXMEM_FLAG_WITH_LOCK etc. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_allocate(const char *region_name, const char *name, size_t size, mode_t mode, + uint64_t flags); + +/** + * Deallocate allocated space in memory + * @param name - name of the share memory object + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_deallocate(const char *name); + +/** + * Map item in UBSMSHMEM to the local virtual address space, and return its pointer. + * @param addr - The starting address for the new mapping is specified in addr, If addr is NULL, then + * the kernel chooses the (page-aligned) address at which to create the mapping + * @param length - The length argument specifies the length of the mapping (which must be greater than 0) + * @param prot - same as mmap, describes the desired memory protection of the mapping (and must not conflict with + * the open mode of the file). + * @param flags - same as mmap + * @param name - name of the share memory object which to be mapped, same as mmap's fd + * @param offset - same as mmap, offset must be a multiple of the page size + * @param local_ptr - [out] within the process virtual address space that can be used to directly access the + * data item in UBSMSHMEM + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_map(void *addr, size_t length, int prot, int flags, const char *name, off_t offset, + void **local_ptr); + +/** + * Unmap a data item in UBSMSHMEM from the local virtual address space. + * @param local_ptr - pointer within the process virtual address space to be unmapped + * @param length - the size to be unmapped + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_unmap(void *local_ptr, size_t length); + +/** + * Change permissions associated with a data item descriptor. + * @param name - descriptor associated with some data item + * @param perm - new permissions for the data item + * @return - 0 on success and other on failure,other return described in UBSM_SHMEM_RETURN. + */ +SHMEM_API int ubsmem_shmem_set_ownership(const char *name, void *start, size_t length, int prot); + +/** + * shmem lock - Set the lock, status, and data consistency of the shmem item + * @param name - descriptor associated with share memory object + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_write_lock(const char *name); +SHMEM_API int ubsmem_shmem_read_lock(const char *name); +SHMEM_API int ubsmem_shmem_unlock(const char *name); + +SHMEM_API int ubsmem_shmem_list_lookup(const char *prefix, ubsmem_shmem_desc_t *shm_list, uint32_t *shm_cnt); +SHMEM_API int ubsmem_shmem_lookup(const char *name, ubsmem_shmem_info_t *shm_info); +SHMEM_API int ubsmem_shmem_attach(const char *name); +SHMEM_API int ubsmem_shmem_detach(const char *name); + +/** + * Alloc an area from the resource pool and use it only within the scope of the current process. + * @param region_name - name of the region. + * @param size - size of the space to allocate in bytes. + * Note that implementations may round up the size to implementation-dependent sizes. + * @param mem_distance - Describe the performance distance between memory resources and local nodes. + * Note that described in perf_desc_distance + * @param is_numa - is numa or fd malloc, true: numa, false: fd + * @param local_ptr - [out] pointer within the process virtual address space that can be used to directly access. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_lease_malloc(const char *region_name, size_t size, ubsmem_distance_t mem_distance, bool is_numa, + void **local_ptr); + +/** + * Release the pointer. + * @param local_ptr - The pointer returned by the malloc function. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_lease_free(void *local_ptr); + +SHMEM_API int ubsmem_lookup_cluster_statistic(ubsmem_cluster_info_t *info); + +/** + * Subscribes to shared memory UB Event. + * @param registerFunc - Shared Memory UB Event Response Handling Function. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_shmem_faults_register(shmem_faults_func registerFunc); + +/** + * Query the supernode ID of this node within the supernode domain. + * @param nid - The supernode ID of this node within the supernode domain. + * @return - 0 on success and other on failure + */ +SHMEM_API int ubsmem_local_nid_query(uint32_t *nid); + +#ifdef __cplusplus +} // end of extern "C" +#endif +#endif //BRPC_UBS_MEM_H \ No newline at end of file diff --git a/src/brpc/ubshm/ubs_mem/ubs_mem_def.h b/src/brpc/ubshm/ubs_mem/ubs_mem_def.h new file mode 100644 index 0000000000..29646611f3 --- /dev/null +++ b/src/brpc/ubshm/ubs_mem/ubs_mem_def.h @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UBS_MEM_DEF_H +#define BRPC_UBS_MEM_DEF_H +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef SHMEM_API +#define SHMEM_API __attribute__((visibility("default"))) +#endif + +// 先修改为48,与旧版本对齐 +#define MAX_HOST_NUM 16 +#define MAX_NUMA_NUM 32 +#define MAX_NUMA_RESV_LEN 16 + +#define MAX_HOST_NAME_DESC_LENGTH 64 +#define MAX_SHM_NAME_LENGTH 48 +#define MAX_REGION_NAME_DESC_LENGTH 48 +#define MAX_REGION_NODE_NUM 16 +#define MAX_REGIONS_NUM 6 +#define MAX_OBMM_SHMDEV_PATH_LEN 64 + +#define MAX_MEMID_NUM 2048 +#define MAX_SHM_CNT 300 + +#define UBSM_FLAG_CACHE 0x0UL +#define UBSM_FLAG_WITH_LOCK 0x1UL +#define UBSM_FLAG_NONCACHE 0x2UL // open O_SYNC +#define UBSM_FLAG_WR_DELAY_COMP 0x4UL // obmm import with wr_delay_comp +#define UBSM_FLAG_ONLY_IMPORT_NONCACHE 0x8UL // only import open O_SYNC +#define UBSM_FLAG_MEM_ANONYMOUS 0x10UL // auto cleanup when all references in domain drop to zero + +typedef enum { + UBSM_OK = 0, + // common error + UBSM_ERR_PARAM_INVALID = 6010, + UBSM_ERR_NOPERM = 6011, // no permision + UBSM_ERR_MEMORY = 6012, // memcpy or other mem func failed + UBSM_ERR_UNIMPL = 6013, // not implement + UBSM_CHECK_RESOURCE_ERROR = 6014, // resource check failed. + UBSM_ERR_MEMLIB = 6015, // mem lib failed + UBSM_ERR_NO_NEEDED = 6016, // default region no need to create + + // resource error + UBSM_ERR_NOT_FOUND = 6020, + UBSM_ERR_ALREADY_EXIST = 6021, + UBSM_ERR_MALLOC_FAIL = 6022, + UBSM_ERR_RECORD = 6023, + UBSM_ERR_IN_USING = 6024, // shm is in use (usrNum > 0) + + // net error + UBSM_ERR_NET = 6040, + + // under api + UBSM_ERR_UBSE = 6050, + UBSM_ERR_OBMM = 6051, + + // cc lock error + UBSM_ERR_LOCK_NOT_SUPPORTED = 6060, + UBSM_ERR_LOCK_ALREADY_LOCKED = 6061, + UBSM_ERR_DLOCK = 6062, + + UBSM_ERR_BUFF = 6099, +} ubsmshmem_ret_t; +/** + * Memory distance, describes the physical memory resource distance relative to the current PE. + */ +typedef enum { + /** direct connect node is provided, same as PerfLevel::L0 */ + DISTANCE_DIRECT_NODE = 0, + /** one hop connect node is provided, same as PerfLevel::L1, not support 930 */ + DISTANCE_HOP_NODE = 1, +} ubsmem_distance_t; + +typedef struct { + // todo +} ubsmem_options_t; + +typedef struct { + char host_name[MAX_HOST_NAME_DESC_LENGTH]; // include '\0' + bool affinity; +} ubsmem_region_node_desc_t; + +typedef struct { + int host_num; + ubsmem_region_node_desc_t hosts[MAX_REGION_NODE_NUM]; +} ubsmem_region_attributes_t; + +typedef struct { + int num; + ubsmem_region_attributes_t region[MAX_REGIONS_NUM]; +} ubsmem_regions_t; + +typedef struct { + char region_name[MAX_REGION_NAME_DESC_LENGTH]; + size_t size; + ubsmem_region_attributes_t region_attr; +} ubsmem_region_desc_t; + +typedef struct { + uint32_t slot_id; // 节点唯一标识, 采用slotid, 与lcne保持一致 + uint32_t socket_id; // socket id + uint32_t numa_id; // 节点中的numa id + uint32_t mem_lend_ratio; // 池化内存借出比例上限 + uint64_t mem_total; // 内存总量, 单位字节 + uint64_t mem_free; // 内存空闲量, 单位字节 + uint64_t mem_borrow; // 借用的内存,单位字节 + uint64_t mem_lend; // 借出的内存,单位字节 + uint8_t resv[MAX_NUMA_RESV_LEN]; +} ubsmem_numa_mem_t; + +typedef struct { + char host_name[MAX_HOST_NAME_DESC_LENGTH]; + int numa_num; + ubsmem_numa_mem_t numa[MAX_NUMA_NUM]; +} ubsmem_host_info_t; + +typedef struct { + int host_num; // 集群可用节点数量 + ubsmem_host_info_t host[MAX_HOST_NUM]; +} ubsmem_cluster_info_t; + +typedef struct { + char name[MAX_SHM_NAME_LENGTH + 1]; + size_t size; +} ubsmem_shmem_desc_t; + +typedef struct { + char name[MAX_SHM_NAME_LENGTH + 1]; + size_t size; + uint32_t mem_num; + uint64_t mem_unit_size; + uint64_t mem_id_list[MAX_MEMID_NUM]; +} ubsmem_shmem_info_t; + +typedef int32_t (*shmem_faults_func)(const char *shm_name); + +#ifdef __cplusplus +} +#endif +#endif //BRPC_UBS_MEM_DEF_H \ No newline at end of file diff --git a/src/brpc/ubshm/ubs_mem/ubshmem_stub.cpp b/src/brpc/ubshm/ubs_mem/ubshmem_stub.cpp new file mode 100644 index 0000000000..f0eaf29f8e --- /dev/null +++ b/src/brpc/ubshm/ubs_mem/ubshmem_stub.cpp @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ubs_mem.h" + +int ubsmem_init_attributes(ubsmem_options_t *ubsm_shmem_opts) +{ + return UBSM_OK; +} + +int ubsmem_initialize(const ubsmem_options_t *ubsm_shmem_opts) +{ + return UBSM_OK; +} + +int ubsmem_finalize(void) +{ + return UBSM_OK; +} + +int ubsmem_set_logger_level(int level) +{ + return UBSM_OK; +} + +int ubsmem_set_extern_logger(void (*func)(int level, const char *msg)) +{ + return UBSM_OK; +} + +int ubsmem_lookup_regions(ubsmem_regions_t* regions) +{ + regions->num = 1; + regions->region[0].host_num = 1; + regions->region[0].hosts[0].affinity = true; + regions->region[0].hosts[0].host_name[0] = 'h'; + regions->region[0].hosts[0].host_name[1] = '1'; + regions->region[0].hosts[0].host_name[2] = '\0'; // 2号位置使用\0 + return UBSM_OK; +} + +int ubsmem_create_region(const char *region_name, size_t size, const ubsmem_region_attributes_t *reg_attr) +{ + return UBSM_OK; +} + + +int ubsmem_destroy_region(const char *region_name) +{ + return UBSM_OK; +} + +int ubsmem_shmem_allocate(const char *region_name, const char *name, size_t size, mode_t mode, uint64_t flags) +{ + return UBSM_OK; +} + +int ubsmem_shmem_deallocate(const char *name) +{ + return UBSM_OK; +} + +int ubsmem_shmem_map(void *addr, size_t length, int prot, int flags, const char *name, off_t offset, + void **local_ptr) +{ + return UBSM_OK; +} + +int ubsmem_shmem_unmap(void *local_ptr, size_t length) +{ + return UBSM_OK; +} + +int ubsmem_shmem_faults_register(shmem_faults_func registerFunc) +{ + return UBSM_OK; +} + +int ubsmem_local_nid_query(uint32_t *nid) +{ + *nid = 1; // stub + return UBSM_OK; +} \ No newline at end of file diff --git a/src/brpc/ubshm_transport.cpp b/src/brpc/ubshm_transport.cpp new file mode 100644 index 0000000000..fec1a4b646 --- /dev/null +++ b/src/brpc/ubshm_transport.cpp @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_UBRING + +#include "brpc/ubshm_transport.h" +#include "brpc/tcp_transport.h" +#include "brpc/ubshm/ub_endpoint.h" +#include "brpc/ubshm/ub_helper.h" + +namespace brpc { +DECLARE_bool(usercode_in_coroutine); +DECLARE_bool(usercode_in_pthread); + +extern SocketVarsCollector *g_vars; + +void UBShmTransport::Init(Socket *socket, const SocketOptions &options) { + CHECK(_ub_ep == NULL); + if (options.socket_mode == SOCKET_MODE_UBRING) { + _ub_ep = new(std::nothrow)ubring::UBShmEndpoint(socket); + if (!_ub_ep) { + const int saved_errno = errno; + PLOG(ERROR) << "Fail to create UBShmEndpoint"; + socket->SetFailed( + saved_errno, "Fail to create UBShmEndpoint: %s", berror(saved_errno)); + } + _ub_state = UB_UNKNOWN; + } else { + _ub_state = UB_OFF; + socket->_socket_mode = SOCKET_MODE_TCP; + } + _socket = socket; + _default_connect = options.app_connect; + _on_edge_trigger = options.on_edge_triggered_events; + if (options.need_on_edge_trigger && _on_edge_trigger == NULL) { + _on_edge_trigger = ubring::UBShmEndpoint::OnNewDataFromTcp; + } + _tcp_transport = std::make_shared(); + _tcp_transport->Init(socket, options); +} + +void UBShmTransport::Release() { + if (_ub_ep) { + delete _ub_ep; + _ub_ep = NULL; + _ub_state = UB_UNKNOWN; + } +} + +int UBShmTransport::Reset(int32_t expected_nref) { + if (_ub_ep) { + _ub_ep->Reset(); + _ub_state = UB_UNKNOWN; + } + return 0; +} + +std::shared_ptr UBShmTransport::Connect() { + if (_default_connect == nullptr) { + return std::make_shared(); + } + return _default_connect; +} + +int UBShmTransport::CutFromIOBuf(butil::IOBuf *buf) { + if (_ub_ep && _ub_state != UB_OFF) { + butil::IOBuf *data_arr[1] = {buf}; + return _ub_ep->CutFromIOBufList(data_arr, 1); + } else { + return _tcp_transport->CutFromIOBuf(buf); + } +} + +ssize_t UBShmTransport::CutFromIOBufList(butil::IOBuf **buf, size_t ndata) { + if (_ub_ep && _ub_state != UB_OFF) { + return _ub_ep->CutFromIOBufList(buf, ndata); + } + return _tcp_transport->CutFromIOBufList(buf, ndata); +} + +int UBShmTransport::WaitEpollOut(butil::atomic *_epollout_butex, + bool pollin, const timespec duetime) { + // LOG(INFO) << "mwj pollin4=" << pollin << " duetime=" << butil::timespec_to_microseconds(duetime); + if (_ub_state == UB_ON) { + // LOG(INFO) << "mwj pollin1=" << pollin; + const int expected_val = _epollout_butex->load(butil::memory_order_acquire); + CHECK(_ub_ep != NULL); + if (!_ub_ep->IsWritable()) { + g_vars->nwaitepollout << 1; + _ub_ep->PollerRegisterEpollOut(pollin); + auto mwj_ret = bthread::butex_wait(_epollout_butex, expected_val, &duetime); + // LOG(INFO) << "mwj pollin2=" << pollin << " mwj_ret=" << mwj_ret; + if (mwj_ret < 0) { + if (errno != EAGAIN && errno != ETIMEDOUT) { + const int saved_errno = errno; + PLOG(WARNING) << "Fail to wait ub window of " << _socket; + _socket->SetFailed(saved_errno, + "Fail to wait ub window of %s: %s", + _socket->description().c_str(), + berror(saved_errno)); + } + if (_socket->Failed()) { + // NOTE: + // Different from TCP, we cannot find the UB channel + // failed by writing to it. Thus we must check if it + // is already failed here. + return 1; + } + } + _ub_ep->PollerUnRegisterEpollOut(pollin); + } + } else { + return _tcp_transport->WaitEpollOut(_epollout_butex, pollin, duetime); + } + // LOG(INFO) << "mwj return 0"; + return 0; +} + +void UBShmTransport::ProcessEvent(bthread_attr_t attr) { + bthread_t tid; + if (FLAGS_usercode_in_coroutine) { + OnEdge(_socket); + } else if (ubring::FLAGS_ub_edisp_unsched == false) { + auto rc = bthread_start_background(&tid, &attr, OnEdge, _socket); + if (rc != 0) { + LOG(FATAL) << "Fail to start ProcessEvent"; + OnEdge(_socket); + } + } else if (bthread_start_urgent(&tid, &attr, OnEdge, _socket) != 0) { + LOG(FATAL) << "Fail to start ProcessEvent"; + OnEdge(_socket); + } +} + +void UBShmTransport::QueueMessage(InputMessageClosure& input_msg, + int* num_bthread_created, bool last_msg) { + if (last_msg) { + return; + } + InputMessageBase* to_run_msg = input_msg.release(); + if (!to_run_msg) { + return; + } + + if (ubring::FLAGS_ub_disable_bthread) { + ProcessInputMessage(to_run_msg); + return; + } + // Create bthread for last_msg. The bthread is not scheduled + // until bthread_flush() is called (in the worse case). + + // TODO(gejun): Join threads. + bthread_t th; + bthread_attr_t tmp = (FLAGS_usercode_in_pthread ? + BTHREAD_ATTR_PTHREAD : + BTHREAD_ATTR_NORMAL) | BTHREAD_NOSIGNAL; + tmp.keytable_pool = _socket->keytable_pool(); + tmp.tag = bthread_self_tag(); + bthread_attr_set_name(&tmp, "ProcessInputMessage"); + + if (!FLAGS_usercode_in_coroutine && bthread_start_background( + &th, &tmp, ProcessInputMessage, to_run_msg) == 0) { + ++*num_bthread_created; + } else { + ProcessInputMessage(to_run_msg); + } +} + +void UBShmTransport::Debug(std::ostream &os) {} + +int UBShmTransport::ContextInitOrDie(bool serverOrNot, const void* _options) { + if (serverOrNot) { + if (!OptionsAvailableOverUB(static_cast(_options))) { + return -1; + } + ubring::GlobalUBInitializeOrDie(); + if (!ubring::InitPollingModeWithTag(static_cast(_options)->bthread_tag)) { + return -1; + } + } else { + if (!OptionsAvailableForUB(static_cast(_options))) { + return -1; + } + ubring::GlobalUBInitializeOrDie(); + if (!ubring::InitPollingModeWithTag(bthread_self_tag())) { + return -1; + } + return 0; + } + + return 0; +} + +bool UBShmTransport::OptionsAvailableForUB(const ChannelOptions* opt) { + if (opt->has_ssl_options()) { + LOG(WARNING) << "Cannot use SSL and UB at the same time"; + return false; + } + if (!ubring::SupportedByUB(opt->protocol.name())) { + LOG(WARNING) << "Cannot use " << opt->protocol.name() + << " over UB"; + return false; + } + return true; +} + +bool UBShmTransport::OptionsAvailableOverUB(const ServerOptions* opt) { + if (opt->rtmp_service) { + LOG(WARNING) << "RTMP is not supported by UB"; + return false; + } + if (opt->has_ssl_options()) { + LOG(WARNING) << "SSL is not supported by UB"; + return false; + } + if (opt->nshead_service) { + LOG(WARNING) << "NSHEAD is not supported by UB"; + return false; + } + if (opt->mongo_service_adaptor) { + LOG(WARNING) << "MONGO is not supported by UB"; + return false; + } + return true; +} +} // namespace brpc +#endif \ No newline at end of file diff --git a/src/brpc/ubshm_transport.h b/src/brpc/ubshm_transport.h new file mode 100644 index 0000000000..13943d763e --- /dev/null +++ b/src/brpc/ubshm_transport.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_UB_TRANSPORT_H +#define BRPC_UB_TRANSPORT_H +#if BRPC_WITH_UBRING +#include "brpc/socket.h" +#include "brpc/channel.h" +#include "brpc/transport.h" + +namespace brpc { +class UBShmTransport : public Transport { + friend class TransportFactory; + friend class ubring::UBShmEndpoint; +friend class ubring::UBConnect; +public: + void Init(Socket* socket, const SocketOptions& options) override; + void Release() override; + int Reset(int32_t expected_nref) override; + std::shared_ptr Connect() override; + int CutFromIOBuf(butil::IOBuf* buf) override; + ssize_t CutFromIOBufList(butil::IOBuf** buf, size_t ndata) override; + int WaitEpollOut(butil::atomic* _epollout_butex, bool pollin, const timespec duetime) override; + void ProcessEvent(bthread_attr_t attr) override; + void QueueMessage(InputMessageClosure& inputMsg, int* num_bthread_created, bool last_msg) override; + void Debug(std::ostream &os) override; + ubring::UBShmEndpoint* GetUBShmEp() { + CHECK(_ub_ep != NULL); + return _ub_ep; + } + static int ContextInitOrDie(bool serverOrNot, const void* _options); +private: + static bool OptionsAvailableForUB(const ChannelOptions* opt); + static bool OptionsAvailableOverUB(const ServerOptions* opt); +private: + // The on/off state of UB + enum UBState { + UB_ON, + UB_OFF, + UB_UNKNOWN + }; + // The UBShmEndpoint + ubring::UBShmEndpoint* _ub_ep = NULL; + // Should use UB or not + UBState _ub_state; + std::shared_ptr _tcp_transport; +}; +} // namespace brpc +#endif // BRPC_WITH_UBRING +#endif //BRPC_UB_TRANSPORT_H \ No newline at end of file diff --git a/src/bthread/bthread.cpp b/src/bthread/bthread.cpp index 27ded27acd..9b0f45991d 100644 --- a/src/bthread/bthread.cpp +++ b/src/bthread/bthread.cpp @@ -396,13 +396,6 @@ int bthread_equal(bthread_t t1, bthread_t t2) { return t1 == t2; } -#ifdef BUTIL_USE_ASAN -// Fixme!!! -// The noreturn `bthread_exit' may cause a warning of ASan, but does not abort the program. -// -// ==94463==WARNING: ASan is ignoring requested __asan_handle_no_return: stack type: default top: 0x00016dd7f000; bottom 0x00010b1a4000; size: 0x000062bdb000 (1656598528) -// False positive error reports may follow -#endif // BUTIL_USE_ASAN void bthread_exit(void* retval) { bthread::TaskGroup* g = bthread::tls_task_group; if (g != NULL && !g->is_current_main_task()) { diff --git a/src/bthread/rwlock.cpp b/src/bthread/rwlock.cpp index e635668373..e28f5ccb8b 100644 --- a/src/bthread/rwlock.cpp +++ b/src/bthread/rwlock.cpp @@ -15,350 +15,508 @@ // specific language governing permissions and limitations // under the License. +#include #include "bvar/collector.h" +#include "butil/memory/scope_guard.h" #include "bthread/rwlock.h" +#include "bthread/mutex.h" #include "bthread/butex.h" namespace bthread { -// A `bthread_rwlock_t' is a reader/writer mutual exclusion lock, -// which is a bthread implementation of golang RWMutex. -// The lock can be held by an arbitrary number of readers or a single writer. -// For details, see https://github.com/golang/go/blob/master/src/sync/rwmutex.go - -// Define in bthread/mutex.cpp +// Defined in bthread/mutex.cpp; reused here so that bthread_rwlock_t +// participates in the global ContentionProfiler just like bthread_mutex_t +// and bthread_sem_t. class ContentionProfiler; extern ContentionProfiler* g_cp; extern bvar::CollectorSpeedLimit g_cp_sl; -extern bool is_contention_site_valid(const bthread_contention_site_t& cs); -extern void make_contention_site_invalid(bthread_contention_site_t* cs); extern void submit_contention(const bthread_contention_site_t& csite, int64_t now_ns); -// It is enough for readers. If the reader exceeds this value, -// need to use `int64_t' instead of `int'. -const int RWLockMaxReaders = 1 << 30; - -// For reading. -static int rwlock_rdlock_impl(bthread_rwlock_t* __restrict rwlock, - const struct timespec* __restrict abstime) { - int reader_count = ((butil::atomic*)&rwlock->reader_count) - ->fetch_add(1, butil::memory_order_acquire) + 1; - // Fast path. - if (reader_count >= 0) { - CHECK_LT(reader_count, RWLockMaxReaders); - return 0; - } - - // Slow path. +// Lazily arm sampling on first contention. Caller must declare +// `size_t sampling_range' and `int64_t start_ns' in scope: +// start_ns == 0 -> not yet decided +// start_ns == -1 -> decided NOT to sample (profiler off / not selected) +// start_ns > 0 -> sampling armed; value is the wall-clock start time +#define BTHREAD_RWLOCK_MAYBE_START_SAMPLING \ + do { \ + if (start_ns == 0) { \ + if (BAIDU_UNLIKELY(g_cp != NULL)) { \ + sampling_range = bvar::is_collectable(&g_cp_sl); \ + start_ns = bvar::is_sampling_range_valid(sampling_range) ? \ + butil::cpuwide_time_ns() : -1; \ + } else { \ + start_ns = -1; \ + } \ + } \ + } while (0) - // Don't sample when contention profiler is off. - if (NULL == bthread::g_cp) { - return bthread_sem_timedwait(&rwlock->reader_sema, abstime); - } - // Ask Collector if this (contended) locking should be sampled. - const size_t sampling_range = bvar::is_collectable(&bthread::g_cp_sl); - if (!bvar::is_sampling_range_valid(sampling_range)) { // Don't sample. - return bthread_sem_timedwait(&rwlock->reader_sema, abstime); +// Submit one contention sample if sampling was armed for this attempt. +// `start_ns > 0' is the convention used everywhere in this file to indicate +// that BTHREAD_RWLOCK_MAYBE_START_SAMPLING actually decided to sample. +// No-op otherwise. Force-inlined so the uncontended fast path stays cheap. +static BUTIL_FORCE_INLINE void submit_contention_if_sampled( + int64_t start_ns, size_t sampling_range) { + if (BAIDU_UNLIKELY(start_ns > 0)) { + const int64_t end_ns = butil::cpuwide_time_ns(); + const bthread_contention_site_t csite{end_ns - start_ns, sampling_range}; + submit_contention(csite, end_ns); } - - // Sample. - const int64_t start_ns = butil::cpuwide_time_ns(); - int rc = bthread_sem_timedwait(&rwlock->reader_sema, abstime); - const int64_t end_ns = butil::cpuwide_time_ns(); - const bthread_contention_site_t csite{end_ns - start_ns, sampling_range}; - // Submit `csite' for each reader immediately after - // owning rdlock to avoid the contention of `csite'. - bthread::submit_contention(csite, end_ns); - - return rc; -} - -static inline int rwlock_rdlock(bthread_rwlock_t* rwlock) { - return rwlock_rdlock_impl(rwlock, NULL); } -static inline int rwlock_timedrdlock(bthread_rwlock_t* __restrict rwlock, - const struct timespec* __restrict abstime) { - return rwlock_rdlock_impl(rwlock, abstime); -} +// bthread RWLock +// writer-priority implementation overview +// Three synchronization fields are used: +// +// * `lock_word' (32-bit butex): +// bit 31 : 1 if the write lock is held, 0 otherwise. +// bit 0~30: number of readers currently holding the read lock. +// Mutually exclusive: when bit 31 is set, the lower 31 bits are 0. +// +// * `writer_wait_count' (32-bit butex): +// Number of writers that have entered wrlock() but not yet finished +// (i.e. currently waiting for the mutex / waiting for lock_word==0 / +// holding the write lock). Each writer accounts for itself: it is +// incremented at the very beginning of wrlock() and decremented at +// the very end of unwrlock()/cleanup(). +// Readers consult this field to implement writer-priority: if any +// writer is "in flight", new readers yield by waiting on it. +// +// * `writer_queue_mutex' (bthread_mutex_t): +// Serializes writers so that at most one writer races for `lock_word' +// at any time. Other writers queue up on this mutex. +// +// Wakeup channels: +// * Readers waiting on writers -> wait on writer_wait_count, woken by unwrlock/cleanup +// * Writers waiting on readers -> wait on lock_word, woken by unrdlock +// * Writers waiting on writers -> wait on writer_queue_mutex + +static int rwlock_rdlock(bthread_rwlock_t* rwlock, bool try_lock, + const struct timespec* abstime) { + auto lock_word = (butil::atomic*)rwlock->lock_word; + auto writer_wait_count = (butil::atomic*)rwlock->writer_wait_count; + + // Sampling state for the contention profiler (lazily armed on first + // contention so that the uncontended fast path stays cheap): + // start_ns == 0 -> not yet decided + // start_ns == -1 -> decided NOT to sample + // start_ns > 0 -> sampling armed; submit on exit + // Each reader samples independently and submits once on its own way out; + // we deliberately do NOT use rwlock->writer_csite here because that field + // is exclusively owned by the writer. + size_t sampling_range = bvar::INVALID_SAMPLING_RANGE; + int64_t start_ns = 0; + int rc = 0; -// Returns 0 if the lock was acquired, otherwise errno. -static inline int rwlock_tryrdlock(bthread_rwlock_t* rwlock) { while (true) { - int reader_count = ((butil::atomic*)&rwlock->reader_count) - ->load(butil::memory_order_relaxed); - if (reader_count < 0) { - // Failed to acquire the read lock because there is a writer. - return EBUSY; - } - if (((butil::atomic*)&rwlock->reader_count) - ->compare_exchange_weak(reader_count, reader_count + 1, - butil::memory_order_acquire, - butil::memory_order_relaxed)) { - return 0; + // Writer-priority: if any writer is in flight, yield to it. + // `relaxed' is sufficient here because: + // - There is no data published via writer_wait_count; + // data visibility is established via the acquire-CAS on + // `lock_word' below paired with the release-CAS in unwrlock(). + // - butex_wait() will re-check the expected value before sleeping, + // so we cannot lose a wakeup even if `w' is slightly stale. + unsigned w = writer_wait_count->load(butil::memory_order_relaxed); + if (w > 0) { + if (try_lock) { + // Don't sample tryrdlock failures: they are by design a + // non-blocking probe, not a contention event. + return EBUSY; + } + // We are about to block on writer_wait_count; arm sampling + // before parking so the wait time is included in the report. + BTHREAD_RWLOCK_MAYBE_START_SAMPLING; + if (butex_wait(writer_wait_count, w, abstime) < 0 && + errno != EWOULDBLOCK && errno != EINTR) { + rc = errno; + break; + } + continue; } - } -} -static inline int rwlock_unrdlock(bthread_rwlock_t* rwlock) { - int reader_count = ((butil::atomic*)&rwlock->reader_count) - ->fetch_add(-1, butil::memory_order_relaxed) - 1; - // Fast path. - if (reader_count >= 0) { - return 0; + // No writer in flight: try to add ourselves to the reader count. + // 2^31 - 1 readers should be enough for any realistic workload. + unsigned l = lock_word->load(butil::memory_order_relaxed); + if ((l >> 31) == 0) { + // Refuse to increment when the reader count has saturated + // the low 31 bits. Otherwise `l + 1' would flip bit 31 and + // we would corrupt lock_word into "writer held" state. + // POSIX-style: report EAGAIN ("max read locks exceeded"). + if (BAIDU_UNLIKELY(l == 0x7FFFFFFFu)) { + LOG(ERROR) << "Too many readers on bthread_rwlock_t=" << rwlock; + rc = EAGAIN; + break; + } + // Acquire on success synchronizes-with the release-CAS in + // unwrlock(), so any data written by the previous writer is + // visible to us before we start reading. + if (lock_word->compare_exchange_weak(l, l + 1, + butil::memory_order_acquire, + butil::memory_order_relaxed)) { + rc = 0; + break; + } + // CAS failed (likely another reader bumped r): retry. + } else if (try_lock) { + // Write lock is currently held. + return EBUSY; + } else { + // Write lock currently held but not yet self-accounted as a + // pending writer (very narrow window inside wrlock). Arm + // sampling now so the spin/wait until writer_wait_count >= 1 + // is also accounted for. + BTHREAD_RWLOCK_MAYBE_START_SAMPLING; + } + // Otherwise (write lock held but not try_lock): spin once more. + // The next iteration will observe writer_wait_count >= 1 (writers + // self-account in writer_wait_count for the entire wrlock lifetime), + // and we will block on it instead of busy spinning. } - // Slow path. - if (BAIDU_UNLIKELY(reader_count + 1 == 0 || reader_count + 1 == -RWLockMaxReaders)) { - CHECK(false) << "rwlock_unrdlock of unlocked rwlock"; - return EINVAL; - } + // Submit one contention sample for this reader (success or failure). + submit_contention_if_sampled(start_ns, sampling_range); + return rc; +} - // A writer is pending. - int reader_wait = ((butil::atomic*)&rwlock->reader_wait) - ->fetch_add(-1, butil::memory_order_relaxed) - 1; - if (reader_wait != 0) { +static int rwlock_unrdlock(bthread_rwlock_t* rwlock) { + auto lock_word = (butil::atomic*)rwlock->lock_word; + while (true) { + unsigned l = lock_word->load(butil::memory_order_relaxed); + // Misuse detection: the caller must currently hold a read lock. + // l == 0 -> no lock is held (double unlock?) + // (l >> 31) != 0 -> write lock is held, not read lock + if (l == 0 || (l >> 31) != 0) { + LOG(ERROR) << "Invalid unrdlock on bthread_rwlock_t=" << rwlock + << ", lock_word=" << l; + return EINVAL; + } + // Release on success publishes any reads/writes done while holding + // the read lock to the next acquirer (typically a writer's + // acquire-CAS in wrlock()). + if(!(lock_word->compare_exchange_weak(l, l - 1, + butil::memory_order_release, + butil::memory_order_relaxed))) { + continue; + } + // We were the last reader (lock_word transitioned 1 -> 0). Wake the + // single writer (if any) that may be sleeping on `lock_word' inside + // wrlock(). At most one writer can be there because writers are + // serialized by writer_queue_mutex. + // No-op if nobody is waiting; butex_wake() short-circuits cheaply. + if (l == 1) { + butex_wake(lock_word); + } return 0; } +} - // The last reader unblocks the writer. - - if (NULL == bthread::g_cp) { - bthread_sem_post(&rwlock->writer_sema); - return 0; +// Roll back the side effects of a failed wrlock attempt: +// - Release writer_queue_mutex if we managed to acquire it. +// - Decrement our share of writer_wait_count. +// - If we were the last in-flight writer, wake all readers that have +// been parked by writer-priority (w == 1 means writer_wait_count is now 0). +// Called on EBUSY (try_lock failed), ETIMEDOUT, EINTR-leading-to-fail. +static BUTIL_FORCE_INLINE void rwlock_wrlock_cleanup(bthread_rwlock_t* rwlock, bool write_queue_locked) { + if (write_queue_locked) { + bthread_mutex_unlock(&rwlock->writer_queue_mutex); } - // Ask Collector if this (contended) locking should be sampled. - const size_t sampling_range = bvar::is_collectable(&bthread::g_cp_sl); - if (!sampling_range) { // Don't sample - bthread_sem_post(&rwlock->writer_sema); - return 0; + auto writer_wait_count = (butil::atomic*)rwlock->writer_wait_count; + // Withdraw our writer-priority "vote" so readers can make progress. + auto w = writer_wait_count->fetch_sub(1, butil::memory_order_relaxed); + // w is the value BEFORE the subtraction, so w == 1 means we were the + // last writer in flight; wake every reader parked on writer_wait_count. + if (w == 1) { + butex_wake_all(writer_wait_count); } - - // Sampling. - const int64_t start_ns = butil::cpuwide_time_ns(); - bthread_sem_post(&rwlock->writer_sema); - const int64_t end_ns = butil::cpuwide_time_ns(); - const bthread_contention_site_t csite{end_ns - start_ns, sampling_range}; - // Submit `csite' for each reader immediately after - // releasing rdlock to avoid the contention of `csite'. - bthread::submit_contention(csite, end_ns); - return 0; } -#define DO_CSITE_IF_NEED \ - do { \ - /* Don't sample when contention profiler is off. */ \ - if (NULL != bthread::g_cp) { \ - /* Ask Collector if this (contended) locking should be sampled. */ \ - sampling_range = bvar::is_collectable(&bthread::g_cp_sl); \ - start_ns = bvar::is_sampling_range_valid(sampling_range) ? \ - butil::cpuwide_time_ns() : -1; \ - } else { \ - start_ns = -1; \ - } \ - } while (0) - -#define SUBMIT_CSITE_IF_NEED \ - do { \ - if (ETIMEDOUT == rc && start_ns > 0) { \ - /* Failed to lock due to ETIMEDOUT, submit the elapse directly. */ \ - const int64_t end_ns = butil::cpuwide_time_ns(); \ - const bthread_contention_site_t csite{end_ns - start_ns, sampling_range}; \ - bthread::submit_contention(csite, end_ns); \ - } \ - } while (0) - -// For writing. -static inline int rwlock_wrlock_impl(bthread_rwlock_t* __restrict rwlock, - const struct timespec* __restrict abstime) { - // First, resolve competition with other writers. - int rc = bthread_mutex_trylock(&rwlock->write_queue_mutex); +static int rwlock_wrlock(bthread_rwlock_t* rwlock, bool try_lock, + const struct timespec* abstime) { + auto writer_wait_count = (butil::atomic*)rwlock->writer_wait_count; + // Step 1: announce ourselves before doing anything else, so that + // concurrent readers immediately observe writer-priority and back off. + // This MUST happen before we try to acquire writer_queue_mutex, + // otherwise a flood of readers could starve us indefinitely. + // 2^31 in-flight writers should be enough for any realistic workload. + writer_wait_count->fetch_add(1, butil::memory_order_relaxed); + + // Sampling state for the contention profiler. Both wrlock() and + // unwrlock() sample independently: wrlock() submits its own wait time + // on the way out (success or failure); unwrlock() samples its own + // CAS-spin / mutex_unlock / butex_wake_all latency separately. We do + // NOT use rwlock->writer_csite here -- the two operations are not + // forced to share a single sample. size_t sampling_range = bvar::INVALID_SAMPLING_RANGE; - // -1: don't sample. - // 0: default value. - // > 0: Start time of sampling. int64_t start_ns = 0; - if (0 != rc) { - DO_CSITE_IF_NEED; - rc = bthread_mutex_timedlock(&rwlock->write_queue_mutex, abstime); + // Step 2: serialize with other writers. At most one writer holds + // `writer_queue_mutex' at a time and races for `lock_word'. + int rc = bthread_mutex_trylock(&rwlock->writer_queue_mutex); + if (0 != rc) { + if (try_lock) { + // Fail to acquire the wrlock. Don't sample trywrlock failures: + // they are by design a non-blocking probe, not a contention event. + rwlock_wrlock_cleanup(rwlock, false); + return rc; + } + // We are about to block on writer_queue_mutex; arm sampling. + // Note: the inner mutex itself has csite disabled (see init), so + // its blocking time is only counted once -- here, by the rwlock. + BTHREAD_RWLOCK_MAYBE_START_SAMPLING; + rc = bthread_mutex_timedlock(&rwlock->writer_queue_mutex, abstime); if (0 != rc) { - SUBMIT_CSITE_IF_NEED; + // Fail to acquire the wrlock. Submit the elapsed wait time + // directly (no unwrlock() will run for this writer). + submit_contention_if_sampled(start_ns, sampling_range); + rwlock_wrlock_cleanup(rwlock, false); return rc; } } - // Announce to readers there is a pending writer. - int reader_count = ((butil::atomic*)&rwlock->reader_count) - ->fetch_add(-RWLockMaxReaders, butil::memory_order_release); - // Wait for active readers. - if (reader_count != 0 && - ((butil::atomic*)&rwlock->reader_wait) - ->fetch_add(reader_count) + reader_count != 0) { - rc = bthread_sem_trywait(&rwlock->writer_sema); - if (0 != rc) { - if (0 == start_ns) { - DO_CSITE_IF_NEED; + // Step 3: with `writer_queue_mutex' held, wait for all readers to drain + // and then claim the write bit of `lock_word'. + auto lock_word = (butil::atomic*)rwlock->lock_word; + while (true) { + unsigned l = lock_word->load(butil::memory_order_relaxed); + if (l != 0) { + // Readers still hold the lock. Park on `lock_word' until the last + // reader releases (unrdlock will butex_wake on transition 1->0). + if (try_lock) { + errno = EBUSY; + break; } - - rc = bthread_sem_timedwait(&rwlock->writer_sema, abstime); - if (0 != rc) { - SUBMIT_CSITE_IF_NEED; - bthread_mutex_unlock(&rwlock->write_queue_mutex); - return rc; + // Arm sampling before parking so the wait-for-readers time is + // counted (in case the queue_mutex acquisition above was uncontended). + BTHREAD_RWLOCK_MAYBE_START_SAMPLING; + // Use the freshly read `r' as expected; if lock_word changes + // before we sleep, butex_wait returns EWOULDBLOCK and we retry. + if (butex_wait(lock_word, l, abstime) < 0 && + errno != EWOULDBLOCK && errno != EINTR) { + break; } + continue; } + // Acquire on success synchronizes-with release-CAS in + // unrdlock()/unwrlock(): we will see all data published by the + // previous reader/writer before we start writing. + if (lock_word->compare_exchange_weak(l, (unsigned)(1 << 31), + butil::memory_order_acquire, + butil::memory_order_relaxed)) { + // Submit the writer's wait sample immediately on success. + // unwrlock() will sample its own latency separately. + submit_contention_if_sampled(start_ns, sampling_range); + return 0; + } + // CAS may spuriously fail (weak); retry without sleeping. } - if (start_ns > 0) { - rwlock->writer_csite.duration_ns = butil::cpuwide_time_ns() - start_ns; - rwlock->writer_csite.sampling_range = sampling_range; - } - rwlock->wlock_flag = true; - return 0; -} -#undef DO_CSITE_IF_NEED -#undef SUBMIT_CSITE_IF_NEED -static inline int rwlock_wrlock(bthread_rwlock_t* rwlock) { - return rwlock_wrlock_impl(rwlock, NULL); + // Failure path: snapshot errno before cleanup, because + // bthread_mutex_unlock / butex_wake_all inside cleanup may invoke + // syscalls or yield and clobber errno on this thread. + int saved_errno = errno; + // Submit the elapsed wait directly; we never reached unwrlock(). + submit_contention_if_sampled(start_ns, sampling_range); + rwlock_wrlock_cleanup(rwlock, true); + return saved_errno; } -static inline int rwlock_timedwrlock(bthread_rwlock_t* __restrict rwlock, - const struct timespec* __restrict abstime) { - return rwlock_wrlock_impl(rwlock, abstime); -} - -static inline int rwlock_trywrlock(bthread_rwlock_t* rwlock) { - int rc = bthread_mutex_trylock(&rwlock->write_queue_mutex); - if (0 != rc) { - return rc; - } - - int expected = 0; - if (!((butil::atomic*)&rwlock->reader_count) - ->compare_exchange_strong(expected, -RWLockMaxReaders, - butil::memory_order_acquire, - butil::memory_order_relaxed)) { - // Failed to acquire the write lock because there are active readers. - bthread_mutex_unlock(&rwlock->write_queue_mutex); - return EBUSY; - } - rwlock->wlock_flag = true; - - return 0; -} +static int rwlock_unwrlock(bthread_rwlock_t* rwlock) { + auto lock_word = (butil::atomic*)rwlock->lock_word; + auto writer_wait_count = (butil::atomic*)rwlock->writer_wait_count; -static inline void rwlock_unwrlock_slow(bthread_rwlock_t* rwlock, int reader_count) { - bthread_sem_post_n(&rwlock->reader_sema, reader_count); - // Allow other writers to proceed. - bthread_mutex_unlock(&rwlock->write_queue_mutex); -} - -static inline int rwlock_unwrlock(bthread_rwlock_t* rwlock) { - rwlock->wlock_flag = false; + // Sampling state for the contention profiler. unwrlock() samples + // independently of wrlock(): although the release-CAS itself cannot + // fail due to writer-writer contention (writers are serialized by + // writer_queue_mutex), the body still does mutex_unlock(), + // butex_wake_all() and may spuriously spin on the weak CAS, all of + // which contribute to the critical-section tail latency. + size_t sampling_range = bvar::INVALID_SAMPLING_RANGE; + int64_t start_ns = 0; + BTHREAD_RWLOCK_MAYBE_START_SAMPLING; - // Announce to readers there is no active writer. - int reader_count = ((butil::atomic*)&rwlock->reader_count)->fetch_add( - RWLockMaxReaders, butil::memory_order_release) + RWLockMaxReaders; - if (BAIDU_UNLIKELY(reader_count >= RWLockMaxReaders)) { - CHECK(false) << "rwlock_unwlock of unlocked rwlock"; - return EINVAL; - } + while (true) { + unsigned l = lock_word->load(butil::memory_order_relaxed); + // Misuse detection: we must currently hold the write lock. + if (BAIDU_UNLIKELY(l != (unsigned)(1 << 31))) { + LOG(ERROR) << "Invalid unwrlock!"; + return EINVAL; + } + // Release-CAS publishes all writes performed under the write lock + // to the next acquirer (a reader's acquire-CAS or another writer's + // acquire-CAS). The CAS itself cannot fail due to contention since + // writers are serialized by writer_queue_mutex; weak failure here is + // only a spurious CAS failure -- just retry. + if (!lock_word->compare_exchange_weak(l, 0, + butil::memory_order_release, + butil::memory_order_relaxed)) { + continue; + } - bool is_valid = bthread::is_contention_site_valid(rwlock->writer_csite); - if (BAIDU_UNLIKELY(is_valid)) { - bthread_contention_site_t saved_csite = rwlock->writer_csite; - bthread::make_contention_site_invalid(&rwlock->writer_csite); + // ---- Order of the next two operations is INTENTIONAL ---- + // + // We deliberately: + // (1) unlock writer_queue_mutex FIRST, then + // (2) fetch_sub(writer_wait_count) and conditionally wake readers. + // + // Rationale (writer-priority semantics): + // * Any writer queued on writer_queue_mutex has already + // fetch_add'ed its share into writer_wait_count back in wrlock() + // (before it even tried to lock the mutex). So when it wakes + // up here and we later fetch_sub, the counter still reflects + // "there is at least one more writer in flight": w_old >= 2, + // which means w != 1, which means we will NOT wake readers. + // Readers must keep yielding to the next writer -- exactly the + // writer-priority invariant. + // * Only when we are truly the last writer in flight (w_old == 1 + // after our fetch_sub, i.e. writer_wait_count is now 0) do we + // wake_all readers parked on writer_wait_count. + // + // Subtle but harmless effect: + // Between (1) and (2) there is a small window in which our + // own "ghost share" is still counted in writer_wait_count even though + // we have effectively left. New readers entering rdlock() during + // this window will see writer_wait_count >= 1 and park on it; they + // will be woken either by step (2) below (if no successor writer + // appeared) or by the successor writer's eventual unwrlock. + // No wakeup is ever lost: butex_wait re-checks the expected + // value before truly sleeping, and any successor writer will + // itself execute this same wake logic on its way out. + // + // Reversing the order (fetch_sub before unlock mutex) would break + // strict writer-priority because woken readers could grab the + // read lock before a successor writer queued on the mutex even + // gets a chance to CAS lock_word. + bthread_mutex_unlock(&rwlock->writer_queue_mutex); + unsigned w = writer_wait_count->fetch_sub(1, butil::memory_order_relaxed); + if (w == 1) { + butex_wake_all(writer_wait_count); + } - const int64_t unlock_start_ns = butil::cpuwide_time_ns(); - rwlock_unwrlock_slow(rwlock, reader_count); - const int64_t unlock_end_ns = butil::cpuwide_time_ns(); - saved_csite.duration_ns += unlock_end_ns - unlock_start_ns; - bthread::submit_contention(saved_csite, unlock_end_ns); - } else { - rwlock_unwrlock_slow(rwlock, reader_count); + // Submit our own unwrlock-side sample (CAS spin + mutex_unlock + + // butex_wake_all). This is independent of the wrlock-side sample. + submit_contention_if_sampled(start_ns, sampling_range); + return 0; } - - return 0; } -static inline int rwlock_unlock(bthread_rwlock_t* rwlock) { - if (rwlock->wlock_flag) { +// Generic unlock entry that dispatches to unwrlock/unrdlock by inspecting +// `lock_word'. This is safe ONLY because the caller must already hold one of +// the two locks: while holding a read lock the high bit of `lock_word' cannot +// flip on, and while holding the write lock the low bits cannot be set. +// Therefore a relaxed load is sufficient to make the dispatch decision. +static int rwlock_unlock(bthread_rwlock_t* rwlock) { + auto lock_word = (butil::atomic*)rwlock->lock_word; + unsigned r = lock_word->load(butil::memory_order_relaxed); + if ((r >> 31) != 0) { return rwlock_unwrlock(rwlock); } else { return rwlock_unrdlock(rwlock); } } -} // namespace bthread - -__BEGIN_DECLS - -int bthread_rwlock_init(bthread_rwlock_t* __restrict rwlock, - const bthread_rwlockattr_t* __restrict) { - int rc = bthread_sem_init(&rwlock->reader_sema, 0); - if (BAIDU_UNLIKELY(0 != rc)) { - return rc; +// Deleter that turns butex_create_checked()'s raw pointer into something +// std::unique_ptr can clean up automatically. Using RAII here lets the +// init-error paths just `return rc' without manually unwinding partial +// allocations; ownership is `release()'d only on the all-success path. +struct ButexDeleter { + void operator()(void* butex) const { + if (butex != NULL) { + butex_destroy(butex); + } } - bthread_sem_disable_csite(&rwlock->reader_sema); - rc = bthread_sem_init(&rwlock->writer_sema, 0); - if (BAIDU_UNLIKELY(0 != rc)) { - bthread_sem_destroy(&rwlock->reader_sema); - return rc; +}; + +static int rwlock_init(bthread_rwlock_t* rwlock) { + std::unique_ptr writer_wait_count( + butex_create_checked()); + if (writer_wait_count == NULL) { + LOG(ERROR) << "Fail to create writer_wait_count butex: out of memory"; + return ENOMEM; } - bthread_sem_disable_csite(&rwlock->writer_sema); - - rwlock->reader_count = 0; - rwlock->reader_wait = 0; - rwlock->wlock_flag = false; + std::unique_ptr lock_word(butex_create_checked()); + if (lock_word == NULL) { + LOG(ERROR) << "Fail to create lock_word butex: out of memory"; + return ENOMEM; + } + *writer_wait_count = 0; + *lock_word = 0; bthread_mutexattr_t attr; bthread_mutexattr_init(&attr); + BRPC_SCOPE_EXIT { bthread_mutexattr_destroy(&attr); }; + // Disable csite on the inner queue mutex so the writer's wait time is + // accounted exactly once -- by the rwlock layer, not double-counted via + // the inner mutex. bthread_mutexattr_disable_csite(&attr); - rc = bthread_mutex_init(&rwlock->write_queue_mutex, &attr); - if (BAIDU_UNLIKELY(0 != rc)) { - bthread_sem_destroy(&rwlock->reader_sema); - bthread_sem_destroy(&rwlock->writer_sema); + const int rc = bthread_mutex_init(&rwlock->writer_queue_mutex, &attr); + if (rc != 0) { + LOG(ERROR) << "Fail to init writer_queue_mutex, rc=" << rc; return rc; } - bthread_mutexattr_destroy(&attr); - - bthread::make_contention_site_invalid(&rwlock->writer_csite); + // All resources successfully created; transfer butex ownership to + // rwlock. From here on, bthread_rwlock_destroy() is responsible for + // releasing them. + rwlock->writer_wait_count = writer_wait_count.release(); + rwlock->lock_word = lock_word.release(); return 0; } +static int rwlock_destroy(bthread_rwlock_t* rwlock) { + // Destroy the inner mutex first; bthread_mutex_init() allocates an + // internal butex which would otherwise leak. Pointers are nulled to + // surface accidental double-destroy / use-after-destroy bugs early. + int rc = bthread_mutex_destroy(&rwlock->writer_queue_mutex); + if (rc != 0) { + LOG(ERROR) << "Fail to destroy writer_queue_mutex, rc=" << rc; + } + if (rwlock->writer_wait_count != NULL) { + butex_destroy(rwlock->writer_wait_count); + rwlock->writer_wait_count = NULL; + } + if (rwlock->lock_word != NULL) { + butex_destroy(rwlock->lock_word); + rwlock->lock_word = NULL; + } + return rc; +} + +} // namespace bthread + +__BEGIN_DECLS + +int bthread_rwlock_init(bthread_rwlock_t* __restrict rwlock, + const bthread_rwlockattr_t* __restrict) { + return bthread::rwlock_init(rwlock); +} + int bthread_rwlock_destroy(bthread_rwlock_t* rwlock) { - bthread_sem_destroy(&rwlock->reader_sema); - bthread_sem_destroy(&rwlock->writer_sema); - bthread_mutex_destroy(&rwlock->write_queue_mutex); - return 0; + return bthread::rwlock_destroy(rwlock); } int bthread_rwlock_rdlock(bthread_rwlock_t* rwlock) { - return bthread::rwlock_rdlock(rwlock); + return bthread::rwlock_rdlock(rwlock, false, NULL); } int bthread_rwlock_tryrdlock(bthread_rwlock_t* rwlock) { - return bthread::rwlock_tryrdlock(rwlock); + return bthread::rwlock_rdlock(rwlock, true, NULL); } int bthread_rwlock_timedrdlock(bthread_rwlock_t* __restrict rwlock, const struct timespec* __restrict abstime) { - return bthread::rwlock_timedrdlock(rwlock, abstime); + return bthread::rwlock_rdlock(rwlock, false, abstime); } int bthread_rwlock_wrlock(bthread_rwlock_t* rwlock) { - return bthread::rwlock_wrlock(rwlock); + return bthread::rwlock_wrlock(rwlock, false, NULL); } int bthread_rwlock_trywrlock(bthread_rwlock_t* rwlock) { - return bthread::rwlock_trywrlock(rwlock); + return bthread::rwlock_wrlock(rwlock, true, NULL); } int bthread_rwlock_timedwrlock(bthread_rwlock_t* __restrict rwlock, const struct timespec* __restrict abstime) { - return bthread::rwlock_timedwrlock(rwlock, abstime); + return bthread::rwlock_wrlock(rwlock, false, abstime); } int bthread_rwlock_unlock(bthread_rwlock_t* rwlock) { diff --git a/src/bthread/task_group.cpp b/src/bthread/task_group.cpp index 579bb23120..4706b7f77e 100644 --- a/src/bthread/task_group.cpp +++ b/src/bthread/task_group.cpp @@ -247,6 +247,12 @@ TaskGroup::~TaskGroup() { } #ifdef BUTIL_USE_ASAN +// Returns the **highest** address of the calling pthread's stack and its +// total size, matching brpc's `StackStorage::bottom` convention (see comment +// in bthread/stack.h: "Assume stack grows upwards"). Note that on Linux +// `pthread_attr_getstack(3)` returns the lowest address of the region, so +// we have to translate it; on macOS `pthread_get_stackaddr_np(3)` already +// returns the stack base (highest address), so we use it as-is. int PthreadAttrGetStack(void*& stack_addr, size_t& stack_size) { #if defined(OS_MACOSX) stack_addr = pthread_get_stackaddr_np(pthread_self()); @@ -259,9 +265,13 @@ int PthreadAttrGetStack(void*& stack_addr, size_t& stack_size) { LOG(ERROR) << "Fail to get pthread attributes: " << berror(rc); return rc; } - rc = pthread_attr_getstack(&attr, &stack_addr, &stack_size); + void* stack_lowest = NULL; + rc = pthread_attr_getstack(&attr, &stack_lowest, &stack_size); if (0 != rc) { LOG(ERROR) << "Fail to get pthread stack: " << berror(rc); + } else { + // Translate lowest -> highest to match StackStorage::bottom. + stack_addr = (char*)stack_lowest + stack_size; } pthread_attr_destroy(&attr); return rc; @@ -635,6 +645,10 @@ int TaskGroup::join(bthread_t tid, void** return_value) { return errno; } } + // Ensure all memory writes made by the joined bthread are visible to + // the joining thread after join returns. This matches the semantic + // guarantee provided by pthread_join() across supported architectures. + butil::atomic_thread_fence(butil::memory_order_acquire); if (return_value) { *return_value = NULL; } diff --git a/src/bthread/task_tracer.cpp b/src/bthread/task_tracer.cpp index b550c03a9d..e6049f0ded 100644 --- a/src/bthread/task_tracer.cpp +++ b/src/bthread/task_tracer.cpp @@ -17,6 +17,13 @@ #ifdef BRPC_BTHREAD_TRACER +#include +#include +#include +#include +#include +#include +#include #include "bthread/task_tracer.h" #include "bthread/processor.h" #include "bthread/task_group.h" @@ -25,11 +32,6 @@ #include "butil/reloadable_flags.h" #include "absl/debugging/stacktrace.h" #include "absl/debugging/symbolize.h" -#include -#include -#include -#include -#include namespace bthread { @@ -77,7 +79,8 @@ std::string TaskTracer::Result::OutputToString() const { char unknown_symbol_name[] = ""; if (frame_count > 0) { for (size_t i = 0; i < frame_count; ++i) { - butil::string_appendf(&str, "#%zu 0x%16p ", i, ips[i]); + butil::string_appendf(&str, "#%zu 0x%016" PRIxPTR " ", i, + reinterpret_cast(ips[i])); if (absl::Symbolize(ips[i], symbol_name, arraysize(symbol_name))) { str.append(symbol_name); } else { @@ -103,7 +106,10 @@ void TaskTracer::Result::OutputToStream(std::ostream& os) const { char unknown_symbol_name[] = ""; if (frame_count > 0) { for (size_t i = 0; i < frame_count; ++i) { - os << "# " << i << " 0x" << std::hex << ips[i] << std::dec << " "; + os << "#" << i << " 0x" + << std::hex << std::setw(16) << std::setfill('0') + << reinterpret_cast(ips[i]) + << std::dec << std::setfill(' ') << " "; if (absl::Symbolize(ips[i], symbol_name, arraysize(symbol_name))) { os << symbol_name; } else { diff --git a/src/bthread/types.h b/src/bthread/types.h index 86148c938b..d46de1e835 100644 --- a/src/bthread/types.h +++ b/src/bthread/types.h @@ -225,16 +225,26 @@ typedef struct bthread_sem_t { typedef struct bthread_rwlock_t { #if defined(__cplusplus) bthread_rwlock_t() - : reader_count(0), reader_wait(0), wlock_flag(false), writer_csite{} {} + : writer_wait_count(0), lock_word(NULL) {} DISALLOW_COPY_AND_ASSIGN(bthread_rwlock_t); #endif - bthread_sem_t reader_sema; // Semaphore for readers to wait for completing writers. - bthread_sem_t writer_sema; // Semaphore for writers to wait for completing readers. - int reader_count; // Number of pending readers. - int reader_wait; // Number of departing readers. - bool wlock_flag; // Flag used to indicate that a write lock has been held. - bthread_mutex_t write_queue_mutex; // Held if there are pending writers. - bthread_contention_site_t writer_csite; + // Number of writers currently in flight (used as a butex): + // writers waiting on writer_queue_mutex, writers waiting for + // lock_word == 0, and the writer currently holding the write lock + // are all counted here. Each writer accounts for itself: incremented + // at the very beginning of wrlock() and decremented at the very end + // of unwrlock()/cleanup(). Readers consult this field to honor + // writer-priority: any non-zero value parks new readers. + unsigned* writer_wait_count; + // Serializes writers so that at most one writer at a time races for + // lock_word. Other writers queue up on this mutex. + bthread_mutex_t writer_queue_mutex; + // Bit-packed atomic lock word (used as a butex): + // bit 31 : 1 if the write lock is held, 0 otherwise. + // bit 0~30: number of readers currently holding the read lock. + // 0 : unlocked. + // The high bit and the low 31 bits are mutually exclusive. + unsigned* lock_word; } bthread_rwlock_t; typedef struct { diff --git a/src/butil/crc32c.cc b/src/butil/crc32c.cc index 1817cb0a05..7de07cf428 100644 --- a/src/butil/crc32c.cc +++ b/src/butil/crc32c.cc @@ -421,7 +421,194 @@ uint32_t ExtendImpl(uint32_t crc, const char* buf, size_t size) { return static_cast(l ^ 0xffffffffu); } -// Detect if SS42 or not. +#if defined(__riscv) && (__riscv_xlen == 64) && defined(__riscv_zbc) +#include + +// RISC-V Zbc carry-less multiplication inline helpers +static inline uint64_t rv_clmul(uint64_t a, uint64_t b) { + uint64_t result; + __asm__ volatile ("clmul %0, %1, %2" : "=r"(result) : "r"(a), "r"(b)); + return result; +} + +static inline uint64_t rv_clmulh(uint64_t a, uint64_t b) { + uint64_t result; + __asm__ volatile ("clmulh %0, %1, %2" : "=r"(result) : "r"(a), "r"(b)); + return result; +} + +// Bitwise CRC32C fallback for small chunks +static inline uint32_t rv_crc32c_bitwise(uint32_t crc, const uint8_t* buf, + size_t len) { + uint32_t c = crc; + for (size_t i = 0; i < len; ++i) { + c ^= buf[i]; + for (int k = 0; k < 8; ++k) { + c = (c >> 1) ^ ((c & 1) ? 0x82F63B78U : 0); + } + } + return c; +} + +// Fold a 128-bit CRC state (lo:hi) with fold constants and XOR in new data +static inline void rv_fold_pair_xor_data(uint64_t* lo, uint64_t* hi, + uint64_t k0, uint64_t k1, + uint64_t d0, uint64_t d1) { + uint64_t l = rv_clmul(*lo, k0) ^ rv_clmul(*hi, k1); + uint64_t h = rv_clmulh(*lo, k0) ^ rv_clmulh(*hi, k1); + *lo = l ^ d0; + *hi = h ^ d1; +} + +// Fold a 128-bit CRC state with fold constants and XOR in another state +static inline void rv_fold_pair_xor_state(uint64_t* lo, uint64_t* hi, + uint64_t k0, uint64_t k1, + uint64_t s0, uint64_t s1) { + uint64_t l = rv_clmul(*lo, k0) ^ rv_clmul(*hi, k1); + uint64_t h = rv_clmulh(*lo, k0) ^ rv_clmulh(*hi, k1); + *lo = l ^ s0; + *hi = h ^ s1; +} + +// Folding constants for CRC32C (Castagnoli polynomial 0x1EDC6F41) +// x^(64*i+64) mod P(x) for i=1..4, in bit-reflected form +static const uint64_t crc32c_fold_const[4] __attribute__((aligned(16))) = { + 0x00000000740eef02ULL, // k1: fold 512->256 + 0x000000009e4addf8ULL, // k2: fold 512->256 + 0x00000000f20c0dfeULL, // k3: fold 256->128 + 0x00000000493c7d27ULL // k4: fold 256->128 +}; + +// Barrett reduction constants for CRC32C finalization +#define RV_CRC32C_CONST_0 0x00000000dd45aab8ULL // x^64 mod P +#define RV_CRC32C_CONST_1 0x00000000493c7d27ULL // x^96 mod P +#define RV_CRC32C_CONST_QUO 0x0000000dea713f1ULL // floor(x^64 / P) +#define RV_CRC32C_CONST_POLY 0x0000000105ec76f1ULL // P(x) true LE full +#define RV_CRC32_MASK32 0x00000000FFFFFFFFULL + +// Hardware-accelerated CRC32C using RISC-V Zbc carry-less multiplication. +// Processes data in 64-byte chunks with 128-bit folding, then Barrett reduces. +static uint32_t rv_crc32c_clmul(uint32_t crc, const char* buf, size_t len) { + // Convert external CRC to internal register state + crc ^= 0xFFFFFFFF; + + const uint8_t* p = reinterpret_cast(buf); + size_t n = len; + + // Small data: use bitwise fallback + if (n < 64) { + return rv_crc32c_bitwise(crc, p, n) ^ 0xFFFFFFFF; + } + + // Align to 16-byte boundary + uintptr_t mis = (uintptr_t)p & 0xF; + if (mis) { + size_t pre = 16 - mis; + if (pre > n) pre = n; + crc = rv_crc32c_bitwise(crc, p, pre); + p += pre; + n -= pre; + if (n < 64) { + return rv_crc32c_bitwise(crc, p, n) ^ 0xFFFFFFFF; + } + } + + // Load first 64 bytes and XOR CRC into the first 8 bytes + uint64_t x0, x1, y0, y1, z0, z1, w0, w1; + memcpy(&x0, p + 0, 8); + memcpy(&x1, p + 8, 8); + memcpy(&y0, p + 16, 8); + memcpy(&y1, p + 24, 8); + memcpy(&z0, p + 32, 8); + memcpy(&z1, p + 40, 8); + memcpy(&w0, p + 48, 8); + memcpy(&w1, p + 56, 8); + + x0 ^= (uint64_t)crc; + p += 64; + n -= 64; + + const uint64_t k1 = crc32c_fold_const[0]; + const uint64_t k2 = crc32c_fold_const[1]; + const uint64_t k3 = crc32c_fold_const[2]; + const uint64_t k4 = crc32c_fold_const[3]; + + // Main loop: fold 64 bytes per iteration using 128-bit folding + while (n >= 64) { + uint64_t d0, d1; + memcpy(&d0, p + 0, 8); + memcpy(&d1, p + 8, 8); + rv_fold_pair_xor_data(&x0, &x1, k1, k2, d0, d1); + memcpy(&d0, p + 16, 8); + memcpy(&d1, p + 24, 8); + rv_fold_pair_xor_data(&y0, &y1, k1, k2, d0, d1); + memcpy(&d0, p + 32, 8); + memcpy(&d1, p + 40, 8); + rv_fold_pair_xor_data(&z0, &z1, k1, k2, d0, d1); + memcpy(&d0, p + 48, 8); + memcpy(&d1, p + 56, 8); + rv_fold_pair_xor_data(&w0, &w1, k1, k2, d0, d1); + p += 64; + n -= 64; + } + + // Reduce 4x128-bit to 1x128-bit + rv_fold_pair_xor_state(&x0, &x1, k3, k4, y0, y1); + rv_fold_pair_xor_state(&x0, &x1, k3, k4, z0, z1); + rv_fold_pair_xor_state(&x0, &x1, k3, k4, w0, w1); + + // Barrett reduction: 128-bit -> 32-bit CRC + uint64_t t4 = rv_clmul(x0, RV_CRC32C_CONST_1); + uint64_t t3 = rv_clmulh(x0, RV_CRC32C_CONST_1); + uint64_t t1 = x1 ^ t4; + t4 = t1 & RV_CRC32_MASK32; + t1 >>= 32; + uint64_t t0 = rv_clmul(t4, RV_CRC32C_CONST_0); + t3 = (t3 << 32) ^ t1 ^ t0; + + t4 = t3 & RV_CRC32_MASK32; + t4 = rv_clmul(t4, RV_CRC32C_CONST_QUO); + t4 &= RV_CRC32_MASK32; + t4 = rv_clmul(t4, RV_CRC32C_CONST_POLY); + t4 ^= t3; + + uint32_t c = (uint32_t)((t4 >> 32) & RV_CRC32_MASK32); + // Handle remaining bytes + if (n) { + c = rv_crc32c_bitwise(c, p, n); + } + // Convert internal register state to external CRC + return c ^ 0xFFFFFFFF; +} + +// Runtime detection: check if RISC-V CPU supports Zbc extension +static bool isZbc() { + static const bool zbc_supported = []() { + FILE* f = fopen("/proc/cpuinfo", "r"); + if (!f) return false; + bool supported = false; + char line[1024]; + while (fgets(line, sizeof(line), f)) { + if (strstr(line, "isa") || strstr(line, "hart isa")) { + char* colon = strchr(line, ':'); + if (colon) { + if (strstr(colon, "_zbc") || strstr(colon, "zbc")) { + supported = true; + break; + } + } + } + } + fclose(f); + return supported; + }(); + return zbc_supported; +} +} +#endif // __riscv && __riscv_xlen == 64 + +// Detect if SSE4.2 or not. +#ifdef __SSE4_2__ static bool isSSE42() { #if defined(__GNUC__) && defined(__x86_64__) && !defined(IOS_CROSS_COMPILE) uint32_t c_; @@ -432,20 +619,32 @@ static bool isSSE42() { return false; #endif } +#endif typedef uint32_t (*Function)(uint32_t, const char*, size_t); static inline Function Choose_Extend() { - return isSSE42() ? (Function)ExtendImpl : - (Function)ExtendImpl; +#ifdef __SSE4_2__ + if (isSSE42()) { + return (Function)ExtendImpl; + } +#endif +#if defined(__riscv) && (__riscv_xlen == 64) && defined(__riscv_zbc) + if (isZbc()) { + return (Function)rv_crc32c_clmul; + } +#endif + return (Function)ExtendImpl; } bool IsFastCrc32Supported() { #ifdef __SSE4_2__ - return isSSE42(); -#else - return false; + if (isSSE42()) return true; +#endif +#if defined(__riscv) && (__riscv_xlen == 64) && defined(__riscv_zbc) + if (isZbc()) return true; #endif + return false; } uint32_t Extend(uint32_t crc, const char* buf, size_t size) { diff --git a/src/butil/endpoint.cpp b/src/butil/endpoint.cpp index a8d8936c63..c5a888b8e2 100644 --- a/src/butil/endpoint.cpp +++ b/src/butil/endpoint.cpp @@ -394,13 +394,38 @@ int endpoint2hostname(const EndPoint& point, std::string* host) { #if defined(OS_LINUX) static short epoll_to_poll_events(uint32_t epoll_events) { - // Most POLL* and EPOLL* are same values. + // Most POLL* and EPOLL* share the same numeric values for the basic + // event bits, so a plain mask is enough to translate them. short poll_events = (epoll_events & (EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLRDNORM | EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND | EPOLLMSG | EPOLLERR | EPOLLHUP)); - CHECK_EQ((uint32_t)poll_events, epoll_events); + // epoll-only modifier bits (EPOLLET / EPOLLONESHOT / EPOLLRDHUP / ...) + // have no poll(2) counterpart and MUST be silently dropped here: + // * poll(2) is already level-triggered and reports events per call, + // so EPOLLET / EPOLLONESHOT degrade naturally to "no-op". + // * `short` cannot even represent EPOLLET (bit 31 = 0x80000000). + // Without this filtering, a caller invoking bthread_fd_wait(fd, + // EPOLLIN | EPOLLET) from a pthread context would CHECK-fail here. + const uint32_t epoll_modifier_bits = 0u +#ifdef EPOLLET + | EPOLLET +#endif +#ifdef EPOLLONESHOT + | EPOLLONESHOT +#endif +#ifdef EPOLLRDHUP + | EPOLLRDHUP +#endif +#ifdef EPOLLEXCLUSIVE + | EPOLLEXCLUSIVE +#endif +#ifdef EPOLLWAKEUP + | EPOLLWAKEUP +#endif + ; + CHECK_EQ((uint32_t)poll_events, epoll_events & ~epoll_modifier_bits); return poll_events; } #elif defined(OS_MACOSX) diff --git a/src/butil/iobuf.cpp b/src/butil/iobuf.cpp index ce60932327..af77d968cf 100644 --- a/src/butil/iobuf.cpp +++ b/src/butil/iobuf.cpp @@ -42,6 +42,24 @@ #include "butil/iobuf_profiler.h" namespace butil { +static size_t default_block_size = 8192; + +size_t GetDefaultBlockSize() { + return default_block_size; +} + +// This is not thread safe +void SetDefaultBlockSize(size_t block_size) { + if (block_size <= 0) { + LOG(FATAL) << "block_size " << block_size << " should be bigger than 0!!!"; + } + if (block_size / 4096 * 4096 != block_size) { + LOG(FATAL) << "block_size " << block_size << " should be multiply of 4096!!!"; + } + LOG(INFO) << "Update default_block_size from " << default_block_size << " to " << block_size; + default_block_size = block_size; +} + namespace iobuf { DEFINE_int32(iobuf_aligned_buf_block_size, 0, "iobuf aligned buf block size"); @@ -399,9 +417,6 @@ size_t IOBuf::block_count_hit_tls_threshold() { BAIDU_CASSERT(sizeof(IOBuf::SmallView) == sizeof(IOBuf::BigView), sizeof_small_and_big_view_should_equal); -BAIDU_CASSERT(IOBuf::DEFAULT_BLOCK_SIZE/4096*4096 == IOBuf::DEFAULT_BLOCK_SIZE, - sizeof_block_should_be_multiply_of_4096); - const IOBuf::Area IOBuf::INVALID_AREA; IOBuf::IOBuf(const IOBuf& rhs) { diff --git a/src/butil/iobuf.h b/src/butil/iobuf.h index 239e82d950..978aa34fe3 100644 --- a/src/butil/iobuf.h +++ b/src/butil/iobuf.h @@ -52,6 +52,10 @@ struct ssl_st; namespace butil { +size_t GetDefaultBlockSize(); + +void SetDefaultBlockSize(size_t block_size); + // IOBuf is a non-continuous buffer that can be cut and combined w/o copying // payload. It can be read from or flushed into file descriptors as well. // IOBuf is [thread-compatible]. Namely using different IOBuf in different @@ -67,7 +71,6 @@ friend class IOBufCutter; friend class SingleIOBuf; public: - static const size_t DEFAULT_BLOCK_SIZE = 8192; static const size_t INITIAL_CAP = 32; // must be power of 2 struct Block; @@ -775,4 +778,4 @@ inline void swap(butil::IOBuf& a, butil::IOBuf& b) { #include "butil/iobuf_inl.h" -#endif // BUTIL_IOBUF_H \ No newline at end of file +#endif // BUTIL_IOBUF_H diff --git a/src/butil/iobuf_inl.h b/src/butil/iobuf_inl.h index 6b1f875145..756cf8bf63 100644 --- a/src/butil/iobuf_inl.h +++ b/src/butil/iobuf_inl.h @@ -640,7 +640,7 @@ inline IOBuf::Block* create_block(const size_t block_size) { } inline IOBuf::Block* create_block() { - return create_block(IOBuf::DEFAULT_BLOCK_SIZE); + return create_block(butil::GetDefaultBlockSize()); } void* cp(void *__restrict dest, const void *__restrict src, size_t n); @@ -649,4 +649,4 @@ void* cp(void *__restrict dest, const void *__restrict src, size_t n); } // namespace butil -#endif // BUTIL_IOBUF_INL_H \ No newline at end of file +#endif // BUTIL_IOBUF_INL_H diff --git a/src/butil/single_iobuf.cpp b/src/butil/single_iobuf.cpp index c51e8fff1d..942c4ba9e7 100644 --- a/src/butil/single_iobuf.cpp +++ b/src/butil/single_iobuf.cpp @@ -111,7 +111,7 @@ IOBuf::Block* SingleIOBuf::alloc_block_by_size(uint32_t data_size) { } } uint32_t total_size = data_size + sizeof(IOBuf::Block); - if (total_size <= IOBuf::DEFAULT_BLOCK_SIZE) { + if (total_size <= butil::GetDefaultBlockSize()) { _cur_block = iobuf::acquire_tls_block(); if (_cur_block != NULL) { if (_cur_block->left_space() >= data_size) { @@ -282,4 +282,4 @@ void SingleIOBuf::target_block_dec_ref(void* b) { block->dec_ref(); } -} // namespace butil \ No newline at end of file +} // namespace butil diff --git a/src/butil/third_party/symbolize/symbolize.cc b/src/butil/third_party/symbolize/symbolize.cc index 19649b3545..0b4000ad4e 100644 --- a/src/butil/third_party/symbolize/symbolize.cc +++ b/src/butil/third_party/symbolize/symbolize.cc @@ -333,10 +333,18 @@ FindSymbol(uint64_t pc, const int fd, char *out, int out_size, // both regular and dynamic symbol tables if necessary. On success, // write the symbol name to "out" and return true. Otherwise, return // false. +// `base_address` is the runtime VA that corresponds to ELF VA 0 of the +// object identified by `fd`. For ET_EXEC binaries it is ignored +// (symbols already hold absolute runtime addresses); for ET_DYN +// objects (PIE executables and shared libraries) the caller must +// supply a precise value (see +// OpenObjectFileContainingPcAndGetStartAddress() below, which derives +// it from the object's PT_LOAD program headers rather than from the +// /proc/self/maps file_offset field). static bool GetSymbolFromObjectFile(const int fd, uint64_t pc, char *out, int out_size, uint64_t *out_saddr, - uint64_t map_start_address) { + uint64_t base_address) { // Read the ELF header. ElfW(Ehdr) elf_header; if (!ReadFromOffsetExact(fd, &elf_header, sizeof(elf_header), 0)) { @@ -345,7 +353,7 @@ static bool GetSymbolFromObjectFile(const int fd, uint64_t pc, uint64_t symbol_offset = 0; if (elf_header.e_type == ET_DYN) { // DSO needs offset adjustment. - symbol_offset = map_start_address; + symbol_offset = base_address; } ElfW(Shdr) symtab, strtab; @@ -528,13 +536,26 @@ OpenObjectFileContainingPcAndGetStartAddress(uint64_t pc, return -1; } + // Also open /proc/self/mem so we can read ELF headers / program + // headers directly out of the mapped binary, which is the only + // reliable way to compute the runtime ELF base address for PIE + // executables and shared libraries (the file_offset reported by + // /proc/self/maps is per-mapping and does not necessarily match the + // ELF base when modern toolchains use `-z separate-code` or when + // the first PT_LOAD has a non-zero p_vaddr). Mirrors the + // implementation in glog upstream. + int mem_fd; + NO_INTR(mem_fd = open("/proc/self/mem", O_RDONLY)); + FileDescriptor wrapped_mem_fd(mem_fd); + if (wrapped_mem_fd.get() < 0) { + return -1; + } + // Iterate over maps and look for the map containing the pc. Then // look into the symbol tables inside. char buf[1024]; // Big enough for line of sane /proc/self/maps - int num_maps = 0; LineReader reader(wrapped_maps_fd.get(), buf, sizeof(buf)); while (true) { - num_maps++; const char *cursor; const char *eol; if (!reader.ReadLine(&cursor, &eol)) { // EOF or malformed line. @@ -563,11 +584,6 @@ OpenObjectFileContainingPcAndGetStartAddress(uint64_t pc, } ++cursor; // Skip ' '. - // Check start and end addresses. - if (!(start_address <= pc && pc < end_address)) { - continue; // We skip this map. PC isn't in this map. - } - // Read flags. Skip flags until we encounter a space or eol. const char * const flags_start = cursor; while (cursor < eol && *cursor != ' ') { @@ -578,32 +594,74 @@ OpenObjectFileContainingPcAndGetStartAddress(uint64_t pc, return -1; // Malformed line. } + // Determine the base address by reading ELF headers in process + // memory. We must do this for *every* readable map, not just the + // r-x map that contains PC, because a PIE binary or shared + // library's ELF header is typically mapped by an earlier r-- map + // (separate-code layout). Once we encounter the r-- map carrying + // the ELF magic, we record base_address for the whole object. + // When we later reach the r-x map containing PC, base_address is + // already correct. + if (flags_start[0] == 'r') { + ElfW(Ehdr) ehdr; + if (ReadFromOffsetExact(wrapped_mem_fd.get(), &ehdr, sizeof(ehdr), + start_address) && + memcmp(ehdr.e_ident, ELFMAG, SELFMAG) == 0) { + switch (ehdr.e_type) { + case ET_EXEC: + base_address = 0; + break; + case ET_DYN: + // Find the PT_LOAD segment with p_offset == 0 (i.e. the + // segment that contains the ELF header). Its p_vaddr is + // the ELF VA that corresponds to the bytes we just read + // at `start_address`, so base_address = start_address - + // p_vaddr. Normally p_vaddr is 0 and base_address == + // start_address, but some non-standard linker scripts + // place the first LOAD at a non-zero VA. Fall back to + // start_address if no such PT_LOAD is found. + base_address = start_address; + for (unsigned i = 0; i != ehdr.e_phnum; ++i) { + ElfW(Phdr) phdr; + if (ReadFromOffsetExact(wrapped_mem_fd.get(), &phdr, + sizeof(phdr), + start_address + ehdr.e_phoff + + i * sizeof(phdr)) && + phdr.p_type == PT_LOAD && phdr.p_offset == 0) { + base_address = start_address - phdr.p_vaddr; + break; + } + } + break; + default: + // ET_REL or ET_CORE. Not directly executable, leave + // base_address untouched. + break; + } + } + } + + // Check start and end addresses. + if (!(start_address <= pc && pc < end_address)) { + continue; // We skip this map. PC isn't in this map. + } + // Check flags. We are only interested in "r-x" maps. if (memcmp(flags_start, "r-x", 3) != 0) { // Not a "r-x" map. continue; // We skip this map. } ++cursor; // Skip ' '. - // Read file offset. + // Read file offset (parsed but no longer used for base_address; + // base_address is now computed from PT_LOAD program headers + // above. Keep the parse so the cursor advances to the file name). uint64_t file_offset; cursor = GetHex(cursor, eol, &file_offset); if (cursor == eol || *cursor != ' ') { return -1; // Malformed line. } ++cursor; // Skip ' '. - - // Don't subtract 'start_address' from the first entry: - // * If a binary is compiled w/o -pie, then the first entry in - // process maps is likely the binary itself (all dynamic libs - // are mapped higher in address space). For such a binary, - // instruction offset in binary coincides with the actual - // instruction address in virtual memory (as code section - // is mapped to a fixed memory range). - // * If a binary is compiled with -pie, all the modules are - // mapped high at address space (in particular, higher than - // shadow memory of the tool), so the module can't be the - // first entry. - base_address = ((num_maps == 1) ? 0U : start_address) - file_offset; + (void)file_offset; // Skip to file name. "cursor" now points to dev. We need to // skip at least two spaces for dev and inode. @@ -794,9 +852,18 @@ static ATTRIBUTE_NOINLINE bool SymbolizeAndDemangle(void *pc, char *out, out_size -= num_bytes_written; } } + // Use `base_address` (computed by + // OpenObjectFileContainingPcAndGetStartAddress() via the object's + // PT_LOAD program headers) as the relocation offset, NOT + // `start_address`. For ET_DYN objects produced by modern toolchains + // (binutils >= 2.31, lld) using `-z separate-code`, the r-x mapping + // starts at a non-zero file offset and `start_address` no longer + // equals the ELF base; only `base_address` is the correct value + // such that `symbol.st_value + base_address` recovers the runtime + // VA of a symbol. if (!GetSymbolFromObjectFile(wrapped_object_fd.get(), pc0, out, out_size, out_saddr, - start_address)) { + base_address)) { return false; } diff --git a/src/butil/thread_key.h b/src/butil/thread_key.h index c150528b63..77f346d608 100644 --- a/src/butil/thread_key.h +++ b/src/butil/thread_key.h @@ -18,6 +18,7 @@ #ifndef BUTIL_THREAD_KEY_H #define BUTIL_THREAD_KEY_H +#include #include #include #include diff --git a/src/bvar/detail/combiner.h b/src/bvar/detail/combiner.h index cae1b8ea8f..3007f50da8 100644 --- a/src/bvar/detail/combiner.h +++ b/src/bvar/detail/combiner.h @@ -233,7 +233,38 @@ friend class GlobalValue; ~AgentCombiner() { if (_id >= 0) { - clear_all_agents(); + // NOTE: We intentionally do NOT walk `_agents` here (e.g. via the + // previously existed `clear_all_agents()`). + // + // `Agent` instances live inside per-thread `ThreadBlock`s owned by + // `AgentGroup` and are destroyed when their owning thread exits + // (via `_destroy_tls_blocks`). At that point `~Agent` calls + // `combiner.lock()`; if the combiner has already started its + // destruction the `weak_ptr` is expired and the agent will skip + // `commit_and_erase`, leaving its `LinkNode` linked to this + // combiner's `_agents`. If we tried to traverse `_agents` here we + // could touch agent nodes whose `ThreadBlock` was just freed by + // a concurrent thread-exit, causing heap-use-after-free + // (see issue #2937 follow-up). + // + // It is safe to leave the list "dirty" because: + // * `butil::LinkedList` / `butil::LinkNode` have trivial + // destructors and never traverse on destruction, so tearing + // down `_agents` here does not dereference any agent node. + // * After this combiner is gone, every still-alive `Agent` will + // observe `combiner.expired() == true` in `~Agent` and skip + // `commit_and_erase`, so the dangling `prev_/next_` pointers + // in those agents are never read. + // * If the freed `_id` is later reused by a new combiner and the + // same TLS slot is taken, `get_or_create_tls_agent` will call + // `Agent::reset` and `Append` the agent into the new + // combiner's `_agents`. `LinkNode::InsertBefore` only writes + // `prev_/next_` (never reads their stale values), so the + // dangling pointers are safely overwritten. + // * `Agent::element` is destroyed together with the `ThreadBlock`, + // so any non-POD resource it holds is still released; if the + // agent slot is reused, `Agent::reset` will overwrite the + // element value before it is observed again. AgentGroup::destroy_agent(_id); _id = -1; } @@ -319,18 +350,31 @@ friend class GlobalValue; return agent; } - void clear_all_agents() { - butil::AutoLock guard(_lock); - // Resting agents is must because the agent object may be reused. - // Set element to be default-constructed so that if it's non-pod, - // internal allocations should be released. - for (butil::LinkNode* node = _agents.head(); node != _agents.end();) { - node->value()->reset(ElementTp(), NULL); - butil::LinkNode* const saved_next = node->next(); - node->RemoveFromList(); - node = saved_next; - } - } + // NOTE: `clear_all_agents()` is intentionally kept but no longer called + // from `~AgentCombiner` (see the long comment in `~AgentCombiner`). + // + // Calling it from the destructor is unsafe: by the time the destructor + // runs, agent weak_ptrs have already expired and `~Agent` will skip + // `commit_and_erase`; a concurrent thread-exit can therefore free the + // `ThreadBlock` (and the agents inside it) while we are still walking + // `_agents` here, which is a heap-use-after-free. + // + // The body is left around (commented out) for reference / future use -- + // do NOT re-enable it from `~AgentCombiner`. + // + // void clear_all_agents() { + // butil::AutoLock guard(_lock); + // // Resetting agents is a must because the agent object may be + // // reused. Set element to be default-constructed so that if it's + // // non-pod, internal allocations should be released. + // for (butil::LinkNode* node = _agents.head(); + // node != _agents.end();) { + // node->value()->reset(ElementTp(), NULL); + // butil::LinkNode* const saved_next = node->next(); + // node->RemoveFromList(); + // node = saved_next; + // } + // } const BinaryOp& op() const { return _op; } diff --git a/src/bvar/detail/percentile.h b/src/bvar/detail/percentile.h index bac729f162..e6acbf3a01 100644 --- a/src/bvar/detail/percentile.h +++ b/src/bvar/detail/percentile.h @@ -561,7 +561,23 @@ class Percentile { typedef VoidOp InvOp; typedef ReducerSampler sampler_type; - Percentile() = default; + Percentile() { + // babylon::ConcurrentSampler defaults every per-value-range + // bucket to capacity 30 (see _bucket_capacity[32] in + // babylon/concurrent/counter.h) and only grows it inside + // Percentile::reset() after observing how many samples + // arrived. That means the *first* collection round always + // observes badly under-sampled buckets, which destabilises + // mid-quantile estimates (e.g. p60/p70 on a uniform 1..10000 + // stream can land at ~5482/~6284 instead of ~6000/~7000 and + // make PercentileTest.add fail). Pre-size every bucket to the + // full SAMPLE_SIZE on construction so the very first reset() + // already sees representative samples. + for (size_t i = 0; i < NUM_INTERVALS; ++i) { + _concurrent_sampler.set_bucket_capacity(i, value_type::SAMPLE_SIZE); + } + } + DISALLOW_COPY_AND_MOVE(Percentile); ~Percentile() noexcept { if (NULL != _sampler) { diff --git a/test/BUILD.bazel b/test/BUILD.bazel index b68b3fa08a..d536197aee 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -13,30 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") +load("//bazel/tools:brpc_proto_library.bzl", "brpc_proto_library") load("@hedron_compile_commands//:refresh_compile_commands.bzl", "refresh_compile_commands") +load("//test:generate_unittests.bzl", "generate_unittests") +load("//test:root_runfiles.bzl", "root_runfiles") COPTS = [ - "-D__STDC_FORMAT_MACROS", - "-DBTHREAD_USE_FAST_PTHREAD_MUTEX", - "-D__const__=__unused__", - "-D_GNU_SOURCE", - "-DUSE_SYMBOLIZE", - "-DNO_TCMALLOC", - "-D__STDC_LIMIT_MACROS", - "-D__STDC_CONSTANT_MACROS", "-fPIC", "-Wno-unused-parameter", "-fno-omit-frame-pointer", "-fno-access-control", - "-DBAZEL_TEST=1", - "-DBVAR_NOT_LINK_DEFAULT_VARIABLES", - "-DUNIT_TEST", -] + select({ - "//bazel/config:brpc_with_glog": ["-DBRPC_WITH_GLOG=1"], - "//conditions:default": ["-DBRPC_WITH_GLOG=0"], -}) +] TEST_BUTIL_SOURCES = [ "at_exit_unittest.cc", @@ -148,26 +136,11 @@ TEST_BUTIL_SOURCES = [ ], }) -proto_library( - name = "test_proto", - srcs = glob( - [ - "*.proto", - ], - ), - strip_import_prefix = "/test", - visibility = ["//visibility:public"], - deps = [ - "//:brpc_idl_options_proto", - ] -) - -cc_proto_library( +brpc_proto_library( name = "cc_test_proto", + srcs = glob(["*.proto"]), + deps = ["//:brpc_idl_options_cc_proto"], visibility = ["//visibility:public"], - deps = [ - ":test_proto", - ], ) cc_library( @@ -184,8 +157,22 @@ cc_library( ], ) +# tcmalloc and AddressSanitizer both intercept malloc/free, so linking +# both produces undefined behaviour at runtime (typically a hard hang +# at process start, or random allocator errors). When `brpc_with_asan` +# is active we drop the @gperftools//:tcmalloc dep entirely; the +# gperftools_helper header above already no-ops gracefully when +# BRPC_ENABLE_CPU_PROFILER is not defined. +# +# Activation: `--config=asan` in .bazelrc also sets +# `--define=with_asan=true`, which flips this config_setting. +TCMALLOC_DEP_UNLESS_ASAN = select({ + "//bazel/config:brpc_with_asan": [], + "//conditions:default": ["@gperftools//:tcmalloc"], +}) + cc_test( - name = "butil_test", + name = "butil_unittests", srcs = TEST_BUTIL_SOURCES + [ "scoped_locale.h", "multiprocess_func_list.h", @@ -196,105 +183,82 @@ cc_test( ":cc_test_proto", ":sstream_workaround", ":gperftools_helper", - "//:brpc", + "//:butil", + "//:bthread", "@com_google_googletest//:gtest", ], ) - cc_test( - name = "bvar_test", - srcs = glob( - [ - "bvar_*_unittest.cpp", - ], - exclude = [ - "bvar_lock_timer_unittest.cpp", - "bvar_recorder_unittest.cpp", - ], - ), - copts = COPTS, + name = "bvar_unittests", + srcs = glob(["bvar_*_unittest.cpp"]), deps = [ ":sstream_workaround", ":gperftools_helper", "//:bvar", "@com_google_googletest//:gtest", - ], + "@com_google_googletest//:gtest_main", + ] + TCMALLOC_DEP_UNLESS_ASAN, + copts = COPTS, ) -cc_test( - name = "bthread_test", - srcs = glob( - [ - "bthread_*_unittest.cpp", - ], - exclude = [ - "bthread_cond_unittest.cpp", - "bthread_execution_queue_unittest.cpp", - "bthread_dispatcher_unittest.cpp", - "bthread_fd_unittest.cpp", - "bthread_mutex_unittest.cpp", - "bthread_setconcurrency_unittest.cpp", - # glog CHECK die with a fatal error - "bthread_key_unittest.cpp", - "bthread_butex_multi_tag_unittest.cpp", - "bthread_rwlock_unittest.cpp", - "bthread_semaphore_unittest.cpp", - ], - ), - copts = COPTS, +generate_unittests( + name = "bthread_unittests", + srcs = glob([ + "bthread*_unittest.cpp", + ]), deps = [ ":sstream_workaround", ":gperftools_helper", - "//:brpc", + "//:bthread", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "brpc_prometheus_test", - srcs = glob( - [ - "brpc_prometheus_*_unittest.cpp", - ], - ), + ] + TCMALLOC_DEP_UNLESS_ASAN, copts = COPTS, - deps = [ - ":cc_test_proto", - ":sstream_workaround", - "//:brpc", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - ], ) -cc_test( - name = "brpc_auto_concurrency_limiter_test", +# Expose unit-test data files (cert*, jsonout) at the runfiles workspace root +# so that tests can open them via plain relative paths like "cert1.crt". +# See test/root_runfiles.bzl for why a genrule does NOT work here. +root_runfiles( + name = "test_runfiles_root_data", srcs = [ - "brpc_auto_concurrency_limiter_unittest.cpp", - ], - copts = COPTS, - deps = [ - "//:brpc", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", + "cert1.crt", + "cert1.key", + "cert2.crt", + "cert2.key", + "jsonout", ], ) -cc_test( - name = "brpc_redis_cluster_test", - srcs = [ - "brpc_redis_cluster_unittest.cpp", - ], - copts = COPTS, +generate_unittests( + name = "brpc_unittests", + srcs = glob([ + "brpc_*_unittest.cpp", + ]), deps = [ ":sstream_workaround", ":gperftools_helper", "//:brpc", + ":cc_test_proto", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", + ] + TCMALLOC_DEP_UNLESS_ASAN, + copts = COPTS, + # Place cert*/jsonout at /_main/ (the cwd of `bazel test`) + # via :test_runfiles_root_data so test code can open them with plain + # relative paths like "cert1.crt". + data = [ + ":test_runfiles_root_data", ], + # brpc_redis_unittest forks a real redis-server located via PATH. Tag it + # "external" so bazel always re-runs it (never serves a cached pass that + # actually skipped) and "local" so it runs outside the sandbox, where the + # apt-installed redis-server is visible and loopback works. The CI bazel + # jobs install redis-server via the install-essential-dependencies action. + per_test_tags = { + "brpc_redis_unittest.cpp": ["external", "local"], + }, ) refresh_compile_commands( @@ -303,9 +267,10 @@ refresh_compile_commands( # For example, specify a dict of targets and their arguments: targets = { "//:brpc": "", - ":bvar_test": "", - ":bthread_test": "", - ":butil_test": "", + ":butil_unittests": "", + ":bvar_unittests": "", + ":bthread_unittests": "", + ":brpc_unittests": "", }, # For more details, feel free to look into refresh_compile_commands.bzl if you want. ) diff --git a/test/brpc_adaptive_class_unittest.cpp b/test/brpc_adaptive_class_unittest.cpp index c0d76c0044..18128f2a5e 100644 --- a/test/brpc_adaptive_class_unittest.cpp +++ b/test/brpc_adaptive_class_unittest.cpp @@ -31,13 +31,13 @@ const std::string kPooled = "PoOled"; TEST(AdaptiveMaxConcurrencyTest, ShouldConvertCorrectly) { brpc::AdaptiveMaxConcurrency amc(0); - EXPECT_EQ(brpc::AdaptiveMaxConcurrency::UNLIMITED, amc.type()); - EXPECT_EQ(brpc::AdaptiveMaxConcurrency::UNLIMITED, amc.value()); + EXPECT_EQ(brpc::AdaptiveMaxConcurrency::UNLIMITED(), amc.type()); + EXPECT_EQ(brpc::AdaptiveMaxConcurrency::UNLIMITED(), amc.value()); EXPECT_EQ(0, int(amc)); - EXPECT_TRUE(amc == brpc::AdaptiveMaxConcurrency::UNLIMITED); + EXPECT_TRUE(amc == brpc::AdaptiveMaxConcurrency::UNLIMITED()); amc = 10; - EXPECT_EQ(brpc::AdaptiveMaxConcurrency::CONSTANT, amc.type()); + EXPECT_EQ(brpc::AdaptiveMaxConcurrency::CONSTANT(), amc.type()); EXPECT_EQ("10", amc.value()); EXPECT_EQ(10, int(amc)); EXPECT_EQ(amc, "10"); diff --git a/test/brpc_alpn_protocol_unittest.cpp b/test/brpc_alpn_protocol_unittest.cpp index 21aada70d6..9c1506710b 100644 --- a/test/brpc_alpn_protocol_unittest.cpp +++ b/test/brpc_alpn_protocol_unittest.cpp @@ -48,7 +48,7 @@ class EchoServerImpl : public test::EchoService { response->set_message(request->message()); brpc::Controller* cntl = static_cast(controller); - LOG(NOTICE) << "protocol:" << cntl->request_protocol(); + LOG(INFO) << "protocol:" << cntl->request_protocol(); } }; @@ -65,9 +65,9 @@ class ALPNTest : public testing::Test { ssl_options->default_cert.private_key = "cert1.key"; ssl_options->alpns = "http, h2, baidu_std"; - EXPECT_EQ(0, _server.AddService(&_echo_server_impl, + ASSERT_EQ(0, _server.AddService(&_echo_server_impl, brpc::SERVER_DOESNT_OWN_SERVICE)); - EXPECT_EQ(0, _server.Start(FLAGS_listen_addr.data(), &server_options)); + ASSERT_EQ(0, _server.Start(FLAGS_listen_addr.data(), &server_options)); } virtual void TearDown() override { diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp index 2004767470..db6e2ac777 100644 --- a/test/brpc_channel_unittest.cpp +++ b/test/brpc_channel_unittest.cpp @@ -298,14 +298,10 @@ class ChannelTest : public ::testing::Test{ cntl->_current_call.sending_sock.reset(ptr.release()); cntl->_server = &ts->_dummy; - google::protobuf::Closure* done = - brpc::NewCallback< - int64_t, brpc::Controller*, - brpc::RpcPBMessages*, - const brpc::Server*, + google::protobuf::Closure* done = brpc::NewCallback< + int64_t, brpc::Controller*, brpc::RpcPBMessages*, const brpc::Server*, brpc::MethodStatus*, int64_t, std::shared_ptr>( - &brpc::policy::SendRpcResponse, - meta.correlation_id(), cntl, + &brpc::policy::SendRpcResponse, meta.correlation_id(), cntl, messages, &ts->_dummy, NULL, -1, nullptr); ts->_svc.CallMethod(method, cntl, req, res, done); } @@ -1491,6 +1487,57 @@ class ChannelTest : public ::testing::Test{ EXPECT_EQ(cntl.response_attachment().to_string(), "123"); StopAndJoin(); } + + void TestBackupRequestSelectiveResponseRace() { + ASSERT_EQ(0, StartAccept(_ep)); + + const size_t NCHANS = 8; + brpc::SelectiveChannel channel; + ASSERT_EQ(0, channel.Init("rr", NULL)); + for (size_t i = 0; i < NCHANS; ++i) { + brpc::Channel* subchan = new brpc::Channel; + SetUpChannel(subchan, false, false); + ASSERT_EQ(0, channel.AddChannel(subchan, NULL)) << "i=" << i; + } + + const int kRounds = 150; + const int kCodeListSize = 20000; + std::atomic call_cnt(0); + _svc.SetMockFunc([&call_cnt](google::protobuf::RpcController*, + const ::test::EchoRequest*, + ::test::EchoResponse* res, + google::protobuf::Closure*) { + const int seen = call_cnt.fetch_add(1, std::memory_order_relaxed); + const bool slow = ((seen & 1) == 0); + if (slow) { + bthread_usleep(1500); + } + res->clear_code_list(); + const int base = slow ? 1000000 : 2000000; + for (int i = 0; i < kCodeListSize; ++i) { + res->add_code_list(base + i); + } + res->set_message(slow ? "slow" : "fast"); + }); + + for (int round = 0; round < kRounds; ++round) { + brpc::Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + cntl.set_backup_request_ms(1); + cntl.set_timeout_ms(3000); + CallMethod(&channel, &cntl, &req, &res, true); + ASSERT_FALSE(cntl.Failed()) << "round=" << round + << " err=" << cntl.ErrorText(); + ASSERT_EQ(kCodeListSize, res.code_list_size()) << "round=" << round; + ASSERT_TRUE(res.message() == "slow" || res.message() == "fast") + << "round=" << round; + } + + EXPECT_EQ(kRounds * 2, call_cnt.load(std::memory_order_relaxed)); + StopAndJoin(); + } void TestCloseFD(bool single_server, bool async, bool short_connection) { std::cout << " *** single=" << single_server @@ -2787,6 +2834,10 @@ TEST_F(ChannelTest, backuprequest_selective) { } } +TEST_F(ChannelTest, backuprequest_selective_response_race) { + TestBackupRequestSelectiveResponseRace(); +} + TEST_F(ChannelTest, close_fd) { for (int i = 0; i <= 1; ++i) { // Flag SingleServer for (int j = 0; j <= 1; ++j) { // Flag Asynchronous diff --git a/test/brpc_http_rpc_protocol_unittest.cpp b/test/brpc_http_rpc_protocol_unittest.cpp index b75a6da3a4..c7022ed283 100644 --- a/test/brpc_http_rpc_protocol_unittest.cpp +++ b/test/brpc_http_rpc_protocol_unittest.cpp @@ -20,7 +20,6 @@ // Date: Sun Jul 13 15:04:18 CST 2014 #include -#include #include #include #include @@ -60,6 +59,7 @@ DECLARE_bool(rpc_dump); DECLARE_string(rpc_dump_dir); DECLARE_int32(rpc_dump_max_requests_in_one_file); DECLARE_bool(allow_chunked_length); +DECLARE_int32(max_connection_pool_size); extern bvar::CollectorSpeedLimit g_rpc_dump_sl; } @@ -163,6 +163,23 @@ class HttpTest : public ::testing::Test{ EXPECT_EQ(expect, brpc::policy::VerifyHttpRequest(msg)); } + void VerifyMessageFromLocalPort(brpc::InputMessageBase* msg, + bool expect, + int local_port) { + brpc::SocketId id; + brpc::SocketOptions options; + options.fd = dup(_pipe_fds[1]); + EXPECT_GE(options.fd, 0); + options.local_side = butil::EndPoint(butil::my_ip(), local_port); + EXPECT_EQ(0, brpc::Socket::Create(options, &id)); + + brpc::SocketUniquePtr socket; + EXPECT_EQ(0, brpc::Socket::Address(id, &socket)); + socket->ReAddress(&msg->_socket); + msg->_arg = &_server; + EXPECT_EQ(expect, brpc::policy::VerifyHttpRequest(msg)); + } + void ProcessMessage(void (*process)(brpc::InputMessageBase*), brpc::InputMessageBase* msg, bool set_eof) { if (msg->_socket == NULL) { @@ -225,6 +242,33 @@ class HttpTest : public ::testing::Test{ return msg; } + void InitHttpPooledChannel(brpc::Channel* channel, + const butil::EndPoint& ep, + const std::string& connection_group) { + brpc::ChannelOptions options; + options.protocol = brpc::PROTOCOL_HTTP; + options.connection_type = brpc::CONNECTION_TYPE_POOLED; + options.connection_group = connection_group; + options.max_retry = 0; + ASSERT_EQ(0, channel->Init(ep, &options)); + } + + void CallVersion(brpc::Channel* channel, brpc::Controller* cntl) { + cntl->http_request().uri() = "/status"; + cntl->http_request().set_method(brpc::HTTP_METHOD_GET); + channel->CallMethod(NULL, cntl, NULL, NULL, NULL); + } + + void CallHttpEcho(brpc::Channel* channel, brpc::Controller* cntl) { + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + cntl->http_request().uri() = "/EchoService/Echo"; + cntl->http_request().set_method(brpc::HTTP_METHOD_POST); + cntl->http_request().set_content_type("application/json"); + channel->CallMethod(NULL, cntl, &req, &res, NULL); + } + brpc::policy::HttpContext* MakeResponseMessage(int code) { brpc::policy::HttpContext* msg = new brpc::policy::HttpContext(false); @@ -300,6 +344,14 @@ class HttpTest : public ::testing::Test{ MyAuthenticator _auth; }; +int AllocateFreePortOrDie() { + butil::fd_guard fd(tcp_listen(butil::EndPoint(butil::my_ip(), 0))); + EXPECT_GE(fd, 0); + butil::EndPoint point; + EXPECT_EQ(0, butil::get_local_side(fd, &point)); + return point.port; +} + TEST_F(HttpTest, indenting_ostream) { std::ostringstream os1; brpc::IndentingOStream is1(os1, 2); @@ -355,6 +407,12 @@ TEST_F(HttpTest, verify_request) { } { brpc::policy::HttpContext* msg = MakeGetRequestMessage("/status"); + VerifyMessage(msg, false); + msg->Destroy(); + } + { + brpc::policy::HttpContext* msg = MakeGetRequestMessage("/status"); + msg->header().SetHeader("Authorization", MOCK_CREDENTIAL); VerifyMessage(msg, true); msg->Destroy(); } @@ -379,6 +437,97 @@ TEST_F(HttpTest, verify_request) { } } +TEST_F(HttpTest, verify_builtin_request_on_internal_port) { + _server._options.internal_port = 9527; + { + brpc::policy::HttpContext* msg = MakeGetRequestMessage("/status"); + VerifyMessage(msg, false); + msg->Destroy(); + } + { + brpc::policy::HttpContext* msg = MakeGetRequestMessage("/status"); + VerifyMessageFromLocalPort(msg, true, _server._options.internal_port); + msg->Destroy(); + } +} + +TEST_F(HttpTest, builtin_auth_policy_on_public_and_internal_port) { + const int saved_max_connection_pool_size = brpc::FLAGS_max_connection_pool_size; + brpc::FLAGS_max_connection_pool_size = 1; + + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:0", &ep)); + + brpc::Server server; + MyEchoService svc; + MyAuthenticator auth; + brpc::ServerOptions options; + options.auth = &auth; + options.internal_port = AllocateFreePortOrDie(); + ASSERT_EQ(0, server.AddService(&svc, brpc::SERVER_DOESNT_OWN_SERVICE)); + ASSERT_EQ(0, server.Start(ep, &options)); + ep = server.listen_address(); + const butil::EndPoint internal_ep(ep.ip, options.internal_port); + + { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.protocol = brpc::PROTOCOL_HTTP; + copt.max_retry = 0; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + brpc::Controller cntl; + cntl.http_request().uri() = "/status"; + cntl.http_request().set_method(brpc::HTTP_METHOD_GET); + chan.CallMethod(NULL, &cntl, NULL, NULL, NULL); + ASSERT_TRUE(cntl.Failed()); + ASSERT_EQ(brpc::EHTTP, cntl.ErrorCode()) << cntl.ErrorText(); + ASSERT_EQ(brpc::HTTP_STATUS_FORBIDDEN, cntl.http_response().status_code()); + } + + { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.protocol = brpc::PROTOCOL_HTTP; + copt.max_retry = 0; + ASSERT_EQ(0, chan.Init(internal_ep, &copt)); + + brpc::Controller cntl; + cntl.http_request().uri() = "/status"; + cntl.http_request().set_method(brpc::HTTP_METHOD_GET); + chan.CallMethod(NULL, &cntl, NULL, NULL, NULL); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + ASSERT_EQ(brpc::HTTP_STATUS_OK, cntl.http_response().status_code()); + } + + { + const std::string connection_group = "builtin-auth-policy"; + brpc::Channel builtin_channel; + brpc::Channel protected_channel; + brpc::ChannelOptions copt; + copt.protocol = brpc::PROTOCOL_HTTP; + copt.connection_type = brpc::CONNECTION_TYPE_POOLED; + copt.connection_group = connection_group; + copt.max_retry = 0; + ASSERT_EQ(0, builtin_channel.Init(ep, &copt)); + ASSERT_EQ(0, protected_channel.Init(ep, &copt)); + + brpc::Controller builtin_cntl; + CallVersion(&builtin_channel, &builtin_cntl); + ASSERT_TRUE(builtin_cntl.Failed()); + ASSERT_EQ(brpc::EHTTP, builtin_cntl.ErrorCode()) << builtin_cntl.ErrorText(); + ASSERT_EQ(brpc::HTTP_STATUS_FORBIDDEN, builtin_cntl.http_response().status_code()); + + brpc::Controller protected_cntl; + CallHttpEcho(&protected_channel, &protected_cntl); + ASSERT_TRUE(protected_cntl.Failed()); + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); + brpc::FLAGS_max_connection_pool_size = saved_max_connection_pool_size; +} + TEST_F(HttpTest, process_request_failed_socket) { brpc::policy::HttpContext* msg = MakePostRequestMessage("/EchoService/Echo"); _socket->SetFailed(); @@ -1128,7 +1277,7 @@ class UploadServiceImpl : public ::test::UploadService { private: void check_header(brpc::Controller* cntl) { const std::string* test_header = cntl->http_request().GetHeader(TEST_PROGRESSIVE_HEADER); - GOOGLE_CHECK_NOTNULL(test_header); + CHECK(test_header != NULL); CHECK_EQ(*test_header, TEST_PROGRESSIVE_HEADER_VAL); } }; @@ -1808,9 +1957,11 @@ TEST_F(HttpTest, proto_text_content_type) { cntl.http_request().set_method(brpc::HTTP_METHOD_POST); cntl.http_request().uri() = "/EchoService/Echo"; cntl.http_request().set_content_type("application/proto-text"); - cntl.request_attachment().append(req.Utf8DebugString()); + std::string req_text; + ASSERT_TRUE(google::protobuf::TextFormat::PrintToString(req, &req_text)); + cntl.request_attachment().append(req_text); channel.CallMethod(nullptr, &cntl, nullptr, nullptr, nullptr); - ASSERT_FALSE(cntl.Failed()); + ASSERT_FALSE(cntl.Failed()) << req_text; ASSERT_EQ("application/proto-text", cntl.http_response().content_type()); ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( cntl.response_attachment().to_string(), &res)); diff --git a/test/brpc_protobuf_json_unittest.cpp b/test/brpc_protobuf_json_unittest.cpp index b9289b20d5..aa73fbd8f2 100644 --- a/test/brpc_protobuf_json_unittest.cpp +++ b/test/brpc_protobuf_json_unittest.cpp @@ -526,7 +526,12 @@ TEST_F(ProtobufJsonTest, json_to_pb_unbounded_recursion) { std::string error; bool ret = json2pb::ProtoJsonToProtoMessage(nested_json, &msg, options, &error); ASSERT_FALSE(ret); - ASSERT_EQ("INVALID_ARGUMENT:Message too deep. Max recursion depth reached for key 'child'", error); + ASSERT_NE(std::string::npos, error.find("INVALID_ARGUMENT")) + << "error=" << error; + ASSERT_TRUE(error.find("recursion") != std::string::npos || + error.find("nested") != std::string::npos || + error.find("too deep") != std::string::npos) + << "error=" << error; } } diff --git a/test/brpc_rdma_unittest.cpp b/test/brpc_rdma_unittest.cpp index ccb280f1c8..43c6edfd12 100644 --- a/test/brpc_rdma_unittest.cpp +++ b/test/brpc_rdma_unittest.cpp @@ -24,7 +24,6 @@ #include #include "butil/endpoint.h" #include "butil/fd_guard.h" -#include "butil/fd_utility.h" #include "butil/iobuf.h" #include "butil/sys_byteorder.h" #include "butil/files/temp_file.h" @@ -36,15 +35,15 @@ #include "brpc/errno.pb.h" #include "brpc/parallel_channel.h" #include "brpc/selective_channel.h" +#include "brpc/rdma_transport.h" #include "brpc/rdma/block_pool.h" #include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_handshake.h" +#include "brpc/rdma/rdma_handshake.pb.h" #include "brpc/rdma/rdma_helper.h" #include "echo.pb.h" static const int PORT = 8713; -static const size_t RDMA_HELLO_MSG_LEN = 40; -static uint16_t RDMA_HELLO_VERSION = 2; -static uint16_t RDMA_IMPL_VERSION = 1; using namespace brpc; @@ -56,23 +55,13 @@ DEFINE_bool(rdma_test_enable, false, "Enable tests requring rdma runtime."); namespace rdma { -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; +extern const uint16_t RDMA_HELLO_V2_VERSION; +extern const uint16_t RDMA_IMPL_V2_VERSION; DECLARE_bool(rdma_trace_verbose); DECLARE_int32(rdma_memory_pool_max_regions); +DECLARE_int32(rdma_client_handshake_version); + extern ibv_cq* (*IbvCreateCq)(ibv_context*, int, void*, ibv_comp_channel*, int); extern int (*IbvDestroyCq)(ibv_cq*); extern ibv_qp* (*IbvCreateQp)(ibv_pd*, ibv_qp_init_attr*); @@ -81,8 +70,8 @@ extern int (*IbvQueryQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask, ibv_qp_init_at extern int (*IbvDestroyQp)(ibv_qp*); extern butil::atomic g_rdma_available; extern bool g_skip_rdma_init; -} -} +} // namespace rdma +} // namespace brpc static std::string g_ip = "127.0.0.1"; static butil::EndPoint g_ep; @@ -109,7 +98,7 @@ class MyEchoService : public ::test::EchoService { LOG(INFO) << "sleep " << req->sleep_us() << "us..."; bthread_usleep(req->sleep_us()); } - res->set_message(req->message()); + res->set_message("MyEchoService"); if (req->code() != 0) { res->add_code_list(req->code()); } @@ -136,11 +125,12 @@ class RdmaTest : public ::testing::Test { rdma::DumpMemoryPoolInfo(std::cout); } -private: +protected: void StartServer(bool use_rdma = true) { ServerOptions options; - options.use_rdma = use_rdma; - options.idle_timeout_sec = 10; + options.enabled_protocols = "baidu_std"; + options.socket_mode = use_rdma ? SOCKET_MODE_RDMA : SOCKET_MODE_TCP; + options.idle_timeout_sec = 5; options.max_concurrency = 0; options.internal_port = -1; EXPECT_EQ(0, _server.Start(PORT, &options)); @@ -171,6 +161,29 @@ class RdmaTest : public ::testing::Test { MyEchoService _svc; }; +// Parameterized fixture used by upper-layer RPC tests that have no +// dependency on the handshake wire format. The parameter is the +// client-side handshake protocol version (FLAGS_rdma_client_handshake_version), +// so every TEST_P below is automatically executed once per supported +// version. Add a new version to INSTANTIATE_TEST_SUITE_P at the bottom +// of this file and these RPC tests will gain coverage for free. +class RdmaRpcTest : public RdmaTest, + public ::testing::WithParamInterface { +protected: + void SetUp() override { + RdmaTest::SetUp(); + _saved_handshake_version = rdma::FLAGS_rdma_client_handshake_version; + rdma::FLAGS_rdma_client_handshake_version = GetParam(); + } + void TearDown() override { + rdma::FLAGS_rdma_client_handshake_version = _saved_handshake_version; + RdmaTest::TearDown(); + } + +private: + int _saved_handshake_version = 2; +}; + TEST_F(RdmaTest, client_close_before_hello_send) { StartServer(); @@ -184,7 +197,7 @@ TEST_F(RdmaTest, client_close_before_hello_send) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -205,13 +218,13 @@ TEST_F(RdmaTest, client_hello_msg_invalid_magic_str) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; memcpy(data, "PRPC", 4); // send as normal baidu_std protocol ASSERT_EQ(4, write(sockfd, data, 4)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); StopServer(); } @@ -231,11 +244,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RD", 2); ASSERT_EQ(2, write(sockfd1, data, 2)); // break in magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -245,11 +258,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -259,12 +272,12 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); memset(data + 4, 0, 4); ASSERT_EQ(8, write(sockfd3, data, 8)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd3); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -280,18 +293,18 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); ASSERT_TRUE(sockfd1 >= 0); ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memset(data + 4, 0, 36); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); // Write invalid length. usleep(100000); // wait for server to handle the msg @@ -302,11 +315,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, sizeof(len)); memset(data + 6, 0, 34); @@ -325,8 +338,8 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); uint16_t ver = butil::HostToNet16(1); butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); @@ -334,22 +347,29 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 34); memcpy(data + 6, &ver, 2); // hello_ver == 1, impl_ver == 0 - ASSERT_EQ(36, write(sockfd1, data, 36)); + // Write the 36B base starting at data + 4 (NOT data). Pre-Step-1 this + // UT mistakenly wrote `data, 36` which included the leftover "RDMA" + // magic at data[0..4); the server parsed it as msg_len = 0x5244 and + // happened to fall through to NegotiationValid (which then failed on + // hello_ver). Now that Step 1 enforces a HELLO_MSG_LEN_MAX upper bound, + // such an oversized msg_len would be rejected before reaching the + // version check, breaking the intent of this UT. + ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); uint32_t flags = 0; ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -359,21 +379,23 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); memcpy(data + 8, &ver, 2); // hello_ver == 0, impl_ver == 1 - ASSERT_EQ(36, write(sockfd2, data, 36)); + // See comment above on `write(sockfd1, data + 4, 36)` for why we + // write from data + 4 instead of data. + ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -390,11 +412,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { addr.sin_port = htons(PORT); Socket* s = NULL; uint32_t flags = butil::HostToNet32(0); - rdma::HelloMessage msg{}; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 10; msg.rq_size = 16; @@ -406,17 +428,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -431,17 +453,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -456,17 +478,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd3, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd3, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd3, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd3.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -482,11 +504,11 @@ TEST_F(RdmaTest, client_close_after_qp_build) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -500,10 +522,10 @@ TEST_F(RdmaTest, client_close_after_qp_build) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(40, write(sockfd1, data, 40)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -519,11 +541,11 @@ TEST_F(RdmaTest, client_close_during_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -537,17 +559,17 @@ TEST_F(RdmaTest, client_close_during_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -563,11 +585,11 @@ TEST_F(RdmaTest, client_close_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -581,18 +603,18 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -602,17 +624,17 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -628,11 +650,11 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -646,17 +668,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -666,17 +688,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -690,7 +712,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -706,7 +728,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); @@ -721,7 +743,7 @@ TEST_F(RdmaTest, server_close_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -737,15 +759,15 @@ TEST_F(RdmaTest, server_close_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -757,7 +779,7 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -773,12 +795,12 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); bthread_id_join(cntl.call_id()); @@ -792,7 +814,7 @@ TEST_F(RdmaTest, server_close_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -808,17 +830,17 @@ TEST_F(RdmaTest, server_close_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -830,7 +852,7 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -846,15 +868,15 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "ABCD", 4)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -866,7 +888,7 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -882,12 +904,12 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); bthread_id_join(cntl.call_id()); @@ -901,7 +923,7 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -917,17 +939,17 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -939,7 +961,7 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -955,19 +977,19 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -979,7 +1001,7 @@ TEST_F(RdmaTest, server_hello_invalid_version) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -995,19 +1017,19 @@ TEST_F(RdmaTest, server_hello_invalid_version) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1022,7 +1044,7 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1038,15 +1060,15 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; msg.hello_ver = 1; msg.impl_ver = 1; msg.sq_size = 0; @@ -1056,10 +1078,10 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1074,7 +1096,7 @@ TEST_F(RdmaTest, server_miss_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1090,17 +1112,17 @@ TEST_F(RdmaTest, server_miss_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1108,10 +1130,10 @@ TEST_F(RdmaTest, server_miss_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1126,7 +1148,7 @@ TEST_F(RdmaTest, server_close_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1142,17 +1164,17 @@ TEST_F(RdmaTest, server_close_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1160,10 +1182,10 @@ TEST_F(RdmaTest, server_close_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1179,7 +1201,7 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1195,17 +1217,17 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1213,23 +1235,528 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); } + +TEST_F(RdmaTest, v2_client_hello_bytes_baseline) { + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + usleep(100000); + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + // [0..4) magic + ASSERT_EQ(0, memcmp(data, "RDMA", 4)); + // [4..6) msg_len, big-endian uint16 == 40 + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)data[4] << 8) | (uint16_t)data[5])); + // [6..8) hello_ver, big-endian uint16 == rdma::RDMA_HELLO_V2_VERSION + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)data[6] << 8) | (uint16_t)data[7])); + // [8..10) impl_ver, big-endian uint16 == rdma::RDMA_IMPL_V2_VERSION + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)data[8] << 8) | (uint16_t)data[9])); + + rdma::v2_wire::HelloMessage msg{}; + msg.Deserialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, msg.impl_ver); + + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v2_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a well-formed v2 hello so the server enters S_ACK_WAIT. + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(data, "RDMA", 4); + msg.Serialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello and assert its byte-level layout. + uint8_t reply[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(sockfd, reply, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + ASSERT_EQ(0, memcmp(reply, "RDMA", 4)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)reply[4] << 8) | (uint16_t)reply[5])); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)reply[6] << 8) | (uint16_t)reply[7])); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)reply[8] << 8) | (uint16_t)reply[9])); + + rdma::v2_wire::HelloMessage reply_msg{}; + reply_msg.Deserialize(reply + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, reply_msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, reply_msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, reply_msg.impl_ver); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_drains_tail_then_reads_ack) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 48 (40 base + 8B zero tail). + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 48; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[48]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + memset(buf + 40, 0x00, 8); // 8B zero tail + ASSERT_EQ(48, write(sockfd, buf, 48)); + usleep(100000); + + // Send the real ACK (flags=1 = ACK_MSG_RDMA_OK). + uint32_t flags = butil::HostToNet32(1); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_rejects_oversized_msg_len) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 4097 (HELLO_MSG_LEN_MAX + 1). + // We only send the 40B base; the server must reject before reading + // (and definitely before attempting to drain) any "tail". + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 4097; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, buf, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + usleep(100000); + + StopServer(); +} + +// RAII for FLAGS_rdma_client_handshake_version: lets us flip the +// client-side handshake version for a single test and restore it on +// scope exit so subsequent tests stay on the v2 default. +class HandshakeVersionFlag { +public: + explicit HandshakeVersionFlag(int v) + : _saved(rdma::FLAGS_rdma_client_handshake_version) { + rdma::FLAGS_rdma_client_handshake_version = v; + } + ~HandshakeVersionFlag() { + rdma::FLAGS_rdma_client_handshake_version = _saved; + } +private: + int _saved; +}; + +// Build a v3 wire packet from an RdmaHello: "RDM3" + pb_size_be + body. +std::string MakeV3Packet(const rdma::RdmaHello& msg) { + std::string body; + EXPECT_TRUE(msg.SerializeToString(&body)); + std::string packet; + packet.reserve(4 + 4 + body.size()); + packet.append("RDM3", 4); + uint32_t pb_size_be = + butil::HostToNet32(static_cast(body.size())); + packet.append(reinterpret_cast(&pb_size_be), 4); + packet.append(body); + return packet; +} + +// Build a fully-valid RdmaHello: all 6 required fields are set, with +// values that pass RdmaHelloV3Wire::RdmaHelloValid(). +// - block_size = 8192 (>= MIN_BLOCK_SIZE) +// - sq_size / rq_size = 16 (>= MIN_QP_SIZE) +// - gid = exactly 16B (sizeof(ibv_gid)) +// - qp_num = 0 (allowed because g_skip_rdma_init in UT) +rdma::RdmaHello MakeValidV3Hello() { + rdma::RdmaHello msg; + msg.set_block_size(8192); + msg.set_sq_size(16); + msg.set_rq_size(16); + msg.set_lid(0); + ibv_gid gid = rdma::GetRdmaGid(); + msg.set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + msg.set_qp_num(0); + return msg; +} + + +TEST_F(RdmaTest, v3_client_hello_bytes_baseline) { + HandshakeVersionFlag _hsv(3); + + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + // [0..4) magic "RDM3" + uint8_t magic[4]; + ASSERT_EQ(4, read(acc_fd, magic, 4)); + ASSERT_EQ(0, memcmp(magic, "RDM3", 4)); + + // [4..8) pb_size, big-endian uint32, must be in (0, 4096] + uint8_t size_buf[4]; + ASSERT_EQ(4, read(acc_fd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + // [8..8+pb_size) RdmaHello protobuf body. + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(acc_fd, &body[0], pb_size)); + rdma::RdmaHello msg; + ASSERT_TRUE(msg.ParseFromString(body)); + + // All 6 required fields must be present (ParseFromString would + // have already returned false otherwise). + ASSERT_TRUE(msg.has_block_size()); + ASSERT_TRUE(msg.has_sq_size()); + ASSERT_TRUE(msg.has_rq_size()); + ASSERT_TRUE(msg.has_lid()); + ASSERT_TRUE(msg.has_gid()); + ASSERT_TRUE(msg.has_qp_num()); + // gid wire encoding must be exactly 16 bytes (sizeof(ibv_gid)). + ASSERT_EQ(sizeof(ibv_gid), msg.gid().size()); + + // Let the RPC time out and release resources. + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v3_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a valid v3 hello. + std::string packet = MakeV3Packet(MakeValidV3Hello()); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello: 4B magic + 4B pb_size + body. + uint8_t reply_magic[4]; + ASSERT_EQ(4, read(sockfd, reply_magic, 4)); + ASSERT_EQ(0, memcmp(reply_magic, "RDM3", 4)); + + uint8_t size_buf[4]; + ASSERT_EQ(4, read(sockfd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(sockfd, &body[0], pb_size)); + rdma::RdmaHello reply; + ASSERT_TRUE(reply.ParseFromString(body)); + ASSERT_TRUE(reply.has_block_size()); + ASSERT_TRUE(reply.has_sq_size()); + ASSERT_TRUE(reply.has_rq_size()); + ASSERT_TRUE(reply.has_gid()); + ASSERT_EQ(sizeof(ibv_gid), reply.gid().size()); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_zero_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 0 (4B big-endian zero). + uint8_t buf[8] = {'R', 'D', 'M', '3', 0, 0, 0, 0}; + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_oversized_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + uint8_t buf[8]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(4097); + memcpy(buf + 4, &pb_size_be, 4); + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_invalid_pb_bytes) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 8 + 8 bytes of 0xff (invalid protobuf body). + uint8_t buf[16]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(8); + memcpy(buf + 4, &pb_size_be, 4); + memset(buf + 8, 0xff, 8); + ASSERT_EQ(16, write(sockfd, buf, 16)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_invalid_sq_size_falls_back) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + rdma::RdmaHello msg = MakeValidV3Hello(); + msg.set_sq_size(0); // invalid: < MIN_QP_SIZE (16) + std::string packet = MakeV3Packet(msg); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + + // Server validated the hello as invalid -> _rdma_state = RDMA_OFF, + // but still proceeds to S_ACK_WAIT (sends its own reply hello). + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); + + // Drain server's reply hello (content not asserted here; covered + // by v3_server_hello_bytes_baseline). + uint8_t reply_hdr[8]; + ASSERT_EQ(8, read(sockfd, reply_hdr, 8)); + ASSERT_EQ(0, memcmp(reply_hdr, "RDM3", 4)); + uint32_t reply_pb_size = butil::NetToHost32( + *reinterpret_cast(reply_hdr + 4)); + std::string reply_body(reply_pb_size, '\0'); + ASSERT_EQ((ssize_t)reply_pb_size, + read(sockfd, &reply_body[0], reply_pb_size)); + + // Client ACK flags=0 -> server settles into FALLBACK_TCP. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + TEST_F(RdmaTest, try_global_disable_rdma) { StartServer(); rdma::g_rdma_available.store(false, butil::memory_order_relaxed); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1245,7 +1772,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); @@ -1256,7 +1783,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { TEST_F(RdmaTest, server_option_invalid) { Server server; ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible options.rtmp_service = (RtmpService*)1; @@ -1281,7 +1808,7 @@ TEST_F(RdmaTest, server_option_invalid) { TEST_F(RdmaTest, channel_option_invalid) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible chan_options.protocol = "rtmp"; @@ -1342,7 +1869,7 @@ TEST_F(RdmaTest, channel_option_invalid) { ASSERT_EQ(-1, channel.Init(g_ep, &chan_options)); } -TEST_F(RdmaTest, rdma_client_to_rdma_server) { +TEST_P(RdmaRpcTest, rdma_client_to_rdma_server) { if (!FLAGS_rdma_test_enable) { return; } @@ -1351,7 +1878,7 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1362,14 +1889,14 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { req.set_message(__FUNCTION__); google::protobuf::Closure* done = DoNothing(); ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); - usleep(100000); + // usleep(100000); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); StopServer(); } -TEST_F(RdmaTest, tcp_client_to_tcp_server) { +TEST_P(RdmaRpcTest, tcp_client_to_tcp_server) { StartServer(false); Channel channel; @@ -1391,7 +1918,7 @@ TEST_F(RdmaTest, tcp_client_to_tcp_server) { StopServer(); } -TEST_F(RdmaTest, tcp_client_to_rdma_server) { +TEST_P(RdmaRpcTest, tcp_client_to_rdma_server) { StartServer(); Channel channel; @@ -1413,12 +1940,12 @@ TEST_F(RdmaTest, tcp_client_to_rdma_server) { StopServer(); } -TEST_F(RdmaTest, rdma_client_to_tcp_server) { +TEST_P(RdmaRpcTest, rdma_client_to_tcp_server) { StartServer(false); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1440,12 +1967,12 @@ static const int RPC_NUM = 1024; void DumpRdmaEndpointInfo(Socket* client, Socket* server) { std::cout << std::endl << "client:"; - client->_rdma_ep->DebugInfo(std::cout); + static_cast(client->_transport.get())->_rdma_ep->DebugInfo(std::cout); std::cout << std::endl << "server:"; - server->_rdma_ep->DebugInfo(std::cout); + static_cast(server->_transport.get())->_rdma_ep->DebugInfo(std::cout); } -TEST_F(RdmaTest, send_rpcs_in_one_qp) { +TEST_P(RdmaRpcTest, send_rpcs_in_one_qp) { if (!FLAGS_rdma_test_enable) { return; } @@ -1454,9 +1981,9 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 5000; + chan_options.timeout_ms = 50000; chan_options.max_retry = 0; ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); Controller cntl[RPC_NUM]; @@ -1516,50 +2043,57 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Socket* m = GetSocketFromServer(0); DumpRdmaEndpointInfo(s.get(), m); } - ASSERT_TRUE(0 == cntl[i].ErrorCode() || EOVERCROWDED == cntl[i].ErrorCode()) - << "req[" << i << "] " << berror(cntl[i].ErrorCode()); + ASSERT_TRUE(0 == cntl[i].ErrorCode() || + EOVERCROWDED == cntl[i].ErrorCode()) << "req[" << i << "] " << berror(cntl[i].ErrorCode()); } + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl[0]._single_server_id, &s)); + Socket* m = GetSocketFromServer(0); + DumpRdmaEndpointInfo(s.get(), m); + StopServer(); } -TEST_F(RdmaTest, send_rpc_in_many_qp) { +TEST_P(RdmaRpcTest, send_rpc_in_many_qp) { if (!FLAGS_rdma_test_enable) { return; } + butil::ip_t ip; + ASSERT_EQ(0, butil::str2ip(g_ip.c_str(), &ip)); + Server server[100]; MyEchoService svc[100]; int num = 100; + butil::EndPoint server_eps[100]; for (int i = 0; i < num; ++i) { ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; options.idle_timeout_sec = 1; options.max_concurrency = 0; options.internal_port = -1; server[i].AddService(&svc[i], SERVER_DOESNT_OWN_SERVICE); - EXPECT_EQ(0, server[i].Start(i + 8000, &options)); + ASSERT_EQ(0, server[i].Start(0, &options)); + server_eps[i] = butil::EndPoint(ip, server[i].listen_address().port); } int port = 0; butil::IOBuf attach; attach.resize(4096); ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 500; + chan_options.timeout_ms = 100000; chan_options.max_retry = 0; Channel channel[RPC_NUM]; Server* svr[RPC_NUM]; Controller cntl[RPC_NUM]; test::EchoRequest req[RPC_NUM]; test::EchoResponse res[RPC_NUM]; - butil::ip_t ip; - butil::str2ip(g_ip.c_str(), &ip); for (int i = 0; i < RPC_NUM; ++i) { svr[i] = &server[i % num]; - butil::EndPoint ep(ip, 8000 + ((port++) % num)); - ASSERT_EQ(0, channel[i].Init(ep, &chan_options)); + ASSERT_EQ(0, channel[i].Init(server_eps[(port++) % num], &chan_options)); req[i].set_message(__FUNCTION__); cntl[i].request_attachment().append(attach); google::protobuf::Closure* done = DoNothing(); @@ -1569,16 +2103,19 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { bthread_id_join(cntl[i].call_id()); if (cntl[i].ErrorCode() == ERPCTIMEDOUT) { SocketUniquePtr s; - ASSERT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); - std::vector sids; - svr[i]->_am->ListConnections(&sids); - for (size_t i = 0; i < sids.size(); ++i) { - SocketUniquePtr m; - ASSERT_EQ(0, Socket::AddressFailedAsWell(sids[i], &m)); - DumpRdmaEndpointInfo(s.get(), m.get()); + EXPECT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); + if (s && svr[i] && svr[i]->_am) { + std::vector sids; + svr[i]->_am->ListConnections(&sids); + for (size_t j = 0; j < sids.size(); ++j) { + SocketUniquePtr m; + if (Socket::AddressFailedAsWell(sids[j], &m) == 0) { + DumpRdmaEndpointInfo(s.get(), m.get()); + } + } } } - ASSERT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; + EXPECT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; } for (int i = 0; i < num; ++i) { @@ -1587,7 +2124,7 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { } } -TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_pooled_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1596,7 +2133,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1628,7 +2165,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { StopServer(); } -TEST_F(RdmaTest, send_rpcs_as_short_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_short_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1637,7 +2174,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1669,7 +2206,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { StopServer(); } -TEST_F(RdmaTest, server_stop_during_rpc) { +TEST_P(RdmaRpcTest, server_stop_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1678,7 +2215,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1707,7 +2244,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { } } -TEST_F(RdmaTest, server_close_during_rpc) { +TEST_P(RdmaRpcTest, server_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1716,7 +2253,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1749,7 +2286,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, client_close_during_rpc) { +TEST_P(RdmaRpcTest, client_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1758,7 +2295,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1789,7 +2326,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, verbs_error_handling) { +TEST_P(RdmaRpcTest, verbs_error_handling) { if (!FLAGS_rdma_test_enable) { return; } @@ -1798,7 +2335,7 @@ TEST_F(RdmaTest, verbs_error_handling) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1826,7 +2363,8 @@ TEST_F(RdmaTest, verbs_error_handling) { wr.sg_list = &sge; wr.num_sge = 1; ibv_send_wr* bad = NULL; - ibv_post_send(s->_rdma_ep->_resource->qp, &wr, &bad); + auto rdma_transport = static_cast(s->_transport.get()); + ibv_post_send(rdma_transport->_rdma_ep->_resource->qp, &wr, &bad); bthread_id_join(cntl.call_id()); ASSERT_EQ(ERDMA, cntl.ErrorCode()); free(buf); @@ -1834,7 +2372,7 @@ TEST_F(RdmaTest, verbs_error_handling) { StopServer(); } -TEST_F(RdmaTest, rdma_use_parallel_channel) { +TEST_P(RdmaRpcTest, rdma_use_parallel_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1845,13 +2383,14 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { Channel subchans[NCHANS]; ParallelChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; for (size_t i = 0; i < NCHANS; ++i) { ASSERT_EQ(0, subchans[i].Init(_naming_url.c_str(), "rR", &opts)); ASSERT_EQ(0, channel.AddChannel( &subchans[i], DOESNT_OWN_CHANNEL, NULL, NULL)); } + ASSERT_EQ(0, channel.Init(NULL)); Controller cntl; test::EchoRequest req; @@ -1865,7 +2404,7 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { StopServer(); } -TEST_F(RdmaTest, rdma_use_selective_channel) { +TEST_P(RdmaRpcTest, rdma_use_selective_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1875,7 +2414,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { const size_t NCHANS = 8; SelectiveChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; ASSERT_EQ(0, channel.Init("rr", &opts)); for (size_t i = 0; i < NCHANS; ++i) { Channel* subchan = new Channel; @@ -1897,7 +2436,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { static void MockFree(void* buf) { } -TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { +TEST_P(RdmaRpcTest, send_rpcs_with_user_defined_iobuf) { if (!FLAGS_rdma_test_enable) { return; } @@ -1906,7 +2445,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1961,7 +2500,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { StopServer(); } -TEST_F(RdmaTest, try_memory_pool_empty) { +TEST_P(RdmaRpcTest, try_memory_pool_empty) { if (!FLAGS_rdma_test_enable) { return; } @@ -1970,7 +2509,7 @@ TEST_F(RdmaTest, try_memory_pool_empty) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 60000; chan_options.max_retry = 0; @@ -2000,6 +2539,19 @@ TEST_F(RdmaTest, try_memory_pool_empty) { StopServer(); } +// Run every TEST_P(RdmaRpcTest, ...) above twice: once with the +// client-side handshake forced to v2 ("RDMA" magic + fixed-layout +// HelloMessage), once with v3 ("RDM3" magic + protobuf RdmaHello). +// The server always accepts both via magic-byte dispatch, so this +// proves the upper-layer RPC paths behave identically under either +// wire format. +INSTANTIATE_TEST_SUITE_P( + HandshakeVersion, RdmaRpcTest, + ::testing::Values(2, 3), + [](const ::testing::TestParamInfo& info) { + return std::string("v") + std::to_string(info.param); + }); + #endif // if BRPC_WITH_RDMA int main(int argc, char* argv[]) { diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp index a7ef1c209e..e7545a0c82 100644 --- a/test/brpc_ssl_unittest.cpp +++ b/test/brpc_ssl_unittest.cpp @@ -311,6 +311,11 @@ TEST_F(SSLTest, connect_on_create) { brpc::Join(correlation_id); ASSERT_EQ(EXP_RESPONSE, res.message()); } + + ptr->SetFailed(); + ptr.reset(); + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); } void CheckCert(const char* cname, const char* cert) { @@ -492,3 +497,79 @@ TEST_F(SSLTest, ssl_perf) { close(clifd); close(servfd); } + +struct AbruptCloseArgs { int listenfd; }; + +static void* abrupt_close_server(void* arg) { + AbruptCloseArgs* a = (AbruptCloseArgs*)arg; + int connfd = accept(a->listenfd, NULL, NULL); + if (connfd < 0) return NULL; + SSL_CTX* ctx = brpc::CreateServerSSLContext( + "cert1.crt", "cert1.key", brpc::SSLOptions(), NULL, NULL); + SSL* ssl = brpc::CreateSSLSession(ctx, 0, connfd, true); + if (ssl) { SSL_do_handshake(ssl); SSL_free(ssl); } + close(connfd); + return NULL; +} + +TEST_F(SSLTest, ssl_unexpected_eof) { + // Verify that Socket::DoRead() returns -1 with errno=ESSL when the + // remote side closes the TCP connection without sending close_notify. + // Without the fix, DoRead() returns 0, causing error_code=0 to + // propagate to Controller::SetFailed() which triggers CHECK(false). + + const int port = 5962; + butil::EndPoint ep(butil::IP_ANY, port); + butil::fd_guard listenfd(butil::tcp_listen(ep)); + ASSERT_GT(listenfd, 0); + + AbruptCloseArgs server_args = { listenfd }; + pthread_t server_tid; + ASSERT_EQ(0, pthread_create(&server_tid, NULL, abrupt_close_server, + &server_args)); + + brpc::Protocol dummy_protocol = { + brpc::policy::ParseRpcMessage, brpc::SerializeRequestDefault, + brpc::policy::PackRpcRequest, NULL, ProcessResponse, + NULL, NULL, NULL, brpc::CONNECTION_TYPE_ALL, "ssl_ut_eof" + }; + ASSERT_EQ(0, RegisterProtocol((brpc::ProtocolType)31, dummy_protocol)); + + brpc::InputMessageHandler dummy_handler = { + dummy_protocol.parse, dummy_protocol.process_response, + NULL, NULL, dummy_protocol.name + }; + brpc::InputMessenger messenger; + ASSERT_EQ(0, messenger.AddHandler(dummy_handler)); + + brpc::SocketOptions socket_options; + butil::EndPoint server_ep(butil::IP_ANY, port); + socket_options.remote_side = server_ep; + socket_options.connect_on_create = true; + // Do NOT set on_edge_triggered_events — we will call DoRead manually. + socket_options.user = &messenger; + + brpc::ChannelSSLOptions ssl_options; + SSL_CTX* raw_ctx = brpc::CreateClientSSLContext(ssl_options); + ASSERT_NE(nullptr, raw_ctx); + std::shared_ptr ssl_ctx = + std::make_shared(); + ssl_ctx->raw_ctx = raw_ctx; + socket_options.initial_ssl_ctx = ssl_ctx; + + brpc::SocketId socket_id; + ASSERT_EQ(0, brpc::Socket::Create(socket_options, &socket_id)); + brpc::SocketUniquePtr ptr; + ASSERT_EQ(0, brpc::Socket::Address(socket_id, &ptr)); + + // Wait for server to close the connection without close_notify. + pthread_join(server_tid, NULL); + usleep(50000); + + // DoRead should detect the unexpected EOF and return -1 with errno=ESSL. + ssize_t nr = ptr->DoRead(1024); + EXPECT_EQ(-1, nr); + EXPECT_EQ(brpc::ESSL, errno); + + ptr->SetFailed(); +} diff --git a/test/brpc_streaming_rpc_unittest.cpp b/test/brpc_streaming_rpc_unittest.cpp index ecb88c6150..6a3b7b9e42 100644 --- a/test/brpc_streaming_rpc_unittest.cpp +++ b/test/brpc_streaming_rpc_unittest.cpp @@ -91,6 +91,7 @@ struct BatchStreamFeedbackRaceState { std::atomic client_got_second_msg{false}; std::atomic server_write_done{false}; std::atomic rpc_done{false}; + std::atomic client_closed_count{0}; bthread_t server_send_tid{0}; std::atomic server_send_started{false}; @@ -123,7 +124,9 @@ class BatchStreamClientHandler : public brpc::StreamInputHandler { void on_idle_timeout(brpc::StreamId /*id*/) override {} - void on_closed(brpc::StreamId /*id*/) override {} + void on_closed(brpc::StreamId /*id*/) override { + _state->client_closed_count.fetch_add(1, std::memory_order_release); + } void on_failed(brpc::StreamId /*id*/, int /*error_code*/, const std::string& /*error_text*/) override {} @@ -224,12 +227,17 @@ static void SetAtomicTrue(std::atomic* f) { f->store(true, std::memory_order_release); } -static bool WaitForTrue(const std::atomic& f, int timeout_ms) { +template +static bool WaitForTrue(Pred pred, int timeout_ms) { const int64_t deadline_us = butil::gettimeofday_us() + (int64_t)timeout_ms * 1000L; - while (!f.load(std::memory_order_acquire) && butil::gettimeofday_us() < deadline_us) { + while (!pred() && butil::gettimeofday_us() < deadline_us) { usleep(1000); } - return f.load(std::memory_order_acquire); + return pred(); +} + +static bool WaitForTrue(const std::atomic& f, int timeout_ms) { + return WaitForTrue([&f]() { return f.load(std::memory_order_acquire); }, timeout_ms); } TEST_F(StreamingRpcTest, sanity) { @@ -307,6 +315,22 @@ TEST_F(StreamingRpcTest, batch_create_stream_feedback_race) { } server.Stop(0); server.Join(); + + // Release the SocketUniquePtr held above so the fake socket can be + // recycled. Otherwise BeforeRecycle / on_closed for the extra stream + // is deferred until `client_extra_ptr` destructs at scope exit, which + // happens *after* `client_handler` and `state` are destroyed -> UAF + // inside Stream::Consume on Linux. + client_extra_ptr.reset(); + + // on_closed() runs asynchronously on each client stream's consumer + // bthread. Wait for both before letting handler/state go out of + // scope, otherwise Stream::Consume will dereference freed memory. + int expected_closed = request_streams.size(); + WaitForTrue([&state, expected_closed]() { + return state.client_closed_count.load(std::memory_order_acquire) + >= expected_closed; + }, 2000); }; test::EchoService_Stub stub(&channel); @@ -609,7 +633,9 @@ TEST_F(StreamingRpcTest, failed_when_rst) { ASSERT_EQ(0, brpc::StreamWrite(request_stream, out)) << "i=" << i; } - usleep(1000 * 10); + while (handler._expected_next_value != N) { + usleep(100); + } { brpc::SocketUniquePtr ptr; ASSERT_EQ(0, brpc::Socket::Address(request_stream, &ptr)); diff --git a/test/bthread_mutex_unittest.cpp b/test/bthread_mutex_unittest.cpp index f839d063e0..121f1ebb91 100644 --- a/test/bthread_mutex_unittest.cpp +++ b/test/bthread_mutex_unittest.cpp @@ -328,7 +328,6 @@ TEST(MutexTest, fast_pthread_mutex) { void* do_pthread_timedlock(void *arg) { struct timespec t = { -2, 0 }; EXPECT_EQ(ETIMEDOUT, pthread_mutex_timedlock((pthread_mutex_t*)arg, &t)); - EXPECT_EQ(ETIMEDOUT, errno); return NULL; } #endif @@ -347,7 +346,7 @@ TEST(MutexTest, pthread_mutex) { struct timespec t = { -2, 0 }; ASSERT_EQ(ETIMEDOUT, pthread_mutex_timedlock(&mutex, &t)); pthread_t th; - ASSERT_EQ(0, pthread_create(&th, NULL, do_fast_pthread_timedlock, &mutex)); + ASSERT_EQ(0, pthread_create(&th, NULL, do_pthread_timedlock, &mutex)); ASSERT_EQ(0, pthread_join(th, NULL)); #endif } diff --git a/test/bthread_rwlock_unittest.cpp b/test/bthread_rwlock_unittest.cpp index 2da226cb2f..9a88051c1a 100644 --- a/test/bthread_rwlock_unittest.cpp +++ b/test/bthread_rwlock_unittest.cpp @@ -17,6 +17,7 @@ #include #include "gperftools_helper.h" +#include "butil/atomicops.h" #include namespace { @@ -286,6 +287,253 @@ TEST(RWLockTest, mix_thread_types) { ASSERT_EQ(0, bthread_rwlock_destroy(&rw)); } +// Tests below verify the writer-priority semantics and the cleanup path +// guarded by the design notes in bthread/rwlock.cpp. +struct WriterPriorityArgs { + bthread_rwlock_t* rw; + butil::atomic* order; + int my_order; // sequence number captured inside the critical section + int hold_us; +}; + +void* wp_writer_fn(void* arg) { + auto* a = (WriterPriorityArgs*)arg; + EXPECT_EQ(0, bthread_rwlock_wrlock(a->rw)); + a->my_order = a->order->fetch_add(1, butil::memory_order_relaxed); + bthread_usleep(a->hold_us); + EXPECT_EQ(0, bthread_rwlock_unlock(a->rw)); + return NULL; +} + +void* wp_reader_fn(void* arg) { + auto* a = (WriterPriorityArgs*)arg; + EXPECT_EQ(0, bthread_rwlock_rdlock(a->rw)); + a->my_order = a->order->fetch_add(1, butil::memory_order_relaxed); + bthread_usleep(a->hold_us); + EXPECT_EQ(0, bthread_rwlock_unlock(a->rw)); + return NULL; +} + +// Verifies the writer-priority invariant guarded by the order +// "unlock writer_queue_mutex BEFORE fetch_sub(writer_wait_count)" in +// rwlock_unwrlock(): once a writer is queued, any new reader arriving +// later MUST yield to that writer. +TEST(RWLockTest, writer_priority) { + bthread_setconcurrency(8); + bthread_rwlock_t rw; + ASSERT_EQ(0, bthread_rwlock_init(&rw, NULL)); + + // (1) Main thread holds the read lock first. + ASSERT_EQ(0, bthread_rwlock_rdlock(&rw)); + + butil::atomic order(0); + WriterPriorityArgs warg {&rw, &order, -1, 5000}; + WriterPriorityArgs r2arg {&rw, &order, -1, 0}; + + // (2) Start a writer; it should park inside wrlock() because the read + // lock is held. Sleep long enough for it to fetch_add into + // writer_wait_count and reach the butex_wait on `lock_word'. + bthread_t wth; + ASSERT_EQ(0, bthread_start_urgent(&wth, NULL, wp_writer_fn, &warg)); + bthread_usleep(50 * 1000); + + // (3) Now spawn a fresh reader. By writer-priority it MUST observe + // writer_wait_count > 0 and park on it (NOT join the active read + // lock). + bthread_t r2th; + ASSERT_EQ(0, bthread_start_urgent(&r2th, NULL, wp_reader_fn, &r2arg)); + bthread_usleep(50 * 1000); + + // (4) Release the original read lock. The writer should win the race + // and complete BEFORE the queued reader. + ASSERT_EQ(0, bthread_rwlock_unlock(&rw)); + + bthread_join(wth, NULL); + bthread_join(r2th, NULL); + + EXPECT_GE(warg.my_order, 0); + EXPECT_GE(r2arg.my_order, 0); + EXPECT_LT(warg.my_order, r2arg.my_order) + << "Writer-priority violated: writer entered with order=" + << warg.my_order << " but late reader entered with order=" + << r2arg.my_order; + + ASSERT_EQ(0, bthread_rwlock_destroy(&rw)); +} + +void* wp_timed_wrlock_short(void* arg) { + auto* rw = (bthread_rwlock_t*)arg; + timespec ts = butil::milliseconds_from_now(50); + EXPECT_EQ(ETIMEDOUT, bthread_rwlock_timedwrlock(rw, &ts)); + return NULL; +} + +// Verifies the cleanup path of rwlock_wrlock_cleanup(): after multiple +// writers fail with ETIMEDOUT, writer_wait_count must be back to 0 so +// that subsequent readers are not blocked by leftover "ghost shares". +TEST(RWLockTest, wrlock_failure_does_not_leak_writer_count) { + bthread_setconcurrency(8); + bthread_rwlock_t rw; + ASSERT_EQ(0, bthread_rwlock_init(&rw, NULL)); + + // Hold the read lock so every wrlock attempt must block on `lock_word'. + ASSERT_EQ(0, bthread_rwlock_rdlock(&rw)); + + const int N = 8; + bthread_t wth[N]; + for (int i = 0; i < N; ++i) { + ASSERT_EQ(0, bthread_start_urgent(&wth[i], NULL, wp_timed_wrlock_short, &rw)); + } + // Wait for all timed wrlock attempts to time out and run cleanup. + for (int i = 0; i < N; ++i) { + bthread_join(wth[i], NULL); + } + + // Release the read lock; from this point on no writer is in flight, + // so a new reader MUST acquire the lock immediately. + ASSERT_EQ(0, bthread_rwlock_unlock(&rw)); + + timespec ts = butil::milliseconds_from_now(500); + butil::Timer t; + t.start(); + ASSERT_EQ(0, bthread_rwlock_timedrdlock(&rw, &ts)); + t.stop(); + EXPECT_LT(t.m_elapsed(), 100) + << "Reader was blocked for " << t.m_elapsed() << "ms; " + << "writer_wait_count was likely leaked by the cleanup path."; + + ASSERT_EQ(0, bthread_rwlock_unlock(&rw)); + ASSERT_EQ(0, bthread_rwlock_destroy(&rw)); +} + +struct DataConsistencyArgs { + bthread_rwlock_t* rw; + int64_t* shared; // protected by rw + int64_t local_inc; // writer: number of increments this thread did + int64_t observed_max; // reader: max value observed + bool is_writer; +}; + +void* dc_worker(void* arg) { + auto* a = (DataConsistencyArgs*)arg; + while (!g_stopped) { + if (a->is_writer) { + EXPECT_EQ(0, bthread_rwlock_wrlock(a->rw)); + ++(*a->shared); + ++a->local_inc; + EXPECT_EQ(0, bthread_rwlock_unlock(a->rw)); + } else { + EXPECT_EQ(0, bthread_rwlock_rdlock(a->rw)); + int64_t v = *a->shared; + if (v > a->observed_max) { + a->observed_max = v; + } + EXPECT_EQ(0, bthread_rwlock_unlock(a->rw)); + } + } + return NULL; +} + +// Verifies the release/acquire memory ordering pair on `lock_word'. +// If the CAS in unwrlock()/unrdlock() weren't release-ordered, or the +// CAS in rdlock()/wrlock() weren't acquire-ordered, writes done inside +// the critical section could appear lost or inconsistent to other +// threads, causing the final counter to disagree with total writer ops. +TEST(RWLockTest, data_consistency) { + bthread_rwlock_t rw; + ASSERT_EQ(0, bthread_rwlock_init(&rw, NULL)); + + g_stopped = false; + const int W = 4; + const int R = 8; + bthread_setconcurrency(W + R + 4); + + int64_t shared = 0; + std::vector args(W + R); + std::vector threads(W + R); + for (int i = 0; i < W + R; ++i) { + args[i].rw = &rw; + args[i].shared = &shared; + args[i].local_inc = 0; + args[i].observed_max = -1; + args[i].is_writer = (i < W); + ASSERT_EQ(0, bthread_start_urgent(&threads[i], NULL, dc_worker, &args[i])); + } + + bthread_usleep(500 * 1000); + g_stopped = true; + + int64_t total_inc = 0; + for (int i = 0; i < W + R; ++i) { + bthread_join(threads[i], NULL); + if (args[i].is_writer) { + total_inc += args[i].local_inc; + } + } + + // No lost updates: every writer's increment is reflected in `shared'. + EXPECT_EQ(total_inc, shared) + << "Lost updates: total writer ops=" << total_inc + << " but shared counter=" << shared; + // No reader saw a value greater than the final counter. + for (int i = W; i < W + R; ++i) { + EXPECT_LE(args[i].observed_max, shared) + << "Reader " << i << " observed_max=" << args[i].observed_max + << " > final shared=" << shared; + } + + ASSERT_EQ(0, bthread_rwlock_destroy(&rw)); +} + +void* ws_reader_loop(void* arg) { + auto* rw = (bthread_rwlock_t*)arg; + while (!g_stopped) { + EXPECT_EQ(0, bthread_rwlock_rdlock(rw)); + // Hold the read lock briefly to keep the lock continuously busy. + bthread_usleep(100); + EXPECT_EQ(0, bthread_rwlock_unlock(rw)); + } + return NULL; +} + +// Verifies that under a continuous read load, a writer can still acquire +// the lock in bounded time. This is the end-to-end guarantee of the +// writer-priority strategy: any reader arriving AFTER the writer entered +// wrlock() must yield, ensuring the writer never starves. +TEST(RWLockTest, no_writer_starvation) { + bthread_rwlock_t rw; + ASSERT_EQ(0, bthread_rwlock_init(&rw, NULL)); + + g_stopped = false; + const int R = 16; + bthread_setconcurrency(R + 4); + bthread_t rth[R]; + for (int i = 0; i < R; ++i) { + ASSERT_EQ(0, bthread_start_urgent(&rth[i], NULL, ws_reader_loop, &rw)); + } + + // Let the readers ramp up and saturate the lock. + bthread_usleep(50 * 1000); + + // A single writer must succeed within a generous budget. + butil::Timer t; + t.start(); + ASSERT_EQ(0, bthread_rwlock_wrlock(&rw)); + t.stop(); + + EXPECT_LT(t.m_elapsed(), 1000) + << "Writer starved for " << t.m_elapsed() << "ms under " + << R << " concurrent readers; writer-priority is broken."; + + ASSERT_EQ(0, bthread_rwlock_unlock(&rw)); + + g_stopped = true; + for (int i = 0; i < R; ++i) { + bthread_join(rth[i], NULL); + } + ASSERT_EQ(0, bthread_rwlock_destroy(&rw)); +} + struct BAIDU_CACHELINE_ALIGNMENT PerfArgs { bthread_rwlock_t* rw; int64_t counter; @@ -386,13 +634,14 @@ void PerfTest(uint32_t writer_ratio, ThreadId* /*dummy*/, int thread_num, << " writer_ratio=" << writer_ratio << " reader_num=" << reader_num << " read_count=" << read_count - << " read_average_time=" << (read_count == 0 ? 0 : read_wait_time / (double)read_count) + << " read_average_time=" << (read_count == 0 ? 0 : read_wait_time / (double)read_count) << "ns" << " writer_num=" << writer_num << " write_count=" << write_count - << " write_average_time=" << (write_count == 0 ? 0 : write_wait_time / (double)write_count); + << " write_average_time=" << (write_count == 0 ? 0 : write_wait_time / (double)write_count) << "ns"; } TEST(RWLockTest, performance) { + bthread_setconcurrency(16); const int thread_num = 12; PerfTest(0, (pthread_t*)NULL, thread_num, pthread_create, pthread_join); PerfTest(0, (bthread_t*)NULL, thread_num, bthread_start_background, bthread_join); diff --git a/test/bvar_percentile_unittest.cpp b/test/bvar_percentile_unittest.cpp index f647e272ba..d9d01846a1 100644 --- a/test/bvar_percentile_unittest.cpp +++ b/test/bvar_percentile_unittest.cpp @@ -28,6 +28,7 @@ class PercentileTest : public testing::Test { void TearDown() {} }; +#if !WITH_BABYLON_COUNTER TEST_F(PercentileTest, add) { bvar::detail::Percentile p; for (int j = 0; j < 10; ++j) { @@ -51,6 +52,7 @@ TEST_F(PercentileTest, add) { b.describe(out); } } +#endif // !WITH_BABYLON_COUNTER TEST_F(PercentileTest, merge1) { // Merge 2 PercentileIntervals b1 and b2. b2 has double SAMPLE_SIZE diff --git a/test/fuzzing/fuzz_baidu_rpc.cpp b/test/fuzzing/fuzz_baidu_rpc.cpp index 027dbbcb47..0302f0cc9e 100644 --- a/test/fuzzing/fuzz_baidu_rpc.cpp +++ b/test/fuzzing/fuzz_baidu_rpc.cpp @@ -33,6 +33,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseRpcMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/fuzzing/fuzz_common.h b/test/fuzzing/fuzz_common.h index 1ab6bf3b4b..604306babd 100644 --- a/test/fuzzing/fuzz_common.h +++ b/test/fuzzing/fuzz_common.h @@ -32,10 +32,10 @@ inline brpc::Socket* get_fuzz_socket() { if (!initialized) { brpc::SocketOptions options; options.remote_side = butil::EndPoint(butil::IP_ANY, 7777); - if (brpc::Socket::Create(options, &sid) == 0) { - brpc::Socket::Address(sid, &sock_ptr); + if (brpc::Socket::Create(options, &sid) == 0 && + brpc::Socket::Address(sid, &sock_ptr) == 0) { + initialized = true; } - initialized = true; } return sock_ptr.get(); diff --git a/test/fuzzing/fuzz_couchbase.cpp b/test/fuzzing/fuzz_couchbase.cpp index 11eee84adb..8807005339 100644 --- a/test/fuzzing/fuzz_couchbase.cpp +++ b/test/fuzzing/fuzz_couchbase.cpp @@ -33,6 +33,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseCouchbaseMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/fuzzing/fuzz_esp.cpp b/test/fuzzing/fuzz_esp.cpp index 4f93d635a9..d1c6649d86 100644 --- a/test/fuzzing/fuzz_esp.cpp +++ b/test/fuzzing/fuzz_esp.cpp @@ -34,6 +34,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseEspMessage(&buf, sock, false, NULL); return 0; diff --git a/test/fuzzing/fuzz_hulu.cpp b/test/fuzzing/fuzz_hulu.cpp index cb81e1414d..50cc62b31f 100644 --- a/test/fuzzing/fuzz_hulu.cpp +++ b/test/fuzzing/fuzz_hulu.cpp @@ -34,6 +34,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseHuluMessage(&buf, sock, false, NULL); return 0; diff --git a/test/fuzzing/fuzz_memcache.cpp b/test/fuzzing/fuzz_memcache.cpp index e1ef86e626..b60d527f5d 100644 --- a/test/fuzzing/fuzz_memcache.cpp +++ b/test/fuzzing/fuzz_memcache.cpp @@ -33,6 +33,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseMemcacheMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/fuzzing/fuzz_mongo.cpp b/test/fuzzing/fuzz_mongo.cpp index c78ed96591..88db824ede 100644 --- a/test/fuzzing/fuzz_mongo.cpp +++ b/test/fuzzing/fuzz_mongo.cpp @@ -33,6 +33,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseMongoMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/fuzzing/fuzz_shead.cpp b/test/fuzzing/fuzz_shead.cpp index e5d574da58..720b4e8e50 100644 --- a/test/fuzzing/fuzz_shead.cpp +++ b/test/fuzzing/fuzz_shead.cpp @@ -34,6 +34,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseNsheadMessage(&buf, sock, false, NULL); return 0; diff --git a/test/fuzzing/fuzz_sofa.cpp b/test/fuzzing/fuzz_sofa.cpp index b393f85270..a5dc418ec3 100644 --- a/test/fuzzing/fuzz_sofa.cpp +++ b/test/fuzzing/fuzz_sofa.cpp @@ -36,6 +36,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseSofaMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/fuzzing/fuzz_streaming.cpp b/test/fuzzing/fuzz_streaming.cpp index 532bb72550..0b58d7b9e1 100644 --- a/test/fuzzing/fuzz_streaming.cpp +++ b/test/fuzzing/fuzz_streaming.cpp @@ -33,6 +33,9 @@ LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) buf.append(input); brpc::Socket* sock = get_fuzz_socket(); + if (sock == NULL) { + return 0; + } brpc::policy::ParseStreamingMessage(&buf, sock, false, NULL); return 0; } diff --git a/test/generate_unittests.bzl b/test/generate_unittests.bzl new file mode 100644 index 0000000000..8cfec78a71 --- /dev/null +++ b/test/generate_unittests.bzl @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def generate_unittests(name, srcs, deps, copts, linkopts = [], data = [], per_test_tags = {}): + tests = [] + for s in srcs: + ut_name = s.replace(".cpp", "") + native.cc_test( + name = ut_name, + srcs = [s], + copts = copts, + deps = deps, + linkopts = linkopts, + data = data, + # Integration tests that fork a real server binary (e.g. redis-server) + # need extra tags: "external" forces a real run instead of a cached + # pass, and "local" runs them outside the sandbox so the PATH-located + # server binary is visible and loopback works. + tags = per_test_tags.get(s, []), + ) + tests.append(":" + ut_name) + + native.test_suite( + name = name, + tests = tests, + ) \ No newline at end of file diff --git a/test/iobuf_unittest.cpp b/test/iobuf_unittest.cpp index 679cdfe799..489460e20f 100644 --- a/test/iobuf_unittest.cpp +++ b/test/iobuf_unittest.cpp @@ -58,7 +58,7 @@ extern IOBuf::Block* get_portal_next(IOBuf::Block const* b); namespace { const size_t BLOCK_OVERHEAD = 32; //impl dependent -const size_t DEFAULT_PAYLOAD = butil::IOBuf::DEFAULT_BLOCK_SIZE - BLOCK_OVERHEAD; +const size_t DEFAULT_PAYLOAD = butil::GetDefaultBlockSize() - BLOCK_OVERHEAD; void check_tls_block() { ASSERT_EQ((butil::IOBuf::Block*)NULL, butil::iobuf::get_tls_block_head()); @@ -534,7 +534,7 @@ TEST_F(IOBufTest, iobuf_sanity) { TEST_F(IOBufTest, copy_and_assign) { install_debug_allocator(); - const size_t TARGET_SIZE = butil::IOBuf::DEFAULT_BLOCK_SIZE * 2; + const size_t TARGET_SIZE = butil::GetDefaultBlockSize() * 2; butil::IOBuf buf0; buf0.append("hello"); ASSERT_EQ(1u, buf0._ref_num()); @@ -1115,7 +1115,7 @@ TEST_F(IOBufTest, extended_backup) { // Consume the left TLS block so that cases are easier to check. butil::iobuf::remove_tls_block_chain(); butil::IOBuf src; - const int BLKSIZE = (i == 0 ? 1024 : butil::IOBuf::DEFAULT_BLOCK_SIZE); + const int BLKSIZE = (i == 0 ? 1024 : butil::GetDefaultBlockSize()); const int PLDSIZE = BLKSIZE - BLOCK_OVERHEAD; butil::IOBufAsZeroCopyOutputStream out_stream1(&src, BLKSIZE); butil::IOBufAsZeroCopyOutputStream out_stream2(&src); diff --git a/test/root_runfiles.bzl b/test/root_runfiles.bzl new file mode 100644 index 0000000000..7d38b87d5e --- /dev/null +++ b/test/root_runfiles.bzl @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A small Starlark rule that exposes files at the *runfiles workspace root*. + +Why this exists +--------------- +`bazel test` sets the test process cwd to `${TEST_SRCDIR}/${TEST_WORKSPACE}`, +which under bzlmod is `/_main/` -- the workspace root inside +the runfiles tree. Files declared via the standard `data` attribute on a +target in `//test` are placed under `_main/test/` (the rule's package +path), which is NOT the cwd. + +Several brpc tests (e.g. brpc_alpn_protocol_unittest, brpc_ssl_unittest, +brpc_protobuf_json_unittest) open data files such as `cert1.crt`, `cert1.key`, +`jsonout` by *plain relative paths* (this works under the CMake build because +file(COPY ...) places them next to the test binary). To make those same +relative paths work under Bazel without rewriting the test code, we need the +files to appear at `_main/`, i.e. directly at the cwd. + +A genrule cannot help here, because its `outs` must live in the genrule's own +package. The clean way is `ctx.runfiles(symlinks = {...})`, which places the +given files at arbitrary paths *relative to the workspace root inside the +runfiles tree* -- which is exactly the cwd of a Bazel-launched test. + +Beware: `ctx.runfiles` has TWO related dict parameters: + * symlinks -> paths are relative to // (the cwd) + * root_symlinks -> paths are relative to / (one level above cwd) +We want the first one. + +Usage +----- + load("//test:root_runfiles.bzl", "root_runfiles") + + root_runfiles( + name = "test_runfiles_root_data", + srcs = [ + "cert1.crt", + "cert1.key", + "jsonout", + ], + ) + +Then add `:test_runfiles_root_data` to a cc_test's `data` attribute. +""" + +def _root_runfiles_impl(ctx): + # Map each input file to its basename, placed at the workspace root inside + # the runfiles tree. That root is the cwd of a `bazel test` process, so + # tests can open the files via plain relative paths like "cert1.crt". + symlinks = {f.basename: f for f in ctx.files.srcs} + runfiles = ctx.runfiles(symlinks = symlinks) + return [DefaultInfo(runfiles = runfiles)] + +root_runfiles = rule( + implementation = _root_runfiles_impl, + attrs = { + "srcs": attr.label_list( + allow_files = True, + doc = "Files to expose at the workspace root inside the runfiles " + + "tree (i.e. the cwd of `bazel test`), keyed by basename.", + ), + }, + doc = "Exposes data files at the workspace root inside the runfiles tree " + + "(the cwd of `bazel test`), so tests can open them via plain " + + "relative paths.", +)