diff --git a/tests/scenarios.py b/tests/scenarios.py index f8e7cb8..717ab56 100644 --- a/tests/scenarios.py +++ b/tests/scenarios.py @@ -41,7 +41,7 @@ def get_multivariate_normal_log_prob( ) -> tuple[LogProbFn, tuple[torch.Tensor, torch.Tensor]]: mean = torch.randn(dim) sqrt_cov = torch.randn(dim, dim) - cov = sqrt_cov @ sqrt_cov.T + cov = sqrt_cov @ sqrt_cov.T * 0.3 + torch.eye(dim) * 0.7 chol_cov = torch.linalg.cholesky(cov) # Lower triangular with positive diagonal def log_prob(p, batch): diff --git a/tests/sgmcmc/test_baoa.py b/tests/sgmcmc/test_baoa.py index 76ba518..468f4f9 100644 --- a/tests/sgmcmc/test_baoa.py +++ b/tests/sgmcmc/test_baoa.py @@ -10,7 +10,7 @@ def test_baoa(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/test_sghmc.py b/tests/sgmcmc/test_sghmc.py index 321be53..f0bdc24 100644 --- a/tests/sgmcmc/test_sghmc.py +++ b/tests/sgmcmc/test_sghmc.py @@ -10,7 +10,7 @@ def test_sghmc(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/test_sgld.py b/tests/sgmcmc/test_sgld.py index adcc76e..348853e 100644 --- a/tests/sgmcmc/test_sgld.py +++ b/tests/sgmcmc/test_sgld.py @@ -10,7 +10,7 @@ def test_sgld(): torch.manual_seed(42) # Set inference parameters - lr = 1e-3 + lr = 1e-2 beta = 0.0 # Run MCMC test on Gaussian diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py index cc8dd78..74b0d9a 100644 --- a/tests/sgmcmc/test_sgnht.py +++ b/tests/sgmcmc/test_sgnht.py @@ -10,7 +10,7 @@ def test_sgnht(): torch.manual_seed(42) # Set inference parameters (with torch.optim.SGD parameterization) - lr = 1e-3 + lr = 1e-2 mu = 0.9 tau = 0.9 diff --git a/tests/sgmcmc/utils.py b/tests/sgmcmc/utils.py index 520d5cb..4fcddea 100644 --- a/tests/sgmcmc/utils.py +++ b/tests/sgmcmc/utils.py @@ -8,9 +8,9 @@ def run_test_sgmcmc_gaussian( transform_builder: Callable[[LogProbFn], Transform], - dim: int = 3, - n_steps: int = 15_000, - burnin: int = 5_000, + dim: int = 2, + n_steps: int = 100_000, + burnin: int = 15_000, rtol: float = 1e-2, # Relative reduction of KL for final distribution compared to initial distribution ): # Load log posterior @@ -33,6 +33,13 @@ def run_test_sgmcmc_gaussian( # Remove burnin all_states = all_states[burnin:] + samples_mean = all_states.params.mean(0) + samples_cov = all_states.params.T.cov() + + # Check sufficient statistics + assert torch.allclose(samples_mean, mean, atol=1e-1) + assert torch.allclose(samples_cov, cov, atol=1e-1) + # Check KL divergence between true and inferred Gaussian kl_init = utils.kl_gaussians(torch.zeros(dim), torch.eye(dim) * init_var, mean, cov) kl_inferred = utils.kl_gaussians( @@ -54,7 +61,12 @@ def run_test_sgmcmc_gaussian( start_num_samples = 500 # Omit first few KLs with high variance due to few samples spacing = 100 - kl_divs_spaced = kl_divs[start_num_samples::spacing] + + kl_divs_min_idx = torch.argmin(kl_divs[2:]) + 2 + kl_divs_before_min = kl_divs[:kl_divs_min_idx] + kl_divs_spaced = kl_divs_before_min[start_num_samples::spacing] spaced_decreasing = kl_divs_spaced[:-1] > kl_divs_spaced[1:] proportion_decreasing = spaced_decreasing.float().mean() - assert proportion_decreasing > 0.8 + assert proportion_decreasing > 0.5, ( + f"Proportion decreasing: {proportion_decreasing}" + )