The jaxtransformer
Library#
Along with these notes, we’ve compiled a small library for training transformer models, known as jaxtransformer
. You can see the source code at kvfrans/jaxtransformer.
These days, almost every model uses a Transformer backbone. Rather than repeat this code every time, we found it might be cleaner to put all the common functions into a helper library. We utilize the jaxtransformer
library throughout these notes. I encourage you to use these utilities as a starting point for any projects that you see fit. In the repo, there is example code to train a diffusion model, a language model, or an image classifier in examples/
.
The jaxtransformer
is more like a set of useful utilities. There are only three main files:
transformer.py
: Main transformer backbone, does not import any other code.modalities.py
: Useful modules such as token embedding, patch embedding, positional encoding.configs.py
: Default hyperparameters for model sizes and optimizers.
Along with some nice utilities:
utils/checkpoint.py
: Minimal checkpointer for flax networks.utils/datasets.py
: Minimal dataloader using TFDS, see kvfrans/tfds_builders.utils/sharding.py
: Minimal implementation of fully-sharded data parallelism.utils/train_state.py
: Simple train state object to hold parameters and model definition.utils/wandb.py
: For logging to wandb.
There are also some generative model specific utilities:
utils/fid.py
: Utilities to measure Frechet Inception Distance, see kvfrans/jax-fid-parallel.utils/pretrained_resnet.py
: Used for FID measurement.utils/stable_vae.py
: Implements the Stable Diffusion VAE, for training generative models over latent representations.