diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 13245f7..8994670 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -233,7 +233,14 @@ def make_adapter(*args, **kwargs): ) def with_transform_adapt(self, **kwargs): - return dataclasses.replace(self, _transform_adapt_args=kwargs) + """Set arguments for the flow transform adapter (``adaptation="flow"``). + + Arguments accumulate across calls; pass ``None`` to reset an + argument to its default. + """ + merged = {**(self._transform_adapt_args or {}), **kwargs} + merged = {k: v for k, v in merged.items() if v is not None} + return dataclasses.replace(self, _transform_adapt_args=merged) def update_user_data(user_data, user_data_storage): @@ -413,6 +420,7 @@ def _compile_pymc_model_jax( gradient_backend=None, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], var_names: Iterable[str] | None = None, + auto_reparam: bool = False, **kwargs, ): if find_spec("jax") is None: @@ -504,7 +512,7 @@ def expand(_x, **shared): dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names) - return from_pyfunc( + compiled = from_pyfunc( ndim=n_dim, make_logp_fn=make_logp_func, make_expand_fn=make_expand_func, @@ -519,6 +527,16 @@ def expand(_x, **shared): reparameterized_names=reparameterized_names, ) + if auto_reparam: + from nutpie.flow_reparam import build_auto_flow + + # None (with a warning) when the rewrite found nothing to do. + auto_flow = build_auto_flow(model, compiled) + if auto_flow is not None: + compiled = compiled.with_transform_adapt(auto_flow=auto_flow) + + return compiled + def compile_pymc_model( model: "pm.Model", @@ -533,6 +551,7 @@ def compile_pymc_model( ] = "support_point", var_names: Iterable[str] | None = None, freeze_model: bool | None = None, + auto_reparam: bool = False, **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -561,6 +580,18 @@ def compile_pymc_model( freeze_model : bool | None Freeze all dimensions and shared variables to treat them as compile time constants. + auto_reparam : bool + Automatically reparametrize free random variables (e.g. continuous + VIP centering of location-scale families) and attach the resulting + flow to the model, so that ``nutpie.sample(compiled_model, + adaptation="flow")`` fits the reparametrization during tuning. + Prints a summary of the reparametrized variables; if nothing can be + reparametrized, warns and attaches no flow. Requires + ``backend="jax"`` and ``gradient_backend="jax"``. With a flow + attached, only the reparametrization (plus a diagonal affine) is + fitted by default; further flow options (e.g. ``num_layers`` to add + neural coupling layers) can be set with + ``compiled_model.with_transform_adapt``. Returns ------- @@ -581,6 +612,13 @@ def compile_pymc_model( if backend is not None: backend = backend.lower() # type: ignore[assignment] + # The flow transform adapter needs the raw JAX logp function, which is + # only kept with the jax gradient backend. + if auto_reparam and (backend != "jax" or gradient_backend != "jax"): + raise ValueError( + "auto_reparam requires backend='jax' and gradient_backend='jax'" + ) + from pymc.initial_point import make_initial_point_fn from pymc.model.transform.optimization import freeze_dims_and_data @@ -618,6 +656,7 @@ def compile_pymc_model( gradient_backend=gradient_backend, pymc_initial_point_fn=initial_point_fn, var_names=var_names, + auto_reparam=auto_reparam, **kwargs, ) else: diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 07602ee..a034ebe 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -46,7 +46,15 @@ def with_data(self, **updates): return dataclasses.replace(self, _shared_data=updated) def with_transform_adapt(self, **kwargs): - return dataclasses.replace(self, _transform_adapt_args=kwargs) + """Set arguments for the flow transform adapter (``adaptation="flow"``). + + Arguments accumulate across calls (so e.g. the ``auto_flow`` attached + by ``compile_pymc_model(..., auto_reparam=True)`` survives later + tuning calls); pass ``None`` to reset an argument to its default. + """ + merged = {**(self._transform_adapt_args or {}), **kwargs} + merged = {k: v for k, v in merged.items() if v is not None} + return dataclasses.replace(self, _transform_adapt_args=merged) def _make_sampler( self, diff --git a/python/nutpie/flow_reparam.py b/python/nutpie/flow_reparam.py new file mode 100644 index 0000000..97fd9a5 --- /dev/null +++ b/python/nutpie/flow_reparam.py @@ -0,0 +1,949 @@ +"""Automatic flow-based reparametrization of free RVs in a PyMC model. + +:func:`reparametrize` picks flows per RV via the extensible rewrite +database :data:`flow_db` and returns plain :class:`FlowSpec` records in +ordinary rv/value/transform space (the rewrite IR never escapes it). +:func:`automatic_flow_reparam` reports the chosen flows, +:func:`build_flow_graph` turns the specs into symbolic +constrain/unconstrain maps over Nutpie's flat point vector, and +:func:`build_auto_flow` wraps those into an ``AutoFlow`` for flow +adaptation. +""" + +from __future__ import annotations + +import abc +import warnings +from itertools import zip_longest +from typing import NamedTuple + +import numpy as np +import pytensor +import pytensor.tensor as pt +from pytensor.compile import optdb +from pytensor.graph.basic import Apply, clone_get_equiv +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import dfs_rewriter, node_rewriter +from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB +from pytensor.graph.traversal import ancestors, explicit_graph_inputs +from pytensor.tensor.math import variadic_add +from pytensor.tensor.random.basic import ( + CauchyRV, + ExponentialRV, + GammaRV, + GumbelRV, + HalfCauchyRV, + HalfNormalRV, + InvGammaRV, + LaplaceRV, + LogisticRV, + LogNormalRV, + NormalRV, + ParetoRV, + StudentTRV, + WeibullRV, +) +from pytensor.xtensor.basic import ( + XTensorFromTensor, + tensor_from_xtensor, + xtensor_from_tensor, +) +from pytensor.xtensor.type import XTensorType + + +from pymc.dims.distributions.transforms import LogTransform as DimLogTransform +from pymc.distributions.multivariate import ZeroSumNormalRV +from pymc.distributions.transforms import ZeroSumTransform +from pymc.distributions.transforms import log as log_transform +from pymc.logprob.transforms import LogTransform +from pymc.logprob.utils import replace_rvs_by_values +from pymc.model.core import Model +from pymc.model.fgraph import ( + ModelFreeRV, + ModelValuedVar, + ModelVar, + fgraph_from_model, +) +from pymc.pytensorf import toposort_replace + + +# (loc_idx, scale_idx, required_value_transform) per RV op type. The transform +# entry is the class the ModelFreeRV's value transform must be for this RV +# op's sampling-space distribution to be loc-scale. ``None`` means no +# transform (rv and value share the same space, already loc-scale). +# ``LogTransform`` covers log-loc-scale families (LogNormal in log-space +# is Normal, so the same affine flow applies). +LOC_SCALE_FAMILIES: dict[type, tuple[int, int, type | None]] = { + NormalRV: (0, 1, None), + CauchyRV: (0, 1, None), + LaplaceRV: (0, 1, None), + LogisticRV: (0, 1, None), + GumbelRV: (0, 1, None), + StudentTRV: (1, 2, None), + LogNormalRV: (0, 1, LogTransform), +} + + +# RV ops of the form ``X = S · Y`` (``Y`` fixed-shape, ``S`` the scale) +# that land in log-space under their default ``LogTransform``. Shifting +# the log-space value by a constant ↔ scaling ``X`` by that constant, so +# a shift-only affine flow in the sampling space captures hierarchical +# variation in the scale parameter without an icdf call. Value is the +# index of the scale/rate parameter within ``dist_params``. +SCALE_SHIFT_FAMILIES: dict[type, int] = { + GammaRV: 1, # β (rate) + WeibullRV: 1, # β + InvGammaRV: 1, # β + ExponentialRV: 0, # λ + HalfNormalRV: 0, # σ + HalfCauchyRV: 0, # β + ParetoRV: 1, # m (scale; α fixed) +} + + +class Flow(abc.ABC): + """Pure-math descriptor for a per-RV reparametrization. + + Subclasses declare how many of the ``constrain`` / ``unconstrain`` + arguments come from the model graph (``n_model_params`` — e.g. the + location/scale a child reads off its parents) versus how many are + trainable per-flow parameters fitted during adaptation + (``n_hyper_params`` — the VIP centering knobs). Both maps take the + same signature ``(point, *model_params, *hyper_params)``: + + - ``unconstrain(x, ...)`` — value-space point ``x`` → NUTS-space + point ``y``. + - ``constrain(y, ...)`` — NUTS-space point ``y`` → value-space + point ``x``. + + Both log|det| Jacobians default to the full Jacobian via + ``pt.jacobian(..., vectorize=True)`` — correct for any invertible + flow; subclasses override with a closed form when it's cheaper (see + :class:`BaseAffineFlow` for the constant-Jacobian shortcut). + """ + + # Leading args sourced from the model fgraph (loc/scale subgraphs). + n_model_params: int = 0 + # Trailing args fitted during adaptation; one entry per ``param_shapes``. + n_hyper_params: int = 0 + + @staticmethod + @abc.abstractmethod + def unconstrain(x, *params): # x -> y + ... + + @staticmethod + @abc.abstractmethod + def constrain(y, *params): # y -> x + ... + + @classmethod + def log_jac_det_constrain(cls, y, *params): + x = cls.constrain(y, *params).ravel() + J = pt.jacobian(x, y, vectorize=True).reshape((x.size, x.size)) + return pt.linalg.slogdet(J)[1] + + @classmethod + def log_jac_det_unconstrain(cls, x, *params): + y = cls.unconstrain(x, *params).ravel() + J = pt.jacobian(y, x, vectorize=True).reshape((y.size, y.size)) + return pt.linalg.slogdet(J)[1] + + +class BaseAffineFlow(Flow): + """Flow that is affine in the point, i.e. its Jacobian is constant. + + For such flows the ``unconstrain`` log|det| is just the negation of + the ``constrain`` one, and the point at which the formula is + evaluated is irrelevant — so only ``log_jac_det_constrain`` needs a + closed form. + """ + + @classmethod + def log_jac_det_unconstrain(cls, x, *params): + return -cls.log_jac_det_constrain(x, *params) + + +class NoFlow(BaseAffineFlow): + """Identity flow. Used for RVs the rewrite skipped.""" + + n_model_params = 0 + n_hyper_params = 0 + + @staticmethod + def unconstrain(x): + return x + + @staticmethod + def constrain(y): + return y + + @staticmethod + def log_jac_det_constrain(y): + return pt.zeros((), dtype=y.dtype) + + +def _pin_if_empty(h): + """Hyper params arrive with concrete static shapes (they are unpacked + from the flat trainable vector); size 0 marks a knob the rewrite + withheld because its dist param did not qualify. Pin it at the + centred no-op ``h = 0`` so it drops out of the transform.""" + if 0 in h.type.shape: + return pt.zeros((), dtype=h.dtype) + return h + + +class AffineFlow(BaseAffineFlow): + """Variationally Inferred Parameterisation of a location-scale RV. + + Following Gorinova et al. (2019), a child ``z ~ Dist(loc, scale)`` is + expressed via a standardized ``y`` interpolating continuously between + the centred (CP) and non-centred (NCP) parameterisations. A single + trainable knob ``h`` per variable (the paper's ``1 - λ``) controls + both location and scale, so ``h = 0`` is the no-op (centred):: + + constrain(y, loc, scale, h) = (y - (1 - h)·loc)·scale^h + loc + + ``h = 0`` ⇒ identity (CP); ``h = 1`` ⇒ full NCP (``y`` is the + standardized residual). ``loc`` and ``scale`` are read off the parent + RVs (model params); ``h`` is the trainable hyper param, left + unconstrained (over-/under-centering is allowed). + + The hyper params broadcast against ``y`` / ``x`` via normal numpy + rules. The rewrite allocates one knob per element of the value var + (following Gorinova et al., whose λ is shaped like the RV), so + different elements of one hierarchical group can settle on different + centerings — except on axes where elements cannot be reparametrized + independently (e.g. ZeroSumNormal core dims), which get size 1. + + A dist param can also fail to qualify for a knob altogether (constant + loc, full-shape scale, ...). The rewrite then allocates its hyper + param at size 0, and the flow pins that knob at the centred no-op + ``h = 0`` (see :func:`_pin_if_empty`) — e.g. with ``h_sigma`` pinned + the transform degenerates to the translation ``y + h_mu·loc``. + """ + + n_model_params = 2 + n_hyper_params = 2 + + @staticmethod + def unconstrain(x, pop_mu, pop_sigma, h_mu, h_sigma): + h_mu, h_sigma = _pin_if_empty(h_mu), _pin_if_empty(h_sigma) + return (x - pop_mu) * (pop_sigma**-h_sigma) + (1 - h_mu) * pop_mu + + @staticmethod + def constrain(y, pop_mu, pop_sigma, h_mu, h_sigma): + h_mu, h_sigma = _pin_if_empty(h_mu), _pin_if_empty(h_sigma) + return (y - (1 - h_mu) * pop_mu) * (pop_sigma**h_sigma) + pop_mu + + @staticmethod + def log_jac_det_constrain(y, pop_mu, pop_sigma, h_mu, h_sigma): + # d(value)/d(y) = scale^h; h = 0 ⇒ identity. + log_det = _pin_if_empty(h_sigma) * pt.log(pop_sigma) + return pt.broadcast_to(log_det, y.shape).sum() + + +class ShiftFlow(BaseAffineFlow): + """VIP reparametrization of a log-transformed scale-family RV. + + For ``X`` whose default ``LogTransform`` lands it in log-space, the + parent enters the log-space value additively through ``log(scale)``, + so non-centering is a *shift* by that amount (no scale exponent):: + + constrain(y, shift, h) = y + h·shift with shift = log(scale) + + ``h = 0`` is the no-op (centred); ``h`` is unconstrained, so the + optimizer reaches full decoupling regardless of whether the model + param is a rate or a scale (the sign is absorbed into ``h``). The + Jacobian is a translation (det = 1 ⇒ log|det| = 0). + """ + + n_model_params = 1 + n_hyper_params = 1 + + @staticmethod + def unconstrain(x, shift, h): + return x - h * shift + + @staticmethod + def constrain(y, shift, h): + return y + h * shift + + @staticmethod + def log_jac_det_constrain(y, shift, h): + return pt.zeros((), dtype=y.dtype) + + +class FlowSpec(NamedTuple): + """Per-free-RV reparametrization record in plain rv/value/transform + space — the rewrite IR never escapes :func:`reparametrize`. + + ``rv`` is the genuine random variable (so PyMC's + ``replace_rvs_by_values`` passes the actual distribution parameters + to ``transform.backward``); ``model_params`` are the loc/scale + subgraphs the flow reads off the model, expressed in terms of the + parent specs' ``rv`` variables; ``param_shapes`` are the concrete + shapes of the flow's trainable hyper params, evaluated at the model's + initial point. + """ + + flow_cls: type[Flow] + rv: object + value: object + transform: object | None + model_params: list + param_shapes: list[tuple[int, ...]] + + +class FlowFreeRV(ModelValuedVar): + """Rewrite-internal marker carrying the chosen flow, the RV's value + transform, the model-derived params, and per-hyper-param shape exprs. + Lives only between the flow rewrite and the IR strip inside + :func:`reparametrize`. + + Inputs (positional, flat so ``op.make_node(*node.inputs)`` is + idempotent):: + + (rv, value, *model_params, *hyper_shape_exprs, *dims) + + with ``len(model_params) == flow_cls.n_model_params`` and + ``len(hyper_shape_exprs) == flow_cls.n_hyper_params``. + """ + + __props__ = ("flow_cls", "transform") + + def __init__(self, flow_cls: type[Flow], transform=None): + self.flow_cls = flow_cls + super().__init__(transform=transform) + + def __call__(self, rv, value, *params, hyperparam_shapes=(), dims=()): + # Ergonomic construction: callers group the variadic chunks by name. + return super().__call__(rv, value, *params, *hyperparam_shapes, *dims) + + def make_node(self, rv, value, *rest): + nm = self.flow_cls.n_model_params + nh = self.flow_cls.n_hyper_params + assert len(rest) >= nm + nh + return Apply(self, [rv, value, *rest], [value.type(name=value.name)]) + + +class _FlowParts(NamedTuple): + """IR-side view of a (Flow|Model)FreeRV node, used only while the + rewritten fgraph is alive.""" + + flow_cls: type[Flow] + rv: object + value: object + out: object + transform: object | None + model_params: list + hyper_shapes: list + + +def _flow_node_parts(node) -> _FlowParts | None: + """Extract :class:`_FlowParts` from a ``FlowFreeRV`` or plain + ``ModelFreeRV`` node (the latter mapped to :class:`NoFlow`); returns + ``None`` for any other node.""" + op = node.op + if isinstance(op, FlowFreeRV): + rv, value, *rest = node.inputs + nm = op.flow_cls.n_model_params + nh = op.flow_cls.n_hyper_params + return _FlowParts( + flow_cls=op.flow_cls, + rv=rv, + value=value, + out=node.outputs[0], + transform=op.transform, + model_params=list(rest[:nm]), + hyper_shapes=list(rest[nm : nm + nh]), + ) + if isinstance(op, ModelFreeRV): + rv, value, *_dims = node.inputs + return _FlowParts( + flow_cls=NoFlow, + rv=rv, + value=value, + out=node.outputs[0], + transform=op.transform, + model_params=[], + hyper_shapes=[], + ) + return None + + +def _depends_on_free_rv(vars_) -> bool: + return any( + anc.owner is not None and isinstance(anc.owner.op, ModelFreeRV) + for anc in ancestors(vars_) + ) + + +@node_rewriter([ModelFreeRV]) +def lift_xtensor_from_model_free_rv(fgraph, node): + """Pull ``XTensorFromTensor`` wrappers out of a ``ModelFreeRV``'s rv + and value inputs, leaving plain tensors inside so downstream flow + rewrites don't have to know about xtensor:: + + ModelFreeRV(XTensorFromTensor(rv), XTensorFromTensor(value), *dims) + -> XTensorFromTensor(ModelFreeRV(rv, aligned_value, *dims)) + + Only fires on RVs with ``transform=None`` or a :class:`DimTransform` + known to have a plain counterpart (see the ``match`` below); unknown + transforms are left alone since peeling might change semantics — + such RVs stay xtensor-typed end to end, which the spec extraction + and graph build handle natively. Invariant: rv and value are both + xtensor (or both plain tensors) — they may declare different dim + orders, in which case the value is ``dimshuffle``d into rv's order + so the rebuilt inner ``ModelFreeRV`` sees matched-axis tensors. + """ + current_transform = node.op.transform + # Swap any dim-aware transform for its plain logprob counterpart so + # downstream rewrites see a single class hierarchy. Unknown transforms + # aren't safe to peel — leave the node alone. + match current_transform: + case None: + new_transform = None + case DimLogTransform(): + new_transform = log_transform + case _: + return None + + xrv, xvalue, *dims = node.inputs + if not isinstance(xrv.owner.op, XTensorFromTensor): + return None + rv_dims = xrv.type.dims + value_dims = xvalue.type.dims + if set(rv_dims) != set(value_dims): + # Unrelated named axes — no safe permutation; leave the node alone. + return None + rv = xrv.owner.inputs[0] + if xvalue.owner is not None and isinstance(xvalue.owner.op, XTensorFromTensor): + value = xvalue.owner.inputs[0] + else: + value = tensor_from_xtensor(xvalue) + value.name = xvalue.name + if rv_dims != value_dims: + value = value.dimshuffle([value_dims.index(d) for d in rv_dims]) + value.name = xvalue.name + new_op = ( + node.op + if new_transform is current_transform + else type(node.op)(transform=new_transform) + ) + new_free_rv = new_op(rv, value, *dims) + return [XTensorFromTensor(dims=rv_dims)(new_free_rv)] + + +_one = pt.constant(1) +_empty_shape = pt.constant(np.array([0], dtype="int64")) + + +def _hyper_shape(value, n_core: int = 0): + """Per-element hyper-param shape: the value var's shape, collapsed to + size 1 on the trailing ``n_core`` axes (elements along support dims — + e.g. ZeroSumNormal core dims — cannot be reparametrized independently + and must share one knob).""" + ndim = value.ndim + return pt.stack( + [*(value.shape[i] for i in range(ndim - n_core)), *([_one] * n_core)] + ) + + +def _qualifies_for_hyper_param(rv, param) -> bool: + """A dist param earns a centering knob only when the RV has event + axes the param unit-broadcasts along (a hierarchical group) and the + param carries free-RV randomness for the knob to decouple.""" + pairs = zip_longest( + reversed(param.type.broadcastable), + reversed(rv.type.broadcastable), + fillvalue=True, + ) + if not any(p_bc and not rv_bc for p_bc, rv_bc in pairs): + return False + return _depends_on_free_rv([param]) + + +@node_rewriter([ModelFreeRV]) +def loc_scale_affine_flow(fgraph, node): + rv, value, *dims = node.inputs + rv_node = rv.owner + + entry = LOC_SCALE_FAMILIES.get(type(rv_node.op)) + if entry is None: + return None + + loc_idx, scale_idx, expected_transform = entry + if expected_transform is None: + if node.op.transform is not None: + return None + elif not isinstance(node.op.transform, expected_transform): + return None + + dist_params = list(rv_node.op.dist_params(rv_node)) + loc, scale = dist_params[loc_idx], dist_params[scale_idx] + loc_qualifies = _qualifies_for_hyper_param(rv, loc) + scale_qualifies = _qualifies_for_hyper_param(rv, scale) + if not (loc_qualifies or scale_qualifies): + return None + + # Each dist param earns its own per-element knob; a param that does + # not qualify gets a size-0 hyper param, which the flow pins at the + # centred no-op. + param_shape = _hyper_shape(value) + flow_rv = FlowFreeRV(AffineFlow, transform=node.op.transform)( + rv, + value, + loc, + scale, + hyperparam_shapes=[ + param_shape if loc_qualifies else _empty_shape, + param_shape if scale_qualifies else _empty_shape, + ], + dims=dims, + ) + return {node.outputs[0]: flow_rv} + + +@node_rewriter([ModelFreeRV]) +def scale_shift_flow(fgraph, node): + if not isinstance(node.op.transform, LogTransform): + return None + + rv, value, *dims = node.inputs + rv_node = rv.owner + + scale_idx = SCALE_SHIFT_FAMILIES.get(type(rv_node.op)) + if scale_idx is None: + return None + + scale = rv_node.op.dist_params(rv_node)[scale_idx] + if not _qualifies_for_hyper_param(rv, scale): + return None + + # VIP shift: the parent enters the log-space value additively through + # ``log(scale)``, so the flow shifts by ``h·log(scale)``; h = 0 is + # centred. + flow_rv = FlowFreeRV(ShiftFlow, transform=node.op.transform)( + rv, value, pt.log(scale), hyperparam_shapes=[_hyper_shape(value)], dims=dims + ) + return {node.outputs[0]: flow_rv} + + +@node_rewriter([ModelFreeRV]) +def zerosum_scale_flow(fgraph, node): + """VIP scale flow for ``ZeroSumNormal`` with a common (per-batch) + sigma. + + Under the default ``ZeroSumTransform`` the value var is iid + ``Normal(0, σ)`` over the reduced support dims (the transform is an + isometry onto the zero-sum hyperplane and σ has core shape 1, so the + scaling commutes with it). That makes the value-space RV a zero-loc + scale family: :class:`AffineFlow` with ``loc = 0`` (loc knob + withheld) applies, with one scale knob per batch element shared + across the core dims (a per-element knob there would break the + zero-sum coupling). + """ + if not isinstance(node.op.transform, ZeroSumTransform): + return None + + rv, value, *dims = node.inputs + rv_node = rv.owner + if not isinstance(rv_node.op, ZeroSumNormalRV): + return None + + # ZeroSumNormalRV constructs sigma with core shape (1, ...), so it is + # always broadcast along the core dims; guard the invariant anyway. + sigma = rv_node.op.dist_params(rv_node)[0] + if not all(sigma.type.broadcastable[sigma.type.ndim - rv_node.op.ndim_supp :]): + return None + if not _depends_on_free_rv([sigma]): + return None + + n_core = rv_node.op.ndim_supp + flow_rv = FlowFreeRV(AffineFlow, transform=node.op.transform)( + rv, + value, + pt.zeros((), dtype=value.dtype), + sigma, + hyperparam_shapes=[_empty_shape, _hyper_shape(value, n_core=n_core)], + dims=dims, + ) + return {node.outputs[0]: flow_rv} + + +# Tag taxonomy: +# "default" — plumbing + safe-by-default flow rewrites (used by the default +# query); always produces correct posteriors. +# "all" — everything, including opt-in rewrites that a user may want +# finer-grained control over. +# Per-flow tags ("affine", "icdf") allow targeted selection. +default_flow_query = RewriteDatabaseQuery(include=("default",)) +flow_db = SequenceDB() +flow_db.register("lower_xtensor", optdb.query("+lower_xtensor"), "default", "all") +flow_db.register( + "lift_xtensor_from_model_free_rv", + dfs_rewriter(lift_xtensor_from_model_free_rv), + "default", + "all", +) +flow_db.register( + "affine_flow", + dfs_rewriter( + loc_scale_affine_flow, + scale_shift_flow, + zerosum_scale_flow, + ), + "default", + "all", + "affine", +) + + +def _eval_hyper_shapes( + model: Model, parts: list[_FlowParts] +) -> dict[str, list[tuple[int, ...]]]: + """Evaluate every flow's symbolic hyper-param shape expressions at the + model's initial point, in one compile. Keyed by value-var name.""" + exprs = [s for p in parts for s in p.hyper_shapes] + if exprs: + inputs = list(explicit_graph_inputs(exprs)) + ip = model.initial_point() + vals = pytensor.function(inputs, exprs)(**{v.name: ip[v.name] for v in inputs}) + else: + vals = [] + shapes: dict[str, list[tuple[int, ...]]] = {} + idx = 0 + for p in parts: + k = p.flow_cls.n_hyper_params + shapes[p.value.name] = [tuple(int(x) for x in s) for s in vals[idx : idx + k]] + idx += k + return shapes + + +def _first_non_model_var(var): + while var.owner is not None and isinstance(var.owner.op, ModelVar): + var = var.owner.inputs[0] + return var + + +def reparametrize( + model: Model, + flow_db_query: RewriteDatabaseQuery = default_flow_query, + db: SequenceDB = flow_db, +) -> list[FlowSpec]: + """Run the flow rewrite and return one :class:`FlowSpec` per free RV, + in fgraph toposort order. + + The rewrite happens on the PyMC model IR (``fgraph_from_model``); + afterwards every ``ModelVar`` dummy is stripped — the same in-place + replacement ``model_from_fgraph`` performs — so the returned specs + live in plain rv/value/transform space. PyMC's + ``replace_rvs_by_values`` then composes parent dependencies and value + transforms correctly on them, including conditional transforms (which + read the RV's actual distribution parameters) and xtensor variables. + """ + fgraph, _memo = fgraph_from_model(model) + db.query(flow_db_query).rewrite(fgraph) + # Per-free-RV parts, in fgraph toposort order. + parts = [ + p for node in fgraph.toposort() if (p := _flow_node_parts(node)) is not None + ] + shapes = _eval_hyper_shapes(model, parts) + # Resolve held references to non-dummy vars *before* the strip: the + # in-place replacement below rewires the ancestors of vars that stay + # in the fgraph, but vars removed from it keep stale inputs. + resolved = [ + (p, [_first_non_model_var(g) for g in (p.rv, *p.model_params)]) for p in parts + ] + # Strip the IR in place (cf. ``model_from_fgraph``). Forward toposort + # order: a dummy is always replaced while its consumers are still in + # the graph, so subgraphs only reachable through a later-stripped + # FlowFreeRV node are rewired before they are pruned. + dummy_replacements = [ + (node.outputs[0], _first_non_model_var(node.inputs[0])) + for node in fgraph.toposort() + if isinstance(node.op, ModelVar) + ] + toposort_replace(fgraph, dummy_replacements) + return [ + FlowSpec( + flow_cls=p.flow_cls, + rv=rv, + value=p.value, + transform=p.transform, + model_params=model_params, + param_shapes=shapes[p.value.name], + ) + for p, (rv, *model_params) in resolved + ] + + +def automatic_flow_reparam( + model: Model, + flow_db_query: RewriteDatabaseQuery = default_flow_query, + db: SequenceDB = flow_db, +) -> dict[str, dict]: + """Run the flow rewrite and report, per unconstrained value variable, + which flow was chosen and the concrete shapes of its hyper params. + + Returns a ``dict`` keyed by the value variable's name (insertion order + matches fgraph toposort). Each value is a dict with: + + ``flow_cls`` — the :class:`Flow` descriptor class (``NoFlow`` when the + rewrite skipped the RV). + ``dtype`` — dtype of the flow's hyper params. + ``transform`` — the RV's value transform (``None`` or a PyMC + :class:`Transform`). + ``param_shapes`` — ``list[tuple[int, ...]]`` concrete shape per hyper + param. + """ + specs = reparametrize(model, flow_db_query, db) + return { + s.value.name: dict( + flow_cls=s.flow_cls, + dtype=s.rv.type.dtype, + transform=s.transform, + param_shapes=s.param_shapes, + ) + for s in specs + } + + +def build_flow_graph_from_specs( + specs: list[FlowSpec], + free_vars_info, + n_dim: int, +) -> dict[str, object]: + """Build the symbolic constrain/unconstrain flow maps over Nutpie's + flat point vector, plus a flat trainable flow-params vector. + + Each RV's ``constrain`` / ``unconstrain`` is first expressed against + fresh value/hyper-param placeholders with its value transform folded + in, then PyMC's :func:`replace_rvs_by_values` composes the + parent→child dependency in topological order. A final + ``toposort_replace`` substitutes each value placeholder by its chunk + of the flat point vector and each hyper placeholder by its slice of + the flat params vector — since those are introduced last, the + returned inputs are exact (no recovering cloned inputs by name). + + Parameters + ---------- + specs + Output of :func:`reparametrize`. + free_vars_info + Per-free-variable descriptors whose ``.name``, ``.start_idx``, + ``.end_idx`` and ``.shape`` define the flat point vector layout — + typically ``compiled_model._variables`` filtered to the free + (unconstrained) ones. Order determines packing into the flat + vector and must match Nutpie's. + n_dim + Total dimension of the flat point vector (``compiled_model.n_dim``). + + Returns + ------- + dict with keys: + ``flow_params_vector`` — flat trainable params ``pt.vector``. + ``constrain`` — ``(inputs, outputs)`` for ``y -> value``: + inputs ``[y_vector, flow_params_vector]``, + outputs ``[value_point, total_log_jac_det_constrain]``. + ``unconstrain`` — ``(inputs, outputs)`` for ``value -> y``. + """ + n_dim = int(n_dim) + order = [v.name for v in free_vars_info] + info = {v.name: v for v in free_vars_info} + + # Flat trainable params vector, sliced into each flow's hyper params. + param_shapes = [sh for s in specs for sh in s.param_shapes] + total = int(sum(np.prod(sh) for sh in param_shapes)) if param_shapes else 0 + flow_params = pt.vector("flow_params", dtype="float64", shape=(total,)) + splits = list(pt.unpack(flow_params, param_shapes)) if param_shapes else [] + # Fresh placeholder per hyper param; substituted by its split last. + hyper: dict[str, list] = {} + hyper_to_split: dict = {} + idx = 0 + for s in specs: + phs = [ + pt.tensor(f"{s.value.name}_hyper{i}", shape=sh, dtype="float64") + for i, sh in enumerate(s.param_shapes) + ] + hyper[s.value.name] = phs + for ph in phs: + hyper_to_split[ph] = splits[idx] + idx += 1 + + def _root_and_chunk(vec, spec): + # The lift rewrite leaves lifted values as derived expressions + # (``tensor_from_xtensor(xvalue)``); the flat chunk substitutes + # the *root* value var, in its own layout. + if spec.value.owner is None: + root = spec.value + else: + (root,) = explicit_graph_inputs([spec.value]) + v = info[spec.value.name] + chunk = vec[v.start_idx : v.end_idx].reshape(tuple(int(x) for x in v.shape)) + if isinstance(root.type, XTensorType): + chunk = xtensor_from_tensor(chunk, dims=root.type.dims, name=root.name) + return root, chunk + + def _build(direction: str): + y = pt.vector("y", shape=(n_dim,)) + # Per-direction copy of the model-param subgraphs: + # replace_rvs_by_values mutates replacement expressions in place + # when they nest other replaced rvs (see replace_vars_in_graphs), + # so the shared spec graphs must not be fed to it directly. The + # rv keys are pinned to identity so they stay valid keys. + memo = {s.rv: s.rv for s in specs} + all_params = [p for s in specs for p in s.model_params] + equiv = clone_get_equiv([], all_params, False, False, memo) + model_params = {s.value.name: [equiv[p] for p in s.model_params] for s in specs} + points: dict[str, object] = {} + ljds: dict[str, object] = {} + rvs_to_values: dict = {} + replacements: list = [] + for s in specs: + name = s.value.name + # Fresh root placeholder for this RV's flat chunk, substituted + # at the end; cloning inside replace_rvs_by_values keeps graph + # inputs identical, so the substitution is exact. The value + # var's derivation (if any) is rebuilt on top of it so the + # flow math sees the same layout as the spec graphs. + root, chunk = _root_and_chunk(y, s) + z_root = root.type(name=root.name) + replacements.append((z_root, chunk)) + if s.value is root: + z = z_root + else: + memo_v = clone_get_equiv( + [root], [s.value], False, False, {root: z_root} + ) + z = memo_v[s.value] + params = model_params[name] + if s.flow_cls is NoFlow: + point, ljd = z, pt.zeros(()) + elif direction == "constrain": + point = s.flow_cls.constrain(z, *params, *hyper[name]) + ljd = s.flow_cls.log_jac_det_constrain(z, *params, *hyper[name]) + else: + point = s.flow_cls.unconstrain(z, *params, *hyper[name]) + ljd = s.flow_cls.log_jac_det_unconstrain(z, *params, *hyper[name]) + points[name], ljds[name] = point, ljd + # A child reads this RV's *constrained* value off its parents: + # the flow output (value space) in constrain, the value var in + # unconstrain — backward-transformed with the RV's actual + # distribution parameters, so conditional transforms compose + # correctly. The backward is folded into the value here + # instead of passing rvs_to_transforms: with transforms, + # replace_rvs_by_values clones the graphs and remaps the + # *keys* to clones, so rvs nested inside other replacement + # values (a flow parent's point expression) would be missed. + rv_value = point if direction == "constrain" else z + if s.transform is not None: + rv_value = s.transform.backward(rv_value, *s.rv.owner.inputs) + rv_value = s.rv.type.filter_variable(rv_value, allow_convert=True) + rv_value.name = s.rv.name + rvs_to_values[s.rv] = rv_value + + graphs = [points[nm] for nm in order] + [ljds[nm] for nm in order] + graphs = replace_rvs_by_values(graphs, rvs_to_values=rvs_to_values) + n = len(order) + point_parts = [ + tensor_from_xtensor(g) if isinstance(g.type, XTensorType) else g + for g in graphs[:n] + ] + point_out = pt.concatenate([g.ravel() for g in point_parts]) + ljd_out = variadic_add(*graphs[n:]) + # Substitute the value/hyper placeholders by their flat-vector + # slices. replace_rvs_by_values keeps graph inputs identical when + # cloning, so the placeholders (and thus these replacements) are + # exact. + fg = FunctionGraph(outputs=[point_out, ljd_out], clone=False) + final_replacements = [ + (ph, repl) + for ph, repl in (*replacements, *hyper_to_split.items()) + if ph in fg.variables + ] + toposort_replace(fg, final_replacements) + return [y, flow_params], list(fg.outputs) + + return dict( + flow_params_vector=flow_params, + constrain=_build("constrain"), + unconstrain=_build("unconstrain"), + ) + + +def build_flow_graph( + model: Model, + free_vars_info, + n_dim: int, + flow_db_query: RewriteDatabaseQuery = default_flow_query, + db: SequenceDB = flow_db, +) -> dict[str, object]: + """:func:`reparametrize` + :func:`build_flow_graph_from_specs`.""" + specs = reparametrize(model, flow_db_query, db) + return build_flow_graph_from_specs(specs, free_vars_info, n_dim) + + +def free_vars_info(compiled_model): + """The compiled model's free (unconstrained) variable descriptors, + whose ``start_idx``/``end_idx``/``shape`` define Nutpie's flat point + vector layout.""" + n_dim = int(compiled_model.n_dim) + return [v for v in compiled_model._variables if v.end_idx <= n_dim] + + +def build_auto_flow( + model: Model, + compiled_model, + *, + init_params=None, + flow_db_query: RewriteDatabaseQuery = default_flow_query, + db: SequenceDB = flow_db, +): + """Build the VIP reparametrization of ``model`` as a single + :class:`nutpie.normalizing_flow.AutoFlow` over ``compiled_model``'s + flat point vector, ready to pass to + ``compiled_model.with_transform_adapt(auto_flow=...)``. + + Prints a summary of the reparametrized variables; if the rewrite found + nothing to reparametrize, warns and returns ``None`` instead (use + :func:`automatic_flow_reparam` for the full per-variable report). + + The PyTensor constrain/unconstrain maps from :func:`build_flow_graph` + are JIT-compiled to JAX. The flow's trainable parameters are the VIP + centering knobs; ``init_params`` defaults to zeros (λ = ½, halfway + between centred and non-centred). + + The flowjax/JAX bijection convention is ``transform: base -> target``, + matching Nutpie's ``transform_and_log_det(sampler) -> value``; so the + flow's ``transform`` is :func:`build_flow_graph`'s ``constrain`` and + its ``inverse`` is ``unconstrain``. + """ + from nutpie.normalizing_flow import AutoFlow + import jax.numpy as jnp + + specs = reparametrize(model, flow_db_query, db) + flowed = [s for s in specs if s.flow_cls is not NoFlow] + if not flowed: + warnings.warn( + "Automatic reparametrization did not find any variables to " + "reparametrize in this model." + ) + return None + chosen = ", ".join(f"{s.value.name} ({s.flow_cls.__name__})" for s in flowed) + print( + f"auto_reparam: reparametrizing {len(flowed)} of {len(specs)} " + f"free variables: {chosen}" + ) + + n_dim = int(compiled_model.n_dim) + g = build_flow_graph_from_specs(specs, free_vars_info(compiled_model), n_dim) + constrain_fn = pytensor.function(*g["constrain"], mode="JAX").vm.jit_fn + unconstrain_fn = pytensor.function(*g["unconstrain"], mode="JAX").vm.jit_fn + + total = int(g["flow_params_vector"].type.shape[0]) + if init_params is None: + init_params = jnp.zeros((total,)) + + return AutoFlow(init_params, (n_dim,), constrain_fn, unconstrain_fn) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index 97c54be..2123b11 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -473,6 +473,31 @@ def inverse_and_log_det(self, y: Array, condition: Array | None = None): return self._householder(y), jnp.zeros(()) +class AutoFlow(AbstractBijection): + shape: tuple[int, ...] + params: Array + transform_and_log_det_fn: Callable + inverse_and_log_det_fn: Callable + cond_shape = None + + def __init__( + self, params: ArrayLike, shape, transform_and_log_det_fn, inverse_and_log_det_fn + ): + params = arraylike_to_array(params) + if params.ndim != 1: + raise ValueError("params must be a vector.") + self.shape = shape + self.params = params + self.transform_and_log_det_fn = transform_and_log_det_fn + self.inverse_and_log_det_fn = inverse_and_log_det_fn + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + return self.transform_and_log_det_fn(x, self.params) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + return self.inverse_and_log_det_fn(y, self.params) + + class MvScale(bijections.AbstractBijection): shape: tuple[int, ...] params: Array @@ -1878,6 +1903,7 @@ def make_flow( sandwich_householder=False, activation=None, reuse_embed=False, + auto_flow=None, ): if activation is None: activation = jax.nn.leaky_relu @@ -1901,6 +1927,14 @@ def make_flow( n_draws, n_dim = positions.shape assert positions.shape == gradients.shape + # Push positions and gradients through autoflow + if auto_flow is not None: + from nutpie.transform_adapter import inverse_gradient_and_val + + positions, gradients, _logp = eqx.filter_vmap( + inverse_gradient_and_val, in_axes=(None, 0, 0, 0) + )(auto_flow, positions, gradients, jnp.zeros((n_draws,))) + if n_draws == 0: raise ValueError("No draws") elif n_draws == 1: @@ -1927,9 +1961,9 @@ def make_flow( replace=diag_param, ) - flows = [ - diag_affine, - ] + flows = [diag_affine] + if auto_flow is not None: + flows.append(auto_flow) if n_layers == 0: return bijections.Chain(flows) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 41d18f4..b263ffa 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -874,7 +874,7 @@ def make_transform_adapter( show_progress=False, nn_depth=None, nn_width=None, - num_layers=8, + num_layers=None, num_diag_windows=6, learning_rate=5e-4, untransformed_dim=None, @@ -905,10 +905,21 @@ def make_transform_adapter( contract_transformer=True, asymmetric_transformer=False, reuse_embed=True, + auto_flow=None, ): if extension_windows is None: extension_windows = [] + # Several auto flows compose into a single bijection; flowjax applies + # Chain members in order in the transform (sampler -> value) direction. + if isinstance(auto_flow, (list, tuple)): + auto_flow = bijections.Chain(list(auto_flow)) if auto_flow else None + + # With an auto flow, default to fitting only the reparametrization (and + # the diag affine); set num_layers explicitly to add coupling layers. + if num_layers is None: + num_layers = 0 if auto_flow is not None else 8 + return partial( TransformAdapter, verbose=verbose, @@ -930,6 +941,7 @@ def make_transform_adapter( contract_transformer=contract_transformer, asymmetric_transformer=asymmetric_transformer, reuse_embed=reuse_embed, + auto_flow=auto_flow, ), show_progress=show_progress, num_diag_windows=num_diag_windows, diff --git a/tests/test_flow_reparam.py b/tests/test_flow_reparam.py new file mode 100644 index 0000000..7c45273 --- /dev/null +++ b/tests/test_flow_reparam.py @@ -0,0 +1,794 @@ +from importlib.util import find_spec + +import pytest + +if find_spec("pymc") is None: + pytest.skip("Skip pymc tests", allow_module_level=True) + +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt +from pymc import dims as pmd +from pymc.distributions.transforms import Interval + +try: + from pytensor.gradient import pullback +except ImportError: # pytensor < 3.0 used the name Lop + from pytensor.gradient import Lop as pullback + +import nutpie +from nutpie.flow_reparam import ( + AffineFlow, + NoFlow, + ShiftFlow, + automatic_flow_reparam, + build_auto_flow, + build_flow_graph, + free_vars_info, +) + + +def _compile_flow(model): + """Compile the model with nutpie, build the flow graph over its + flat variables, and compile the pytensor constrain/unconstrain + functions.""" + records = automatic_flow_reparam(model) + compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax") + g = build_flow_graph(model, free_vars_info(compiled), compiled.n_dim) + constrain_fn = pytensor.function(*g["constrain"]) + unconstrain_fn = pytensor.function(*g["unconstrain"]) + return compiled, records, constrain_fn, unconstrain_fn + + +def _total_params(records): + return sum(int(np.prod(sh)) for r in records.values() for sh in r["param_shapes"]) + + +def _var_slice(compiled, name): + v = next(v for v in compiled._variables if v.name == name) + return slice(v.start_idx, v.end_idx) + + +@pytest.mark.pymc +def test_root_rv_not_reparametrized(): + with pm.Model() as m: + pm.Normal("x", 0, 1, shape=(3,)) + + records = automatic_flow_reparam(m) + assert records["x"]["flow_cls"] is NoFlow + assert records["x"]["param_shapes"] == [] + + +@pytest.mark.pymc +def test_transformed_rv_not_reparametrized(): + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + pop_mu = pm.Normal("pop_mu", 0, 1) + pop_sigma = pm.HalfNormal("pop_sigma", 1) + # This could be fine, but there may rewrites + # like ordered/zerosum that change things too much? + pm.Normal( + "ind_mu", + pop_mu, + pop_sigma, + dims="group", + transform=Interval(lower=-10, upper=10), + ) + + records = automatic_flow_reparam(m) + (name,) = [n for n in records if n.startswith("ind_mu")] + assert records[name]["flow_cls"] is NoFlow + + +@pytest.mark.parametrize( + "dist_fn", + [ + lambda mu, sigma, **k: pm.Normal("y", mu, sigma, **k), + lambda mu, sigma, **k: pm.Cauchy("y", mu, sigma, **k), + lambda mu, sigma, **k: pm.Laplace("y", mu, sigma, **k), + lambda mu, sigma, **k: pm.Logistic("y", mu, sigma, **k), + lambda mu, sigma, **k: pm.Gumbel("y", mu, sigma, **k), + lambda mu, sigma, **k: pm.StudentT("y", nu=3, mu=mu, sigma=sigma, **k), + lambda mu, sigma, **k: pm.LogNormal("y", mu, sigma, **k), + ], + ids=["Normal", "Cauchy", "Laplace", "Logistic", "Gumbel", "StudentT", "LogNormal"], +) +@pytest.mark.pymc +def test_loc_scale_affine_flow(dist_fn): + with pm.Model(coords={"group": range(4)}) as m: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + dist_fn(mu=mu, sigma=sigma, dims="group") + + records = automatic_flow_reparam(m) + (name,) = [n for n in records if n in {"y", "y_log__"}] + assert records[name]["flow_cls"] is AffineFlow + assert records[name]["param_shapes"] == [(4,), (4,)] + + +@pytest.mark.pymc +def test_per_param_qualification(): + """A dist param that is constant or full-shape gets a size-0 knob, + pinned at the centred no-op.""" + coords = {"group": range(3)} + + with pm.Model(coords=coords) as m: + sigma = pm.HalfNormal("sigma", 1) + pm.Normal("y", 0.0, sigma, dims="group") + r = automatic_flow_reparam(m)["y"] + assert r["flow_cls"] is AffineFlow + assert r["param_shapes"] == [(0,), (3,)] + + with pm.Model(coords=coords) as m: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 2.0, dims="group") + r = automatic_flow_reparam(m)["y"] + assert r["flow_cls"] is AffineFlow + assert r["param_shapes"] == [(3,), (0,)] + + with pm.Model(coords=coords) as m: + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1, dims="group") + pm.Normal("y", mu, sigma, dims="group") + r = automatic_flow_reparam(m)["y"] + assert r["flow_cls"] is AffineFlow + assert r["param_shapes"] == [(3,), (0,)] + + with pm.Model(coords=coords) as m: + pm.Normal("root", 0, 1) + pm.Normal("y", 0.0, 2.0, dims="group") + r = automatic_flow_reparam(m)["y"] + assert r["flow_cls"] is NoFlow + + +@pytest.mark.pymc +def test_scale_shift_flow(): + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + beta = pm.HalfNormal("beta", 1) + pm.Gamma("x", alpha=2.0, beta=beta, dims="group") + + r = automatic_flow_reparam(m)["x_log__"] + assert r["flow_cls"] is ShiftFlow + assert r["param_shapes"] == [(3,)] + + +@pytest.mark.pymc +def test_zerosum_scale_flow(): + with pm.Model(coords={"group": range(5)}) as m: + sigma = pm.HalfNormal("pop_sigma", 1) + pm.ZeroSumNormal("x", sigma=sigma, dims="group") + + r = automatic_flow_reparam(m)["x_zerosum__"] + assert r["flow_cls"] is AffineFlow + # No loc knob; one scale knob shared across the zero-sum core dim. + assert r["param_shapes"] == [(0,), (1,)] + + with pm.Model(coords={"batch": range(2), "group": range(5)}) as m: + sigma = pm.HalfNormal("pop_sigma", 1) + pm.ZeroSumNormal("x", sigma=sigma, dims=("batch", "group")) + + r = automatic_flow_reparam(m)["x_zerosum__"] + assert r["flow_cls"] is AffineFlow + assert r["param_shapes"] == [(0,), (2, 1)] + + with pm.Model(coords={"group": range(5)}) as m: + pm.Normal("root", 0, 1) + pm.ZeroSumNormal("x", sigma=2.0, dims="group") + + r = automatic_flow_reparam(m)["x_zerosum__"] + assert r["flow_cls"] is NoFlow + + +@pytest.mark.pymc +def test_hierarchical_normal(): + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + pop_mu = pm.Normal("pop_mu", 0, 1) + pop_sigma = pm.HalfNormal("pop_sigma", 1) + pm.Normal("ind_mu", pop_mu, pop_sigma, dims="group") + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + + assert records["pop_mu"]["flow_cls"] is NoFlow + assert records["pop_sigma_log__"]["flow_cls"] is NoFlow + assert records["ind_mu"]["flow_cls"] is AffineFlow + assert records["ind_mu"]["param_shapes"] == [(3,), (3,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 5 # pop_mu + pop_sigma_log__ + ind_mu(group=3) + total_params = _total_params(records) + assert total_params == 6 + pop_mu_sl = _var_slice(compiled, "pop_mu") + pop_sigma_sl = _var_slice(compiled, "pop_sigma_log__") + ind_mu_sl = _var_slice(compiled, "ind_mu") + + # Centred no-op (h = 0): both directions are the identity, log|J| = 0. + point = np.arange(n_dim, dtype="float64") + zero_params = np.zeros(total_params, dtype="float64") + c_point, ljd_c = constrain_fn(point, zero_params) + u_point, ljd_u = unconstrain_fn(point, zero_params) + np.testing.assert_allclose(c_point, point, atol=1e-9) + np.testing.assert_allclose(u_point, point, atol=1e-9) + np.testing.assert_allclose(ljd_c, 0.0, atol=1e-9) + np.testing.assert_allclose(ljd_u, 0.0, atol=1e-9) + + # Random non-identity: roundtrip and log|J| cancellation. + rng = np.random.default_rng(0) + phi0 = rng.normal(size=n_dim).astype("float64") + rand_params = rng.normal(size=total_params).astype("float64") * 0.3 + value, ljd_c = constrain_fn(phi0, rand_params) + phi_back, ljd_u = unconstrain_fn(value, rand_params) + np.testing.assert_allclose(phi_back, phi0, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + assert not np.isclose(ljd_c, 0.0) + + # Analytical per-element VIP transform (Gorinova et al. 2019, with + # h = 1-λ so that h = 0 is centred): + # value_i = μ + σ^h_σi·(y_i - (1-h_μi)·μ), log|J| = Σ_i h_σi·log σ + h_mu = np.array([0.4, 0.2, 0.0]) + h_sigma = np.array([0.1, 0.5, 0.9]) + analytic_params = np.concatenate([h_mu, h_sigma]) + y = rng.normal(size=n_dim).astype("float64") + mu = y[pop_mu_sl][0] + sigma = np.exp(y[pop_sigma_sl][0]) + value, ljd_c = constrain_fn(y, analytic_params) + expected = mu + sigma**h_sigma * (y[ind_mu_sl] - (1 - h_mu) * mu) + np.testing.assert_allclose(value[ind_mu_sl], expected, atol=1e-10) + np.testing.assert_allclose(ljd_c, (h_sigma * np.log(sigma)).sum(), atol=1e-10) + + +@pytest.mark.pymc +def test_partial_broadcast(): + coords = {"group": range(3), "rep": range(4)} + with pm.Model(coords=coords) as m: + pop_mu = pm.Normal("pop_mu", 0, 1, dims="group") + pm.Normal("x", pop_mu[:, None], 1.0, dims=("group", "rep")) + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + assert records["x"]["flow_cls"] is AffineFlow + assert records["x"]["param_shapes"] == [(3, 4), (0,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 15 # pop_mu(group=3) + x(group=3, rep=4) + rng = np.random.default_rng(1) + phi0 = rng.normal(size=n_dim).astype("float64") + flow_params = rng.normal(size=_total_params(records)).astype("float64") * 0.2 + + value, ljd_c = constrain_fn(phi0, flow_params) + phi_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(phi_back, phi0, atol=1e-10) + # With the scale knob pinned, the transform is a translation. + np.testing.assert_allclose(ljd_c, 0.0, atol=1e-10) + np.testing.assert_allclose(ljd_u, 0.0, atol=1e-10) + + +@pytest.mark.pymc +def test_flow_alongside_dirichlet(): + coords = {"group": [0, 1], "k": range(3)} + with pm.Model(coords=coords) as m: + pi = pm.Dirichlet("pi", a=np.ones(3), dims="k") + pop_sigma = pm.HalfNormal("pop_sigma", 1) + pm.Normal("ind_mu", pi, pop_sigma, dims=("group", "k")) + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + assert records["ind_mu"]["flow_cls"] is AffineFlow + assert records["ind_mu"]["param_shapes"] == [(2, 3), (2, 3)] + (pi_name,) = [n for n in records if n.startswith("pi")] + assert records[pi_name]["flow_cls"] is NoFlow + + n_dim = int(compiled.n_dim) + assert n_dim == 9 # 2 (dirichlet, simplex-transformed) + 1 + 6 + total_params = _total_params(records) + + phi0 = np.arange(n_dim, dtype="float64") + zero_params = np.zeros(total_params, dtype="float64") + c_point, ljd_c = constrain_fn(phi0, zero_params) + np.testing.assert_allclose(c_point, phi0, atol=1e-9) + np.testing.assert_allclose(ljd_c, 0.0, atol=1e-9) + + rng = np.random.default_rng(0) + flow_params = rng.normal(size=total_params).astype("float64") * 0.3 + value, ljd_c = constrain_fn(phi0, flow_params) + phi_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(phi_back, phi0, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + +@pytest.mark.pymc +def test_zerosum_roundtrip(): + with pm.Model(coords={"group": range(5)}) as m: + sigma = pm.HalfNormal("pop_sigma", 1) + pm.ZeroSumNormal("x", sigma=sigma, dims="group") + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + n_dim = int(compiled.n_dim) + assert n_dim == 5 # pop_sigma_log__ + x_zerosum__(4) + total_params = _total_params(records) + assert total_params == 1 + + phi0 = np.arange(n_dim, dtype="float64") + zero_params = np.zeros(total_params, dtype="float64") + c_point, ljd_c = constrain_fn(phi0, zero_params) + np.testing.assert_allclose(c_point, phi0, atol=1e-9) + np.testing.assert_allclose(ljd_c, 0.0, atol=1e-9) + + rng = np.random.default_rng(0) + phi0 = rng.normal(size=n_dim).astype("float64") + flow_params = np.array([0.7]) + value, ljd_c = constrain_fn(phi0, flow_params) + phi_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(phi_back, phi0, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + # Full NCP (h = 1): value = y·σ, log|J| = 4·log σ. + x_sl = _var_slice(compiled, "x_zerosum__") + s_sl = _var_slice(compiled, "pop_sigma_log__") + value, ljd_c = constrain_fn(phi0, np.array([1.0])) + sigma_val = np.exp(phi0[s_sl][0]) + np.testing.assert_allclose(value[x_sl], phi0[x_sl] * sigma_val, atol=1e-10) + np.testing.assert_allclose(ljd_c, 4 * np.log(sigma_val), atol=1e-10) + + +@pytest.mark.pymc +def test_dim_distributions(): + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + pop_mu = pmd.Normal("pop_mu", 0, 1) + pop_sigma = pmd.HalfNormal("pop_sigma", 1) + pmd.Normal("ind_mu", pop_mu, pop_sigma, dims=("group",)) + pmd.LogNormal("scale", pop_mu, pop_sigma, dims=("group",)) + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + assert records["pop_mu"]["flow_cls"] is NoFlow + assert records["pop_sigma_log__"]["flow_cls"] is NoFlow + assert records["ind_mu"]["flow_cls"] is AffineFlow + assert records["ind_mu"]["param_shapes"] == [(3,), (3,)] + assert records["scale_log__"]["flow_cls"] is AffineFlow + assert records["scale_log__"]["param_shapes"] == [(3,), (3,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 8 # pop_mu + pop_sigma_log__ + ind_mu(3) + scale_log__(3) + total_params = _total_params(records) + phi0 = np.arange(n_dim, dtype="float64") + + zero_params = np.zeros(total_params, dtype="float64") + c_point, ljd_c = constrain_fn(phi0, zero_params) + np.testing.assert_allclose(c_point, phi0, atol=1e-9) + np.testing.assert_allclose(ljd_c, 0.0, atol=1e-9) + + rng = np.random.default_rng(0) + flow_params = rng.normal(size=total_params).astype("float64") * 0.2 + value, ljd_c = constrain_fn(phi0, flow_params) + phi_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(phi_back, phi0, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + +def _fisher_loss_fn(model, compiled): + """Build ``loss(draws, flow_params)`` mirroring nutpie's FisherLoss: + pull value-space posterior draws and logp-gradients back through the + flow and measure deviation from a standard-normal score, minimized + analytically over the per-coordinate affine that the production + chain's diagonal-affine layers would absorb. Zero iff the pulled-back + posterior is iid normal up to a diagonal affine.""" + n_dim = int(compiled.n_dim) + free_vars = free_vars_info(compiled) + g = build_flow_graph(model, free_vars, n_dim) + + (y, flow_params), (value_out, ljd) = g["constrain"] + value_grad = value_out.type("value_grad") + # vjp of (constrain, log|J|) at cotangent (grad, 1), as in + # transform_adapter.inverse_gradient_and_val. + pulled_grad = pullback([value_out, ljd], y, [value_grad, pt.ones(())]) + pullback_fn = pytensor.function([y, flow_params, value_grad], pulled_grad) + unconstrain_fn = pytensor.function(*g["unconstrain"]) + + value_vars = {v.name: v for v in model.value_vars} + ordered = [value_vars[v.name] for v in free_vars] + grad_fn = pytensor.function(ordered, pt.grad(model.logp(), ordered)) + + def flat_grad(draw): + vals = [ + draw[v.start_idx : v.end_idx].reshape(tuple(int(s) for s in v.shape)) + for v in free_vars + ] + return np.concatenate([np.asarray(a).ravel() for a in grad_fn(*vals)]) + + def loss(draws, params): + params = np.asarray(params, dtype="float64") + xs, gs = [], [] + for draw in draws: + x, _ = unconstrain_fn(draw, params) + xs.append(x) + gs.append(pullback_fn(x, params, flat_grad(draw))) + x, g = np.array(xs), np.array(gs) + # min over per-coordinate affine z = a·(x-b) of E[(z + g/a)²]: + # 2·(sqrt(Var x·Var g) + Cov(x, g)), ≥ 0 by Cauchy-Schwarz. + cov = ((x - x.mean(0)) * (g - g.mean(0))).mean(0) + # ≥ 0 by Cauchy-Schwarz; clamp the float noise around exact zeros. + costs = np.maximum(2 * (np.sqrt(x.var(0) * g.var(0)) + cov), 0.0) + return float(np.log(np.maximum(costs.sum(), 1e-300))) + + return loss + + +def _funnel_model(obs_sigma=None, n_groups=5, n_obs=1000): + """The mixed-balance funnel from notebooks/auto_reparam-Copy1.ipynb: + per-group means with a common log-scale hyper, optionally observed + with per-group noise.""" + coords = {"group": range(n_groups)} + with pm.Model(coords=coords) as m: + s = pm.Normal("pop_sigma_log", 0, 1) + ind_mu = pm.Normal("ind_mu", 0, pm.math.exp(s / 2), dims="group") + if obs_sigma is not None: + rng = np.random.default_rng(1) + y = rng.normal( + loc=np.linspace(-0.5, 0.5, n_groups), + scale=1.0, + size=(n_obs, n_groups), + ) + pm.Normal("y", ind_mu, sigma=np.asarray(obs_sigma), observed=y) + return m, y + return m, None + + +def _funnel_posterior_draws(compiled, rng, n_draws, obs_sigma=None, y=None): + """Exact posterior draws in nutpie's flat value space: dense-grid + sampling for the 1-D hyper (conjugate marginal over the group means), + then the exact Gaussian conditional for ind_mu | s, y.""" + s_sl = _var_slice(compiled, "pop_sigma_log") + m_sl = _var_slice(compiled, "ind_mu") + n_groups = m_sl.stop - m_sl.start + + if obs_sigma is None: + s = rng.normal(0.0, 1.0, size=n_draws) + means = np.zeros((n_draws, n_groups)) + post_sd = np.exp(s / 2)[:, None] * np.ones(n_groups) + else: + obs_sigma = np.asarray(obs_sigma, dtype="float64") + n_obs = y.shape[0] + ybar = y.mean(axis=0) + grid = np.linspace(-10, 10, 8001) + tau2 = np.exp(grid) + # p(s | y) ∝ p(s) · Π_i N(ȳ_i; 0, τ² + σ_i²/n) + marg_var = tau2[:, None] + (obs_sigma**2 / n_obs)[None, :] + log_post = -0.5 * grid**2 - 0.5 * ( + np.log(marg_var) + ybar[None, :] ** 2 / marg_var + ).sum(axis=1) + p = np.exp(log_post - log_post.max()) + s = rng.choice(grid, size=n_draws, p=p / p.sum()) + prec = 1 / np.exp(s)[:, None] + (n_obs / obs_sigma**2)[None, :] + means = (n_obs * ybar / obs_sigma**2)[None, :] / prec + post_sd = 1 / np.sqrt(prec) + + ind_mu = rng.normal(means, post_sd) + draws = np.empty((n_draws, int(compiled.n_dim))) + draws[:, s_sl] = s[:, None] + draws[:, m_sl] = ind_mu + return draws + + +@pytest.mark.pymc +def test_loss_prefers_noncentered_on_prior_funnel(): + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + loss = _fisher_loss_fn(m, compiled) + + rng = np.random.default_rng(42) + draws = _funnel_posterior_draws(compiled, rng, n_draws=256) + + cp = loss(draws, np.zeros(5)) + ncp = loss(draws, np.ones(5)) + assert ncp < cp + # Full NCP standardizes the prior funnel exactly, so the loss is ~0. + assert ncp < -25 + + +@pytest.mark.pymc +def test_loss_prefers_centered_on_strong_data(): + obs_sigma = np.ones(5) + m, y = _funnel_model(obs_sigma=obs_sigma) + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + loss = _fisher_loss_fn(m, compiled) + + rng = np.random.default_rng(42) + draws = _funnel_posterior_draws(compiled, rng, 256, obs_sigma=obs_sigma, y=y) + + assert loss(draws, np.zeros(5)) < loss(draws, np.ones(5)) + + +@pytest.mark.pymc +def test_loss_prefers_mixed_on_mixed_balance(): + """Strong-evidence groups want centred, weak ones non-centred: the + per-element mixed parameterization beats both global ones — the case + a single per-group knob cannot express.""" + obs_sigma = np.array([1.0, 1000.0, 1.0, 1000.0, 1.0]) + m, y = _funnel_model(obs_sigma=obs_sigma) + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + loss = _fisher_loss_fn(m, compiled) + + rng = np.random.default_rng(42) + draws = _funnel_posterior_draws(compiled, rng, 256, obs_sigma=obs_sigma, y=y) + + mixed = loss(draws, np.array([0.0, 1.0, 0.0, 1.0, 0.0])) + assert mixed < loss(draws, np.zeros(5)) + assert mixed < loss(draws, np.ones(5)) + + +@pytest.mark.pymc +def test_loss_prefers_noncentered_on_zerosum_prior(): + with pm.Model(coords={"group": range(5)}) as m: + s = pm.Normal("s", 0, 1) + x = pm.ZeroSumNormal("x", sigma=pm.math.exp(s), dims="group") + + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + loss = _fisher_loss_fn(m, compiled) + + s_draws, x_draws = pm.draw([s, x], draws=256, random_seed=42) + x_in = pt.matrix("x_in") + forward_fn = pytensor.function([x_in], m.rvs_to_transforms[x].forward(x_in)) + + draws = np.empty((256, int(compiled.n_dim))) + draws[:, _var_slice(compiled, "s")] = s_draws[:, None] + draws[:, _var_slice(compiled, "x_zerosum__")] = forward_fn(x_draws) + + cp = loss(draws, np.zeros(1)) + ncp = loss(draws, np.ones(1)) + assert ncp < cp + # The zero-sum transform is an isometry, so full NCP is exact. + assert ncp < -25 + + +@pytest.mark.pymc +def test_conditional_transform_parent(): + # A parent with a *conditional* transform (Interval reads the RV's + # distribution parameters) feeding a flow child: the transform must be + # applied with the actual rv inputs, not IR wrapper inputs. + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + a = pm.Uniform("a", -1.0, 3.0) + pop_sigma = pm.HalfNormal("pop_sigma", 1) + pm.Normal("x", a, pop_sigma, dims="group") + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + assert records["a_interval__"]["flow_cls"] is NoFlow + assert records["pop_sigma_log__"]["flow_cls"] is NoFlow + assert records["x"]["flow_cls"] is AffineFlow + assert records["x"]["param_shapes"] == [(3,), (3,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 5 + a_sl = _var_slice(compiled, "a_interval__") + sigma_sl = _var_slice(compiled, "pop_sigma_log__") + x_sl = _var_slice(compiled, "x") + + rng = np.random.default_rng(2) + y = rng.normal(size=n_dim) + h_mu = np.array([0.7, 0.3, 0.1]) + h_sigma = np.array([0.2, 0.5, 0.8]) + flow_params = np.concatenate([h_mu, h_sigma]) + + value, ljd_c = constrain_fn(y, flow_params) + # interval backward: lower + (upper - lower) * sigmoid(value) + a_con = -1.0 + 4.0 / (1.0 + np.exp(-y[a_sl][0])) + sigma_con = np.exp(y[sigma_sl][0]) + expected_x = (y[x_sl] - (1 - h_mu) * a_con) * sigma_con**h_sigma + a_con + np.testing.assert_allclose(value[x_sl], expected_x, atol=1e-10) + np.testing.assert_allclose(ljd_c, (h_sigma * np.log(sigma_con)).sum(), atol=1e-10) + + y_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(y_back, y, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + +@pytest.mark.pymc +def test_flow_parent_read_through_expression(): + # A *flow* parent read by its child through an expression (indexing): + # the constrain and unconstrain graphs must each compose their own + # direction of the parent flow (regression for in-place mutation of + # the shared param subgraphs between the two builds). + coords = {"group": range(3), "rep": range(4)} + with pm.Model(coords=coords) as m: + pop_mu = pm.Normal("pop_mu", 0, 1) + pop_sigma = pm.HalfNormal("pop_sigma", 1) + mu_g = pm.Normal("mu_g", pop_mu, pop_sigma, dims="group") + pm.Normal("x", mu_g[:, None], 1.0, dims=("group", "rep")) + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + assert records["mu_g"]["flow_cls"] is AffineFlow + assert records["mu_g"]["param_shapes"] == [(3,), (3,)] + assert records["x"]["flow_cls"] is AffineFlow + # Constant sigma: the scale knob is withheld (size 0) and pinned. + assert records["x"]["param_shapes"] == [(3, 4), (0,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 17 # 1 + 1 + 3 + 12 + total_params = _total_params(records) + assert total_params == 18 + + rng = np.random.default_rng(3) + y = rng.normal(size=n_dim) + h1_mu = rng.normal(size=3) * 0.4 + h1_sigma = rng.normal(size=3) * 0.4 + h2_mu = rng.normal(size=(3, 4)) * 0.4 + flow_params = np.concatenate([h1_mu, h1_sigma, h2_mu.ravel()]) + + mu_sl = _var_slice(compiled, "pop_mu") + sigma_sl = _var_slice(compiled, "pop_sigma_log__") + mug_sl = _var_slice(compiled, "mu_g") + x_sl = _var_slice(compiled, "x") + + value, ljd_c = constrain_fn(y, flow_params) + pop_mu_con = y[mu_sl][0] + sigma_con = np.exp(y[sigma_sl][0]) + mug_con = (y[mug_sl] - (1 - h1_mu) * pop_mu_con) * sigma_con**h1_sigma + pop_mu_con + # With the scale knob pinned the child flow is the translation + # y + h_mu·loc. + expected_x = y[x_sl].reshape(3, 4) + h2_mu * mug_con[:, None] + np.testing.assert_allclose(value[mug_sl], mug_con, atol=1e-10) + np.testing.assert_allclose(value[x_sl], expected_x.ravel(), atol=1e-10) + np.testing.assert_allclose(ljd_c, (h1_sigma * np.log(sigma_con)).sum(), atol=1e-10) + + y_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(y_back, y, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + +@pytest.mark.pymc +def test_xtensor_unknown_transform_parent(): + # A dims RV whose transform has no plain counterpart (Beta -> logodds) + # is not lifted: it stays xtensor-typed end to end and its dim + # transform is applied natively. + coords = {"group": [0, 1, 2]} + with pm.Model(coords=coords) as m: + p = pmd.Beta("p", 1.0, 1.0) + pop_sigma = pmd.HalfNormal("pop_sigma", 1) + pmd.Normal("x", p, pop_sigma, dims=("group",)) + + compiled, records, constrain_fn, unconstrain_fn = _compile_flow(m) + (p_name,) = [n for n in records if n.startswith("p_")] + assert records[p_name]["flow_cls"] is NoFlow + assert records["x"]["flow_cls"] is AffineFlow + assert records["x"]["param_shapes"] == [(3,), (3,)] + + n_dim = int(compiled.n_dim) + assert n_dim == 5 + p_sl = _var_slice(compiled, p_name) + sigma_sl = _var_slice(compiled, "pop_sigma_log__") + x_sl = _var_slice(compiled, "x") + + rng = np.random.default_rng(4) + y = rng.normal(size=n_dim) + h_mu = np.array([0.6, 0.2, 0.9]) + h_sigma = np.array([0.4, 0.7, 0.1]) + flow_params = np.concatenate([h_mu, h_sigma]) + + value, ljd_c = constrain_fn(y, flow_params) + p_con = 1.0 / (1.0 + np.exp(-y[p_sl][0])) # logodds backward + sigma_con = np.exp(y[sigma_sl][0]) + expected_x = (y[x_sl] - (1 - h_mu) * p_con) * sigma_con**h_sigma + p_con + np.testing.assert_allclose(value[x_sl], expected_x, atol=1e-10) + np.testing.assert_allclose(ljd_c, (h_sigma * np.log(sigma_con)).sum(), atol=1e-10) + + y_back, ljd_u = unconstrain_fn(value, flow_params) + np.testing.assert_allclose(y_back, y, atol=1e-10) + np.testing.assert_allclose(ljd_c + ljd_u, 0.0, atol=1e-10) + + +@pytest.mark.pymc +@pytest.mark.flow +def test_build_auto_flow_roundtrip(): + import jax.numpy as jnp + + from nutpie.normalizing_flow import AutoFlow + + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + flow = build_auto_flow(m, compiled, init_params=jnp.full((5,), 0.3)) + assert isinstance(flow, AutoFlow) + assert flow.shape == (int(compiled.n_dim),) + + rng = np.random.default_rng(0) + y = jnp.asarray(rng.normal(size=flow.shape)) + x, ljd = flow.transform_and_log_det(y) + y_back, ljd_back = flow.inverse_and_log_det(x) + np.testing.assert_allclose(np.asarray(y_back), np.asarray(y), atol=1e-10) + np.testing.assert_allclose(float(ljd) + float(ljd_back), 0.0, atol=1e-10) + assert not np.isclose(float(ljd), 0.0) + + +@pytest.mark.pymc +@pytest.mark.flow +def test_auto_reparam_compile_api(capsys): + from nutpie.normalizing_flow import AutoFlow + + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model( + m, backend="jax", gradient_backend="jax", auto_reparam=True + ) + summary = capsys.readouterr().out + assert "reparametrizing 1 of 2 free variables" in summary + assert "ind_mu (AffineFlow)" in summary + auto_flow = compiled._transform_adapt_args["auto_flow"] + assert isinstance(auto_flow, AutoFlow) + + tuned = compiled.with_transform_adapt(num_layers=0) + assert tuned._transform_adapt_args["auto_flow"] is auto_flow + assert tuned._transform_adapt_args["num_layers"] == 0 + cleared = tuned.with_transform_adapt(auto_flow=None) + assert "auto_flow" not in cleared._transform_adapt_args + + with pytest.raises(ValueError, match="auto_reparam"): + nutpie.compile_pymc_model(m, auto_reparam=True) + with pytest.raises(ValueError, match="auto_reparam"): + nutpie.compile_pymc_model( + m, backend="jax", gradient_backend="pytensor", auto_reparam=True + ) + + +@pytest.mark.pymc +@pytest.mark.flow +@pytest.mark.parametrize("n_layers", [0, 2]) +def test_auto_flow_is_outermost_bijection(n_layers): + """The VIP flow's constrain output is the model's value vector, so it + must sit at the value-space end of the chain; the diag affine and any + coupling layers operate in its base space.""" + from nutpie.normalizing_flow import AutoFlow, make_flow + + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + auto_flow = build_auto_flow(m, compiled) + + rng = np.random.default_rng(0) + n_dim = int(compiled.n_dim) + positions = rng.normal(size=(10, n_dim)) + gradients = rng.normal(size=(10, n_dim)) + chain = make_flow(1, positions, gradients, n_layers=n_layers, auto_flow=auto_flow) + assert isinstance(chain.bijections[-1], AutoFlow) + + +@pytest.mark.pymc +@pytest.mark.flow +def test_auto_reparam_nothing_found(): + with pm.Model() as m: + pm.Normal("x", 0, 1, shape=(3,)) + + with pytest.warns(UserWarning, match="did not find any variables"): + compiled = nutpie.compile_pymc_model( + m, backend="jax", gradient_backend="jax", auto_reparam=True + ) + assert "auto_flow" not in (compiled._transform_adapt_args or {}) + + +@pytest.mark.pymc +@pytest.mark.flow +def test_multiple_auto_flows_chain_to_one(): + from flowjax import bijections + + from nutpie.transform_adapter import make_transform_adapter + + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model(m, backend="jax", gradient_backend="jax") + flow = build_auto_flow(m, compiled) + adapter = make_transform_adapter(auto_flow=[flow, flow]) + chained = adapter.keywords["make_flow_fn"].keywords["auto_flow"] + assert isinstance(chained, bijections.Chain) + assert len(chained.bijections) == 2 + + +@pytest.mark.pymc +@pytest.mark.flow +def test_auto_reparam_sampling(): + m, _ = _funnel_model() + compiled = nutpie.compile_pymc_model( + m, backend="jax", gradient_backend="jax", auto_reparam=True + ) + trace = nutpie.sample( + compiled, chains=1, seed=1, adaptation="flow", tune=1000, draws=500 + ) + assert float(trace.sample_stats.diverging.sum()) <= 5 + np.testing.assert_allclose( + float(trace.posterior.pop_sigma_log.std()), 1.0, atol=0.3 + )