Skip to content

Auto reparametrize PyMC models for Normalizing Flow adaptation#301

Draft
ricardoV94 wants to merge 1 commit into
pymc-devs:mainfrom
ricardoV94:pymc_auto_reparam
Draft

Auto reparametrize PyMC models for Normalizing Flow adaptation#301
ricardoV94 wants to merge 1 commit into
pymc-devs:mainfrom
ricardoV94:pymc_auto_reparam

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Apr 17, 2026

Copy link
Copy Markdown
Member

Example: Neal's funnel

import pymc as pm
import nutpie

with pm.Model(coords={"group": range(9)}) as funnel:
    log_sigma = pm.Normal("log_sigma", 0, 3)
    pm.Normal("x", 0, pm.math.exp(log_sigma / 2), dims="group")

compiled = nutpie.compile_pymc_model(
    funnel, backend="jax", gradient_backend="jax", auto_reparam=True
)
# auto_reparam: reparametrizing 1 of 2 free variables: x (AffineFlow)

trace = nutpie.sample(compiled, adaptation="flow", seed=0)

print("divergences:", int(trace.sample_stats.diverging.sum()))
# divergences: 0  (baseline with diag mass matrix adaptation: 30)
print(float(trace.posterior.log_sigma.mean()), float(trace.posterior.log_sigma.std()))
# 0.043 3.21  (true posterior is the N(0, 3) prior)

auto_reparam=True rewrites eligible free RVs into a continuously parametrized centered/non-centered form (VIP, Gorinova et al. 2019) and attaches it as an AutoFlow bijection; adaptation="flow" then fits the per-element centering knobs during tuning.

It requires backend="jax" and gradient_backend="jax" (anything else raises). By default only the reparametrization (plus a diagonal affine) is fitted — add neural coupling layers with e.g. compiled.with_transform_adapt(num_layers=8).

@ricardoV94 ricardoV94 requested a review from aseyboldt April 17, 2026 21:06
@ricardoV94 ricardoV94 force-pushed the pymc_auto_reparam branch 2 times, most recently from edac9ca to b0fd14a Compare June 11, 2026 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant