From 616350c5d5cb0f7ed50c8570a6277e80057ce106 Mon Sep 17 00:00:00 2001 From: Sid Madhuk Date: Tue, 8 Jul 2025 23:54:00 +0000 Subject: [PATCH 1/3] Track joint probabilities of correlated errors --- .../sparse_blossom/driver/user_graph.cc | 37 ++++- .../sparse_blossom/driver/user_graph.h | 15 +- .../sparse_blossom/driver/user_graph.test.cc | 154 +++++++++++++++--- 3 files changed, 182 insertions(+), 24 deletions(-) diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index b14713984..766297fe0 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -385,16 +385,48 @@ double pm::UserGraph::get_edge_weight_normalising_constant(size_t max_num_distin } } +namespace { + +double bernoulli_xor(double p1, double p2) { + return p1 * (1 - p2) + p2 * (1 - p1); +} + +} // namespace + +void pm::add_decomposed_error_to_joint_probabilities( + DecomposedDemError& error, + std::map, std::map, double>>& joint_probabilites) { + if (error.components.size() > 1) { + for (size_t k0 = 0; k0 < error.components.size(); k0++) { + for (size_t k1 = k0 + 1; k1 < error.components.size(); k1++) { + auto& c0 = error.components[k0]; + auto& c1 = error.components[k1]; + double& p01 = joint_probabilites[{c0.node1, c0.node2}][{c1.node1, c1.node2}]; + double& p10 = joint_probabilites[{c1.node1, c1.node2}][{c0.node1, c0.node2}]; + p01 = bernoulli_xor(p01, error.probability); + p10 = bernoulli_xor(p10, error.probability); + } + } + } + + for (auto& e : error.components) { + double& p = joint_probabilites[{e.node1, e.node2}][{e.node1, e.node2}]; + p = bernoulli_xor(p, error.probability); + } +}; + pm::UserGraph pm::detector_error_model_to_user_graph( const stim::DetectorErrorModel& detector_error_model, const bool enable_correlations) { pm::UserGraph user_graph(detector_error_model.count_detectors(), detector_error_model.count_observables()); + std::map, std::map, double>> joint_probabilites; if (enable_correlations) { - // TODO: Support correlated matching. pm::iter_dem_instructions_include_correlations( detector_error_model, [&](double p, const std::vector& detectors, std::vector& observables) { return; - }); + }, + joint_probabilites); + // TODO: Support correlated matching. Add implied edge weights to the User Graph here. } else { pm::iter_detector_error_model_edges( detector_error_model, @@ -402,5 +434,6 @@ pm::UserGraph pm::detector_error_model_to_user_graph( user_graph.handle_dem_instruction(p, detectors, observables); }); } + return user_graph; } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.h b/src/pymatching/sparse_blossom/driver/user_graph.h index 511dffc03..8cd4b3fdc 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.h +++ b/src/pymatching/sparse_blossom/driver/user_graph.h @@ -239,10 +239,15 @@ struct DecomposedDemError { bool operator!=(const DecomposedDemError& other) const; }; -// TODO: Capture information about correlations. +void add_decomposed_error_to_joint_probabilities( + DecomposedDemError& error, + std::map, std::map, double>>& joint_probabilites); + template void iter_dem_instructions_include_correlations( - const stim::DetectorErrorModel& detector_error_model, const Handler& handle_dem_error) { + const stim::DetectorErrorModel& detector_error_model, + const Handler& handle_dem_error, + std::map, std::map, double>>& joint_probabilites) { detector_error_model.iter_flatten_error_instructions([&](const stim::DemInstruction& instruction) { double p = instruction.arg_data[0]; pm::DecomposedDemError decomposed_err; @@ -251,6 +256,10 @@ void iter_dem_instructions_include_correlations( throw ::std::invalid_argument( "Errors with probability greater than 0.5 are not supported with correlations enabled"); } + if (p == 0) { + // Ignore errors with no error probability. + return; + } decomposed_err.components = {}; decomposed_err.components.push_back({}); UserEdge* component = &decomposed_err.components.back(); @@ -297,7 +306,7 @@ void iter_dem_instructions_include_correlations( handle_dem_error(p, {component->node1, component->node2}, component->observable_indices); } - // TODO: Capture information from decomposed_error into correlation data structure here. + add_decomposed_error_to_joint_probabilities(decomposed_err, joint_probabilites); }); } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.test.cc b/src/pymatching/sparse_blossom/driver/user_graph.test.cc index a3043c265..5ad191272 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.test.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.test.cc @@ -215,81 +215,129 @@ struct TestHandler { TEST(IterDemInstructionsTest, EmptyDem) { stim::DetectorErrorModel dem; TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); ASSERT_TRUE(handler.handled_errors.empty()); + ASSERT_TRUE(joint_probabilities.empty()); } // Test a simple error involving one detector, which is an error on the boundary. TEST(IterDemInstructionsTest, SingleDetectorErrorToBoundary) { stim::DetectorErrorModel dem("error(0.1) D0"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + + // Check handler calls ASSERT_EQ(handler.handled_errors.size(), 1); EXPECT_EQ(handler.handled_errors[0], (HandledError{0.1, 0, SIZE_MAX, {}})); + + // Check joint probabilities (marginal probability in this case) + std::pair key = {0, SIZE_MAX}; + ASSERT_EQ(joint_probabilities.size(), 1); + ASSERT_EQ(joint_probabilities[key].size(), 1); + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.1); } // Test a standard error between two detectors. TEST(IterDemInstructionsTest, TwoDetectorError) { stim::DetectorErrorModel dem("error(0.25) D5 D10"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + ASSERT_EQ(handler.handled_errors.size(), 1); EXPECT_EQ(handler.handled_errors[0], (HandledError{0.25, 5, 10, {}})); + + std::pair key = {5, 10}; + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.25); } // Test an error that also flips a logical observable. TEST(IterDemInstructionsTest, ErrorWithOneObservable) { stim::DetectorErrorModel dem("error(0.125) D1 D2 L0"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + ASSERT_EQ(handler.handled_errors.size(), 1); EXPECT_EQ(handler.handled_errors[0], (HandledError{0.125, 1, 2, {0}})); + + std::pair key = {1, 2}; + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.125); } -// Test an error that flips multiple logical observables. TEST(IterDemInstructionsTest, ErrorWithMultipleObservables) { stim::DetectorErrorModel dem("error(0.3) D3 D4 L1 L3"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + ASSERT_EQ(handler.handled_errors.size(), 1); EXPECT_EQ(handler.handled_errors[0], (HandledError{0.3, 3, 4, {1, 3}})); + + std::pair key = {3, 4}; + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.3); } -// Test an error with probability 0. It should be ignored. TEST(IterDemInstructionsTest, ZeroProbabilityError) { stim::DetectorErrorModel dem("error(0.0) D0 D1"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); ASSERT_TRUE(handler.handled_errors.empty()); + ASSERT_TRUE(joint_probabilities.empty()); } -// Test an error involving more than two detectors (a hyperedge). This should be ignored. TEST(IterDemInstructionsTest, ThreeDetectorErrorIsIgnored) { stim::DetectorErrorModel dem("error(0.1) D0 D1 D2"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); ASSERT_TRUE(handler.handled_errors.empty()); + ASSERT_TRUE(joint_probabilities.empty()); } // Test a decomposed error instruction. The handler should be called for each component. TEST(IterDemInstructionsTest, DecomposedError) { stim::DetectorErrorModel dem("error(0.1) D0 D1 ^ D2 D3 L0 ^ D4"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); - ASSERT_EQ(handler.handled_errors.size(), 3); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); - EXPECT_EQ(handler.handled_errors[0], (HandledError{0.1, 0, 1, {}})); - EXPECT_EQ(handler.handled_errors[1], (HandledError{0.1, 2, 3, {0}})); - EXPECT_EQ(handler.handled_errors[2], (HandledError{0.1, 4, SIZE_MAX, {}})); + // Check handler calls + ASSERT_EQ(handler.handled_errors.size(), 3); + std::vector expected_handled = {{0.1, 0, 1, {}}, {0.1, 2, 3, {0}}, {0.1, 4, SIZE_MAX, {}}}; + EXPECT_EQ(handler.handled_errors, expected_handled); + + // Check joint probabilities + std::pair key01 = {0, 1}; + std::pair key23 = {2, 3}; + std::pair key4B = {4, SIZE_MAX}; + + // Marginal probabilities + EXPECT_DOUBLE_EQ(joint_probabilities[key01][key01], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key23][key23], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key4B], 0.1); + + // Joint probabilities between components + EXPECT_DOUBLE_EQ(joint_probabilities[key01][key23], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key23][key01], 0.1); // Symmetric + EXPECT_DOUBLE_EQ(joint_probabilities[key01][key4B], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key01], 0.1); // Symmetric + EXPECT_DOUBLE_EQ(joint_probabilities[key23][key4B], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key4B][key23], 0.1); // Symmetric } // Test that a decomposed error with a hyperedge component throws an exception. TEST(IterDemInstructionsTest, DecomposedErrorWithHyperedgeThrows) { stim::DetectorErrorModel dem("error(0.15) D0 D1 ^ D2 D3 D4 ^ D5 D6 L2"); TestHandler handler; + std::map, std::map, double>> joint_probabilities; // Assert that the function throws std::invalid_argument when processing the DEM. - ASSERT_THROW(pm::iter_dem_instructions_include_correlations(dem, handler), std::invalid_argument); + ASSERT_THROW( + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument); } // Test a complex DEM with multiple instruction types and edge cases combined. @@ -302,19 +350,87 @@ TEST(IterDemInstructionsTest, CombinedComplexDem) { error(0.4) D8 ^ D9 L1 # Instruction 5: Decomposed )DEM"); TestHandler handler; - pm::iter_dem_instructions_include_correlations(dem, handler); + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); ASSERT_EQ(handler.handled_errors.size(), 4); std::vector expected = { {0.1, 0, SIZE_MAX, {}}, {0.2, 1, 2, {0}}, {0.4, 8, SIZE_MAX, {}}, {0.4, 9, SIZE_MAX, {1}}}; - EXPECT_EQ(handler.handled_errors, expected); + + // Check joint probabilities + std::pair key0B = {0, SIZE_MAX}; + std::pair key12 = {1, 2}; + std::pair key8B = {8, SIZE_MAX}; + std::pair key9B = {9, SIZE_MAX}; + + // Marginal probabilities from each instruction + EXPECT_DOUBLE_EQ(joint_probabilities[key0B][key0B], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[key12][key12], 0.2); + EXPECT_DOUBLE_EQ(joint_probabilities[key8B][key8B], 0.4); + EXPECT_DOUBLE_EQ(joint_probabilities[key9B][key9B], 0.4); + + // Joint probability from the last instruction + EXPECT_DOUBLE_EQ(joint_probabilities[key8B][key9B], 0.4); + EXPECT_DOUBLE_EQ(joint_probabilities[key9B][key8B], 0.4); + + // Check that there are no other joint probabilities + EXPECT_EQ(joint_probabilities[key0B].count(key12), 0); +} + +double bernoulli_xor(double p1, double p2) { + return p1 * (1 - p2) + p2 * (1 - p1); +} + +// Tests that multiple error instructions on the same edge correctly combine their probabilities. +TEST(IterDemInstructionsTest, MultipleErrorsOnSameEdgeCombine) { + stim::DetectorErrorModel dem("error(0.1) D0 D1\n error(0.2) D0 D1"); + TestHandler handler; + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + + // Expected probability = 0.1*(1-0.2) + 0.2*(1-0.1) = 0.08 + 0.18 = 0.26 + double expected_p = bernoulli_xor(0.1, 0.2); + + std::pair key = {0, 1}; + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], expected_p); +} + +// Tests how marginal and joint probabilities are combined across different decomposed error instructions. +TEST(IterDemInstructionsTest, ComplexCombinationOfErrors) { + stim::DetectorErrorModel dem(R"DEM( + error(0.1) D0 ^ D1 + error(0.2) D0 ^ D2 + )DEM"); + TestHandler handler; + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + + std::pair k0 = {0, SIZE_MAX}; + std::pair k1 = {1, SIZE_MAX}; + std::pair k2 = {2, SIZE_MAX}; + + // Marginal probabilities + EXPECT_DOUBLE_EQ(joint_probabilities[k0][k0], bernoulli_xor(0.1, 0.2)); // 0.26 + EXPECT_DOUBLE_EQ(joint_probabilities[k1][k1], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[k2][k2], 0.2); + + // Joint probabilities + EXPECT_DOUBLE_EQ(joint_probabilities[k0][k1], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[k1][k0], 0.1); + EXPECT_DOUBLE_EQ(joint_probabilities[k0][k2], 0.2); + EXPECT_DOUBLE_EQ(joint_probabilities[k2][k0], 0.2); + + // No instruction connects D1 and D2, so their joint probability should be 0. + EXPECT_DOUBLE_EQ(joint_probabilities[k1][k2], 0.0); } // Test that an error greater than 0.5 results in a throw. TEST(IterDemInstructionsTest, ProbabilityGreaterThanHalfThrows) { stim::DetectorErrorModel dem("error(0.51) D0 D2"); TestHandler handler; - ASSERT_THROW(pm::iter_dem_instructions_include_correlations(dem, handler), std::invalid_argument); + std::map, std::map, double>> joint_probabilities; + ASSERT_THROW( + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities), std::invalid_argument); } From 28bea911801d9128d96e29727dce0fa9fb66e716 Mon Sep 17 00:00:00 2001 From: Sid Madhuk Date: Wed, 9 Jul 2025 18:49:54 +0000 Subject: [PATCH 2/3] clenaup: remove extra newline --- src/pymatching/sparse_blossom/driver/user_graph.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index 766297fe0..2b7f2312c 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -434,6 +434,5 @@ pm::UserGraph pm::detector_error_model_to_user_graph( user_graph.handle_dem_instruction(p, detectors, observables); }); } - return user_graph; } From 60bf4347ec5e706b663699afdd14a8cf63c210d4 Mon Sep 17 00:00:00 2001 From: Sid Madhuk Date: Wed, 9 Jul 2025 23:50:36 +0000 Subject: [PATCH 3/3] pr feedback: keep edges sorted in the user_graph and joint_probabilities --- src/pymatching/sparse_blossom/driver/user_graph.cc | 8 +++++--- src/pymatching/sparse_blossom/driver/user_graph.h | 8 +++++++- .../sparse_blossom/driver/user_graph.test.cc | 14 ++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index 2b7f2312c..ac598548e 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -401,8 +401,10 @@ void pm::add_decomposed_error_to_joint_probabilities( for (size_t k1 = k0 + 1; k1 < error.components.size(); k1++) { auto& c0 = error.components[k0]; auto& c1 = error.components[k1]; - double& p01 = joint_probabilites[{c0.node1, c0.node2}][{c1.node1, c1.node2}]; - double& p10 = joint_probabilites[{c1.node1, c1.node2}][{c0.node1, c0.node2}]; + std::pair e0 = std::minmax(c0.node1, c0.node2); + std::pair e1 = std::minmax(c1.node1, c1.node2); + double& p01 = joint_probabilites[e0][e1]; + double& p10 = joint_probabilites[e1][e0]; p01 = bernoulli_xor(p01, error.probability); p10 = bernoulli_xor(p10, error.probability); } @@ -410,7 +412,7 @@ void pm::add_decomposed_error_to_joint_probabilities( } for (auto& e : error.components) { - double& p = joint_probabilites[{e.node1, e.node2}][{e.node1, e.node2}]; + double& p = joint_probabilites[std::minmax(e.node1, e.node2)][std::minmax(e.node1, e.node2)]; p = bernoulli_xor(p, error.probability); } }; diff --git a/src/pymatching/sparse_blossom/driver/user_graph.h b/src/pymatching/sparse_blossom/driver/user_graph.h index 8cd4b3fdc..e96d75e34 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.h +++ b/src/pymatching/sparse_blossom/driver/user_graph.h @@ -275,7 +275,13 @@ void iter_dem_instructions_include_correlations( const size_t& d1 = target.raw_id(); component->node1 = d1; } else if (num_component_detectors == 2) { - component->node2 = target.raw_id(); + // Maintain invariant that node1 <= node2. + if (component->node1 <= target.raw_id()) { + component->node2 = target.raw_id(); + } else { + component->node2 = component->node1; + component->node1 = target.raw_id(); + } } else { // We mark errors which have 3 or more detectors as a special boundary-to-boundary edge. component->node1 = SIZE_MAX; diff --git a/src/pymatching/sparse_blossom/driver/user_graph.test.cc b/src/pymatching/sparse_blossom/driver/user_graph.test.cc index 5ad191272..26b909e1a 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.test.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.test.cc @@ -253,6 +253,20 @@ TEST(IterDemInstructionsTest, TwoDetectorError) { EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.25); } +// Test a standard error between two detectors where they are not sorted in the DEM. +TEST(IterDemInstructionsTest, TwoDetectorErrorNotSorted) { + stim::DetectorErrorModel dem("error(0.25) D10 D5"); + TestHandler handler; + std::map, std::map, double>> joint_probabilities; + pm::iter_dem_instructions_include_correlations(dem, handler, joint_probabilities); + + ASSERT_EQ(handler.handled_errors.size(), 1); + EXPECT_EQ(handler.handled_errors[0], (HandledError{0.25, 5, 10, {}})); + + std::pair key = {5, 10}; + EXPECT_DOUBLE_EQ(joint_probabilities[key][key], 0.25); +} + // Test an error that also flips a logical observable. TEST(IterDemInstructionsTest, ErrorWithOneObservable) { stim::DetectorErrorModel dem("error(0.125) D1 D2 L0");