Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,14 @@ def make_adapter(*args, **kwargs):
)

def with_transform_adapt(self, **kwargs):
return dataclasses.replace(self, _transform_adapt_args=kwargs)
"""Set arguments for the flow transform adapter (``adaptation="flow"``).

Arguments accumulate across calls; pass ``None`` to reset an
argument to its default.
"""
merged = {**(self._transform_adapt_args or {}), **kwargs}
merged = {k: v for k, v in merged.items() if v is not None}
return dataclasses.replace(self, _transform_adapt_args=merged)


def update_user_data(user_data, user_data_storage):
Expand Down Expand Up @@ -413,6 +420,7 @@ def _compile_pymc_model_jax(
gradient_backend=None,
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
var_names: Iterable[str] | None = None,
auto_reparam: bool = False,
**kwargs,
):
if find_spec("jax") is None:
Expand Down Expand Up @@ -504,7 +512,7 @@ def expand(_x, **shared):

dims, coords = _prepare_dims_and_coords(model, shape_info, reparameterized_names)

return from_pyfunc(
compiled = from_pyfunc(
ndim=n_dim,
make_logp_fn=make_logp_func,
make_expand_fn=make_expand_func,
Expand All @@ -519,6 +527,16 @@ def expand(_x, **shared):
reparameterized_names=reparameterized_names,
)

if auto_reparam:
from nutpie.flow_reparam import build_auto_flow

# None (with a warning) when the rewrite found nothing to do.
auto_flow = build_auto_flow(model, compiled)
if auto_flow is not None:
compiled = compiled.with_transform_adapt(auto_flow=auto_flow)

return compiled


def compile_pymc_model(
model: "pm.Model",
Expand All @@ -533,6 +551,7 @@ def compile_pymc_model(
] = "support_point",
var_names: Iterable[str] | None = None,
freeze_model: bool | None = None,
auto_reparam: bool = False,
**kwargs,
) -> CompiledModel:
"""Compile necessary functions for sampling a pymc model.
Expand Down Expand Up @@ -561,6 +580,18 @@ def compile_pymc_model(
freeze_model : bool | None
Freeze all dimensions and shared variables to treat them as compile time
constants.
auto_reparam : bool
Automatically reparametrize free random variables (e.g. continuous
VIP centering of location-scale families) and attach the resulting
flow to the model, so that ``nutpie.sample(compiled_model,
adaptation="flow")`` fits the reparametrization during tuning.
Prints a summary of the reparametrized variables; if nothing can be
reparametrized, warns and attaches no flow. Requires
``backend="jax"`` and ``gradient_backend="jax"``. With a flow
attached, only the reparametrization (plus a diagonal affine) is
fitted by default; further flow options (e.g. ``num_layers`` to add
neural coupling layers) can be set with
``compiled_model.with_transform_adapt``.

Returns
-------
Expand All @@ -581,6 +612,13 @@ def compile_pymc_model(
if backend is not None:
backend = backend.lower() # type: ignore[assignment]

# The flow transform adapter needs the raw JAX logp function, which is
# only kept with the jax gradient backend.
if auto_reparam and (backend != "jax" or gradient_backend != "jax"):
raise ValueError(
"auto_reparam requires backend='jax' and gradient_backend='jax'"
)

from pymc.initial_point import make_initial_point_fn
from pymc.model.transform.optimization import freeze_dims_and_data

Expand Down Expand Up @@ -618,6 +656,7 @@ def compile_pymc_model(
gradient_backend=gradient_backend,
pymc_initial_point_fn=initial_point_fn,
var_names=var_names,
auto_reparam=auto_reparam,
**kwargs,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ def with_data(self, **updates):
return dataclasses.replace(self, _shared_data=updated)

def with_transform_adapt(self, **kwargs):
return dataclasses.replace(self, _transform_adapt_args=kwargs)
"""Set arguments for the flow transform adapter (``adaptation="flow"``).

Arguments accumulate across calls (so e.g. the ``auto_flow`` attached
by ``compile_pymc_model(..., auto_reparam=True)`` survives later
tuning calls); pass ``None`` to reset an argument to its default.
"""
merged = {**(self._transform_adapt_args or {}), **kwargs}
merged = {k: v for k, v in merged.items() if v is not None}
return dataclasses.replace(self, _transform_adapt_args=merged)

def _make_sampler(
self,
Expand Down
Loading
Loading