SkillAgentSearch skills...

Trax

Trax — Deep Learning with Clear Code and Speed

Install / Use

/learn @google/Trax

README

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI
version GitHub
Issues GitHub Build Contributions
welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can

Related Skills

View on GitHub
GitHub Stars8.3k
CategoryEducation
Updated1d ago
Forks828

Languages

Python

Security Score

100/100

Audited on Mar 21, 2026

No findings