From 23c3decd75168006b8d69c07b88555937809282a Mon Sep 17 00:00:00 2001 From: Maarten Marsman Date: Thu, 14 May 2026 21:46:50 +0200 Subject: [PATCH] refactor: align MH proposal-SD adaptation across OMRF/GGM/MixedMRF - Centralize stage-3b RM weight in WarmupSchedule::rm_weight_for_proposal_sd. - Wire GGM and MixedMRF MH paths to MetropolisAdaptationController, mirroring OMRF. update_* return per-slot accept_prob; do_one_metropolis_step collects them and calls the adapter(s) once per iteration. - Decouple NUTS step-size target (default 0.80) from MH componentwise RW MH target (0.44) in sample_{omrf,ggm,mixed}.cpp. - Switch GGM proposal_sds_ and MixedMRF proposal_sd_main_continuous_ from arma::vec to (N, 1) arma::mat so the adapter can hold an arma::mat&. OMRF/GGM/MixedMRF MH samples bitwise identical to main; MH vs NUTS per-parameter posterior means agree across all three models (cor > 0.999). --- src/mcmc/execution/warmup_schedule.h | 20 ++++ src/models/ggm/ggm_model.cpp | 53 +++++---- src/models/ggm/ggm_model.h | 43 +++++--- src/models/mixed/mixed_mrf_metropolis.cpp | 48 ++++----- src/models/mixed/mixed_mrf_model.cpp | 124 +++++++++++++++++++--- src/models/mixed/mixed_mrf_model.h | 73 ++++++++++--- src/models/omrf/omrf_model.cpp | 8 +- src/sample_ggm.cpp | 13 ++- src/sample_mixed.cpp | 13 ++- src/sample_omrf.cpp | 15 ++- 10 files changed, 298 insertions(+), 112 deletions(-) diff --git a/src/mcmc/execution/warmup_schedule.h b/src/mcmc/execution/warmup_schedule.h index b4e50086..84af6fbf 100644 --- a/src/mcmc/execution/warmup_schedule.h +++ b/src/mcmc/execution/warmup_schedule.h @@ -2,6 +2,8 @@ #include #include +#include +#include /** @@ -167,6 +169,24 @@ struct WarmupSchedule { return learn_proposal_sd && !stage3b_skipped && in_stage3b(i); } + /// Robbins-Monro decay rate for proposal-SD adaptation. Single source of + /// truth; every model's tune_proposal_sd consults this. + static constexpr double proposal_sd_rm_decay = 0.75; + + /// Robbins-Monro weight for proposal-SD adaptation at the given iteration. + /// + /// Returns the RM weight (1-indexed since stage 3b began, decay + /// `proposal_sd_rm_decay`) iff `adapt_proposal_sd(iter)` is true. Returns + /// nullopt otherwise. Use this in every `tune_proposal_sd` rather than + /// computing the weight inline — keeps the policy (which iterations adapt, + /// what the weight schedule looks like) in one place and makes stage-3b + /// adaptation sampler-agnostic by construction. + std::optional rm_weight_for_proposal_sd(int iter) const { + if (!adapt_proposal_sd(iter)) return std::nullopt; + const double t = static_cast(iter - stage3b_start + 1); + return std::pow(t, -proposal_sd_rm_decay); + } + /// Current Stage-2 window index (-1 outside Stage-2) int current_window(int i) const { for (size_t k = 0; k < window_ends.size(); ++k) diff --git a/src/models/ggm/ggm_model.cpp b/src/models/ggm/ggm_model.cpp index 4ed455b3..82cba081 100644 --- a/src/models/ggm/ggm_model.cpp +++ b/src/models/ggm/ggm_model.cpp @@ -673,10 +673,10 @@ double GGMModel::log_density_impl_diag(size_t j) const { } -void GGMModel::update_edge_parameter(size_t i, size_t j, int iteration) { +double GGMModel::update_edge_parameter(size_t i, size_t j) { if (edge_indicators_(i, j) == 0) { - return; // Edge is not included; skip update + return 0.0; // Edge is not included; skip update (AR irrelevant, masked out) } get_constants(i, j); @@ -721,12 +721,7 @@ void GGMModel::update_edge_parameter(size_t i, size_t j, int iteration) { cholesky_update_after_edge(omega_ij_old, omega_jj_old, i, j); } - // Robbins-Monro proposal-SD adaptation (warmup only) - if (iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); - proposal_sds_(e) = update_proposal_sd_with_robbins_monro( - proposal_sds_(e), ln_alpha, rm_weight, target_accept_); - } + return std::min(1.0, std::exp(ln_alpha)); } void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_old, size_t i, size_t j) @@ -768,7 +763,7 @@ void GGMModel::cholesky_update_after_edge(double omega_ij_old, double omega_jj_o } -void GGMModel::update_diagonal_parameter(size_t i, int iteration) { +double GGMModel::update_diagonal_parameter(size_t i) { double logdet_omega = cholesky_helpers::get_log_det(cholesky_of_precision_); double logdet_omega_sub_ii = logdet_omega + MY_LOG(covariance_matrix_(i, i)); @@ -793,12 +788,7 @@ void GGMModel::update_diagonal_parameter(size_t i, int iteration) { cholesky_update_after_diag(omega_ii, i); } - // Robbins-Monro proposal-SD adaptation (warmup only) - if (iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); - proposal_sds_(e) = update_proposal_sd_with_robbins_monro( - proposal_sds_(e), ln_alpha, rm_weight, target_accept_); - } + return std::min(1.0, std::exp(ln_alpha)); } void GGMModel::cholesky_update_after_diag(double omega_ii_old, size_t i) @@ -955,22 +945,41 @@ void GGMModel::update_edge_indicator_parameter_pair(size_t i, size_t j) { } void GGMModel::do_one_metropolis_step(int iteration) { + // Collect per-slot accept probabilities for the Robbins-Monro adapter. + // proposal_sds_ is stored as a flat dim_-length vec indexed by the + // upper-triangle scheme `e = j * (j + 1) / 2 + i`; we mirror that here + // as a dim_ x 1 matrix. + arma::mat accept_prob(dim_, 1, arma::fill::zeros); + arma::umat index_mask(dim_, 1, arma::fill::zeros); // Update off-diagonals (upper triangle) for (size_t i = 0; i < p_ - 1; ++i) { for (size_t j = i + 1; j < p_; ++j) { - update_edge_parameter(i, j, iteration); + double ap = update_edge_parameter(i, j); + if (edge_indicators_(i, j) == 1) { + size_t e = j * (j + 1) / 2 + i; + accept_prob(e, 0) = ap; + index_mask(e, 0) = 1; + } } } // Update diagonals for (size_t i = 0; i < p_; ++i) { - update_diagonal_parameter(i, iteration); + double ap = update_diagonal_parameter(i); + size_t e = i * (i + 3) / 2; + accept_prob(e, 0) = ap; + index_mask(e, 0) = 1; + } + + if (metropolis_adapter_) { + metropolis_adapter_->update(index_mask, accept_prob, iteration); } } void GGMModel::init_metropolis_adaptation(const WarmupSchedule& schedule) { - total_warmup_ = schedule.total_warmup; + metropolis_adapter_ = std::make_unique( + proposal_sds_, schedule, target_accept_); } void GGMModel::prepare_iteration() { @@ -1000,12 +1009,10 @@ void GGMModel::update_edge_indicators() { } void GGMModel::tune_proposal_sd(int iteration, const WarmupSchedule& schedule) { - if (!schedule.adapt_proposal_sd(iteration)) return; - + auto rm_weight_opt = schedule.rm_weight_for_proposal_sd(iteration); + if (!rm_weight_opt) return; + const double rm_weight = *rm_weight_opt; const double target_accept = target_accept_; - const double rm_decay = 0.75; - double t = iteration - schedule.stage3b_start + 1; - double rm_weight = std::pow(t, -rm_decay); // Off-diagonal sweeps for (size_t i = 0; i < p_ - 1; ++i) { diff --git a/src/models/ggm/ggm_model.h b/src/models/ggm/ggm_model.h index e9f4f8d1..367e3ae3 100644 --- a/src/models/ggm/ggm_model.h +++ b/src/models/ggm/ggm_model.h @@ -8,6 +8,7 @@ #include "models/ggm/graph_constraint_structure.h" #include "models/ggm/ggm_gradient.h" #include "priors/parameter_prior.h" +#include "mcmc/samplers/metropolis_adaptation.h" /** @@ -60,7 +61,7 @@ class GGMModel : public BaseModel { edge_indicators_(initial_edge_indicators), vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - proposal_sds_(arma::vec(dim_, arma::fill::ones) * 0.25), + proposal_sds_(arma::mat(dim_, 1, arma::fill::ones) * 0.25), num_pairwise_(p_ * (p_ - 1) / 2), observations_(na_impute ? observations : arma::mat()), precision_proposal_(arma::mat(p_, p_, arma::fill::none)) @@ -108,7 +109,7 @@ class GGMModel : public BaseModel { edge_indicators_(initial_edge_indicators), vectorized_parameters_(dim_), vectorized_indicator_parameters_(edge_selection_ ? dim_ : 0), - proposal_sds_(arma::vec(dim_, arma::fill::ones) * 0.25), + proposal_sds_(arma::mat(dim_, 1, arma::fill::ones) * 0.25), num_pairwise_(p_ * (p_ - 1) / 2), precision_proposal_(arma::mat(p_, p_, arma::fill::none)) { @@ -139,7 +140,6 @@ class GGMModel : public BaseModel { vectorized_parameters_(other.vectorized_parameters_), vectorized_indicator_parameters_(other.vectorized_indicator_parameters_), proposal_sds_(other.proposal_sds_), - total_warmup_(other.total_warmup_), shuffled_edge_order_(other.shuffled_edge_order_), num_pairwise_(other.num_pairwise_), rng_(other.rng_), @@ -191,9 +191,6 @@ class GGMModel : public BaseModel { edge_selection_active_ = active; } - /** Store warmup length for Robbins-Monro proposal-SD adaptation. */ - void init_metropolis_adaptation(const WarmupSchedule& schedule) override; - /** * Set the Robbins-Monro target acceptance rate used by the * adaptive-Metropolis updates of this GGM. Honoured by all @@ -203,6 +200,13 @@ class GGMModel : public BaseModel { target_accept_ = target; } + /** + * Construct Robbins-Monro adaptation controller for the per-iteration + * MH proposal SDs. Called once by MetropolisSampler before warmup; + * under NUTS this is never called and the controller stays null. + */ + void init_metropolis_adaptation(const WarmupSchedule& schedule) override; + /** Shuffle edge visit order (random scan). */ void prepare_iteration() override; @@ -456,6 +460,10 @@ class GGMModel : public BaseModel { // to 0.44 (componentwise random-walk Metropolis optimum). double target_accept_ = 0.44; + /// Per-iteration adaptation controller (MH mode only — under NUTS this + /// stays null and the stage-3b path in tune_proposal_sd is used instead). + std::unique_ptr metropolis_adapter_; + /** Extract upper triangle of the precision matrix into a vector. */ arma::vec extract_upper_triangle() const { arma::vec result(dim_); @@ -500,10 +508,10 @@ class GGMModel : public BaseModel { /// Pre-allocated storage returned by get_vectorized_indicator_parameters(). arma::ivec vectorized_indicator_parameters_; - /// Proposal standard deviations for Metropolis updates (one per element). - arma::vec proposal_sds_; - /// Total number of warmup iterations (for Robbins-Monro adaptation). - int total_warmup_ = 0; + /// Proposal standard deviations for Metropolis updates (one per element, + /// stored as a (dim_, 1) matrix so it can be wrapped by + /// MetropolisAdaptationController). + arma::mat proposal_sds_; /// Shuffled edge visit order for random-scan edge selection. arma::uvec shuffled_edge_order_; @@ -562,21 +570,22 @@ class GGMModel : public BaseModel { * on an unconstrained reparameterization. Accepts or rejects with a * Metropolis ratio using the Gaussian likelihood and Cauchy prior. * - * @param i Row index (i < j) - * @param j Column index - * @param iteration Current iteration (for Robbins-Monro adaptation) + * @param i Row index (i < j) + * @param j Column index + * @return Metropolis acceptance probability min(1, exp(ln_alpha)), + * or 0.0 if the edge is inactive (caller masks it out). */ - void update_edge_parameter(size_t i, size_t j, int iteration); + double update_edge_parameter(size_t i, size_t j); /** * Propose a new diagonal precision entry on the log scale. * Accepts or rejects with a Metropolis ratio using the Gaussian * likelihood, a Gamma(1,1) prior, and a Jacobian correction. * - * @param i Diagonal index - * @param iteration Current iteration (for Robbins-Monro adaptation) + * @param i Diagonal index + * @return Metropolis acceptance probability min(1, exp(ln_alpha)). */ - void update_diagonal_parameter(size_t i, int iteration); + double update_diagonal_parameter(size_t i); /** * Metropolis-Hastings add-delete move for an edge indicator. diff --git a/src/models/mixed/mixed_mrf_metropolis.cpp b/src/models/mixed/mixed_mrf_metropolis.cpp index e4850eb9..a323b2f0 100644 --- a/src/models/mixed/mixed_mrf_metropolis.cpp +++ b/src/models/mixed/mixed_mrf_metropolis.cpp @@ -17,7 +17,7 @@ // The accept/reject uses log_marginal_omrf(s) + beta-type prior. // ============================================================================= -void MixedMRFModel::update_main_effect(int s, int c, int iteration) { +double MixedMRFModel::update_main_effect(int s, int c, std::optional rm_weight) { double& current = main_effects_discrete_(s, c); double proposal_sd = proposal_sd_main_discrete_(s, c); @@ -39,11 +39,11 @@ void MixedMRFModel::update_main_effect(int s, int c, int iteration) { current = current_val; // reject } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_main_discrete_(s, c) = update_proposal_sd_with_robbins_monro( - proposal_sd_main_discrete_(s, c), ln_alpha, rm_weight, target_accept_); + proposal_sd_main_discrete_(s, c), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } @@ -55,7 +55,7 @@ void MixedMRFModel::update_main_effect(int s, int c, int iteration) { // Must save/restore conditional_mean_ around the proposal. // ============================================================================= -void MixedMRFModel::update_continuous_mean(int j, int iteration) { +double MixedMRFModel::update_continuous_mean(int j, std::optional rm_weight) { double current_val = main_effects_continuous_(j); double proposed = rnorm(rng_, current_val, proposal_sd_main_continuous_(j)); @@ -80,11 +80,11 @@ void MixedMRFModel::update_continuous_mean(int j, int iteration) { conditional_mean_ = std::move(cond_mean_saved); } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_main_continuous_(j) = update_proposal_sd_with_robbins_monro( - proposal_sd_main_continuous_(j), ln_alpha, rm_weight, target_accept_); + proposal_sd_main_continuous_(j), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } @@ -96,7 +96,7 @@ void MixedMRFModel::update_continuous_mean(int j, int iteration) { // Acceptance: log_marginal_omrf(i) + log_marginal_omrf(j) + Cauchy prior. // ============================================================================= -void MixedMRFModel::update_pairwise_discrete(int i, int j, int iteration) { +double MixedMRFModel::update_pairwise_discrete(int i, int j, std::optional rm_weight) { double current_val = pairwise_effects_discrete_(i, j); double proposed = rnorm(rng_, current_val, proposal_sd_pairwise_discrete_(i, j)); @@ -119,11 +119,11 @@ void MixedMRFModel::update_pairwise_discrete(int i, int j, int iteration) { recompute_marginal_interactions(); } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_pairwise_discrete_(i, j) = update_proposal_sd_with_robbins_monro( - proposal_sd_pairwise_discrete_(i, j), ln_alpha, rm_weight, target_accept_); + proposal_sd_pairwise_discrete_(i, j), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } @@ -378,7 +378,7 @@ void MixedMRFModel::cholesky_update_after_precision_diag(double old_ii, int i) { // Storage: pairwise_effects_continuous_ = -1/2 * precision. // ============================================================================= -void MixedMRFModel::update_pairwise_effects_continuous_offdiag(int i, int j, int iteration) { +double MixedMRFModel::update_pairwise_effects_continuous_offdiag(int i, int j, std::optional rm_weight) { get_precision_constants(i, j); double phi_curr = cont_constants_[0]; // Phi_q1q @@ -444,11 +444,11 @@ void MixedMRFModel::update_pairwise_effects_continuous_offdiag(int i, int j, int recompute_marginal_interactions(); } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_pairwise_continuous_(i, j) = update_proposal_sd_with_robbins_monro( - proposal_sd_pairwise_continuous_(i, j), ln_alpha, rm_weight, target_accept_); + proposal_sd_pairwise_continuous_(i, j), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } @@ -461,7 +461,7 @@ void MixedMRFModel::update_pairwise_effects_continuous_offdiag(int i, int j, int // Prior: Gamma(1, 1) on negative diagonal + Jacobian for log-scale proposal. // ============================================================================= -void MixedMRFModel::update_pairwise_effects_continuous_diag(int i, int iteration) { +double MixedMRFModel::update_pairwise_effects_continuous_diag(int i, std::optional rm_weight) { double logdet = cholesky_helpers::get_log_det(cholesky_of_precision_); double logdet_sub_ii = logdet + MY_LOG(covariance_continuous_(i, i)); @@ -512,11 +512,11 @@ void MixedMRFModel::update_pairwise_effects_continuous_diag(int i, int iteration recompute_marginal_interactions(); } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_pairwise_continuous_(i, i) = update_proposal_sd_with_robbins_monro( - proposal_sd_pairwise_continuous_(i, i), ln_alpha, rm_weight, target_accept_); + proposal_sd_pairwise_continuous_(i, i), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } @@ -528,7 +528,7 @@ void MixedMRFModel::update_pairwise_effects_continuous_diag(int i, int iteration // Must save/restore conditional_mean_ and marginal_interactions_ around the proposal. // ============================================================================= -void MixedMRFModel::update_pairwise_cross(int i, int j, int iteration) { +double MixedMRFModel::update_pairwise_cross(int i, int j, std::optional rm_weight) { double current_val = pairwise_effects_cross_(i, j); double proposed = rnorm(rng_, current_val, proposal_sd_pairwise_cross_(i, j)); @@ -558,11 +558,11 @@ void MixedMRFModel::update_pairwise_cross(int i, int j, int iteration) { marginal_interactions_ = std::move(marginal_saved); } - if(iteration >= 1 && iteration < total_warmup_) { - double rm_weight = std::pow(iteration, -0.75); + if (rm_weight) { proposal_sd_pairwise_cross_(i, j) = update_proposal_sd_with_robbins_monro( - proposal_sd_pairwise_cross_(i, j), ln_alpha, rm_weight, target_accept_); + proposal_sd_pairwise_cross_(i, j), ln_alpha, *rm_weight, target_accept_); } + return std::min(1.0, std::exp(ln_alpha)); } diff --git a/src/models/mixed/mixed_mrf_model.cpp b/src/models/mixed/mixed_mrf_model.cpp index 50529dd0..55676699 100644 --- a/src/models/mixed/mixed_mrf_model.cpp +++ b/src/models/mixed/mixed_mrf_model.cpp @@ -74,7 +74,7 @@ MixedMRFModel::MixedMRFModel( // Initialize proposal SDs proposal_sd_main_discrete_ = arma::ones(p_, max_cats_); - proposal_sd_main_continuous_ = arma::ones(q_); + proposal_sd_main_continuous_ = arma::ones(q_, 1); proposal_sd_pairwise_discrete_ = arma::ones(p_, p_); proposal_sd_pairwise_continuous_ = arma::ones(q_, q_); proposal_sd_pairwise_cross_ = arma::ones(p_, q_); @@ -162,7 +162,6 @@ MixedMRFModel::MixedMRFModel(const MixedMRFModel& other) proposal_sd_pairwise_discrete_(other.proposal_sd_pairwise_discrete_), proposal_sd_pairwise_continuous_(other.proposal_sd_pairwise_continuous_), proposal_sd_pairwise_cross_(other.proposal_sd_pairwise_cross_), - total_warmup_(other.total_warmup_), cholesky_of_precision_(other.cholesky_of_precision_), inv_cholesky_of_precision_(other.inv_cholesky_of_precision_), covariance_continuous_(other.covariance_continuous_), @@ -1185,42 +1184,134 @@ void MixedMRFModel::impute_missing() { // ============================================================================= void MixedMRFModel::do_one_metropolis_step(int iteration) { + // Per-slot accept-probability and visit-mask matrices for the five + // proposal-SD storages. Only entries we actually visit get mask=1; the + // adapter only RM-updates those slots. + arma::mat ar_main_disc = arma::zeros(p_, max_cats_); + arma::umat mask_main_disc= arma::zeros(p_, max_cats_); + arma::mat ar_main_cont = arma::zeros(q_, 1); + arma::umat mask_main_cont= arma::zeros(q_, 1); + arma::mat ar_pair_disc = arma::zeros(p_, p_); + arma::umat mask_pair_disc= arma::zeros(p_, p_); + arma::mat ar_pair_cont = arma::zeros(q_, q_); + arma::umat mask_pair_cont= arma::zeros(q_, q_); + arma::mat ar_pair_cross = arma::zeros(p_, q_); + arma::umat mask_pair_cross= arma::zeros(p_, q_); + + // Step 1: main effects (ordinal thresholds or BC α/β) + for(size_t s = 0; s < p_; ++s) { + if(is_ordinal_variable_(s)) { + for(int c = 0; c < num_categories_(s); ++c) { + ar_main_disc(s, c) = update_main_effect(s, c, std::nullopt); + mask_main_disc(s, c) = 1; + } + } else { + ar_main_disc(s, 0) = update_main_effect(s, 0, std::nullopt); + ar_main_disc(s, 1) = update_main_effect(s, 1, std::nullopt); + mask_main_disc(s, 0) = 1; + mask_main_disc(s, 1) = 1; + } + } + + // Step 2: continuous means + for(size_t j = 0; j < q_; ++j) { + ar_main_cont(j, 0) = update_continuous_mean(j, std::nullopt); + mask_main_cont(j, 0) = 1; + } + + // Step 3: pairwise_effects_discrete_ (upper triangle, edge-gated) + for(size_t i = 0; i < p_ - 1; ++i) + for(size_t j = i + 1; j < p_; ++j) + if(!edge_selection_active_ || gxx(i, j) == 1) { + ar_pair_disc(i, j) = update_pairwise_discrete(i, j, std::nullopt); + mask_pair_disc(i, j) = 1; + } + + // Step 4: pairwise_effects_continuous_ (off-diag + diagonal, edge-gated) + if(q_ >= 2) { + for(size_t i = 0; i < q_ - 1; ++i) + for(size_t j = i + 1; j < q_; ++j) + if(!edge_selection_active_ || gyy(i, j) == 1) { + ar_pair_cont(i, j) = update_pairwise_effects_continuous_offdiag(i, j, std::nullopt); + mask_pair_cont(i, j) = 1; + } + } + for(size_t i = 0; i < q_; ++i) { + ar_pair_cont(i, i) = update_pairwise_effects_continuous_diag(i, std::nullopt); + mask_pair_cont(i, i) = 1; + } + + // Step 5: pairwise_effects_cross_ (edge-gated) + for(size_t i = 0; i < p_; ++i) + for(size_t j = 0; j < q_; ++j) + if(!edge_selection_active_ || gxy(i, j) == 1) { + ar_pair_cross(i, j) = update_pairwise_cross(i, j, std::nullopt); + mask_pair_cross(i, j) = 1; + } + + // Robbins-Monro batch update on each storage's adapter (MH mode only). + if (mh_adapter_main_discrete_) + mh_adapter_main_discrete_->update(mask_main_disc, ar_main_disc, iteration); + if (mh_adapter_main_continuous_) + mh_adapter_main_continuous_->update(mask_main_cont, ar_main_cont, iteration); + if (mh_adapter_pairwise_discrete_) + mh_adapter_pairwise_discrete_->update(mask_pair_disc, ar_pair_disc, iteration); + if (mh_adapter_pairwise_continuous_) + mh_adapter_pairwise_continuous_->update(mask_pair_cont, ar_pair_cont, iteration); + if (mh_adapter_pairwise_cross_) + mh_adapter_pairwise_cross_->update(mask_pair_cross, ar_pair_cross, iteration); +} + +void MixedMRFModel::init_metropolis_adaptation(const WarmupSchedule& schedule) { + mh_adapter_main_discrete_ = std::make_unique( + proposal_sd_main_discrete_, schedule, target_accept_); + mh_adapter_main_continuous_ = std::make_unique( + proposal_sd_main_continuous_, schedule, target_accept_); + mh_adapter_pairwise_discrete_ = std::make_unique( + proposal_sd_pairwise_discrete_, schedule, target_accept_); + mh_adapter_pairwise_continuous_ = std::make_unique( + proposal_sd_pairwise_continuous_, schedule, target_accept_); + mh_adapter_pairwise_cross_ = std::make_unique( + proposal_sd_pairwise_cross_, schedule, target_accept_); +} + +void MixedMRFModel::sweep_within_model_mh(std::optional rm_weight) { // Step 1: Update all main effects (ordinal thresholds or BC α/β) for(size_t s = 0; s < p_; ++s) { if(is_ordinal_variable_(s)) { for(int c = 0; c < num_categories_(s); ++c) - update_main_effect(s, c, iteration); + update_main_effect(s, c, rm_weight); } else { - update_main_effect(s, 0, iteration); // linear α - update_main_effect(s, 1, iteration); // quadratic β + update_main_effect(s, 0, rm_weight); // linear α + update_main_effect(s, 1, rm_weight); // quadratic β } } // Step 2: Update all continuous means for(size_t j = 0; j < q_; ++j) - update_continuous_mean(j, iteration); + update_continuous_mean(j, rm_weight); // Step 3: Update pairwise_effects_discrete_ (upper triangle, edge-gated) for(size_t i = 0; i < p_ - 1; ++i) for(size_t j = i + 1; j < p_; ++j) if(!edge_selection_active_ || gxx(i, j) == 1) - update_pairwise_discrete(i, j, iteration); + update_pairwise_discrete(i, j, rm_weight); // Step 4: Update pairwise_effects_continuous_ (off-diag + diagonal, edge-gated) if(q_ >= 2) { for(size_t i = 0; i < q_ - 1; ++i) for(size_t j = i + 1; j < q_; ++j) if(!edge_selection_active_ || gyy(i, j) == 1) - update_pairwise_effects_continuous_offdiag(i, j, iteration); + update_pairwise_effects_continuous_offdiag(i, j, rm_weight); } for(size_t i = 0; i < q_; ++i) - update_pairwise_effects_continuous_diag(i, iteration); + update_pairwise_effects_continuous_diag(i, rm_weight); // Step 5: Update pairwise_effects_cross_ (edge-gated) for(size_t i = 0; i < p_; ++i) for(size_t j = 0; j < q_; ++j) if(!edge_selection_active_ || gxy(i, j) == 1) - update_pairwise_cross(i, j, iteration); + update_pairwise_cross(i, j, rm_weight); // Edge-indicator updates are handled by ChainRunner, not here. // (Matches the OMRF pattern; avoids double-counting indicator proposals.) @@ -1281,11 +1372,10 @@ void MixedMRFModel::prepare_iteration() { edge_order_xy_ = arma_randperm(rng_, num_cross_); } -void MixedMRFModel::init_metropolis_adaptation(const WarmupSchedule& schedule) { - total_warmup_ = schedule.total_warmup; -} - -void MixedMRFModel::tune_proposal_sd(int /*iteration*/, const WarmupSchedule& /*schedule*/) { - // Robbins-Monro adaptation is embedded in each MH update function, - // gated by iteration < total_warmup_. No separate tuning pass needed. +void MixedMRFModel::tune_proposal_sd(int iteration, const WarmupSchedule& schedule) { + auto rm_weight_opt = schedule.rm_weight_for_proposal_sd(iteration); + if (!rm_weight_opt) return; + // Stage-3b sweep: re-run every within-model MH proposal with RM + // applied to its proposal-SD slot via *rm_weight_opt. Sampler-agnostic. + sweep_within_model_mh(rm_weight_opt); } diff --git a/src/models/mixed/mixed_mrf_model.h b/src/models/mixed/mixed_mrf_model.h index 7c8056d3..7bd00a8b 100644 --- a/src/models/mixed/mixed_mrf_model.h +++ b/src/models/mixed/mixed_mrf_model.h @@ -2,6 +2,7 @@ #include #include +#include #include "models/base_model.h" #include "models/ggm/graph_constraint_structure.h" #include "models/ggm/ggm_gradient.h" @@ -9,6 +10,7 @@ #include "math/cholupdate.h" #include "rng/rng_utils.h" #include "priors/parameter_prior.h" +#include "mcmc/samplers/metropolis_adaptation.h" /** * MixedMRFModel - Mixed Markov Random Field Model @@ -121,12 +123,6 @@ class MixedMRFModel : public BaseModel { */ void do_one_metropolis_step(int iteration = -1) override; - /** - * Initialize Metropolis adaptation controllers for proposal-SD tuning. - * Called before warmup begins. - */ - void init_metropolis_adaptation(const WarmupSchedule& schedule) override; - /** * Set the Robbins-Monro target acceptance rate used by the * adaptive-Metropolis updates of this mixed model. Honoured by the @@ -138,9 +134,21 @@ class MixedMRFModel : public BaseModel { target_accept_ = target; } + /** + * Construct Robbins-Monro adaptation controllers for the per-iteration + * MH proposal SDs. Called once by MetropolisSampler before warmup; under + * NUTS this is never called and the controllers stay null. One adapter + * per proposal-SD storage (5 in total). + */ + void init_metropolis_adaptation(const WarmupSchedule& schedule) override; + /** * Tune proposal SDs via Robbins-Monro (Stage 3b). - * Called every iteration; checks schedule internally. + * + * Re-runs every MH sweep with the schedule-supplied RM weight applied to + * each proposal-SD slot. Outside stage 3b the schedule returns nullopt + * and this is a no-op. Mirrors `OMRFModel::tune_proposal_sd` / + * `GGMModel::tune_proposal_sd`; sampler-agnostic by construction. */ void tune_proposal_sd(int iteration, const WarmupSchedule& schedule) override; @@ -294,6 +302,16 @@ class MixedMRFModel : public BaseModel { // to 0.44 (componentwise random-walk Metropolis optimum). double target_accept_ = 0.44; + /// Per-iteration adaptation controllers (MH mode only — under NUTS these + /// stay null and the stage-3b path in tune_proposal_sd is used instead). + /// One adapter per proposal-SD storage; off-diag and diag of the continuous + /// pairwise share `proposal_sd_pairwise_continuous_` and thus one adapter. + std::unique_ptr mh_adapter_main_discrete_; + std::unique_ptr mh_adapter_main_continuous_; + std::unique_ptr mh_adapter_pairwise_discrete_; + std::unique_ptr mh_adapter_pairwise_continuous_; + std::unique_ptr mh_adapter_pairwise_cross_; + // ========================================================================= // Counts and dimensions // ========================================================================= @@ -371,11 +389,10 @@ class MixedMRFModel : public BaseModel { // ========================================================================= arma::mat proposal_sd_main_discrete_; ///< p x max_cats - arma::vec proposal_sd_main_continuous_; ///< q-vector + arma::mat proposal_sd_main_continuous_; ///< q x 1 (mat-shaped for MetropolisAdaptationController) arma::mat proposal_sd_pairwise_discrete_; ///< p x p arma::mat proposal_sd_pairwise_continuous_; ///< q x q arma::mat proposal_sd_pairwise_cross_; ///< p x q - int total_warmup_ = 0; ///< Stored by init_metropolis_adaptation // ========================================================================= // Cached quantities @@ -518,23 +535,49 @@ class MixedMRFModel : public BaseModel { // --- Parameter update sweeps --- + /** + * Run every within-model MH proposal once (main effects, continuous + * means, pairwise discrete/continuous/cross). Edge-indicator updates + * are handled separately by ChainRunner. + * + * If `rm_weight` is set, each proposal also Robbins-Monro-updates its + * proposal-SD slot using `ln_alpha` and `*rm_weight`. If nullopt, only + * the accept/reject step runs. + * + * Shared between `do_one_metropolis_step` (called by MetropolisSampler + * every iteration, with nullopt) and `tune_proposal_sd` (called every + * iteration but a no-op outside stage 3b; during 3b passes the + * schedule-supplied weight). + */ + void sweep_within_model_mh(std::optional rm_weight); + + // All within-model MH update sweeps below take an optional `rm_weight` + // and return the Metropolis acceptance probability for the proposal. + // - From `do_one_metropolis_step`: called with std::nullopt; the returned + // AR is collected per slot and fed to the MetropolisAdaptationController + // batch update after the sweep. + // - From `tune_proposal_sd` during stage 3b: called with the schedule's + // `rm_weight_for_proposal_sd(iter)`; Robbins-Monro updates the matching + // proposal-SD slot inline using `ln_alpha` and `*rm_weight`. Returned + // AR is discarded. + /** Update one main-effect: main_effects_discrete_(s, c). Ordinal threshold or BC α/β. */ - void update_main_effect(int s, int c, int iteration); + double update_main_effect(int s, int c, std::optional rm_weight); /** Update one continuous mean: main_effects_continuous_(j). */ - void update_continuous_mean(int j, int iteration); + double update_continuous_mean(int j, std::optional rm_weight); /** Update one discrete interaction: pairwise_effects_discrete_(i, j). Symmetric. */ - void update_pairwise_discrete(int i, int j, int iteration); + double update_pairwise_discrete(int i, int j, std::optional rm_weight); /** Update one off-diagonal precision element. Cholesky-based. */ - void update_pairwise_effects_continuous_offdiag(int i, int j, int iteration); + double update_pairwise_effects_continuous_offdiag(int i, int j, std::optional rm_weight); /** Update one diagonal precision element. Log-scale Cholesky. */ - void update_pairwise_effects_continuous_diag(int i, int iteration); + double update_pairwise_effects_continuous_diag(int i, std::optional rm_weight); /** Update one cross interaction: pairwise_effects_cross_(i, j). */ - void update_pairwise_cross(int i, int j, int iteration); + double update_pairwise_cross(int i, int j, std::optional rm_weight); // --- Edge-indicator update sweeps --- diff --git a/src/models/omrf/omrf_model.cpp b/src/models/omrf/omrf_model.cpp index 3b0b1a86..d912eb78 100644 --- a/src/models/omrf/omrf_model.cpp +++ b/src/models/omrf/omrf_model.cpp @@ -252,12 +252,10 @@ void OMRFModel::init_metropolis_adaptation(const WarmupSchedule& schedule) { void OMRFModel::tune_proposal_sd(int iteration, const WarmupSchedule& schedule) { - if (!schedule.adapt_proposal_sd(iteration)) return; - + auto rm_weight_opt = schedule.rm_weight_for_proposal_sd(iteration); + if (!rm_weight_opt) return; + const double rm_weight = *rm_weight_opt; const double target_accept = target_accept_; - const double rm_decay = 0.75; - double t = iteration - schedule.stage3b_start + 1; - double rm_weight = std::pow(t, -rm_decay); const int num_variables = static_cast(p_); diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index caf77f2b..9a7ab080 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -68,9 +68,16 @@ Rcpp::List sample_ggm( edge_selection, std::move(interaction_prior), std::move(diagonal_prior), na_impute); - // Forward target_accept to the model so adaptive-Metropolis updates - // target the user's value rather than the hard-coded 0.44 default. - model.set_metropolis_target_accept(target_acceptance); + // Forward target_accept to the model's MH proposal-SD tuner. + // - Under "adaptive-metropolis": user's target_accept goes through + // directly (default 0.44 = componentwise RW MH optimum). + // - Under "nuts": user's target_accept (default 0.80) is the + // HMC step-size dual-averaging target and should NOT govern the + // between-model MH proposal SDs, which are still 1-D componentwise + // RW MH. Hardcode 0.44 there to keep stage-3b RM on the right + // fixed point. + const double mh_target = (sampler_type == "nuts") ? 0.44 : target_acceptance; + model.set_metropolis_target_accept(mh_target); // Set up missing data imputation (same pattern as OMRF) if (na_impute && missing_index_nullable.isNotNull()) { diff --git a/src/sample_mixed.cpp b/src/sample_mixed.cpp index 7fda6616..557caf17 100644 --- a/src/sample_mixed.cpp +++ b/src/sample_mixed.cpp @@ -136,9 +136,16 @@ Rcpp::List sample_mixed_mrf( seed ); - // Forward target_accept to the model so adaptive-Metropolis updates - // target the user's value rather than the hard-coded 0.44 default. - model.set_metropolis_target_accept(target_acceptance); + // Forward target_accept to the model's MH proposal-SD tuner. + // - Under "adaptive-metropolis": user's target_accept goes through + // directly (default 0.44 = componentwise RW MH optimum). + // - Under "nuts": user's target_accept (default 0.80) is the + // HMC step-size dual-averaging target and should NOT govern the + // between-model MH proposal SDs, which are still 1-D componentwise + // RW MH. Hardcode 0.44 there to keep stage-3b RM on the right + // fixed point. + const double mh_target = (sampler_type == "nuts") ? 0.44 : target_acceptance; + model.set_metropolis_target_accept(mh_target); // Set up missing data imputation if(na_impute) { diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index 1ab157ee..1606a964 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -95,11 +95,16 @@ Rcpp::List sample_omrf( std::move(interaction_prior), std::move(threshold_prior), edge_selection); - // Forward the user's target_accept to the model. This drives the - // Robbins-Monro proposal-SD adaptation under update_method = - // "adaptive-metropolis"; ignored under "nuts" (whose own dual - // averaging consumes target_acceptance via the SamplerConfig). - model.set_metropolis_target_accept(target_acceptance); + // Forward target_accept to the model's MH proposal-SD tuner. + // - Under "adaptive-metropolis": user's target_accept goes through + // directly (default 0.44 = componentwise RW MH optimum). + // - Under "nuts": user's target_accept (default 0.80) is the + // HMC step-size dual-averaging target and should NOT govern the + // between-model MH proposal SDs, which are still 1-D componentwise + // RW MH. Hardcode 0.44 there to keep stage-3b RM on the right + // fixed point. + const double mh_target = (sampler_type == "nuts") ? 0.44 : target_acceptance; + model.set_metropolis_target_accept(mh_target); // Set pairwise scaling factors (if provided) if (pairwise_scaling_factors_nullable.isNotNull()) {