Mesh
Mesh TensorFlow: Model Parallelism Made Easier
Install / Use
/learn @tensorflow/MeshREADME
Mesh TensorFlow - Model Parallelism Made Easier
Introduction
Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of
specifying a broad class of distributed tensor computations. The purpose of
Mesh TensorFlow is to formalize and implement distribution strategies for your
computation graph over your hardware/processors. For example: "Split the batch
over rows of processors and split the units in the hidden layer across columns
of processors." Mesh TensorFlow is implemented as a layer over TensorFlow.
Watch our YouTube video.
Do I need Mesh TensorFlow?
If you just want data-parallel training (batch-splitting), then you do not need Mesh TensorFlow, though Mesh TensorFlow can do this. The most common reasons for more sophisticated parallel computation are:
-
The parameters of the model do not fit on one device - e.g. a 5-billion-parameter language model.
-
An example is so large that the activations do not fit on one device. - e.g. large 3D image model(
experimental/unet.py). -
Lower-latency parallel inference (at batch size 1).
The Mesh TensorFlow Approach to Distributed Computation
-
A "Mesh" is an n-dimensional array of processors, connected by a network.
-
Each tensor is distributed (split and/or replicated) across all processors in a mesh.
-
Tensor dimensions and mesh dimensions are named. The layouts of all tensors follow from a set of user-defined layout rules which specify which tensor-dimensions are split across which mesh-dimensions. This ensures that the corresponding dimensions in different tensors are split in the same manner.
-
Layouts do not affect results - only performance.
-
The implementation of an operation involves parallel computation on all processors in the mesh, and sometimes also collective communication. A processor usually just manipulates the slices of the input tensors already resident on that processor, and produces the slice of the output that goes on that processor.
Getting Started
Installation
To install the latest stable version, run
pip install mesh-tensorflow
To install the latest development version, run
pip install -e "git+https://github.com/tensorflow/mesh.git#egg=mesh-tensorflow"
Installing mesh-tensorflow does not automatically install or update
TensorFlow. We recommend installing it via pip install tensorflow or pip install tensorflow-gpu. See TensorFlow’s
installation instructions for details.
If you're using a development version of Mesh TensorFlow, you may need to
use TensorFlow's nightly package (tf-nightly).
Example Network (MNIST)
To illustrate, let us consider a simple model for the MNIST image-classification task. Our network has one hidden layer with 1024 units, and an output layer with 10 units (corresponding to the 10 digit classes).
The code consists of two parts, the first describing the mathematical
operations, and the second describing the devices and tensor/computation layout.
For the full example, see examples/mnist.py.
TODO(noam): verify that this code works.
# tf_images is a tf.Tensor with shape [100, 28, 28] and dtype tf.float32
# tf_labels is a tf.Tensor with shape [100] and dtype tf.int32
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
batch_dim = mtf.Dimension("batch", 100)
rows_dim = mtf.Dimension("rows", 28)
cols_dim = mtf.Dimension("cols", 28)
hidden_dim = mtf.Dimension("hidden", 1024)
classes_dim = mtf.Dimension("classes", 10)
images = mtf.import_tf_tensor(
mesh, tf_images, shape=[batch_dim, rows_dim, cols_dim])
labels = mtf.import_tf_tensor(mesh, tf_labels, [batch_dim])
w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim])
w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim])
# einsum is a generalization of matrix multiplication (see numpy.einsum)
hidden = mtf.relu(mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim]))
logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim])
loss = mtf.reduce_mean(mtf.layers.softmax_cross_entropy_with_logits(
logits, mtf.one_hot(labels, classes_dim), classes_dim))
w1_grad, w2_grad = mtf.gradients([loss], [w1, w2])
update_w1_op = mtf.assign(w1, w1 - w1_grad * 0.001)
update_w2_op = mtf.assign(w2, w2 - w2_grad * 0.001)
In the code above, we have built a Mesh TensorFlow graph, which is simply a Python structure. We have completely defined the mathematical operations. In the code below, we specify the mesh of processors and the layout of the computation.
devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"]
mesh_shape = [("all_processors", 4)]
layout_rules = [("batch", "all_processors")]
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
mesh_shape, layout_rules, devices)
lowering = mtf.Lowering(graph, {mesh:mesh_impl})
tf_update_ops = [lowering.lowered_operation(update_w1_op),
lowering.lowered_operation(update_w2_op)]
The particular layout above implements data-parallelism, splitting the batch of
examples evenly across all four processors. Any Tensor with a "batch" dimension
(e.g. images, h, logits, and their gradients) is split in that dimension
across all processors, while any tensor without a "batch" dimension (e.g. the
model parameters) is replicated identically on every processor.
Alternatively, for model-parallelism, we can set
layout_rules=[("hidden", "all_processors")]. In this case,
any tensor with a "hidden" dimension (e.g. hidden, w1, w2) is split,
while any other tensor (e.g. image, logits) is fully replicated.
We can even combine data-parallelism and model-parallelism on a 2-dimensional mesh of processors. We split the batch along one dimension of the mesh, and the units in the hidden layer along the other dimension of the mesh, as below. In this case, the hidden layer is actually tiled between the four processors, being split in both the "batch" and "hidden_units" dimensions.
mesh_shape = [("processor_rows", 2), ("processor_cols", 2)]
layout_rules = [("batch", "processor_rows"), ("hidden", "processor_cols")]
Where does the network communication happen?
Some Mesh TensorFlow operations cause network communication. For example, an einsum (generalized matrix multiplication) is computed as follows:
- On each processor, compute the einsum of the slices of the two operands that are local to that processor.
- If no reduced-out dimensions are split, then we are done.
- If reduced-out dimensions are split, then perform an "allreduce" operation on the resulting slices - summing across any mesh dimensions over which the reduced-out dimensions are split.
Where the allreduces happen depends will depend on the computation layout. For example, in a data-parallel layout where the "batch" dimension is split, allreduces will happen when computing the parameter gradients, since this involves matrix multiplications which reduce out the "batch" dimension.
How do I pick a layout?
While results do not depend on layout (except in the realm of roundoff errors and random seeds), performance and memory consumption depend heavily on layout. Fortunately, the auto_mtf subpackage provides a method for automatically choosing a layout. For more information about what auto_mtf is doing to choose a layout, see its README file.
import mesh_tensorflow.auto_mtf
graph = mtf.Graph()
mesh = mtf.Mesh(graph, "my_mesh")
# Insert model code here.
outputs = [logits, loss] # iterable of mtf.Tensor, the outputs you're computing
mesh_shape = [("processor_rows", 2), ("processor_cols", 2)]
layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, outputs)
It is possible for advanced users to eke out additional performance by tuning the layout (and model) further. Mesh TensorFlow helps by accumulating and printing counters of computation/communication. To start, here are some tricks/guidelines.
- It is illegal for two dimensions of the same tensor to be split across the same mesh dimension.
- For any compute-intense operation (e.g. einsum), make sure that all mesh-dimensions are used to split dimensions of the inputs or outputs. Otherwise, computation is duplicated.
- To keep the ratio of compute/communication high (i.e. not be bandwidth-bound), split dimensions into large chunks. This should be familiar in the data-parallelism case, where we want a large batch size per processor to avoid spending most of our time communicating.
The Mesh TensorFlow Language
Mesh TensorFlow (v0.0) is implemented as a Python library which can generate
part of a TensorFlow graph. The user first builds a mtf.Graph (the analog of
a TensorFlow graph) made up of mtf.Tensors and mtf.Operations. As in
TensorFlow, this graph consists of simple Python objects. The user then creates
a mtf.Lowering object, which lowers the mtf.Graph into TensorFlow, adding to
the default TensorFlow graph.
The Mesh TensorFlow language is nearly identical to TensorFlow, with the familiar notion of a Graph, Tensors, Operations, and automatic gradient computation. The principal differences are
Related Skills
node-connect
329.7kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
81.2kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
329.7kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
81.2kCommit, push, and open a PR
