A lightweight JAX-only rewrite of redback for electromagnetic transient modeling and Bayesian inference, designed to run efficiently on GPUs and TPUs in float32.
Redback-JAX reimplements redback's analytical transient models in JAX, using log10-space arithmetic throughout to stay float32-safe on GPU hardware. All bolometric functions return log10(L) rather than linear luminosities (which exceed the float32 maximum of ~3.4×10³⁸ erg/s). The full spectra pipeline — photosphere, blackbody SED, and bandflux integration — also operates in log10 space end-to-end.
- Float32-safe physics: All models operate in log10 space; no overflow on GPU even for luminosities ~10⁴⁵ erg/s
- JIT-compiled and differentiable: Every model is decorated with
@jax.jit; gradients flow through the full pipeline viajax.grad vmap-based diffusion integrals: Arnett-style diffusion usesjax.vmapover time points with log-mirror quadrature nodes- Spectra pipeline:
make_spectra_model(bolometric_fn)wraps any bolometric model to produce time × wavelength spectra for bandflux/magnitude comparison - Clean inference API:
Prior,Likelihood,NestedSampler, andMCMCSampler— compose a full Bayesian fit in ~15 lines - Multi-sampler support: BlackJAX NUTS (MCMC) and nested sampling via
blackjax.nss
| Function | Physics | Reference |
|---|---|---|
arnett_bolometric |
Ni/Co decay + Arnett diffusion | Arnett 1982 |
magnetar_powered_bolometric |
Dipole spin-down + Arnett diffusion | Nicholl+ 2017 |
csm_interaction_bolometric |
Forward/reverse shocks + CSM diffusion | Chatzopoulos+ 2013 |
tde_analytical_bolometric |
t⁻⁵/³ fallback + Arnett diffusion | — |
shock_cooling_bolometric |
Shock-cooling envelope (n=10) | Piro 2021 |
shocked_cocoon_bolometric |
Shocked jet cocoon | Piro & Kollmeier 2018 |
metzger_kilonova_bolometric |
r-process ODE, 200 shells | Metzger 2017 |
magnetar_boosted_kilonova_bolometric |
r-process ODE + magnetar injection | Yu+ 2013 |
All bolometric functions return log10_lbol (log base-10 of luminosity in erg/s).
This is the natural unit for GPU inference — float32 can represent log10 values for any physically realistic luminosity.
make_spectra_model(bolometric_fn) wraps any bolometric model into a full SED pipeline:
- Calls
bolometric_fn(time, **kwargs)→log10_lbol - Computes photospheric temperature and radius in log10 space (with temperature floor)
- Evaluates blackbody flux density in log10 space
- Returns
(time, lambdas, spectra)in observer frame
Since models return log10_lbol, fit observed bolometric luminosities in log10 space:
import jax.numpy as jnp
from redback_jax.models.supernova_models import arnett_bolometric
# Observed data
log10_lbol_obs = jnp.log10(observed_lbol) # convert once
log10_lbol_err = sigma_lbol / (observed_lbol * jnp.log(10.0)) # propagate errors
# Model prediction
log10_lbol_model = arnett_bolometric(time, f_nickel=0.5, mej=1.0,
vej=10000.0, kappa=0.1, kappa_gamma=10.0)
# Gaussian log-likelihood in log10 space
log_like = -0.5 * jnp.sum(((log10_lbol_obs - log10_lbol_model) / log10_lbol_err)**2)The Prior / Likelihood / NestedSampler / MCMCSampler API handles the full pipeline: model evaluation, bandflux integration, and sampling.
import jax
from redback_jax.inference import Prior, Uniform, Likelihood, NestedSampler, MCMCSampler
from redback_jax.utils import luminosity_distance_cm
from redback_jax.transient import Transient
REDSHIFT = 0.01
DL_CM = luminosity_distance_cm(REDSHIFT) # ~1.37e26 cm
# Free parameters
prior = Prior([
Uniform(58580, 58620, name='t0'), # MJD explosion epoch
Uniform(0.05, 0.30, name='f_nickel'),
Uniform(0.5, 3.0, name='mej'),
Uniform(3000, 12000, name='vej'),
])
# Similar to how you would load a transient for Redback — but with JAX arrays
transient = Transient(
time=times_list,
y=mags_list,
y_err=y_err,
bands=bands_list,
data_mode='magnitude',
name='SN2019abcde',
redshift=REDSHIFT,
)
# Likelihood — transient.time (MJD), transient.y (AB mag), transient.y_err, transient.bands
likelihood = Likelihood(
model='arnett_spectra',
transient=transient,
fixed_params={
'redshift': REDSHIFT,
'lum_dist': DL_CM,
'temperature_floor': 5000.0,
'kappa': 0.07,
'kappa_gamma': 0.1,
},
)
# Nested sampling (BlackJAX)
ns_result = NestedSampler(likelihood, prior, outdir='results/').run(jax.random.PRNGKey(0))
ns_result.summary()
# Or MCMC with NUTS (BlackJAX)
mcmc_result = MCMCSampler(likelihood, prior, n_warmup=500, n_samples=2000, n_chains=4).run(
jax.random.PRNGKey(1)
)
mcmc_result.summary()Pass any string from redback_jax.models.MODELS as the model argument:
from redback_jax.models import MODELS
print(list(MODELS.keys()))
# ['arnett_spectra', 'magnetar_spectra', 'csm_spectra', ...]To compute magnitudes outside of inference (e.g. for plotting):
from redback_jax.sources import PrecomputedSpectraSource
from redback_jax.utils import luminosity_distance_cm
source = PrecomputedSpectraSource.from_arnett_model(
f_nickel=0.15, mej=1.0, vej=8000.0,
redshift=0.01,
cosmo_H0=67.66, cosmo_Om0=0.3111,
)
# AB magnitude in ztfr at a set of phases
phases = jnp.linspace(-5, 40, 200)
mags = source.bandmag({'amplitude': 1.0}, 'ztfr', phases)Some parameters changed from the original redback package for float32 safety:
| Model | Old parameter | New parameter | Reason |
|---|---|---|---|
tde_analytical_bolometric |
l0 (erg/s, ~10⁴³) |
log10_l0 |
Linear value overflows float32 |
shock_cooling_bolometric |
mass (Msun), radius (cm), energy (erg) |
log10_mass, log10_radius, log10_energy |
Intermediate products overflow float32 |
All other parameter names match redback exactly.
Physical luminosities of transients (~10³⁸–10⁴⁵ erg/s) exceed float32 max (~3.4×10³⁸). Redback-JAX solves this by:
- Storing all engine luminosities as
log10(L)throughout - Using log-sum-exp for combining decay terms (Ni/Co engine)
- Normalising ODE state variables by a scale factor (
E_scale) in the kilonova scan - Computing prefactors in log10 before any exponentiation
- Keeping the blackbody SED, temperature, and photospheric radius all in log10 space
The only step that materialises linear values is the final bandflux integral over the SED — where the flux densities (~10⁻²⁰ erg/s/cm²/Å) are comfortably within float32 range.
git clone https://github.com/nikhil-sarin/redback-jax.git
cd redback-jax
pip install -e .Python 3.12+ required.
Core: jax, numpy, scipy, pandas, matplotlib, astropy, wcosmo
Optional (inference): blackjax, flowmc, optax
- redback — the original full-featured package
- JAX-bandflux:
jax-bandflux - JAX — the underlying numerical computing library
GNU General Public License v3.0 — see LICENSE.
Based on the original redback package.
If you use Redback-JAX, please cite the redback paper. Please make sure you also cite all relevant papers for the models. These are the same as the papers cited in the original redback package.
If you use magnitude/flux evaluation please also cite
- JAX-bandflux:
jax-bandfluxand any other papers recommended by those authors.
If you do any sampling, please cite the relevant sampling papers.
- BlackJAX:
blackjaxand any papers recommended by those authors.
A paper describing the Redback-JAX package is in preparation. Redback-JAX is still very much in development and the API/etc may not be stable.