Skip to content

Commit 2481607

Browse files
committed
fix: bloom filter core implementation
- Use uint64_t with bswap64 for portable serialization - Add input validation for zero expected_elements and invalid FPR - Fix truncated payload handling in BloomFilter constructor - Make h2 in get_bit_position key-dependent via rehashing - Add num_bits_ == 0 guards in insert() and might_contain()
1 parent 5154dde commit 2481607

2 files changed

Lines changed: 92 additions & 35 deletions

File tree

include/network/rpc_message.hpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,8 @@ struct BloomFilterArgs {
449449
std::string probe_table;
450450
std::string probe_key_col; // Join key column on probe side for filtering
451451
std::vector<uint8_t> filter_data;
452-
size_t expected_elements;
453-
size_t num_hashes;
452+
size_t expected_elements = 0;
453+
size_t num_hashes = 0;
454454

455455
[[nodiscard]] std::vector<uint8_t> serialize() const {
456456
std::vector<uint8_t> out;
@@ -466,11 +466,13 @@ struct BloomFilterArgs {
466466
std::memcpy(out.data() + off, &filter_len, Serializer::VAL_SIZE_32);
467467
out.insert(out.end(), filter_data.begin(), filter_data.end());
468468

469-
// Serialize metadata
469+
// Serialize metadata using fixed-width temporaries
470+
uint64_t tmp_expected = static_cast<uint64_t>(expected_elements);
471+
uint8_t tmp_hashes = static_cast<uint8_t>(num_hashes);
470472
const size_t off2 = out.size();
471-
out.resize(off2 + Serializer::VAL_SIZE_32);
472-
std::memcpy(out.data() + off2, &expected_elements, Serializer::VAL_SIZE_32);
473-
out.push_back(static_cast<uint8_t>(num_hashes));
473+
out.resize(off2 + 9); // 8 bytes for expected_elements + 1 for num_hashes
474+
std::memcpy(out.data() + off2, &tmp_expected, 8);
475+
out[off2 + 8] = tmp_hashes;
474476
return out;
475477
}
476478

@@ -493,12 +495,13 @@ struct BloomFilterArgs {
493495
offset += filter_len;
494496
}
495497

496-
if (offset + Serializer::VAL_SIZE_32 <= in.size()) {
497-
std::memcpy(&args.expected_elements, in.data() + offset, Serializer::VAL_SIZE_32);
498-
offset += Serializer::VAL_SIZE_32;
499-
}
500-
if (offset < in.size()) {
501-
args.num_hashes = in[offset];
498+
// Deserialize metadata using fixed-width temporaries
499+
if (offset + 9 <= in.size()) {
500+
uint64_t tmp_expected = 0;
501+
std::memcpy(&tmp_expected, in.data() + offset, 8);
502+
args.expected_elements = static_cast<size_t>(tmp_expected);
503+
offset += 8;
504+
args.num_hashes = static_cast<size_t>(in[offset]);
502505
}
503506
return args;
504507
}

src/common/bloom_filter.cpp

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,36 @@
66
#include "common/bloom_filter.hpp"
77

88
#include <cmath>
9+
#include <cstdint>
10+
11+
#if defined(__APPLE__)
12+
#include <libkern/OSByteOrder.h>
13+
#define bswap64(x) OSSwapInt64(x)
14+
#else
15+
#include <endian.h>
16+
#define bswap64(x) __builtin_bswap64(x)
17+
#endif
918

1019
namespace cloudsql::common {
1120

1221
BloomFilter::BloomFilter(size_t expected_elements, double false_positive_rate)
1322
: expected_elements_(expected_elements) {
23+
// Handle zero expected_elements as empty filter
24+
if (expected_elements == 0) {
25+
num_bits_ = 0;
26+
num_hashes_ = 0;
27+
return;
28+
}
29+
30+
// Clamp false_positive_rate to safe range [0.001, 0.99]
31+
double p = false_positive_rate;
32+
if (p <= 0.0 || p >= 1.0) {
33+
p = 0.01; // Safe default
34+
}
35+
1436
// m = -n * ln(p) / (ln(2)^2)
1537
// k = m/n * ln(2)
1638
double n = static_cast<double>(expected_elements);
17-
double p = false_positive_rate;
1839

1940
double m = -n * std::log(p) / (std::log(2) * std::log(2));
2041
double k = (m / n) * std::log(2);
@@ -37,25 +58,45 @@ BloomFilter::BloomFilter(size_t expected_elements, double false_positive_rate)
3758
}
3859

3960
BloomFilter::BloomFilter(const uint8_t* data, size_t size) {
40-
if (size < sizeof(size_t) * 3) {
61+
// Minimum size: 3 x uint64_t header + at least 1 byte of bits
62+
if (size < sizeof(uint64_t) * 3 + 1) {
4163
return; // Invalid data
4264
}
4365

4466
size_t offset = 0;
45-
std::memcpy(&num_bits_, data + offset, sizeof(size_t));
46-
offset += sizeof(size_t);
47-
48-
std::memcpy(&num_hashes_, data + offset, sizeof(size_t));
49-
offset += sizeof(size_t);
50-
51-
std::memcpy(&expected_elements_, data + offset, sizeof(size_t));
52-
offset += sizeof(size_t);
5367

68+
// Read with fixed-width uint64_t and proper byte-order conversion
69+
uint64_t tmp_num_bits = 0;
70+
std::memcpy(&tmp_num_bits, data + offset, sizeof(uint64_t));
71+
tmp_num_bits = bswap64(tmp_num_bits);
72+
num_bits_ = static_cast<size_t>(tmp_num_bits);
73+
offset += sizeof(uint64_t);
74+
75+
uint64_t tmp_num_hashes = 0;
76+
std::memcpy(&tmp_num_hashes, data + offset, sizeof(uint64_t));
77+
tmp_num_hashes = bswap64(tmp_num_hashes);
78+
num_hashes_ = static_cast<size_t>(tmp_num_hashes);
79+
offset += sizeof(uint64_t);
80+
81+
uint64_t tmp_expected = 0;
82+
std::memcpy(&tmp_expected, data + offset, sizeof(uint64_t));
83+
tmp_expected = bswap64(tmp_expected);
84+
expected_elements_ = static_cast<size_t>(tmp_expected);
85+
offset += sizeof(uint64_t);
86+
87+
// Validate bit array size
5488
size_t bit_bytes = (num_bits_ + 7) / 8;
55-
if (size >= offset + bit_bytes) {
56-
bits_.resize(bit_bytes);
57-
std::memcpy(bits_.data(), data + offset, bit_bytes);
89+
if (size < offset + bit_bytes) {
90+
// Truncated payload - reset to safe empty state
91+
num_bits_ = 0;
92+
num_hashes_ = 0;
93+
expected_elements_ = 0;
94+
bits_.clear();
95+
return;
5896
}
97+
98+
bits_.resize(bit_bytes);
99+
std::memcpy(bits_.data(), data + offset, bit_bytes);
59100
}
60101

61102
size_t BloomFilter::murmur3_hash(const Value& key) const {
@@ -84,14 +125,21 @@ size_t BloomFilter::murmur3_hash(const uint8_t* data, size_t len, size_t seed) c
84125

85126
size_t BloomFilter::get_bit_position(size_t hash, size_t i) const {
86127
// Double hashing technique: h(i) = h1 + i * h2
87-
// Use two different hash seeds
128+
// Make h2 key-dependent by rehashing the input hash with a different seed
88129
size_t h1 = hash;
89-
size_t h2 = murmur3_hash(reinterpret_cast<const uint8_t*>("salt"), 4, 0xcafebabe);
130+
size_t h2 = murmur3_hash(reinterpret_cast<const uint8_t*>(&hash), sizeof(hash), 0xcafebabe);
131+
132+
// Ensure h2 is non-zero to avoid degenerate probing
133+
if (h2 == 0) {
134+
h2 = 1;
135+
}
90136

91137
return (h1 + i * h2) % num_bits_;
92138
}
93139

94140
void BloomFilter::insert(const Value& key) {
141+
if (num_bits_ == 0) return; // Empty filter
142+
95143
size_t base_hash = murmur3_hash(key);
96144

97145
for (size_t i = 0; i < num_hashes_; ++i) {
@@ -103,6 +151,8 @@ void BloomFilter::insert(const Value& key) {
103151
}
104152

105153
bool BloomFilter::might_contain(const Value& key) const {
154+
if (num_bits_ == 0) return false; // Empty filter
155+
106156
size_t base_hash = murmur3_hash(key);
107157

108158
for (size_t i = 0; i < num_hashes_; ++i) {
@@ -121,17 +171,21 @@ bool BloomFilter::might_contain(const Value& key) const {
121171
std::vector<uint8_t> BloomFilter::serialize() const {
122172
std::vector<uint8_t> out;
123173

124-
// Store metadata
125-
out.resize(sizeof(size_t) * 3);
174+
// Store metadata using fixed-width uint64_t with byte-order conversion
175+
out.resize(sizeof(uint64_t) * 3);
126176
size_t offset = 0;
127-
std::memcpy(out.data() + offset, &num_bits_, sizeof(size_t));
128-
offset += sizeof(size_t);
129177

130-
std::memcpy(out.data() + offset, &num_hashes_, sizeof(size_t));
131-
offset += sizeof(size_t);
178+
uint64_t tmp_num_bits = bswap64(static_cast<uint64_t>(num_bits_));
179+
std::memcpy(out.data() + offset, &tmp_num_bits, sizeof(uint64_t));
180+
offset += sizeof(uint64_t);
181+
182+
uint64_t tmp_num_hashes = bswap64(static_cast<uint64_t>(num_hashes_));
183+
std::memcpy(out.data() + offset, &tmp_num_hashes, sizeof(uint64_t));
184+
offset += sizeof(uint64_t);
132185

133-
std::memcpy(out.data() + offset, &expected_elements_, sizeof(size_t));
134-
offset += sizeof(size_t);
186+
uint64_t tmp_expected = bswap64(static_cast<uint64_t>(expected_elements_));
187+
std::memcpy(out.data() + offset, &tmp_expected, sizeof(uint64_t));
188+
offset += sizeof(uint64_t);
135189

136190
// Store bits
137191
size_t bit_bytes = (num_bits_ + 7) / 8;

0 commit comments

Comments
 (0)