MuyGPyS
A fast, pure python implementation of the MuyGPs Gaussian process realization and training algorithm.
Install / Use
/learn @llnl/MuyGPySREADME
Fast implementation of the MuyGPs scalable Gaussian process algorithm
MuyGPs is a scalable approximate Gaussian process (GP) model that achieves fast prediction and model optimization while retaining high-accuracy predictions and uncertainty quantification. The MuyGPyS implementation allows the user to easily create GP models that can quickly train and predict on million-scale problems on a laptop or scale to billions of observations on distributed memory systems using the same front-end code.
What is MuyGPyS?
MuyGPyS is a general-purpose Gaussian process library, similar to GPy, GPyTorch, or GPflow.
MuyGPyS differs from the other options in that it constructs approximate GP models using nearest neighbors sparsification, conditioning predictions only on the most relevant training data to drastically improve training time and time-to-solution on large-scale problems. Indeed, MuyGPyS is intended for GP problems with millions or more observations, and supports a distributed memory backend for smoothly scaling to billion-scale problems.
MuyGPs uses nearest neighbors sparsification and performs leave-one-out cross validation using regularized loss functions to rapidly optimize a GP model without evaluating a much more expensive likelihood, which is required by similar scalable methods.
Getting Started
See the illustration tutorial to see an illustration of why the neighborhood sparsification approach of MuyGPs works.
Next, see the univariate regression tutorial for a full description of the API and an end-to-end walkthrough of a simple regression problem.
The full documentation, including several additional tutorials with code examples, can be found at readthedocs.io.
Read further in this document for installation instructions.
Backend Math Implementation Options
In addition to the default basic numpy backend, as of release v0.6.6, MuyGPyS
supports three additional backend implementations of all of its underlying math
functions:
- MPI - distributed memory acceleration
- PyTorch - GPU acceleration and neural network integration
- JAX - GPU acceleration
It is possible to include the dependencies of any, all, or none of these additional backends at install time. Please see the below installation instructions.
MuyGPyS uses the MUYGPYS_BACKEND environment variable to determine which
backend to use at import time.
It is also possible to manipulate MuyGPyS.config to switch between backends
programmatically.
This is not advisable unless the user knows exactly what they are doing
(and must occur before importing any other MuyGPyS components).
MuyGPyS will default to the numpy backend.
It is possible to switch back ends by manipulating the MUYGPYS_BACKEND
environment variable in your shell, e.g.
$ export MUYGPYS_BACKEND=jax # turn on JAX backend
$ export MUYGPYS_BACKEND=torch # turn on Torch backend
$ export MUYGPYS_BACKEND=mpi # turn on MPI backend
Distributed memory support with MPI
The MPI version of MuyGPyS performs all tensor manipulation in distributed
memory.
The tensor creation functions will in fact create and distribute a chunk of each
tensor to each MPI rank.
This data and subsequent data such as posterior means and variances remains
partitioned, and most operations are embarassingly parallel.
Global operations such as loss function computation make use of MPI collectives
like allreduce.
If the user needs to reason about all products of an experiment, such the full
posterior distribution in local memory, it is necessary to employ a collective
such as MPI.gather.
The wrapped KNN algorithms are not distributed, and so MuyGPyS does not yet
have an internal distributed KNN implementation.
Future versions will support a distributed memory approximate KNN solution.
The user can run a script myscript.py with MPI using, e.g. mpirun (or srun
if using slurm) via
$ export MUYGPYS_BACKEND=mpi
$ # mpirun version
$ mpirun -n 4 python myscript.py
$ # srun version
$ srun -N 1 --tasks-per-node 4 -p pbatch python myscript.py
PyTorch Integration
The torch version of MuyGPyS allows for construction and training of complex
kernels, e.g., convolutional neural network kernels. All low-level math is done
on torch.Tensor objects. Due to PyTorch's lack of support for the Bessel
function of the second kind, we only support special cases of the Matern kernel,
in particular when the smoothness parameter is $\nu = 1/2, 3/2,$ or $5/2$. The
RBF kernel is supported as the Matern kernel with $\nu = \infty$.
The MuyGPyS framework is implemented as a custom PyTorch layer. In the
high-level API found in examples/muygps_torch, a PyTorch MuyGPs model is
assumed to have two components: a model.embedding which deforms the original
feature data, and a model.GP_layer which does Gaussian Process regression on
the deformed feature space. A code example is provided below.
Most users will want to use the MuyGPyS.torch.muygps_layer module to construct
a custom MuyGPs model. The model can then be calibrated using a standard
PyTorch training loop. An example of the approach based on the low-level API
is provided in docs/examples/torch_tutorial.ipynb.
In order to use the MuyGPyS torch backend, run the following command in your
shell environment.
$ export MUYGPYS_BACKEND=torch
One can also use the following workflow to programmatically set the backend to torch, although the environment variable method is preferred.
from MuyGPyS import config
MuyGPyS.config.update("muygpys_backend","torch")
...subsequent imports from MuyGPyS
Just-In-Time Compilation with JAX
MuyGPyS supports just-in-time compilation of the
underlying math functions to CPU or GPU using
JAX since version v0.5.0.
The JAX-compiled versions of the code are significantly faster than numpy,
especially on GPUs.
In order to use the MuyGPyS jax backend, run the following command in your
shell environment.
$ export MUYGPYS_BACKEND=jax
Precision
JAX and torch use 32 bit types by default, whereas numpy tends to promote
everything to 64 bits.
For highly stable operations like matrix multiplication, this difference in
precision tends to result in a roughly 1e-8 disagreement between 64 bit and 32
bit implementations.
However, MuyGPyS depends upon matrix-vector solves, which can result in
disagreements up to 1e-2.
Hence, MuyGPyS forces all back end implementations to use 64 bit types by
default.
However, the 64 bit operations are slightly slower than their 32 bit
counterparts, and limit throughput on GPUs.
MuyGPyS accordingly supports 32 bit types, but this feature is experimental
and might have sharp edges.
For example, MuyGPyS might throw errors or otherwise behave strangely if the
user passes arrays of 64 bit types while in 32 bit mode.
Be sure to set your data types appropriately.
A user can have MuyGPySuse 32 bit types by setting the MUYGPYS_FTYPE
environment variable to "32", e.g.
$ export MUYGPYS_FTYPE=32 # use 32 bit types in MuyGPyS functions
It is also possible to manipulate MuyGPyS.config to switch between types
programmatically.
This is not advisable unless the user knows exactly what they are doing.
Installation
Installation using Pip: CPU
The index muygpys is maintained on PyPI and can be installed using pip.
muygpys supports many optional extras flags, which will install additional
dependencies if specified.
If installing CPU-only with pip, you might want to consider the following flags:
These extras include:
hnswlib- install hnswlib dependency to support fast approximate nearest neighbors indexingjax- install JAX dependencies to support just-in-time compilation of math functions on CPU (see below to install on GPU CUDA architectures)torch- install PyTorch dependencies to employ GPU acceleration and the use of theMuyGPyS.torchsubmodulempi- install MPI dependencies to support distributed memory parallel computation. Requires that the user has installed a version of MPI such as mvapich or open-mpi.
$ # numpy-only installation. Functions will internally use numpy.
$ pip install --upgrade muygpys
$ # The same, but includes hnswlib.
$ pip install --upgrade muygpys[hnswlib]
$ # CPU-only JAX installation. Functions will be jit-compiled using JAX.
$ pip install --upgrade muygpys[jax]
$ # The same, but includes hnswlib.
$ pip install --upgrade muygpys[jax,hnswlib]
$ # MPI installation. Functions will operate in distributed memory.
$ pip install --upgrade muygpys[mpi]
$ # The same, but includes hnswlib.
$ pip install --upgrade muygpys[mpi,hnswlib]
$ # pytorch installation. MuyGPyS.torch will be usable.
$ pip install --upgrade muygpys[torch]
Installation using Pip: GPU (CUDA)
JAX GPU Instructions
JAX also supports just-in-time compilation to
various GPU platforms, making the compiled math functions within MuyGPyS
