Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/network/rpc_message.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ enum class RpcType : uint8_t {
UnmatchedRowsPush = 14, // Coordinator sends unmatched rows for NULL-padding
FetchUnmatchedRows = 15, // Coordinator fetches stored unmatched rows from data node
// LEFT-side counterparts for FULL join
UnmatchedLeftRowsReport = 16, // Data node reports unmatched LEFT rows for FULL join
FetchUnmatchedLeftRows = 17, // Coordinator fetches stored unmatched LEFT rows
UnmatchedLeftRowsReport = 16, // Data node reports unmatched LEFT rows for FULL join
FetchUnmatchedLeftRows = 17, // Coordinator fetches stored unmatched LEFT rows
Error = 255
};

Expand Down Expand Up @@ -710,9 +710,9 @@ struct FetchUnmatchedRowsArgs {
struct UnmatchedLeftRowsReportArgs {
std::string context_id;
std::string left_table;
std::string join_key_col; // Which column was the join key
std::vector<std::string> unmatched_keys; // LEFT key values that had no match
uint32_t right_column_count = 0; // Number of right columns for NULL-padding
std::string join_key_col; // Which column was the join key
std::vector<std::string> unmatched_keys; // LEFT key values that had no match
uint32_t right_column_count = 0; // Number of right columns for NULL-padding

[[nodiscard]] std::vector<uint8_t> serialize() const {
std::vector<uint8_t> out;
Expand Down
116 changes: 107 additions & 9 deletions src/distributed/distributed_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
bool is_outer_join_join_query = false;
std::string outer_join_left_table;
std::string outer_join_right_table;
std::string outer_join_left_key;
std::string outer_join_right_key;
parser::SelectStatement::JoinType outer_join_type = parser::SelectStatement::JoinType::Inner;

Expand Down Expand Up @@ -234,6 +235,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
is_outer_join_join_query = true;
outer_join_left_table = left_table;
outer_join_right_table = right_table;
outer_join_left_key = left_key;
outer_join_right_key = right_key;
outer_join_type = join.type;
}
Expand Down Expand Up @@ -611,15 +613,15 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
}
}

// Phase 3-5: Currently disabled for all outer joins due to issues with column indexing
// when SELECT doesn't use SELECT * (causes duplicate rows instead of correct results).
// Phase 3-5: For FULL JOIN, collect unmatched LEFT rows from data nodes
// LEFT rows are emitted during probe phase when no match found, but we need to
// COLLECT them from all data nodes for the coordinator's final result.
//
// For RIGHT JOIN: Local executor on each data node correctly handles unmatched right rows.
// For FULL JOIN: Unmatched LEFT rows are not collected (to be implemented in separate PR).
// For RIGHT JOIN: Local executor on each data node correctly handles unmatched right rows
// (no collection needed - each node emits them locally).
//
// TODO: Re-enable Phase 3-5 for FULL JOIN once column indexing is fixed to properly
// identify which rows were unmatched during the distributed join.
if (false && is_outer_join_join_query && all_success) {
// This block is only enabled for FULL JOIN.
if (outer_join_type == parser::SelectStatement::JoinType::Full && all_success) {
// Extract matched right keys from aggregated results
// The right key column is at a known position in the result schema
std::vector<std::string> matched_keys;
Expand All @@ -643,7 +645,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
}
}

// Phase 3: Ask each node to scan local table and store unmatched rows
// Phase 3: Ask each node to scan local right table and store unmatched rows
// First, compute the left column count for NULL-padding
uint32_t left_column_count = 0;
if (!outer_join_left_table.empty()) {
Expand Down Expand Up @@ -713,7 +715,7 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
}));
}

// Aggregate all unmatched rows from all nodes
// Aggregate all unmatched RIGHT rows from all nodes
for (auto& f : fetch_futures) {
auto result = f.get();
if (result.first) {
Expand All @@ -722,6 +724,102 @@ QueryResult DistributedExecutor::execute(const parser::Statement& stmt,
}
}
}

// === LEFT-side Phase 3-4 for FULL JOIN ===
// Extract matched LEFT keys from aggregated results
std::vector<std::string> matched_left_keys;
size_t left_key_idx = static_cast<size_t>(-1);
for (size_t i = 0; i < result_schema.columns().size(); ++i) {
const auto& col = result_schema.columns()[i];
if (col.name() == outer_join_left_key) {
left_key_idx = i;
break;
}
}
if (left_key_idx != static_cast<size_t>(-1)) {
for (const auto& row : aggregated_rows) {
if (row.size() > left_key_idx) {
matched_left_keys.push_back(row.get(left_key_idx).to_string());
}
}
}

// LEFT-side Phase 3: Ask each node to scan local left table and store unmatched rows
uint32_t right_column_count = 0;
if (!outer_join_right_table.empty()) {
auto right_table_info = catalog_.get_table_by_name(outer_join_right_table);
if (right_table_info.has_value()) {
right_column_count = static_cast<uint32_t>((*right_table_info)->columns.size());
}
}

std::vector<std::future<std::pair<bool, network::UnmatchedLeftRowsReportArgs>>>
left_report_futures;

for (const auto& node : data_nodes) {
left_report_futures.push_back(std::async(
std::launch::async, [node, context_id, outer_join_left_table, outer_join_left_key,
matched_left_keys, right_column_count]() {
network::RpcClient client(node.address, node.cluster_port);
network::UnmatchedLeftRowsReportArgs reply;
if (client.connect()) {
network::UnmatchedLeftRowsReportArgs report_args;
report_args.context_id = context_id;
report_args.left_table = outer_join_left_table;
report_args.join_key_col = outer_join_left_key;
report_args.unmatched_keys = matched_left_keys;
report_args.right_column_count = right_column_count;

std::vector<uint8_t> resp;
if (client.call(network::RpcType::UnmatchedLeftRowsReport,
report_args.serialize(), resp)) {
reply = network::UnmatchedLeftRowsReportArgs::deserialize(resp);
return std::make_pair(true, reply);
}
}
return std::make_pair(false, reply);
}));
}

// Wait for all LEFT report futures to complete
for (auto& f : left_report_futures) {
f.get();
}

// LEFT-side Phase 4: Fetch stored unmatched LEFT rows from each node
std::vector<std::future<std::pair<bool, std::vector<executor::Tuple>>>> left_fetch_futures;

for (const auto& node : data_nodes) {
left_fetch_futures.push_back(
std::async(std::launch::async, [node, context_id, outer_join_left_table]() {
network::RpcClient client(node.address, node.cluster_port);
std::vector<executor::Tuple> rows;
if (client.connect()) {
network::FetchUnmatchedLeftRowsArgs fetch_args;
fetch_args.context_id = context_id;
fetch_args.table_name = outer_join_left_table;

std::vector<uint8_t> resp;
if (client.call(network::RpcType::FetchUnmatchedLeftRows,
fetch_args.serialize(), resp)) {
auto reply = network::UnmatchedRowsPushArgs::deserialize(resp);
rows = std::move(reply.unmatched_rows);
return std::make_pair(true, std::move(rows));
}
}
return std::make_pair(false, std::move(rows));
}));
}

// Aggregate all unmatched LEFT rows from all nodes
for (auto& f : left_fetch_futures) {
auto result = f.get();
if (result.first) {
for (auto& row : result.second) {
aggregated_rows.push_back(std::move(row));
}
}
}
}

if (all_success) {
Expand Down