Building NNs with Flax#
On top of JAX, we will use Flax which is a higher-level framework for building neural networks. Flax is still decently simplistic (we don’t get a one-line keras.fit
) and should be seen as a collection of helpful primitives we can put together how we want.
Taking a gradient in JAX#
Built-in to JAX is an autograd engine, which lets us automatically compute gradients of functions with jax.grad
. We can take a pure function mapping x
to y
, and get a new function which calculates the gradient of y
with respect to x
.
Our high-level strategy will be to write our neural networks using Flax, then use jax.grad
to run gradient descent and minimize loss.
def loss(x):
return jnp.square(x - 2).mean()
grad_x = jax.grad(loss) # grad(x) = 2 * (x-2)
print('grad(x=8) = 2*(x-2) =', grad_x(8.0)) # 12.0
grad(x=8) = 2*(x-2) = 12.0
Flax Modules#
The main abstraction of Flax is a module, which represents some kind of computation along with relevant parameters. Modules are composable, so we can construct complicated modules out of simpler ones. We will view everything from a feedforward layer to an entire neural network as a module.
Flax separtes the concept of the computation graph and the actual parameters. When we write a Flax module, we can declare parameters with an initialization function that describes how to create them (e.g. sample from a normal distribution). These parameters are treated as placeholders until we actually call init
on the module, at which we can handle the parameters as a separate object.
import flax.linen as nn
class MyModule(nn.Module):
output = 3
@nn.compact
def __call__(self, x):
weights = self.param('weights', nn.initializers.lecun_normal(),
(self.output, x.shape[-1]))
return weights @ x
x = jnp.array([1,2,3])
m = MyModule()
k = jax.random.key(42)
params = m.init(k, x) # Initialize the module, which creates parameters.
y = m.apply(params, x) # Call the module
print('Parameters:', params['params']['weights'])
print('Output:', y)
Parameters: [[-0.606425 -0.58962685 -0.4824153 ]
[ 0.33814558 1.2511139 0.02361481]
[ 0.5528764 -0.15400036 0.06896964]]
Output: [-3.2329245 2.9112177 0.4517846]
In the example above we defined our module from scratch, but in general we can compose our networks out of built-in Flax primitives.
class MyNet(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.LayerNorm(x)
x = nn.Dense(128)(x)
x = jnp.tanh(x)
return x
Pytrees#
Flax modules naturally have a tree-like structure. For example, the MyNet
module above has an nn.Dense
sub-module, which contains kernel
and bias
parameters. We want a nice way to reason about trees without having to know their structure.
JAX uses the pytree as a helpful abstraction. Pytrees are just native Python objects – they can be dictionaries, lists, or tuples. JAX provides jax.tree_utils
helper functions, which can operate over pytrees of arbitrary structure. For example, if we wanted to add ten to every parameter in params
:
# This works regardless if `params` is a list, a dictionary, a list of dicts, etc.
params = m.init(k, x)
params_new = jax.tree_util.tree_map(lambda x : x + 10, params)
print('params:', params)
print('params_new:', params_new)
params: {'params': {'weights': Array([[-0.606425 , -0.58962685, -0.4824153 ],
[ 0.33814558, 1.2511139 , 0.02361481],
[ 0.5528764 , -0.15400036, 0.06896964]], dtype=float32)}}
params_new: {'params': {'weights': Array([[ 9.393575, 9.410373, 9.517585],
[10.338145, 11.251114, 10.023615],
[10.552876, 9.846 , 10.06897 ]], dtype=float32)}}
Optimization#
Let’s build a simple gradient descent algorithm out of these building blocks. We’ll start by defining a simple neural network, then we will write down our loss function, and pass the entire thing through jax.grad
. We will iteratively compute the gradient and adjust our parameters slightly, minimizing our loss.
class MyNet(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(128)(x)
x = nn.relu(x)
return nn.Dense(1)(x)
x = jnp.arange(128) / 128
y_true = jnp.sin(x)
net = MyNet()
k = jax.random.key(42)
params = net.init(k, x)
def loss_fn(params, x):
y_pred = net.apply(params, x)
return jnp.mean((y_pred - y_true)**2)
loss_and_grad = jax.jit(jax.value_and_grad(loss_fn))
for i in range(10):
loss, grad = loss_and_grad(params, x)
params = jax.tree_util.tree_map(lambda p, g : p - g*0.005, params, grad)
if i % 3 == 0:
print(f'Iter {i}: loss = ', loss)
Iter 0: loss = 0.50568265
Iter 3: loss = 0.07405198
Iter 6: loss = 0.061948203
Iter 9: loss = 0.06161511
Writing our gradient descent step by hand is nice for understanding, but Flax also supports the optax
package which gives us a set of common optimizers to use. To achieve the same behavior as above, we can use optax.sgd
:
import optax
tx = optax.sgd(learning_rate=0.005)
params = net.init(k, x)
opt_state = tx.init(params) # This is empty for SGD, but some optimizers need it.
for i in range(10):
loss, grad = loss_and_grad(params, x)
updates, opt_state = tx.update(grad, opt_state)
params = optax.apply_updates(params, updates)
if i % 3 == 0:
print(f'Iter {i}: loss = ', loss)
Iter 0: loss = 0.50568265
Iter 3: loss = 0.07405198
Iter 6: loss = 0.061948203
Iter 9: loss = 0.06161511
If we wanted to use Adam instead, we can just change optax.sgd
to optax.adam
. Optax has clean implementations of AdamW too for weight decay.
Train State#
A common paradigm when using Flax is to define a training state that contains all the relevant things for our model. We will want to store the Flax module, its parameters, the optimization state, and the learning step. It can also be convenient to keep the RNG state here too. If the run ever crashes, as long as we can recover the training state we have everything we need to keep going.
Generally, we will JIT an update
function that returns a new TrainState
object each time.
class TrainState(flax.struct.PyTreeNode):
params: Any
opt_state: Any
step: int
@jax.jit
def update(train_state, x):
loss, grad = loss_and_grad(params, x)
updates, opt_state = tx.update(grad, opt_state)
params = optax.apply_updates(params, updates)
return train_state.replace(params=params, opt_state=opt_state,
step=train_state.step+1)