Overview of JAX#
Throughout these notes, we’re going to need a framework for doing our numerical computation. I’m a big fan of JAX, which is best thought of as a GPU-accelerated version of numpy
– in fact, we will often use the jax.numpy
library which implements most numpy
functions one-to-one. JAX is great because it’s relatively barebones, so we can build a lot of things from scratch and learn on the way.
JAX has a great growing ecosystem with many contributors writing their own JAX libraries. For example, Flax is a Pytorch-like library for training neural networks, which we will use later. For an in-depth guide on JAX, I would highly reccomend the JAX quickstart guide.
Arrays#
The basic element of JAX is the array, which represents an N-dimensional collection of numbers. Elementary scalar operations (addition, multiplication) are performed element-wise. Arrays can also be manipulated and combined using linear algebra operations such as matrix multiplication, dot products, etc.
import jax.numpy as jnp
x = jnp.array([1,2,0,5])
y = jnp.array([5,2,3,1])
print('x * 2 = ', x * 2)
print('x @ y = ', x @ y)
print('y[:2] = ', y[:2])
x * 2 = [ 2 4 0 10]
x @ y = 14
y[:2] = [5 2]
The important properties of an array are the shape and data type. Most head-scratching situations can be resolved by printing out the shapes of each array. Remember that JAX follows numpy’s broadcasting semantics when arrays have differing shapes. The data type represents what quantity the array holds – most of time, we’re happy with float32
or bfloat16
to represent real numbers.
z = jnp.array([[1.,2.,3.5], [4.,5.,6.1]])
print('Shape of z:', z.shape)
print('Dtype of z:', z.dtype)
Shape of z: (2, 3)
Dtype of z: float32
Computational Backend#
We interface with JAX through Python code, but under the hood JAX supports a range of backends that handle the computation. JAX has backend support for GPU, TPU, and CPU. By default, JAX searches your system to see if it supports an accelerator device (GPU or TPU), and uses that. This means that if you have a GPU, JAX arrays are created by default on the GPU’s memory. Generally, if you have a single GPU or are running on Colab, you can ignore anything about device placement. For multi-GPU training, see the section [todo].
Asyncronous Execution#
JAX will execute commands asyncronously if it is able to do so. This means when you call jnp.square(x)
, JAX will start processing that command on the backend, but the main CPU thread will not wait for the result. The main thread will only stop if it needs to use a previous result for the next computation. This means that if multiple lines of JAX code don’t depend on each other’s outputs, JAX will execute them simultaneously.
Asyncronous execution saves time, and the JAX backend handles everything so you don’t have to think about it. Just be careful when inserting print commands or profiling time, since CPU commands can run ahead of the computation.
x = jnp.arange(1024)[:, None]
y = x @ x.T # This takes a long time to execute.
w = (x @ x.T + 1) # This line can run in parallel to the above line.
print('This print statement happens **before** the above commands are done.')
print('Result:', y.mean(), w.mean()) # The CPU will wait until y and w are computed.
This print statement happens **before** the above commands are done.
Result: 261632.17 261633.06
JIT (Just-in-Time Compilation)#
When we naively run JAX code, we send instructions from the CPU to our GPU one at a time. This can be inefficient; it would be faster if we could compile a set of instructions and execute it all at once. This is what just-in-time compilation (JIT) does. JIT acts as a simple wrapper around a python function. When a JIT-wrapped function is first called, JAX will first compile that function by tracing it, then the JIT-compiled function will be executed. Compilation only happens once, and is then stored in a cache. JIT gives us dramatic speedups, and is largely why JAX runs faster than other frameworks.
x = jnp.arange(1024)
y = jnp.arange(1024) - 512
def my_function(x, y):
x = x * y
x = jnp.minimum(x, 0)
x = x + y
return x
# Slow naive version. 'block_until_ready' forces the CPU to wait.
%timeit my_function(x,y).block_until_ready()
13.6 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# Fast JIT version.
my_function_jit = jax.jit(my_function)
_ = my_function_jit(x,y) # Run once to compile.
%timeit my_function_jit(x,y).block_until_ready()
3.72 µs ± 47.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Functions can also be decorated with @jax.jit
when we declare them. JAX will only compile the function when it is first called.
@jax.jit
def new_function(x):
print('Function is compiling. X has shape:', x.shape)
return x @ x.T
print('Calling function.')
new_function(x);
Calling function.
Function is compiling. X has shape: (4,)
Restrictions on JIT functions#
JIT compilation has two key restrictions – functions must be have no side effects and have value-agnostic execution.
To understand the above, let’s note that JIT compilation works by tracing your Python code. During JIT compilation, JAX will run through your function and record the operations. This means that the JITted function won’t replicate any side-effect behavior, such as print statements. You can’t modify any arrays in-place, either.
@jax.jit
def bad_function_1(x):
w = os.write('temp.txt', 'test') # Don't make side effects
@jax.jit
def bad_function_2(x):
x[2] = 1 # Can't modify x
Value-agnostic execution means that our function structure should not depend on the numerical contents of an array. For example, we cannot use a Python if
statement that branches away if a value is less than zero. This restriction is due to the nature of tracing; we want our traced function to work on any input values, so it can’t change based on the values.
JIT compilation takes value-agnostic compilation into it’s core design philsoophy. When a function is traced, the compiler does not see the actual input arrays, but rather placeholder arrays with the same shape
and dtype
. JIT knows your array is a \((2,4)\) array of integers, but not the contents.
@jax.jit
def compile_func(x):
print('Compiling with input:', x) # Traced array has a shape, but not a value.
compile_func(jnp.array([1,2,3]))
Compiling with input: Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>
Compiled JIT functions are cached, but only for inputs with the exact shape and dtype. For example, the following will trigger a recompile:
compile_func(jnp.array([4,5]))
compile_func(jnp.array([4,5,6,7])) # This will trigger a recompile.compile_func
Compiling with input: Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>
Compiling with input: Traced<ShapedArray(int32[4])>with<DynamicJaxprTrace(level=1/0)>
JIT Tricks#
JAX has a few helpful techniques for dealing with the limitations of JIT.
Array Modification. If we absolutely want to set a slice of an array to a different value, we can use the .set()
function that will return a modified copy of the array. Remember that this is just syntactic sugar, and we are creating a new array under the hood (not modifying in place).
@jax.jit
def f(x):
return x.at[2].set(1)
f(jnp.array([0, 0, 0]))
Array([0, 0, 1], dtype=int32)
Prints. JIT functions don’t wory nicely with python print commands, because the prints won’t be traced by the compiler. But JAX gives us a nice helper function that uses a special print that does get traced. This print command will execute every time the JITted function is called.
@jax.jit
def f(x):
print('Compiling function, this will print once.')
jax.debug.print('Printing value of {arr} during execution', arr=x)
f(jnp.array([1,2,3]))
f(jnp.array([4,5,6]))
Compiling function, this will print once.
Printing value of [1 2 3] during execution
Printing value of [4 5 6] during execution
Conditionals. We can’t use python if
commands in a jit function, but we can use jnp.where
to dynamically set an array based on the contents of another (boolean) array. When writing JIT functions, jnp.where
statements are how you should handle conditional-like behavior. Remember that array shapes are always constant, only the values can change.
@jax.jit
def f(x, y):
return jnp.where(y == 1, x, 10)
f(jnp.array([1,2,3]), jnp.array([1,0,0]))
Array([ 1, 10, 10], dtype=int32)
Random State#
In numpy or Pytorch, you can call np.random
or torch.random
and expect a different result every time. Under the hood, the library is storing a global random state and updating it each time.
JAX takes a strict stance on random number generation. Instead of relying on a global random state, we keep track of a random key manually. The random key is also a JAX array. All JAX random functions are deterministic functions with the random key as an input.
key = jax.random.key(42) # Create random key.
print('Random with state `key`: ', jax.random.normal(key))
print('Random with state `key`: ', jax.random.normal(key)) # This will be the same as the above line.
Random with state `key`: -0.18471177
Random with state `key`: -0.18471177
key2 = jax.random.fold_in(key, 43) # Introduce additional seed.
print('Random with state `key2`: ', jax.random.normal(key2))
Random with state `key2`: 0.11627013
new_key, new_key2 = jax.random.split(key) # Generates two new random keys.
print('Random with state `new_key`: ', jax.random.normal(new_key))
print('Random with state `new_key2`: ', jax.random.normal(new_key2))
Random with state `new_key`: 0.13790321
Random with state `new_key2`: 1.3694694