Trax
Trax — Deep Learning with Clear Code and Speed
Install / Use
/learn @google/TraxREADME
Trax — Deep Learning with Clear Code and Speed
Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.
- Run a pre-trained Transformer: create a translator in a few lines of code
- Features and resources: API docs, where to talk to us, how to open an issue and more
- Walkthrough: how Trax works, how to make new models and train on your own data
We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!
Here are a few example notebooks:-
- trax.data API explained : Explains some of the major functions in the
trax.dataAPI - Named Entity Recognition using Reformer : Uses a Kaggle dataset for implementing Named Entity Recognition using the Reformer architecture.
- Deep N-Gram models : Implementation of deep n-gram models trained on Shakespeares works
General Setup
Execute the following cell (once) before running any of the code samples.
import os
import numpy as np
!pip install -q -U trax
import trax
1. Run a pre-trained Transformer
Here is how you create an English-German translator in a few lines of code:
- create a Transformer model in Trax with trax.models.Transformer
- initialize it from a file with pre-trained weights with model.init_from_file
- tokenize your input sentence to input into the model with trax.data.tokenize
- decode from the Transformer with trax.supervised.decoding.autoregressive_sample
- de-tokenize the decoded result to get the translation with trax.data.detokenize
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode='predict')
# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!
2. Features and resources
Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.
You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.
- API docs
- chat with us
- open an issue
- subscribe to trax-discuss for news
3. Walkthrough
You can learn here how Trax works, how to create new models and how to train them on your own data.
Tensors and Fast Math
The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]
Gradients can be calculated using trax.fastmath.grad.
def f(x):
return 2.0 * x * x
grad_f = trax.fastmath.grad(f)
print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0
Layers
Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:
class Embedding(base.Layer):
"""Trainable layer that maps discrete tokens/IDs to vectors."""
def __init__(self,
vocab_size,
d_feature,
kernel_initializer=init.RandomNormalInitializer(1.0)):
"""Returns an embedding layer with given vocabulary size and vector size.
Args:
vocab_size: Size of the input vocabulary. The layer will assign a unique
vector to each ID in `range(vocab_size)`.
d_feature: Dimensionality/depth of the output vectors.
kernel_initializer: Function that creates (random) initial vectors for
the embedding.
"""
super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
self._d_feature = d_feature # feature dimensionality
self._vocab_size = vocab_size
self._kernel_initializer = kernel_initializer
def forward(self, x):
"""Returns embedding vectors corresponding to input token IDs.
Args:
x: Tensor of token IDs.
Returns:
Tensor of embedding vectors.
"""
return jnp.take(self.weights, x, axis=0, mode='clip')
def init_weights_and_state(self, input_signature):
"""Returns tensor of newly initialized embedding vectors."""
del input_signature
shape_w = (self._vocab_size, self._d_feature)
w = self._kernel_initializer(shape_w, self.rng)
self.weights = w
Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
openclaw-plugin-loom
Loom Learning Graph Skill This skill guides agents on how to use the Loom plugin to build and expand a learning graph over time. Purpose - Help users navigate learning paths (e.g., Nix, German)
Leadership-Mirror
Product Overview Project Purpose Hack Atria is a leadership development and team management platform that provides AI-powered insights, feedback analysis, and learning resources to help leaders
groundhog
398Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
