Paxml
Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Install / Use
/learn @google/PaxmlREADME
Paxml (aka Pax)
Pax is a framework to configure and run machine learning experiments on top of Jax.
Quickstart
Setting up a Cloud TPU VM
We refer to this page for more exhaustive documentation about starting a Cloud TPU project. The following command is sufficient to create a Cloud TPU VM with 8 cores from a corp machine.
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml
#create a TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR
If you are using TPU Pod slices, please refer to this guide. Run all the commands from a local machine using gcloud with the --worker=all option:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \
--worker=all --command="<commmands>"
The following quickstart sections assume you run on a single-host TPU, so you can ssh to the VM and run the commands there.
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
Installing Pax
After ssh-ing the VM, you can install the paxml stable release from PyPI, or the dev version from github.
For installing the stable release from PyPI (https://pypi.org/project/paxml/):
python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
If you encounter issues with transitive dependencies and you are using the native Cloud TPU VM environment, please navigate to the corresponding release branch rX.Y.Z and download paxml/pip_package/requirements.txt. This file includes the exact versions of all transitive dependencies needed in the native Cloud TPU VM environment, in which we build/test the corresponding release.
git clone -b rX.Y.Z https://github.com/google/paxml
pip install --no-deps -r paxml/paxml/pip_package/requirements.txt
For installing the dev version from github, and for the ease of editing code:
# install the dev version of praxis first
git clone https://github.com/google/praxis
pip install -e praxis
git clone https://github.com/google/paxml
pip install -e paxml
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Run a test model
# example model using pjit (SPMD)
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>
# example model using pmap
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \
--job_log_dir=gs://<your-bucket> \
--pmap_use_tensorstore=True
Documentations
Please visit our docs folder for documentations and Jupyter Notebook tutorials. Please see the following section for instructions of running Jupyter Notebooks on a Cloud TPU VM.
Run a notebook
You can run the example notebooks in the TPU VM in which you just installed paxml.
####Steps to enable a notebook in a v4-8
-
ssh in TPU VM with port forwarding
gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME --zone=$ZONE --ssh-flag="-4 -L 8080:localhost:8080" -
install jupyter notebook on the TPU vm and downgrade markupsafe
pip install notebook
pip install markupsafe==2.0.1
-
export
jupyterpathexport PATH=/home/$USER/.local/bin:$PATH -
scp the example notebooks to your TPU VM
gcloud compute tpus tpu-vm scp $TPU_NAME:<path inside TPU> <local path of the notebooks> --zone=$ZONE --project=$PROJECT -
start jupyter notebook from the TPU VM and note the token generated by jupyter notebook
jupyter notebook --no-browser --port=8080 -
then in your local browser go to: http://localhost:8080/ and enter the token provided
Note: In case you need to start using a second notebook while the first notebook is still occupying the TPUs, you can run
pkill -9 python3
to free up the TPUs.
Run on GPU
Note: NVIDIA has released an updated version of Pax with H100 FP8 support and broad GPU performance improvements. Please visit the NVIDIA Rosetta repository for more details and usage instructions.
FAQs
-
Pax runs on Jax, you can find details on running Jax jobs on Cloud TPU here, also you can find details on running Jax jobs on a Cloud TPU pod here
-
If you run into dependency errors, please refer to the
requirements.txtfile in the branch corresponding to the stable release you are installing. For e.g., for the stable release 0.4.0 use branchr0.4.0and refer to the requirements.txt for the exact versions of the dependencies used for the stable release.
Example Convergence Runs
Here are some sample convergence runs on c4 dataset.
1B model on c4 dataset
You can run a 1B params model on c4 dataset on TPU v4-8using the config C4Spmd1BAdam4Replicasfrom c4.py as follows:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket>
You can observe loss curve and log perplexity graph as follows:
<img src=paxml/docs/images/1B-loss.png width="400" height="300"> <img src=paxml/docs/images/1B-pplx.png width="400" height="300">
16B model on c4 dataset
You can run a 16B params model on c4 dataset on TPU v4-64using the config C4Spmd16BAdam32Replicasfrom c4.py as follows:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \
--job_log_dir=gs://<your-bucket>
You can observe loss curve and log perplexity graph as follows:
<img src=paxml/docs/images/16B-loss.png width="400" height="300"> <img src=paxml/docs/images/16B-pplx.png width="400" height="300">
GPT3-XL model on c4 dataset
You can run the GPT3-XL model on c4 dataset on TPU v4-128using the config C4SpmdPipelineGpt3SmallAdam64Replicasfrom c4.py as follows:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \
--job_log_dir=gs://<your-bucket>
You can observe loss curve and log perplexity graph as follows:
<img src=paxml/docs/images/GPT3-XL-loss.png width="400" height="300"> <img src=paxml/docs/images/GPT3-XL-pplx.png width="400" height="300">
Benchmark on Cloud TPU v4
The PaLM paper introduced an efficiency metric called Model FLOPs Utilization (MFU). This is measured as the ratio of the observed throughput (in, for example, tokens per second for a language model) to the theoretical maximum throughput of a system harnessing 100% of peak FLOPs. It differs from other ways of measuring compute utilization because it doesn’t include FLOPs spent on activation rematerialization during the backward pass, meaning that efficiency as measured by MFU translates directly into end-to-end training speed.
To evaluate the MFU of a key class of workloads on TPU v4 Pods with Pax, we carried out an in-depth benchmark campaign on a series of decoder-only Transformer language model (GPT) configurations that range in size from billions to trillions of parameters on the c4 dataset. The following graph shows the training efficiency using the "weak scaling" pattern where we grew the model size in proportion to the number of chips used.
<img src=paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png width="500" height="300">
Pax on Multislice
The multislice configs in this repo refer to 1. Singlie slice configs for syntax / model architecture and 2. MaxText repo for config values.
We provide example runs under c4_multislice.py` as a starting point for Pax on multislice.
Setting up Cloud TPU VMs using Queued Resources
We refer to this page for more exhaustive documentation about using Queued Resources for a multi-slice Cloud TPU project. The following shows the steps needed to set up TPUs for running example configs in this repo.
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-128 # or v4-384 depending on which config you run
Say, for running C4Spmd22BAdam2xv4_128 on 2 slices of v4-128, you'd need to set up TPUs the following way:
export TPU_PREFIX=<your-prefix> # New TPUs will be created based off this prefix
export QR_ID=$TPU_PREFIX
export NODE_COUNT=<number-of-slices> # 1, 2, or 4 depending on which config you run
#create a TPU VM
gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX
Installing Pax
The setup commands described earlier need to be run on ALL workers in ALL slices. You can 1) ssh into each worker and each slice individually; or 2) use for loop with --worker=all flag as the following command.
for ((i=0; i<$NODE_COUNT; i++))
do
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command="pip ins
