Skip to content
Merged
43 changes: 43 additions & 0 deletions include/common/cluster_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,45 @@ class ClusterManager {
return "";
}

/**
* @brief Store local bloom filter bits from this node (called on data nodes)
*/
void set_local_bloom_bits(const std::string& context_id, std::vector<uint8_t> bits,
size_t expected_elements, size_t num_hashes) {
const std::scoped_lock<std::mutex> lock(mutex_);
local_bloom_bits_[context_id] = std::move(bits);
local_expected_elements_ = expected_elements;
local_num_hashes_ = num_hashes;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

/**
* @brief Get stored local bloom filter bits for a context
*/
[[nodiscard]] std::vector<uint8_t> get_local_bloom_bits(const std::string& context_id) const {
const std::scoped_lock<std::mutex> lock(mutex_);
auto it = local_bloom_bits_.find(context_id);
if (it != local_bloom_bits_.end()) {
return it->second;
}
return {};
}

/**
* @brief Get expected_elements for local bloom filter
*/
[[nodiscard]] size_t get_local_expected_elements() const {
const std::scoped_lock<std::mutex> lock(mutex_);
return local_expected_elements_;
}

/**
* @brief Get num_hashes for local bloom filter
*/
[[nodiscard]] size_t get_local_num_hashes() const {
const std::scoped_lock<std::mutex> lock(mutex_);
return local_num_hashes_;
}

/**
* @brief Clear bloom filter for a context
*/
Expand Down Expand Up @@ -311,6 +350,10 @@ class ClusterManager {
shuffle_buffers_;
/* context_id -> bloom filter data */
std::unordered_map<std::string, BloomFilterEntry> bloom_filters_;
/* context_id -> local bloom filter bits (for aggregation during distributed build) */
std::unordered_map<std::string, std::vector<uint8_t>> local_bloom_bits_;
size_t local_expected_elements_ = 0;
size_t local_num_hashes_ = 0;
mutable std::mutex mutex_;
};

Expand Down
59 changes: 59 additions & 0 deletions include/network/rpc_message.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum class RpcType : uint8_t {
PushData = 9,
ShuffleFragment = 10,
BloomFilterPush = 11,
BloomFilterBits = 12,
Error = 255
};

Expand Down Expand Up @@ -507,6 +508,64 @@ struct BloomFilterArgs {
}
};

/**
* @brief Arguments for sending local bloom filter bits from data node to coordinator
* Used during Phase 1 to collect and aggregate bloom filters from all nodes
*/
struct BloomFilterBitsArgs {
std::string context_id;
std::vector<uint8_t> filter_data;
size_t expected_elements = 0;
size_t num_hashes = 0;

[[nodiscard]] std::vector<uint8_t> serialize() const {
std::vector<uint8_t> out;
Serializer::serialize_string(context_id, out);

// Serialize filter data (blob)
const uint32_t filter_len = static_cast<uint32_t>(filter_data.size());
const size_t off = out.size();
out.resize(off + Serializer::VAL_SIZE_32);
std::memcpy(out.data() + off, &filter_len, Serializer::VAL_SIZE_32);
out.insert(out.end(), filter_data.begin(), filter_data.end());

// Serialize metadata
uint64_t tmp_expected = static_cast<uint64_t>(expected_elements);
uint8_t tmp_hashes = static_cast<uint8_t>(num_hashes);
const size_t off2 = out.size();
out.resize(off2 + 9); // 8 bytes for expected_elements + 1 for num_hashes
std::memcpy(out.data() + off2, &tmp_expected, 8);
out[off2 + 8] = tmp_hashes;
return out;
}

static BloomFilterBitsArgs deserialize(const std::vector<uint8_t>& in) {
BloomFilterBitsArgs args;
size_t offset = 0;
args.context_id = Serializer::deserialize_string(in.data(), offset, in.size());

uint32_t filter_len = 0;
if (offset + Serializer::VAL_SIZE_32 <= in.size()) {
std::memcpy(&filter_len, in.data() + offset, Serializer::VAL_SIZE_32);
offset += Serializer::VAL_SIZE_32;
}
if (offset + filter_len <= in.size()) {
args.filter_data.resize(filter_len);
std::memcpy(args.filter_data.data(), in.data() + offset, filter_len);
offset += filter_len;
}

if (offset + 9 <= in.size()) {
uint64_t tmp_expected = 0;
std::memcpy(&tmp_expected, in.data() + offset, 8);
args.expected_elements = static_cast<size_t>(tmp_expected);
offset += 8;
args.num_hashes = static_cast<size_t>(in[offset]);
}
return args;
}
};

/**
* @brief Arguments for TxnPrepare/Commit/Abort RPC
*/
Expand Down
45 changes: 33 additions & 12 deletions src/distributed/distributed_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,44 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
return res;
}

// After Phase 1, each node will have received left table data.
// Now broadcast bloom filter built from that data to all nodes for Phase 2
// filtering. The filter is sent as a separate RPC that data nodes will store and
// apply to their right table shuffle. For now, we send a simple metadata-only
// filter that signals "filtering enabled" - the actual filter building happens on
// each data node during Phase 1 and they stash it for use during Phase 2.
//
// In production, we'd collect and OR all local bloom filters, but for POC
// we just signal that bloom filtering is enabled for this context.
// After Phase 1, collect bloom filter bits from each data node and aggregate
// via bitwise OR to create the combined bloom filter
std::vector<uint8_t> aggregated_bits;
size_t total_expected = 0;
size_t max_hashes = 0;

for (const auto& node : data_nodes) {
network::RpcClient client(node.address, node.cluster_port);
if (!client.connect()) {
continue;
}
network::BloomFilterBitsArgs bits_args;
bits_args.context_id = context_id;
std::vector<uint8_t> resp;
if (client.call(network::RpcType::BloomFilterBits, bits_args.serialize(),
resp)) {
auto reply = network::BloomFilterBitsArgs::deserialize(resp);
if (reply.filter_data.size() > aggregated_bits.size()) {
aggregated_bits.resize(reply.filter_data.size(), 0);
}
// Bitwise OR aggregation
for (size_t i = 0; i < reply.filter_data.size(); i++) {
aggregated_bits[i] |= reply.filter_data[i];
}
total_expected += reply.expected_elements;
max_hashes = std::max(max_hashes, reply.num_hashes);
}
}

// Broadcast the aggregated bloom filter to all nodes for Phase 2 filtering
network::BloomFilterArgs bf_args;
bf_args.context_id = context_id;
bf_args.build_table = left_table;
bf_args.probe_table = right_table;
bf_args.probe_key_col = right_key; // Tell probe side which column to filter on
bf_args.filter_data.clear(); // Empty = filter built distributed
bf_args.expected_elements = data_nodes.size() * 1000; // Estimate
bf_args.num_hashes = 4;
bf_args.filter_data = aggregated_bits;
bf_args.expected_elements = total_expected;
bf_args.num_hashes = max_hashes > 0 ? max_hashes : 4;
auto bf_payload = bf_args.serialize();

for (const auto& node : data_nodes) {
Expand Down
42 changes: 42 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,33 @@ int main(int argc, char* argv[]) {
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
});

// Handler for collecting local bloom filter bits from data nodes
// Coordinator calls this after Phase 1 to aggregate bloom filters
rpc_server->set_handler(
cloudsql::network::RpcType::BloomFilterBits,
[&](const cloudsql::network::RpcHeader& h, const std::vector<uint8_t>& p,
int fd) {
(void)h;
auto args = cloudsql::network::BloomFilterBitsArgs::deserialize(p);
cloudsql::network::BloomFilterBitsArgs reply_args;
reply_args.context_id = args.context_id;
reply_args.filter_data =
cluster_manager->get_local_bloom_bits(args.context_id);
reply_args.expected_elements =
cluster_manager->get_local_expected_elements();
reply_args.num_hashes = cluster_manager->get_local_num_hashes();

auto resp_p = reply_args.serialize();
cloudsql::network::RpcHeader resp_h;
resp_h.type = cloudsql::network::RpcType::QueryResults;
resp_h.payload_len = static_cast<uint16_t>(resp_p.size());
char h_buf[cloudsql::network::RpcHeader::HEADER_SIZE];
resp_h.encode(h_buf);
static_cast<void>(
send(fd, h_buf, cloudsql::network::RpcHeader::HEADER_SIZE, 0));
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
});
Comment thread
coderabbitai[bot] marked this conversation as resolved.

rpc_server->set_handler(
cloudsql::network::RpcType::ShuffleFragment,
[&](const cloudsql::network::RpcHeader& h, const std::vector<uint8_t>& p,
Expand Down Expand Up @@ -556,11 +583,19 @@ int main(int argc, char* argv[]) {
partitions[node.id] = {};
}

// Estimate expected elements for bloom filter
// For now, estimate based on table size (will be refined with actual
// count)
size_t estimated_count = 1000;
cloudsql::common::BloomFilter local_bloom(estimated_count);

auto iter = table.scan();
cloudsql::storage::HeapTable::TupleMeta t_meta;
while (iter.next_meta(t_meta)) {
if (t_meta.xmax == 0) { // Visible
const auto& key_val = t_meta.tuple.get(key_idx);
// Build bloom filter from join key values
local_bloom.insert(key_val);
uint32_t node_idx =
cloudsql::cluster::ShardManager::compute_shard(
key_val, static_cast<uint32_t>(data_nodes.size()));
Expand All @@ -569,6 +604,13 @@ int main(int argc, char* argv[]) {
}
}

// Store local bloom filter bits for coordinator to collect
// The coordinator will aggregate these during Phase 1
auto bloom_bits = local_bloom.serialize();
cluster_manager->set_local_bloom_bits(args.context_id, bloom_bits,
local_bloom.expected_elements(),
local_bloom.num_hashes());

bool overall_success = true;
std::string delivery_errors;

Expand Down
22 changes: 22 additions & 0 deletions tests/distributed_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,36 @@ TEST(DistributedExecutorTests, ShuffleJoinOrchestration) {
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
};

auto bloom_bits_handler = [&](const RpcHeader& h, const std::vector<uint8_t>& p, int fd) {
(void)h;
auto args = BloomFilterBitsArgs::deserialize(p);
BloomFilterBitsArgs reply_args;
reply_args.context_id = args.context_id;
// Return empty bloom filter bits for mock - real implementation would return actual bits
reply_args.filter_data = {};
reply_args.expected_elements = 0;
reply_args.num_hashes = 4;

auto resp_p = reply_args.serialize();
RpcHeader resp_h;
resp_h.type = RpcType::QueryResults;
resp_h.payload_len = static_cast<uint16_t>(resp_p.size());
char h_buf[RpcHeader::HEADER_SIZE];
resp_h.encode(h_buf);
static_cast<void>(send(fd, h_buf, RpcHeader::HEADER_SIZE, 0));
static_cast<void>(send(fd, resp_p.data(), resp_p.size(), 0));
};

node1.set_handler(RpcType::ShuffleFragment, handler);
node1.set_handler(RpcType::PushData, handler);
node1.set_handler(RpcType::ExecuteFragment, handler);
node1.set_handler(RpcType::BloomFilterPush, handler);
node1.set_handler(RpcType::BloomFilterBits, bloom_bits_handler);
node2.set_handler(RpcType::ShuffleFragment, handler);
node2.set_handler(RpcType::PushData, handler);
node2.set_handler(RpcType::ExecuteFragment, handler);
node2.set_handler(RpcType::BloomFilterPush, handler);
node2.set_handler(RpcType::BloomFilterBits, bloom_bits_handler);

ASSERT_TRUE(node1.start());
ASSERT_TRUE(node2.start());
Expand Down