From 003507333223e0d33d977b9b643f17824601f8b4 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 24 Sep 2025 10:20:12 +0100 Subject: [PATCH 1/2] Longer sgmcmc tests, also check mean and cov --- tests/scenarios.py | 2 +- tests/sgmcmc/test_baoa.py | 2 +- tests/sgmcmc/test_sghmc.py | 2 +- tests/sgmcmc/test_sgld.py | 2 +- tests/sgmcmc/test_sgnht.py | 2 +- tests/sgmcmc/utils.py | 17 +++++++++++++---- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/scenarios.py b/tests/scenarios.py index f8e7cb8..717ab56 100644 --- a/tests/scenarios.py +++ b/tests/scenarios.py @@ -41,7 +41,7 @@ def get_multivariate_normal_log_prob( ) -> tuple[LogProbFn, tuple[torch.Tensor, torch.Tensor]]: mean = torch.randn(dim) sqrt_cov = torch.randn(dim, dim) - cov = sqrt_cov @ sqrt_cov.T + cov = sqrt_cov @ sqrt_cov.T * 0.3 + torch.eye(dim) * 0.7 chol_cov = torch.linalg.cholesky(cov) # Lower triangular with positive diagonal def log_prob(p, batch): diff --git a/tests/sgmcmc/test_baoa.py b/tests/sgmcmc/test_baoa.py index 76ba518..468f4f9 100644 --- a/tests/sgmcmc/test_baoa.py +++ b/tests/sgmcmc/test_baoa.py @@ -10,7 +10,7 @@ def test_baoa(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/test_sghmc.py b/tests/sgmcmc/test_sghmc.py index 321be53..f0bdc24 100644 --- a/tests/sgmcmc/test_sghmc.py +++ b/tests/sgmcmc/test_sghmc.py @@ -10,7 +10,7 @@ def test_sghmc(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/test_sgld.py b/tests/sgmcmc/test_sgld.py index adcc76e..348853e 100644 --- a/tests/sgmcmc/test_sgld.py +++ b/tests/sgmcmc/test_sgld.py @@ -10,7 +10,7 @@ def test_sgld(): torch.manual_seed(42) # Set inference parameters - lr = 1e-3 + lr = 1e-2 beta = 0.0 # Run MCMC test on Gaussian diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py index cc8dd78..74b0d9a 100644 --- a/tests/sgmcmc/test_sgnht.py +++ b/tests/sgmcmc/test_sgnht.py @@ -10,7 +10,7 @@ def test_sgnht(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/utils.py b/tests/sgmcmc/utils.py index 520d5cb..6d0dc3e 100644 --- a/tests/sgmcmc/utils.py +++ b/tests/sgmcmc/utils.py @@ -8,9 +8,9 @@ def run_test_sgmcmc_gaussian( transform_builder: Callable[[LogProbFn], Transform], - dim: int = 3, - n_steps: int = 15_000, - burnin: int = 5_000, + dim: int = 2, + n_steps: int = 100_000, + burnin: int = 15_000, rtol: float = 1e-2, # Relative reduction of KL for final distribution compared to initial distribution ): # Load log posterior @@ -33,6 +33,13 @@ def run_test_sgmcmc_gaussian( # Remove burnin all_states = all_states[burnin:] + samples_mean = all_states.params.mean(0) + samples_cov = all_states.params.T.cov() + + # Check sufficient statistics + assert torch.allclose(samples_mean, mean, atol=1e-1) + assert torch.allclose(samples_cov, cov, atol=1e-1) + # Check KL divergence between true and inferred Gaussian kl_init = utils.kl_gaussians(torch.zeros(dim), torch.eye(dim) * init_var, mean, cov) kl_inferred = utils.kl_gaussians( @@ -57,4 +64,6 @@ def run_test_sgmcmc_gaussian( kl_divs_spaced = kl_divs[start_num_samples::spacing] spaced_decreasing = kl_divs_spaced[:-1] > kl_divs_spaced[1:] proportion_decreasing = spaced_decreasing.float().mean() - assert proportion_decreasing > 0.8 + assert proportion_decreasing > 0.5, ( + f"Proportion decreasing: {proportion_decreasing}" + ) From 3637d9de280f88c246e5bcb6919228c0652f7483 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Wed, 24 Sep 2025 10:29:10 +0100 Subject: [PATCH 2/2] Fix kl spacing test --- tests/sgmcmc/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/sgmcmc/utils.py b/tests/sgmcmc/utils.py index 6d0dc3e..4fcddea 100644 --- a/tests/sgmcmc/utils.py +++ b/tests/sgmcmc/utils.py @@ -61,7 +61,10 @@ def run_test_sgmcmc_gaussian( start_num_samples = 500 # Omit first few KLs with high variance due to few samples spacing = 100 - kl_divs_spaced = kl_divs[start_num_samples::spacing] + + kl_divs_min_idx = torch.argmin(kl_divs[2:]) + 2 + kl_divs_before_min = kl_divs[:kl_divs_min_idx] + kl_divs_spaced = kl_divs_before_min[start_num_samples::spacing] spaced_decreasing = kl_divs_spaced[:-1] > kl_divs_spaced[1:] proportion_decreasing = spaced_decreasing.float().mean() assert proportion_decreasing > 0.5, (