Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ mixed_test_logp_and_gradient <- function(params, discrete_observations, continuo
.Call(`_bgms_mixed_test_logp_and_gradient`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate)
}

mixed_test_logp_and_gradient_full <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0, means_prior_type = "normal", means_scale = 1.0, diagonal_prior_type = "gamma", diagonal_shape = 1.0, diagonal_rate = 1.0) {
.Call(`_bgms_mixed_test_logp_and_gradient_full`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate)
mixed_test_logp_and_gradient_full <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0, means_prior_type = "normal", means_scale = 1.0, diagonal_prior_type = "gamma", diagonal_shape = 1.0, diagonal_rate = 1.0, inv_mass_diag = NULL) {
.Call(`_bgms_mixed_test_logp_and_gradient_full`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate, inv_mass_diag)
}

mixed_test_project_position <- function(x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0) {
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ BEGIN_RCPP
END_RCPP
}
// mixed_test_logp_and_gradient_full
Rcpp::List mixed_test_logp_and_gradient_full(const arma::vec& params, const arma::imat& discrete_observations, const arma::mat& continuous_observations, const arma::ivec& num_categories, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const arma::imat& edge_indicators, double pairwise_scale, double main_alpha, double main_beta, std::string interaction_prior_type, std::string threshold_prior_type, double threshold_scale, std::string means_prior_type, double means_scale, std::string diagonal_prior_type, double diagonal_shape, double diagonal_rate);
RcppExport SEXP _bgms_mixed_test_logp_and_gradient_full(SEXP paramsSEXP, SEXP discrete_observationsSEXP, SEXP continuous_observationsSEXP, SEXP num_categoriesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP interaction_prior_typeSEXP, SEXP threshold_prior_typeSEXP, SEXP threshold_scaleSEXP, SEXP means_prior_typeSEXP, SEXP means_scaleSEXP, SEXP diagonal_prior_typeSEXP, SEXP diagonal_shapeSEXP, SEXP diagonal_rateSEXP) {
Rcpp::List mixed_test_logp_and_gradient_full(const arma::vec& params, const arma::imat& discrete_observations, const arma::mat& continuous_observations, const arma::ivec& num_categories, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, const arma::imat& edge_indicators, double pairwise_scale, double main_alpha, double main_beta, std::string interaction_prior_type, std::string threshold_prior_type, double threshold_scale, std::string means_prior_type, double means_scale, std::string diagonal_prior_type, double diagonal_shape, double diagonal_rate, Rcpp::Nullable<Rcpp::NumericVector> inv_mass_diag);
RcppExport SEXP _bgms_mixed_test_logp_and_gradient_full(SEXP paramsSEXP, SEXP discrete_observationsSEXP, SEXP continuous_observationsSEXP, SEXP num_categoriesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP edge_indicatorsSEXP, SEXP pairwise_scaleSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP interaction_prior_typeSEXP, SEXP threshold_prior_typeSEXP, SEXP threshold_scaleSEXP, SEXP means_prior_typeSEXP, SEXP means_scaleSEXP, SEXP diagonal_prior_typeSEXP, SEXP diagonal_shapeSEXP, SEXP diagonal_rateSEXP, SEXP inv_mass_diagSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -326,7 +326,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< std::string >::type diagonal_prior_type(diagonal_prior_typeSEXP);
Rcpp::traits::input_parameter< double >::type diagonal_shape(diagonal_shapeSEXP);
Rcpp::traits::input_parameter< double >::type diagonal_rate(diagonal_rateSEXP);
rcpp_result_gen = Rcpp::wrap(mixed_test_logp_and_gradient_full(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate));
Rcpp::traits::input_parameter< Rcpp::Nullable<Rcpp::NumericVector> >::type inv_mass_diag(inv_mass_diagSEXP);
rcpp_result_gen = Rcpp::wrap(mixed_test_logp_and_gradient_full(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate, inv_mass_diag));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -802,7 +803,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_bgms_compute_rhat_cpp", (DL_FUNC) &_bgms_compute_rhat_cpp, 1},
{"_bgms_compute_indicator_ess_cpp", (DL_FUNC) &_bgms_compute_indicator_ess_cpp, 1},
{"_bgms_mixed_test_logp_and_gradient", (DL_FUNC) &_bgms_mixed_test_logp_and_gradient, 18},
{"_bgms_mixed_test_logp_and_gradient_full", (DL_FUNC) &_bgms_mixed_test_logp_and_gradient_full, 18},
{"_bgms_mixed_test_logp_and_gradient_full", (DL_FUNC) &_bgms_mixed_test_logp_and_gradient_full, 19},
{"_bgms_mixed_test_project_position", (DL_FUNC) &_bgms_mixed_test_project_position, 14},
{"_bgms_mixed_test_project_momentum", (DL_FUNC) &_bgms_mixed_test_project_momentum, 15},
{"_bgms_mixed_test_leapfrog_constrained", (DL_FUNC) &_bgms_mixed_test_leapfrog_constrained, 17},
Expand Down
10 changes: 9 additions & 1 deletion src/mixed_gradient_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ Rcpp::List mixed_test_logp_and_gradient_full(
double means_scale = 1.0,
std::string diagonal_prior_type = "gamma",
double diagonal_shape = 1.0,
double diagonal_rate = 1.0)
double diagonal_rate = 1.0,
Rcpp::Nullable<Rcpp::NumericVector> inv_mass_diag = R_NilValue)
{
size_t p = discrete_observations.n_cols;
size_t q = continuous_observations.n_cols;
Expand All @@ -95,6 +96,13 @@ Rcpp::List mixed_test_logp_and_gradient_full(
42
);

// Plug the integrator's inverse-mass diagonal through to the gradient
// so the mass-weighted Pfaffian correction can be exercised from R.
// Empty/NULL falls back to identity in the gradient engine.
if(inv_mass_diag.isNotNull()) {
model.set_inv_mass(Rcpp::as<arma::vec>(inv_mass_diag));
}

auto result = model.logp_and_gradient_full(params);

return Rcpp::List::create(
Expand Down
192 changes: 162 additions & 30 deletions src/models/mixed/mixed_mrf_gradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ std::pair<double, arma::vec> MixedMRFModel::logp_and_gradient(
const arma::vec& parameters)
{
ensure_gradient_cache();
ensure_constraint_structure();

// --- Unvectorize into temporaries (blocks 1–4) ---
arma::mat temp_main_discrete = main_effects_discrete_;
Expand Down Expand Up @@ -637,32 +638,85 @@ std::pair<double, arma::vec> MixedMRFModel::logp_and_gradient(
arma::mat Omega_bar_sym = Omega_bar + Omega_bar.t();
arma::mat R_bar = temp_cholesky * Omega_bar_sym;

// Roverato graph-constrained Cholesky Jacobian for the Kyy block:
// log|det J| = q log 2 + Σ_j (deg_higher_yy(j) + 2) ψ_j
// where deg_higher_yy(j) = number of active Kyy edges (j, qq) with qq > j.
// For the full Kyy graph deg_higher_yy(j) = q - 1 - j, so this reduces to
// (q - j + 1) — the original formula. For sparse Kyy graphs, the
// graph-aware count is required for the prior to integrate correctly
// (Marsman/Claude, 2026-05-01).
size_t gidx = static_cast<size_t>(chol_grad_offset_);
// Cholesky-to-K Jacobian (graph-agnostic) + per-column Pfaffian correction.
// Mirrors GGMGradientEngine::logp_and_gradient_full:
// ldj = q*log(2) + Σ_j (q+1-j) ψ_j
// pfaffian = 0.5 * Σ_qq log det(A_qq A_qq^T) (identity mass here)
// logp += ldj - pfaffian
// Theta-space integration uses identity mass on the manifold; for the
// full-space (RATTLE) path we plug through the integrator's inverse-mass
// diagonal. For the full Kyy graph, all A_qq are empty, so the Pfaffian
// collapses to 0 and the formula reduces to the original q-j+1 weight.
logp += static_cast<double>(q_) * std::log(2.0);
for(size_t j = 0; j < q_; ++j) {
double psi_j = std::log(temp_cholesky(j, j));
size_t deg_higher_yy = 0;
for(size_t qq = j + 1; qq < q_; ++qq) {
if(edge_indicators_(p_ + j, p_ + qq) == 1) ++deg_higher_yy;
logp += static_cast<double>(q_ + 1 - j) * std::log(temp_cholesky(j, j));
}

const auto& cs = chol_constraint_structure_;
arma::mat Aq_buf;
std::vector<arma::mat> G_chol(q_);
std::vector<arma::mat> Aq_cache(q_);
double pfaffian = 0.0;
for(size_t col = 1; col < q_; ++col) {
const auto& cc = cs.columns[col];
if(cc.m_q == 0) continue;

GGMGradientEngine::build_Aq(temp_cholesky, cc, col, Aq_buf);
Aq_cache[col] = Aq_buf;

// Identity mass (theta-space): G_q = A_q A_q^T.
arma::mat G_q = Aq_buf * Aq_buf.t();

arma::mat L_q;
bool chol_ok = arma::chol(L_q, G_q, "lower");
if(!chol_ok) {
double ridge = 1e-12 * (arma::trace(G_q) /
static_cast<double>(cc.m_q) + 1.0);
chol_ok = arma::chol(L_q, G_q + ridge * arma::eye(cc.m_q, cc.m_q),
"lower");
if(!chol_ok) {
return {-std::numeric_limits<double>::infinity(),
arma::vec(grad.n_elem, arma::fill::zeros)};
}
}
G_chol[col] = L_q;
pfaffian += arma::accu(arma::log(arma::diagvec(L_q)));
}
logp -= pfaffian;

// Pfaffian adjoint: d/dA_q [-0.5 log det(A_q A_q^T)] = -G_q^{-1} A_q.
// Each A_q(r, l) = R(l, i_r) for l <= i_r, so the adjoint flows back to
// column i_r of R_bar at rows l = 0..i_r.
for(size_t col = 1; col < q_; ++col) {
const auto& cc = cs.columns[col];
if(cc.m_q == 0) continue;

const arma::mat& L_q = G_chol[col];
const arma::mat& Aq = Aq_cache[col];

arma::mat Z = arma::solve(arma::trimatl(L_q), Aq,
arma::solve_opts::fast);
Z = arma::solve(arma::trimatu(L_q.t()), Z,
arma::solve_opts::fast);

for(size_t r = 0; r < cc.m_q; ++r) {
size_t i_r = cc.excluded_indices[r];
for(size_t l = 0; l <= i_r; ++l) {
R_bar(l, i_r) -= Z(r, l);
}
}
double jac_weight = static_cast<double>(deg_higher_yy + 2);
logp += jac_weight * psi_j;
}

// Off-diagonal Cholesky entries: ∂ℓ/∂R_{ij} = R̄_{ij}
// Extract position gradient from R_bar with the unified weight (q+1-j)
// on the diagonal-psi entries.
size_t gidx = static_cast<size_t>(chol_grad_offset_);
for(size_t j = 0; j < q_; ++j) {
double w_j = static_cast<double>(q_ + 1 - j);
for(size_t i = 0; i < j; ++i) {
grad(gidx++) = R_bar(i, j);
}
// Diagonal (log-scale): ∂ℓ/∂ψ_j = R̄_{jj} R_{jj} + (deg_higher_yy + 2)
grad(gidx++) = R_bar(j, j) * temp_cholesky(j, j) + jac_weight;
grad(gidx++) = R_bar(j, j) * temp_cholesky(j, j) + w_j;
}
// Add constant Jacobian term to logp
logp += static_cast<double>(q_) * std::log(2.0);

return {logp, grad};
}
Expand All @@ -680,6 +734,7 @@ std::pair<double, arma::vec> MixedMRFModel::logp_and_gradient(
std::pair<double, arma::vec> MixedMRFModel::logp_and_gradient_full(
const arma::vec& x)
{
ensure_constraint_structure();
const size_t full_dim = full_parameter_dimension();

// --- Unpack all 5 blocks from full-space vector ---
Expand Down Expand Up @@ -1110,24 +1165,101 @@ std::pair<double, arma::vec> MixedMRFModel::logp_and_gradient_full(
arma::mat Omega_bar_sym = Omega_bar + Omega_bar.t();
arma::mat R_bar = temp_cholesky * Omega_bar_sym;

// Roverato graph-constrained Cholesky Jacobian for the Kyy block.
// jac_weight = deg_higher_yy(j) + 2; reduces to q - j + 1 for full graph.
size_t gidx = chol_offset;
// Cholesky-to-K Jacobian (graph-agnostic) + mass-weighted per-column
// Pfaffian correction for the RATTLE manifold marginal:
// ldj = q*log(2) + Σ_j (q+1-j) ψ_j
// pfaffian = 0.5 * Σ_qq log det(A_qq diag(M_qq^{-1}) A_qq^T)
// logp += ldj - pfaffian
// Mass diagonal is plumbed from this->inv_mass_ (empty ⇒ identity).
// For the full Kyy graph, all A_qq are empty and Pfaffian = 0.
logp += static_cast<double>(q_) * std::log(2.0);
for(size_t j = 0; j < q_; ++j) {
double psi_j = std::log(temp_cholesky(j, j));
size_t deg_higher_yy = 0;
for(size_t qq = j + 1; qq < q_; ++qq) {
if(edge_indicators_(p_ + j, p_ + qq) == 1) ++deg_higher_yy;
logp += static_cast<double>(q_ + 1 - j) * std::log(temp_cholesky(j, j));
}

const auto& cs = chol_constraint_structure_;
const bool identity_mass = inv_mass_.is_empty();
arma::mat Aq_buf;
std::vector<arma::mat> G_chol(q_);
std::vector<arma::mat> Aq_cache(q_);
std::vector<arma::vec> inv_mass_q_cache(q_);
double pfaffian = 0.0;
for(size_t col = 1; col < q_; ++col) {
const auto& cc = cs.columns[col];
if(cc.m_q == 0) continue;

GGMGradientEngine::build_Aq(temp_cholesky, cc, col, Aq_buf);
Aq_cache[col] = Aq_buf;

arma::vec inv_mass_q(col);
if(identity_mass) {
inv_mass_q.ones();
} else {
size_t off_q = chol_block_offset_ + cs.full_theta_offsets[col];
for(size_t l = 0; l < col; ++l) {
inv_mass_q(l) = inv_mass_(off_q + l);
}
}
double jac_weight = static_cast<double>(deg_higher_yy + 2);
logp += jac_weight * psi_j;
inv_mass_q_cache[col] = inv_mass_q;

// G_q = A_q diag(inv_mass_q) A_q^T
arma::mat Aq_scaled = Aq_buf;
Aq_scaled.each_row() %= inv_mass_q.t();
arma::mat G_q = Aq_scaled * Aq_buf.t();

arma::mat L_q;
bool chol_ok = arma::chol(L_q, G_q, "lower");
if(!chol_ok) {
double ridge = 1e-12 * (arma::trace(G_q) /
static_cast<double>(cc.m_q) + 1.0);
chol_ok = arma::chol(L_q, G_q + ridge * arma::eye(cc.m_q, cc.m_q),
"lower");
if(!chol_ok) {
return {-std::numeric_limits<double>::infinity(),
arma::vec(full_dim, arma::fill::zeros)};
}
}
G_chol[col] = L_q;
pfaffian += arma::accu(arma::log(arma::diagvec(L_q)));
}
logp -= pfaffian;

// Pfaffian adjoint: d/dA_q [-0.5 log det(A_q M_q^{-1} A_q^T)] flows back
// to R_bar at the excluded-edge columns. dA_q = G_q^{-1} A_q · diag(M_q^{-1}).
for(size_t col = 1; col < q_; ++col) {
const auto& cc = cs.columns[col];
if(cc.m_q == 0) continue;

const arma::mat& L_q = G_chol[col];
const arma::mat& Aq = Aq_cache[col];
const arma::vec& inv_mass_q = inv_mass_q_cache[col];

arma::mat Z = arma::solve(arma::trimatl(L_q), Aq,
arma::solve_opts::fast);
Z = arma::solve(arma::trimatu(L_q.t()), Z,
arma::solve_opts::fast);

arma::mat dAq = Z;
dAq.each_row() %= inv_mass_q.t();

for(size_t r = 0; r < cc.m_q; ++r) {
size_t i_r = cc.excluded_indices[r];
for(size_t l = 0; l <= i_r; ++l) {
R_bar(l, i_r) -= dAq(r, l);
}
}
}

// Extract position gradient from R_bar with the unified weight (q+1-j)
// on the diagonal-psi entries.
size_t gidx = chol_offset;
for(size_t j = 0; j < q_; ++j) {
double w_j = static_cast<double>(q_ + 1 - j);
for(size_t i = 0; i < j; ++i) {
grad(gidx++) = R_bar(i, j);
}
grad(gidx++) = R_bar(j, j) * temp_cholesky(j, j) + jac_weight;
grad(gidx++) = R_bar(j, j) * temp_cholesky(j, j) + w_j;
}
logp += static_cast<double>(q_) * std::log(2.0);

return {logp, grad};
}
Loading
Loading