Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export(extract_sbm)
export(gamma_prior)
export(mrfSampler)
export(normal_prior)
export(sample_precision_prior)
export(sample_ggm_prior)
export(sbm_prior)
export(simulate_mrf)
import(RcppParallel)
Expand Down
16 changes: 8 additions & 8 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ ggm_test_leapfrog_constrained_checked <- function(x0, r0, step_size, n_steps, su
.Call(`_bgms_ggm_test_leapfrog_constrained_checked`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, reverse_check_tol, inv_mass_in)
}

sample_precision_prior_cpp <- function(p, n_samples, n_warmup = 1000L, pairwise_scale = 2.5, interaction_prior_type = "cauchy", scale_prior_type = "gamma", gamma_shape = 1.0, gamma_rate = 1.0, step_size = 0.1, max_depth = 10L, seed = 1L, verbose = TRUE, edge_indicators_nullable = NULL) {
.Call(`_bgms_sample_precision_prior`, p, n_samples, n_warmup, pairwise_scale, interaction_prior_type, scale_prior_type, gamma_shape, gamma_rate, step_size, max_depth, seed, verbose, edge_indicators_nullable)
sample_ggm_prior_cpp <- function(p, n_samples, n_warmup = 1000L, pairwise_scale = 2.5, interaction_prior_type = "cauchy", scale_prior_type = "gamma", gamma_shape = 1.0, gamma_rate = 1.0, step_size = 0.1, max_depth = 10L, seed = 1L, verbose = TRUE, edge_indicators_nullable = NULL, delta = 0.0) {
.Call(`_bgms_sample_ggm_prior`, p, n_samples, n_warmup, pairwise_scale, interaction_prior_type, scale_prior_type, gamma_shape, gamma_rate, step_size, max_depth, seed, verbose, edge_indicators_nullable, delta)
}

.compute_ess_cpp <- function(array3d) {
Expand All @@ -69,8 +69,8 @@ mixed_test_logp_and_gradient <- function(params, discrete_observations, continuo
.Call(`_bgms_mixed_test_logp_and_gradient`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate)
}

mixed_test_logp_and_gradient_full <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0, means_prior_type = "normal", means_scale = 1.0, diagonal_prior_type = "gamma", diagonal_shape = 1.0, diagonal_rate = 1.0, inv_mass_diag = NULL) {
.Call(`_bgms_mixed_test_logp_and_gradient_full`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate, inv_mass_diag)
mixed_test_logp_and_gradient_full <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0, means_prior_type = "normal", means_scale = 1.0, diagonal_prior_type = "gamma", diagonal_shape = 1.0, diagonal_rate = 1.0, inv_mass_diag = NULL, delta = 0.0) {
.Call(`_bgms_mixed_test_logp_and_gradient_full`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha, main_beta, interaction_prior_type, threshold_prior_type, threshold_scale, means_prior_type, means_scale, diagonal_prior_type, diagonal_shape, diagonal_rate, inv_mass_diag, delta)
}

mixed_test_project_position <- function(x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pairwise_scale, main_alpha = 1.0, main_beta = 1.0, interaction_prior_type = "cauchy", threshold_prior_type = "beta-prime", threshold_scale = 1.0) {
Expand Down Expand Up @@ -141,12 +141,12 @@ ggm_test_logp_and_gradient_full_prior <- function(x, suf_stat, n, edge_indicator
.Call(`_bgms_ggm_test_logp_and_gradient_full_prior`, x, suf_stat, n, edge_indicators, interaction_prior_type, interaction_scale, interaction_alpha, interaction_beta, diagonal_prior_type, diagonal_shape, diagonal_rate, inv_mass_diag)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)
sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL, delta = 0.0) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable, delta)
}

sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) {
.Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)
sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL, delta = 0.0) {
.Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable, delta)
}

sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, pairwise_scaling_factors_nullable = NULL) {
Expand Down
10 changes: 10 additions & 0 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@
#' Ignored for pure ordinal models.
#' Default: \code{gamma_prior(shape = 1, rate = 1)}.
#'
#' @param delta Non-negative numeric. Determinant-tilt exponent on the
#' continuous-block precision matrix \eqn{K} (GGM) or \eqn{K_{yy}}
#' (mixed MRF): multiplies the prior by \eqn{|K|^{\delta}}, pushing the
#' chain away from the positive-definite cone boundary. Both NUTS and
#' adaptive-Metropolis update paths apply the tilt. \code{delta = 0}
#' (default) recovers the untilted prior. Not allowed for pure ordinal
#' models (no precision matrix to tilt).
#'
#' @param pairwise_scale `r lifecycle::badge("deprecated")` Double.
#' Scale of the Cauchy prior for pairwise
#' interaction parameters. Use \code{interaction_prior} instead.
Expand Down Expand Up @@ -325,6 +333,7 @@ bgm = function(
threshold_prior = beta_prime_prior(alpha = 0.5, beta = 0.5),
means_prior = normal_prior(scale = 1),
precision_scale_prior = gamma_prior(shape = 1, rate = 1),
delta = 0,
edge_selection = TRUE,
edge_prior = bernoulli_prior(0.5),
na_action = c("listwise", "impute"),
Expand Down Expand Up @@ -492,6 +501,7 @@ bgm = function(
scale_prior_type = sp$scale_prior_type,
scale_shape = sp$scale_shape,
scale_rate = sp$scale_rate,
delta = delta,
standardize = standardize,
edge_selection = edge_selection,
edge_prior = edge_prior,
Expand Down
20 changes: 20 additions & 0 deletions R/bgm_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ bgm_spec = function(x,
scale_prior_type = "gamma",
scale_shape = 1,
scale_rate = 1,
delta = 0,
standardize = FALSE,
edge_selection = TRUE,
edge_prior = bernoulli_prior(0.5),
Expand Down Expand Up @@ -344,6 +345,19 @@ bgm_spec = function(x,
model_type = "mixed_mrf"
}

# Validate determinant-tilt exponent and reject for pure-ordinal models
if(!is.numeric(delta) || length(delta) != 1L || is.na(delta) ||
!is.finite(delta) || delta < 0) {
stop("'delta' must be a single finite non-negative numeric.")
}
if(delta > 0 && model_type %in% c("omrf", "compare")) {
stop(
"'delta' (determinant tilt) requires continuous variables; the ",
"current model_type is '", model_type, "', which has no precision ",
"matrix to tilt. Pass delta = 0 (default) or use continuous data."
)
}

# --- Sampler (needs is_continuous and edge_selection early) ------------------
sampler = validate_sampler(
update_method = update_method,
Expand Down Expand Up @@ -415,6 +429,7 @@ bgm_spec = function(x,
scale_prior_type = scale_prior_type,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
edge_prior_flat = ep_flat
)
} else if(model_type == "mixed_mrf") {
Expand All @@ -438,6 +453,7 @@ bgm_spec = function(x,
scale_prior_type = scale_prior_type,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
standardize = standardize,
edge_prior_flat = ep_flat
)
Expand Down Expand Up @@ -505,6 +521,7 @@ build_spec_ggm = function(x, data_columnnames, num_variables,
interaction_prior_type, pairwise_scale,
interaction_alpha, interaction_beta,
scale_prior_type, scale_shape, scale_rate,
delta = 0,
edge_prior_flat) {
# Missing data
md = validate_missing_data(
Expand Down Expand Up @@ -545,6 +562,7 @@ build_spec_ggm = function(x, data_columnnames, num_variables,
scale_prior_type = scale_prior_type,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
edge_selection = ep$edge_selection,
edge_prior = ep$edge_prior,
inclusion_probability = ep$inclusion_probability,
Expand Down Expand Up @@ -683,6 +701,7 @@ build_spec_mixed_mrf = function(x, data_columnnames, num_variables,
means_prior_type, means_scale,
means_alpha, means_beta,
scale_prior_type, scale_shape, scale_rate,
delta = 0,
standardize,
edge_prior_flat) {
# Identify discrete vs continuous columns
Expand Down Expand Up @@ -818,6 +837,7 @@ build_spec_mixed_mrf = function(x, data_columnnames, num_variables,
scale_prior_type = scale_prior_type,
scale_shape = scale_shape,
scale_rate = scale_rate,
delta = delta,
standardize = standardize,
edge_selection = ep$edge_selection,
edge_prior = ep$edge_prior,
Expand Down
6 changes: 4 additions & 2 deletions R/run_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ run_sampler_ggm = function(spec) {
target_acceptance = s$target_accept,
max_tree_depth = s$nuts_max_depth,
na_impute = m$na_impute,
missing_index_nullable = m$missing_index
missing_index_nullable = m$missing_index,
delta = p$delta
)

out_raw
Expand Down Expand Up @@ -236,7 +237,8 @@ run_sampler_mixed_mrf = function(spec) {
max_tree_depth = s$nuts_max_depth,
na_impute = m$na_impute,
missing_index_discrete_nullable = m$missing_index_discrete,
missing_index_continuous_nullable = m$missing_index_continuous
missing_index_continuous_nullable = m$missing_index_continuous,
delta = p$delta
)

out_raw
Expand Down
Loading
Loading