Skip to content

Commit d100371

Browse files
committed
feat: implement bloom filter building from left table data
- Build local bloom filter during Phase 1 left table scan - Collect and OR-aggregate bits via BloomFilterBits RPC - Broadcast aggregated filter via BloomFilterPush before Phase 2 - Apply sender-side filtering before PushData in Phase 2
1 parent 078c09b commit d100371

5 files changed

Lines changed: 196 additions & 12 deletions

File tree

include/common/cluster_manager.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,45 @@ class ClusterManager {
279279
return "";
280280
}
281281

282+
/**
283+
* @brief Store local bloom filter bits from this node (called on data nodes)
284+
*/
285+
void set_local_bloom_bits(const std::string& context_id, std::vector<uint8_t> bits,
286+
size_t expected_elements, size_t num_hashes) {
287+
const std::scoped_lock<std::mutex> lock(mutex_);
288+
local_bloom_bits_[context_id] = std::move(bits);
289+
local_expected_elements_ = expected_elements;
290+
local_num_hashes_ = num_hashes;
291+
}
292+
293+
/**
294+
* @brief Get stored local bloom filter bits for a context
295+
*/
296+
[[nodiscard]] std::vector<uint8_t> get_local_bloom_bits(const std::string& context_id) const {
297+
const std::scoped_lock<std::mutex> lock(mutex_);
298+
auto it = local_bloom_bits_.find(context_id);
299+
if (it != local_bloom_bits_.end()) {
300+
return it->second;
301+
}
302+
return {};
303+
}
304+
305+
/**
306+
* @brief Get expected_elements for local bloom filter
307+
*/
308+
[[nodiscard]] size_t get_local_expected_elements() const {
309+
const std::scoped_lock<std::mutex> lock(mutex_);
310+
return local_expected_elements_;
311+
}
312+
313+
/**
314+
* @brief Get num_hashes for local bloom filter
315+
*/
316+
[[nodiscard]] size_t get_local_num_hashes() const {
317+
const std::scoped_lock<std::mutex> lock(mutex_);
318+
return local_num_hashes_;
319+
}
320+
282321
/**
283322
* @brief Clear bloom filter for a context
284323
*/
@@ -311,6 +350,10 @@ class ClusterManager {
311350
shuffle_buffers_;
312351
/* context_id -> bloom filter data */
313352
std::unordered_map<std::string, BloomFilterEntry> bloom_filters_;
353+
/* context_id -> local bloom filter bits (for aggregation during distributed build) */
354+
std::unordered_map<std::string, std::vector<uint8_t>> local_bloom_bits_;
355+
size_t local_expected_elements_ = 0;
356+
size_t local_num_hashes_ = 0;
314357
mutable std::mutex mutex_;
315358
};
316359

include/network/rpc_message.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum class RpcType : uint8_t {
3434
PushData = 9,
3535
ShuffleFragment = 10,
3636
BloomFilterPush = 11,
37+
BloomFilterBits = 12,
3738
Error = 255
3839
};
3940

@@ -507,6 +508,64 @@ struct BloomFilterArgs {
507508
}
508509
};
509510

511+
/**
512+
* @brief Arguments for sending local bloom filter bits from data node to coordinator
513+
* Used during Phase 1 to collect and aggregate bloom filters from all nodes
514+
*/
515+
struct BloomFilterBitsArgs {
516+
std::string context_id;
517+
std::vector<uint8_t> filter_data;
518+
size_t expected_elements = 0;
519+
size_t num_hashes = 0;
520+
521+
[[nodiscard]] std::vector<uint8_t> serialize() const {
522+
std::vector<uint8_t> out;
523+
Serializer::serialize_string(context_id, out);
524+
525+
// Serialize filter data (blob)
526+
const uint32_t filter_len = static_cast<uint32_t>(filter_data.size());
527+
const size_t off = out.size();
528+
out.resize(off + Serializer::VAL_SIZE_32);
529+
std::memcpy(out.data() + off, &filter_len, Serializer::VAL_SIZE_32);
530+
out.insert(out.end(), filter_data.begin(), filter_data.end());
531+
532+
// Serialize metadata
533+
uint64_t tmp_expected = static_cast<uint64_t>(expected_elements);
534+
uint8_t tmp_hashes = static_cast<uint8_t>(num_hashes);
535+
const size_t off2 = out.size();
536+
out.resize(off2 + 9); // 8 bytes for expected_elements + 1 for num_hashes
537+
std::memcpy(out.data() + off2, &tmp_expected, 8);
538+
out[off2 + 8] = tmp_hashes;
539+
return out;
540+
}
541+
542+
static BloomFilterBitsArgs deserialize(const std::vector<uint8_t>& in) {
543+
BloomFilterBitsArgs args;
544+
size_t offset = 0;
545+
args.context_id = Serializer::deserialize_string(in.data(), offset, in.size());
546+
547+
uint32_t filter_len = 0;
548+
if (offset + Serializer::VAL_SIZE_32 <= in.size()) {
549+
std::memcpy(&filter_len, in.data() + offset, Serializer::VAL_SIZE_32);
550+
offset += Serializer::VAL_SIZE_32;
551+
}
552+
if (offset + filter_len <= in.size()) {
553+
args.filter_data.resize(filter_len);
554+
std::memcpy(args.filter_data.data(), in.data() + offset, filter_len);
555+
offset += filter_len;
556+
}
557+
558+
if (offset + 9 <= in.size()) {
559+
uint64_t tmp_expected = 0;
560+
std::memcpy(&tmp_expected, in.data() + offset, 8);
561+
args.expected_elements = static_cast<size_t>(tmp_expected);
562+
offset += 8;
563+
args.num_hashes = static_cast<size_t>(in[offset]);
564+
}
565+
return args;
566+
}
567+
};
568+
510569
/**
511570
* @brief Arguments for TxnPrepare/Commit/Abort RPC
512571
*/

src/distributed/distributed_executor.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,43 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
242242
return res;
243243
}
244244

245-
// After Phase 1, each node will have received left table data.
246-
// Now broadcast bloom filter built from that data to all nodes for Phase 2
247-
// filtering. The filter is sent as a separate RPC that data nodes will store and
248-
// apply to their right table shuffle. For now, we send a simple metadata-only
249-
// filter that signals "filtering enabled" - the actual filter building happens on
250-
// each data node during Phase 1 and they stash it for use during Phase 2.
251-
//
252-
// In production, we'd collect and OR all local bloom filters, but for POC
253-
// we just signal that bloom filtering is enabled for this context.
245+
// After Phase 1, collect bloom filter bits from each data node and aggregate
246+
// via bitwise OR to create the combined bloom filter
247+
std::vector<uint8_t> aggregated_bits;
248+
size_t total_expected = 0;
249+
size_t max_hashes = 0;
250+
251+
for (const auto& node : data_nodes) {
252+
network::RpcClient client(node.address, node.cluster_port);
253+
if (!client.connect()) {
254+
continue;
255+
}
256+
network::BloomFilterBitsArgs bits_args;
257+
bits_args.context_id = context_id;
258+
std::vector<uint8_t> resp;
259+
if (client.call(network::RpcType::BloomFilterBits, bits_args.serialize(), resp)) {
260+
auto reply = network::BloomFilterBitsArgs::deserialize(resp);
261+
if (reply.filter_data.size() > aggregated_bits.size()) {
262+
aggregated_bits.resize(reply.filter_data.size(), 0);
263+
}
264+
// Bitwise OR aggregation
265+
for (size_t i = 0; i < reply.filter_data.size(); i++) {
266+
aggregated_bits[i] |= reply.filter_data[i];
267+
}
268+
total_expected += reply.expected_elements;
269+
max_hashes = std::max(max_hashes, reply.num_hashes);
270+
}
271+
}
272+
273+
// Broadcast the aggregated bloom filter to all nodes for Phase 2 filtering
254274
network::BloomFilterArgs bf_args;
255275
bf_args.context_id = context_id;
256276
bf_args.build_table = left_table;
257277
bf_args.probe_table = right_table;
258278
bf_args.probe_key_col = right_key; // Tell probe side which column to filter on
259-
bf_args.filter_data.clear(); // Empty = filter built distributed
260-
bf_args.expected_elements = data_nodes.size() * 1000; // Estimate
261-
bf_args.num_hashes = 4;
279+
bf_args.filter_data = aggregated_bits;
280+
bf_args.expected_elements = total_expected;
281+
bf_args.num_hashes = max_hashes > 0 ? max_hashes : 4;
262282
auto bf_payload = bf_args.serialize();
263283

264284
for (const auto& node : data_nodes) {

src/main.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,31 @@ int main(int argc, char* argv[]) {
516516
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
517517
});
518518

519+
// Handler for collecting local bloom filter bits from data nodes
520+
// Coordinator calls this after Phase 1 to aggregate bloom filters
521+
rpc_server->set_handler(
522+
cloudsql::network::RpcType::BloomFilterBits,
523+
[&](const cloudsql::network::RpcHeader& h, const std::vector<uint8_t>& p,
524+
int fd) {
525+
(void)h;
526+
auto args = cloudsql::network::BloomFilterBitsArgs::deserialize(p);
527+
cloudsql::network::BloomFilterBitsArgs reply_args;
528+
reply_args.context_id = args.context_id;
529+
reply_args.filter_data = cluster_manager->get_local_bloom_bits(args.context_id);
530+
reply_args.expected_elements = cluster_manager->get_local_expected_elements();
531+
reply_args.num_hashes = cluster_manager->get_local_num_hashes();
532+
533+
auto resp_p = reply_args.serialize();
534+
cloudsql::network::RpcHeader resp_h;
535+
resp_h.type = cloudsql::network::RpcType::QueryResults;
536+
resp_h.payload_len = static_cast<uint16_t>(resp_p.size());
537+
char h_buf[cloudsql::network::RpcHeader::HEADER_SIZE];
538+
resp_h.encode(h_buf);
539+
static_cast<void>(
540+
send(fd, h_buf, cloudsql::network::RpcHeader::HEADER_SIZE, 0));
541+
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
542+
});
543+
519544
rpc_server->set_handler(
520545
cloudsql::network::RpcType::ShuffleFragment,
521546
[&](const cloudsql::network::RpcHeader& h, const std::vector<uint8_t>& p,
@@ -556,11 +581,18 @@ int main(int argc, char* argv[]) {
556581
partitions[node.id] = {};
557582
}
558583

584+
// Estimate expected elements for bloom filter
585+
// For now, estimate based on table size (will be refined with actual count)
586+
size_t estimated_count = 1000;
587+
cloudsql::common::BloomFilter local_bloom(estimated_count);
588+
559589
auto iter = table.scan();
560590
cloudsql::storage::HeapTable::TupleMeta t_meta;
561591
while (iter.next_meta(t_meta)) {
562592
if (t_meta.xmax == 0) { // Visible
563593
const auto& key_val = t_meta.tuple.get(key_idx);
594+
// Build bloom filter from join key values
595+
local_bloom.insert(key_val);
564596
uint32_t node_idx =
565597
cloudsql::cluster::ShardManager::compute_shard(
566598
key_val, static_cast<uint32_t>(data_nodes.size()));
@@ -569,6 +601,14 @@ int main(int argc, char* argv[]) {
569601
}
570602
}
571603

604+
// Store local bloom filter bits for coordinator to collect
605+
// The coordinator will aggregate these during Phase 1
606+
auto bloom_bits = local_bloom.serialize();
607+
cluster_manager->set_local_bloom_bits(
608+
args.context_id, bloom_bits,
609+
local_bloom.expected_elements(),
610+
local_bloom.num_hashes());
611+
572612
bool overall_success = true;
573613
std::string delivery_errors;
574614

tests/distributed_tests.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,36 @@ TEST(DistributedExecutorTests, ShuffleJoinOrchestration) {
330330
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
331331
};
332332

333+
auto bloom_bits_handler = [&](const RpcHeader& h, const std::vector<uint8_t>& p, int fd) {
334+
(void)h;
335+
auto args = BloomFilterBitsArgs::deserialize(p);
336+
BloomFilterBitsArgs reply_args;
337+
reply_args.context_id = args.context_id;
338+
// Return empty bloom filter bits for mock - real implementation would return actual bits
339+
reply_args.filter_data = {};
340+
reply_args.expected_elements = 0;
341+
reply_args.num_hashes = 4;
342+
343+
auto resp_p = reply_args.serialize();
344+
RpcHeader resp_h;
345+
resp_h.type = RpcType::QueryResults;
346+
resp_h.payload_len = static_cast<uint16_t>(resp_p.size());
347+
char h_buf[RpcHeader::HEADER_SIZE];
348+
resp_h.encode(h_buf);
349+
static_cast<void>(send(fd, h_buf, RpcHeader::HEADER_SIZE, 0));
350+
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
351+
};
352+
333353
node1.set_handler(RpcType::ShuffleFragment, handler);
334354
node1.set_handler(RpcType::PushData, handler);
335355
node1.set_handler(RpcType::ExecuteFragment, handler);
336356
node1.set_handler(RpcType::BloomFilterPush, handler);
357+
node1.set_handler(RpcType::BloomFilterBits, bloom_bits_handler);
337358
node2.set_handler(RpcType::ShuffleFragment, handler);
338359
node2.set_handler(RpcType::PushData, handler);
339360
node2.set_handler(RpcType::ExecuteFragment, handler);
340361
node2.set_handler(RpcType::BloomFilterPush, handler);
362+
node2.set_handler(RpcType::BloomFilterBits, bloom_bits_handler);
341363

342364
ASSERT_TRUE(node1.start());
343365
ASSERT_TRUE(node2.start());

0 commit comments

Comments
 (0)