Cell motility inference with Kalman filters and simulation based inference
The simulator generates two-dimensional cell trajectories from a polarity-driven Ornstein-Uhlenbeck (OU) motility model. Each cell has an unobserved polarity-like variable
where:
-
$a < 0$ controls polarity relaxation, -
$b$ controls how strongly polarity drives displacement, -
$s_0$ is the polarity diffusion scale, -
$s_1$ is the positional diffusion scale.
In two dimensions, the same dynamics are applied independently to the
The onset/offset-style simulations use two regimes:
- Pre-onset / offset regime: motion is close to pure stochastic motion. Polarity relaxes rapidly and coupling from polarity to position is weak or zero.
- Post-onset regime: the cell enters a stochastic-plus-driving regime in which polarity persistence and polarity-position coupling generate directed motility.
The transition is smoothed with a sigmoid gate around the supplied onset frame rather than imposed as a hard switch. This gives trajectories that begin as weakly persistent random motion and then acquire a polarity-driven component.
The same generative structure is used by both the analytical Kalman-filter/MCMC workflow and the simulation-based inference workflows.
The analytical inference route uses the fact that the OU motility model is linear and Gaussian after exact discretisation. For a sampling interval
Here
For a two-dimensional trajectory, the latent state is
where
with
The pre/post transition is encoded by making the drift and coupling time-dependent. For trajectory
For fixed kinetic parameters, the latent polarity paths do not need to be sampled. They are integrated out exactly with a Kalman filter, producing the marginal trajectory likelihood
The total log likelihood is the sum over trajectories. Missing or padded frames are handled by skipping the measurement update while retaining state propagation.
Parameter uncertainty is inferred with NumPyro. The Kalman filter supplies the marginal likelihood term, and Hamiltonian Monte Carlo / NUTS samples the posterior over the kinetic parameters. A MAP estimate from deterministic variational optimisation is used to initialise the chain.
The implementation supports:
- a single-population model, where all trajectories share one parameter vector;
- a labelled multi-class model, where class-specific parameters are inferred with hierarchical priors.
In the multi-class case, each trajectory contributes its own Kalman-filter likelihood under the parameters of its assigned class.
For the differentiable JAX Kalman filtering implementation, see:
https://github.com/dpohanlon/kaljax
The simulation-based inference (SBI) route avoids evaluating the analytical likelihood directly. Instead, it learns a compact representation of simulated trajectories and uses this representation for approximate posterior inference.
Training data are read from HDF5 simulation files. The expected fields include:
- observed tracks,
- masks for padded time points,
- class labels,
- onset indices,
- per-cell or per-class kinetic parameters.
When full unmasked tracks are available, the training scripts can use those directly; otherwise they use the stored masked trajectories.
For the onset model, the target parameter vector is
with
A Transformer encoder maps each trajectory to a fixed-dimensional embedding. The input consists of the two positional coordinates, optionally augmented with a post-onset indicator channel.
The model is trained so that embeddings capture parameter-relevant variation in the trajectories. Auxiliary heads and losses can encourage:
- class separation,
- parameter prediction,
- onset-aware structure in the learned representation.
The learned embedding is then used as the summary statistic for approximate bayesian computation (ABC), and as the context representation for Gaussian mixture density network (MDN) posterior inference.
The ABC workflow uses the trained embedding model as a summary-statistic map.
For observed tracks:
- compute observed summaries using the trained encoder;
- draw candidate parameters from the prior;
- simulate trajectories under those parameters;
- embed the simulated trajectories with the same encoder;
- compare simulated and observed summaries.
The sampler uses ABC with sequential Monte Carlo (SMC) rather than simple rejection ABC. A pilot stage sets an initial distance threshold. Successive populations then:
- perturb previously accepted particles,
- resimulate trajectories,
- recompute summaries,
- retain particles whose summary distance is below the current threshold.
Importance weights correct for the proposal distribution.
The final output is a weighted posterior sample over
saved alongside physical-scale samples such as
The MDN workflow amortises posterior inference. It uses the same trajectory encoder, but trains a mixture-density head to predict a conditional density over parameters from a set-level context.
The context is built from embeddings of all selected tracks in a class or dataset subset. It can include summary statistics such as:
- embedding mean,
- embedding standard deviation,
- log number of tracks.
The MDN head represents
as a Gaussian mixture with learned mixture weights, component means, and full covariance factors.
At inference time:
- observed tracks are embedded;
- a set-level context vector is constructed;
- posterior samples are drawn directly from the learned mixture;
- samples are transformed from the normalised training parameter space back to physical parameters;
- results are written to the same posterior-style HDF5 format used by the other inference routes.