Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/pymatching/sparse_blossom/driver/user_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,16 +385,50 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin
}
}

namespace {

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

} // namespace

void pm::add_decomposed_error_to_joint_probabilities(
DecomposedDemError& error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
if (error.components.size() > 1) {
for (size_t k0 = 0; k0 < error.components.size(); k0++) {
for (size_t k1 = k0 + 1; k1 < error.components.size(); k1++) {
auto& c0 = error.components[k0];
auto& c1 = error.components[k1];
std::pair<size_t, size_t> e0 = std::minmax(c0.node1, c0.node2);
std::pair<size_t, size_t> e1 = std::minmax(c1.node1, c1.node2);
double& p01 = joint_probabilites[e0][e1];
double& p10 = joint_probabilites[e1][e0];
p01 = bernoulli_xor(p01, error.probability);
p10 = bernoulli_xor(p10, error.probability);
}
}
}

for (auto& e : error.components) {
double& p = joint_probabilites[std::minmax(e.node1, e.node2)][std::minmax(e.node1, e.node2)];
p = bernoulli_xor(p, error.probability);
}
};

pm::UserGraph pm::detector_error_model_to_user_graph(
const stim::DetectorErrorModel& detector_error_model, const bool enable_correlations) {
pm::UserGraph user_graph(detector_error_model.count_detectors(), detector_error_model.count_observables());
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilites;
if (enable_correlations) {
// TODO: Support correlated matching.
pm::iter_dem_instructions_include_correlations(
detector_error_model,
[&](double p, const std::vector<size_t>& detectors, std::vector<size_t>& observables) {
return;
});
},
joint_probabilites);
// TODO: Support correlated matching. Add implied edge weights to the User Graph here.
} else {
pm::iter_detector_error_model_edges(
detector_error_model,
Expand Down
23 changes: 19 additions & 4 deletions src/pymatching/sparse_blossom/driver/user_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,15 @@ struct DecomposedDemError {
bool operator!=(const DecomposedDemError& other) const;
};

// TODO: Capture information about correlations.
void add_decomposed_error_to_joint_probabilities(
DecomposedDemError& error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites);

template <typename Handler>
void iter_dem_instructions_include_correlations(
const stim::DetectorErrorModel& detector_error_model, const Handler& handle_dem_error) {
const stim::DetectorErrorModel& detector_error_model,
const Handler& handle_dem_error,
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>>& joint_probabilites) {
detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) {
double p = instruction.arg_data[0];
pm::DecomposedDemError decomposed_err;
Expand All @@ -251,6 +256,10 @@ void iter_dem_instructions_include_correlations(
throw ::std::invalid_argument(
"Errors with probability greater than 0.5 are not supported with correlations enabled");
}
if (p == 0) {
// Ignore errors with no error probability.
return;
}
decomposed_err.components = {};
decomposed_err.components.push_back({});
UserEdge* component = &decomposed_err.components.back();
Expand All @@ -266,7 +275,13 @@ void iter_dem_instructions_include_correlations(
const size_t& d1 = target.raw_id();
component->node1 = d1;
} else if (num_component_detectors == 2) {
component->node2 = target.raw_id();
// Maintain invariant that node1 <= node2.
if (component->node1 <= target.raw_id()) {
component->node2 = target.raw_id();
} else {
component->node2 = component->node1;
component->node1 = target.raw_id();
}
} else {
// We mark errors which have 3 or more detectors as a special boundary-to-boundary edge.
component->node1 = SIZE_MAX;
Expand Down Expand Up @@ -297,7 +312,7 @@ void iter_dem_instructions_include_correlations(
handle_dem_error(p, {component->node1, component->node2}, component->observable_indices);
}

// TODO: Capture information from decomposed_error into correlation data structure here.
add_decomposed_error_to_joint_probabilities(decomposed_err, joint_probabilites);
});
}

Expand Down
168 changes: 149 additions & 19 deletions src/pymatching/sparse_blossom/driver/user_graph.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,81 +215,143 @@ struct TestHandler {
TEST(IterDemInstructionsTest, EmptyDem) {
stim::DetectorErrorModel dem;
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);
ASSERT_TRUE(handler.handled_errors.empty());
ASSERT_TRUE(joint_probabilities.empty());
}

// Test a simple error involving one detector, which is an error on the boundary.
TEST(IterDemInstructionsTest, SingleDetectorErrorToBoundary) {
stim::DetectorErrorModel dem("error(0.1) D0");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

// Check handler calls
ASSERT_EQ(handler.handled_errors.size(), 1);
EXPECT_EQ(handler.handled_errors[0], (HandledError{0.1, 0, SIZE_MAX, {}}));

// Check joint probabilities (marginal probability in this case)
std::pair<size_t, size_t> key = {0, SIZE_MAX};
ASSERT_EQ(joint_probabilities.size(), 1);
ASSERT_EQ(joint_probabilities[key].size(), 1);
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.1);
}

// Test a standard error between two detectors.
TEST(IterDemInstructionsTest, TwoDetectorError) {
stim::DetectorErrorModel dem("error(0.25) D5 D10");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 1);
EXPECT_EQ(handler.handled_errors[0], (HandledError{0.25, 5, 10, {}}));

std::pair<size_t, size_t> key = {5, 10};
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.25);
}

// Test a standard error between two detectors where they are not sorted in the DEM.
TEST(IterDemInstructionsTest, TwoDetectorErrorNotSorted) {
stim::DetectorErrorModel dem("error(0.25) D10 D5");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 1);
EXPECT_EQ(handler.handled_errors[0], (HandledError{0.25, 5, 10, {}}));

std::pair<size_t, size_t> key = {5, 10};
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.25);
}

// Test an error that also flips a logical observable.
TEST(IterDemInstructionsTest, ErrorWithOneObservable) {
stim::DetectorErrorModel dem("error(0.125) D1 D2 L0");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 1);
EXPECT_EQ(handler.handled_errors[0], (HandledError{0.125, 1, 2, {0}}));

std::pair<size_t, size_t> key = {1, 2};
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.125);
}

// Test an error that flips multiple logical observables.
TEST(IterDemInstructionsTest, ErrorWithMultipleObservables) {
stim::DetectorErrorModel dem("error(0.3) D3 D4 L1 L3");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 1);
EXPECT_EQ(handler.handled_errors[0], (HandledError{0.3, 3, 4, {1, 3}}));

std::pair<size_t, size_t> key = {3, 4};
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.3);
}

// Test an error with probability 0. It should be ignored.
TEST(IterDemInstructionsTest, ZeroProbabilityError) {
stim::DetectorErrorModel dem("error(0.0) D0 D1");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);
ASSERT_TRUE(handler.handled_errors.empty());
ASSERT_TRUE(joint_probabilities.empty());
}

// Test an error involving more than two detectors (a hyperedge). This should be ignored.
TEST(IterDemInstructionsTest, ThreeDetectorErrorIsIgnored) {
stim::DetectorErrorModel dem("error(0.1) D0 D1 D2");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);
ASSERT_TRUE(handler.handled_errors.empty());
ASSERT_TRUE(joint_probabilities.empty());
}

// Test a decomposed error instruction. The handler should be called for each component.
TEST(IterDemInstructionsTest, DecomposedError) {
stim::DetectorErrorModel dem("error(0.1) D0 D1 ^ D2 D3 L0 ^ D4");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
ASSERT_EQ(handler.handled_errors.size(), 3);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

EXPECT_EQ(handler.handled_errors[0], (HandledError{0.1, 0, 1, {}}));
EXPECT_EQ(handler.handled_errors[1], (HandledError{0.1, 2, 3, {0}}));
EXPECT_EQ(handler.handled_errors[2], (HandledError{0.1, 4, SIZE_MAX, {}}));
// Check handler calls
ASSERT_EQ(handler.handled_errors.size(), 3);
std::vector<HandledError> expected_handled = {{0.1, 0, 1, {}}, {0.1, 2, 3, {0}}, {0.1, 4, SIZE_MAX, {}}};
EXPECT_EQ(handler.handled_errors, expected_handled);

// Check joint probabilities
std::pair<size_t, size_t> key01 = {0, 1};
std::pair<size_t, size_t> key23 = {2, 3};
std::pair<size_t, size_t> key4B = {4, SIZE_MAX};

// Marginal probabilities
EXPECT_DOUBLE_EQ(joint_probabilities[key01][key01], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key23][key23], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key4B], 0.1);

// Joint probabilities between components
EXPECT_DOUBLE_EQ(joint_probabilities[key01][key23], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key23][key01], 0.1); // Symmetric
EXPECT_DOUBLE_EQ(joint_probabilities[key01][key4B], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key01], 0.1); // Symmetric
EXPECT_DOUBLE_EQ(joint_probabilities[key23][key4B], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key23], 0.1); // Symmetric
}

// Test that a decomposed error with a hyperedge component throws an exception.
TEST(IterDemInstructionsTest, DecomposedErrorWithHyperedgeThrows) {
stim::DetectorErrorModel dem("error(0.15) D0 D1 ^ D2 D3 D4 ^ D5 D6 L2");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;

// Assert that the function throws std::invalid_argument when processing the DEM.
ASSERT_THROW(pm::iter_dem_instructions_include_correlations(dem, handler), std::invalid_argument);
ASSERT_THROW(
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);
}

// Test a complex DEM with multiple instruction types and edge cases combined.
Expand All @@ -302,19 +364,87 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) {
error(0.4) D8 ^ D9 L1 # Instruction 5: Decomposed
)DEM");
TestHandler handler;
pm::iter_dem_instructions_include_correlations(dem, handler);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

ASSERT_EQ(handler.handled_errors.size(), 4);

std::vector<HandledError> expected = {
{0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}};

EXPECT_EQ(handler.handled_errors, expected);

// Check joint probabilities
std::pair<size_t, size_t> key0B = {0, SIZE_MAX};
std::pair<size_t, size_t> key12 = {1, 2};
std::pair<size_t, size_t> key8B = {8, SIZE_MAX};
std::pair<size_t, size_t> key9B = {9, SIZE_MAX};

// Marginal probabilities from each instruction
EXPECT_DOUBLE_EQ(joint_probabilities[key0B][key0B], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[key12][key12], 0.2);
EXPECT_DOUBLE_EQ(joint_probabilities[key8B][key8B], 0.4);
EXPECT_DOUBLE_EQ(joint_probabilities[key9B][key9B], 0.4);

// Joint probability from the last instruction
EXPECT_DOUBLE_EQ(joint_probabilities[key8B][key9B], 0.4);
EXPECT_DOUBLE_EQ(joint_probabilities[key9B][key8B], 0.4);

// Check that there are no other joint probabilities
EXPECT_EQ(joint_probabilities[key0B].count(key12), 0);
}

double bernoulli_xor(double p1, double p2) {
return p1 * (1 - p2) + p2 * (1 - p1);
}

// Tests that multiple error instructions on the same edge correctly combine their probabilities.
TEST(IterDemInstructionsTest, MultipleErrorsOnSameEdgeCombine) {
stim::DetectorErrorModel dem("error(0.1) D0 D1\n error(0.2) D0 D1");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

// Expected probability = 0.1*(1-0.2) + 0.2*(1-0.1) = 0.08 + 0.18 = 0.26
double expected_p = bernoulli_xor(0.1, 0.2);

std::pair<size_t, size_t> key = {0, 1};
EXPECT_DOUBLE_EQ(joint_probabilities[key][key], expected_p);
}

// Tests how marginal and joint probabilities are combined across different decomposed error instructions.
TEST(IterDemInstructionsTest, ComplexCombinationOfErrors) {
stim::DetectorErrorModel dem(R"DEM(
error(0.1) D0 ^ D1
error(0.2) D0 ^ D2
)DEM");
TestHandler handler;
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities);

std::pair<size_t, size_t> k0 = {0, SIZE_MAX};
std::pair<size_t, size_t> k1 = {1, SIZE_MAX};
std::pair<size_t, size_t> k2 = {2, SIZE_MAX};

// Marginal probabilities
EXPECT_DOUBLE_EQ(joint_probabilities[k0][k0], bernoulli_xor(0.1, 0.2)); // 0.26
EXPECT_DOUBLE_EQ(joint_probabilities[k1][k1], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[k2][k2], 0.2);

// Joint probabilities
EXPECT_DOUBLE_EQ(joint_probabilities[k0][k1], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[k1][k0], 0.1);
EXPECT_DOUBLE_EQ(joint_probabilities[k0][k2], 0.2);
EXPECT_DOUBLE_EQ(joint_probabilities[k2][k0], 0.2);

// No instruction connects D1 and D2, so their joint probability should be 0.
EXPECT_DOUBLE_EQ(joint_probabilities[k1][k2], 0.0);
}

// Test that an error greater than 0.5 results in a throw.
TEST(IterDemInstructionsTest, ProbabilityGreaterThanHalfThrows) {
stim::DetectorErrorModel dem("error(0.51) D0 D2");
TestHandler handler;
ASSERT_THROW(pm::iter_dem_instructions_include_correlations(dem, handler), std::invalid_argument);
std::map<std::pair<size_t, size_t>, std::map<std::pair<size_t, size_t>, double>> joint_probabilities;
ASSERT_THROW(
pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument);
}
Loading