Jaxns
Probabilistic Programming and Nested sampling in JAX
Install / Use
/learn @Joshuaalbert/JaxnsREADME
Main
Status:
Develop
Status:

Mission: To make nested sampling faster, easier, and more powerful
What is it?
JAXNS is:
- a simple and powerful probabilistic programming framework using nested sampling as the engine;
- coded in JAX in a manner that allows lowering the entire inference algorithm to XLA primitives, which are JIT-compiled for high performance;
- continuously improving on its mission of making nested sampling faster, easier, and more powerful; and
- citable, use the (old) pre-print here.
What can you do with JAXNS?
- Compute the Bayesian evidence of a model or hypothesis (the ultimate scientific method);
- Produce high-quality samples from the posterior distribution;
- Easily handle degenerate difficult multi-modal posteriors;
- Model both discrete and continuous priors and likelihoods;
- Encode complex constraints on the prior space;
- Easily embed neural networks or any other ML model in the likelihood/prior;
JAXNS Probabilistic Programming Framework
JAXNS provides a powerful JAX-based probabilistic programming framework, which allows you to define probabilistic models easily, and use them for advanced purposes. Probabilistic models can have both Bayesian and parameterised variables. Bayesian variables are random variables, and are sampled from a prior distribution. Parameterised variables are point-wise representations of a prior distribution, and are thus not random. Associated with them is the log-probability of the prior distribution at that point.
Let's break apart an example of a simple probabilistic model. Note, this example can also be followed in docs/examples/intro_example.ipynb.
Defining a probabilistic model
Prior models are functions that produce generators of Prior objects.
The function must eventually return the inputs to the likelihood function.
The returned values of a yielded Prior is a simple JAX array, i.e. you can do anything you want to it with JAX ops.
The rules of static programming apply, i.e. you cannot dynamically allocate arrays.
JAXNS makes use of the Tensorflow Probability library for defining prior distributions, thus you can use almost any of the TFP distributions. You can also use any of the TFP bijectors to define transformed distributions.
Distributions do have some requirements to be valid for use in JAXNS.
- They must have a quantile function, i.e.
dist.quantile(dist.cdf(x)) == x. - They must have a
log_probmethod that returns the log-probability of the distribution at a given value.
Most of the TFP distributions satisfy these requirements.
JAXNS has some special priors defined that can't be defined from TFP, see jaxns.framework.special_priors. You can
always request more if you need them.
Prior variables may be named but don't have to be. If they are named then they can be collected later via a transformation, otherwise they are deemed hidden variables.
The output values of prior models are the inputs to the likelihood function. They can be PyTree's,
e.g. typing.NamedTuple's.
Finally, priors can become point-wise estimates of the prior distribution, by calling parametrised(). This turns a
Bayesian variable into a parameterised variable, e.g. one which can be used in optimisation.
import jax
import tensorflow_probability.substrates.jax as tfp
tfpd = tfp.distributions
from jaxns.framework.model import Model
from jaxns.framework.prior import Prior
def prior_model():
mu = yield Prior(tfpd.Normal(loc=0., scale=1.))
# Let's make sigma a parameterised variable
sigma = yield Prior(tfpd.Exponential(rate=1.), name='sigma').parametrised()
x = yield Prior(tfpd.Cauchy(loc=mu, scale=sigma), name='x')
uncert = yield Prior(tfpd.Exponential(rate=1.), name='uncert')
return x, uncert
def log_likelihood(x, uncert):
return tfpd.Normal(loc=0., scale=uncert).log_prob(x)
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
# You can sanity check the model (always a good idea when exploring)
model.sanity_check(key=jax.random.PRNGKey(0), S=100)
# The size of the Bayesian part of the prior space is `model.U_ndims`.
Sampling and transforming variables
There are two spaces of samples:
- U-space: samples in base measure space, and is dimensionless, or rather has units of probability.
- X-space: samples in the space of the model, and has units of the prior variable.
# Sample the prior in U-space (base measure)
U = model.sample_U(key=jax.random.PRNGKey(0))
# Transform to X-space
X = model.transform(U=U)
# Only named Bayesian prior variables are returned, the rest are treated as hidden variables.
assert set(X.keys()) == {'x', 'uncert'}
# Get the return value of the prior model, i.e. the input to the likelihood
x_sample, uncert_sample = model.prepare_input(U=U)
Computing log-probabilities
All computations are based on the U-space variables.
# Evaluate different parts of the model
log_prob_prior = model.log_prob_prior(U)
log_prob_likelihood = model.log_prob_likelihood(U, allow_nan=False)
log_prob_joint = model.log_prob_joint(U, allow_nan=False)
Computing gradients of the joint probability w.r.t. parameters
init_params = model.params
def log_prob_joint_fn(params, U):
# Calling model with params returns a new model with the params set
return model(params).log_prob_joint(U, allow_nan=False)
value, grad = jax.value_and_grad(log_prob_joint_fn)(init_params, U)
Nested Sampling Engine
Given a probabilistic model, JAXNS can perform nested sampling on it. This allows computing the Bayesian evidence and posterior samples.
from jaxns import NestedSampler
ns = NestedSampler(model=model, max_samples=1e5)
# Run the sampler
termination_reason, state = ns(jax.random.PRNGKey(42))
# Get the results
results = ns.to_results(termination_reason=termination_reason, state=state)
To AOT or JIT-compile the sampler
# Ahead of time compilation (sometimes useful)
ns_aot = jax.jit(ns).lower(jax.random.PRNGKey(42)).compile()
# Just-in-time compilation (usually useful)
ns_jit = jax.jit(ns)
You can inspect the results, and plot them.
from jaxns import summary, plot_diagnostics, plot_cornerplot, save_results, load_results
# Optionally save the results to file
save_results(results, 'results.json')
# To load the results back use this
results = load_results('results.json')
summary(results)
plot_diagnostics(results)
plot_cornerplot(results)
Output:
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 149918
samples: 3780
phantom samples: 1710
likelihood evals / sample: 39.7
phantom fraction (%): 45.2%
--------
logZ=-1.65 +- 0.15
H=-1.13
ESS=132
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.68 +- 0.58 | 0.13 / 0.48 / 1.37 | 0.0 | 0.0
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.07 +- 0.62 | -0.57 / 0.06 / 0.73 | 0.0 | 0.0
--------

Using the posterior samples
Nested sampling produces weighted posterior samples. To use for most use cases, you can simply resample (with replacement).
from jaxns import resample
samples = resample(
key=jax.random.PRNGKey(0),
samples=results.samples,
log_weights=results.log_dp_mean,
S=1000,
replace=True
)
Maximising the evidence
The Bayesian evidence is the ultimate model selection density, and choosing a model that maximises the evidence is
the best way to select a model. We can use the evidence maximisation algorithm to optimise the parametrised variables
of the model, in the manner that maximises the evidence. Below EvidenceMaximisation does this for the model we defined
above, where the parametrised variables are
automatically constrained to be in the right range, and numerical stability is ensured with proper scaling.
We see that the evidence maximisation chooses a sigma the is very small.
from jaxns.experimental import EvidenceMaximisation
# Let's train the sigma parameter to maximise the evidence
em = EvidenceMaximisation(model)
results, params = em.train(num_steps=5)
summary(results, with_parametrised=True)
Output:
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 72466
samples: 1440
phantom samples: 0
likelihood evals / sample: 50.3
phantom fraction (%): 0.0%
--------
logZ=-1.119 +- 0.098
H=-0.93
ESS=241
--------
sigma: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
sigma: 5.40077599e-05 +- 3.6e-12 | 5.40077563e-05 / 5.40077563e-05 / 5.40077563e-05 | 5.40077563e-05 | 5.40077563e-05
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.6 +- 0.54 | 0.05 / 0.45 / 1.37 | 0.0 | 0.0
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.01 +- 0.56 | -0.6 / -0.0 / 0.69 | 0.0 | -0.0
--------
Documentation
You can read the documentation here. In addition, JAXNS is partially described in the original paper, as well as the paper on [Phantom-Powered Nested S
