The jaxtransformer Library

The jaxtransformer Library#

Along with these notes, we’ve compiled a small library for training transformer models, known as jaxtransformer. You can see the source code at kvfrans/jaxtransformer.

These days, almost every model uses a Transformer backbone. Rather than repeat this code every time, we found it might be cleaner to put all the common functions into a helper library. We utilize the jaxtransformer library throughout these notes. I encourage you to use these utilities as a starting point for any projects that you see fit. In the repo, there is example code to train a diffusion model, a language model, or an image classifier in examples/.

The jaxtransformer is more like a set of useful utilities. There are only three main files:

  • transformer.py: Main transformer backbone, does not import any other code.

  • modalities.py: Useful modules such as token embedding, patch embedding, positional encoding.

  • configs.py: Default hyperparameters for model sizes and optimizers.

Along with some nice utilities:

  • utils/checkpoint.py: Minimal checkpointer for flax networks.

  • utils/datasets.py: Minimal dataloader using TFDS, see kvfrans/tfds_builders.

  • utils/sharding.py: Minimal implementation of fully-sharded data parallelism.

  • utils/train_state.py: Simple train state object to hold parameters and model definition.

  • utils/wandb.py: For logging to wandb.

There are also some generative model specific utilities:

  • utils/fid.py: Utilities to measure Frechet Inception Distance, see kvfrans/jax-fid-parallel.

  • utils/pretrained_resnet.py: Used for FID measurement.

  • utils/stable_vae.py: Implements the Stable Diffusion VAE, for training generative models over latent representations.