Backpropagation#
Backpropgation is a technique for recursively calculating the gradients of neural networks.
The essential recipe for training neural networks is gradient descent. The gradient of a neural network represents the first-order derivatives of its parameters with respect to loss. In other words, if we iteratively update our parameters \(\theta\) by their negative gradient:
then with a small enough step size, our loss should decrease. So, how do we calculate what the gradient is?
Chain Rule#
Remember that neural networks are just sequences of functions that feed into one other. We can use the chain rule to tell us how to calculate gradients of composed functions. Given a composition \(f(g(x))\):
i.e., we can take the gradients of each sub-function and multiply them together, giving us the gradient of the entire composition.
The chain rule works even if \(f\) and \(g\) are vector- or matrix-valued functions. Using the chain rule, we know that we can recursively compute the gradients of complex neural networks by breaking them down into simple components.
Example: Gradient of a Two-Layer MLP#
Let’s go through a simple two-layer neural network as an example. We’ll derive the gradient of our network by hand. Let’s assume we have a some inputs x
and outputs y
and we would like to minimize mean squared error. Our network would look like:
W1 = jnp.array(np.random.randn(128, 16) / np.sqrt(16))
b1 = jnp.array(np.random.randn(128) * 0.01)
W2 = jnp.array(np.random.randn(1, 128) / np.sqrt(128))
x = jnp.array(np.random.randn(16, 10))
y = jnp.array(np.random.randn(1, 10))
# Loss calculation
h1 = W1 @ x
h2 = h1 + b1[:, None]
h3 = jnp.clip(h2, 0, None) # ReLU
h4 = W2 @ h3
loss = jnp.mean((h4 - y)**2)
We want to compute the gradient of the loss with relation to parameters W1, b1, W2
. We can do this iteratively, starting from the back and working forwards. We will start with the gradient of the network output with respect to the loss function.
d_output = 2 * (h4 - y) # Gradient of x^2 = 2x
d_h4 = d_output / y.size # Gradient of jnp.mean()
Now that we know the gradient of h4
with respect to loss, we can work backwards. As an example, let’s calculate the gradient of h3
. We know that h4 = W2 @ h3
. Via the chain rule:
d_h3 = W2.T @ d_h4
and the same recursive logic can be used to calculate h2
and h1
:
d_h2 = d_h3 * (h2 > 0)
d_h1 = d_h2
Finally, we can calculate the gradients for the actual parameter vectors in terms of the h
gradients. Each of these is a simple one-step operation.
d_W2 = d_h4 @ h3.T
d_b1 = jnp.mean(d_h2, axis=1)
d_W1 = d_h1 @ x.T
Let’s sanity check by comparing to jax.grad
’s automatic differentation engine.
def manual_grad(params, data_input, data_output):
W1, b1, W2 = params
h1 = W1 @ data_input
h2 = h1 + b1[:, None]
h3 = h2 * (h2 > 0)
h4 = W2 @ h3
loss = jnp.mean(jnp.square(h4 - data_output))
d_h4 = 2 * (h4 - data_output) / data_output.shape[1]
d_h3 = W2.T @ d_h4
d_h2 = d_h3 * (h2 > 0)
d_h1 = d_h2
d_W2 = d_h4 @ h3.T
d_b1 = jnp.mean(d_h2, axis=1)
d_W1 = d_h1 @ data_input.T
return [d_W1, d_b1, d_W2], loss
manual_grad_output = manual_grad([W1, b1, W2], x, y)
# JAX automatic gradient.
def loss(params, data_input, data_output):
W1, b1, W2 = params
h1 = W1 @ data_input
h2 = h1 + b1[:, None]
h3 = h2 * (h2 > 0)
h4 = W2 @ h3
return jnp.mean(jnp.square(h4 - data_output))
jax_grad = jax.grad(loss)
jax_grad_output = jax_grad([W1, b1, W2], x, y)
print('Grads match?', jnp.allclose(jax_grad_output[0], manual_grad_output[0][0]))
Grads match? True
Forward and Backward#
By now, you should see a structure in the way our gradient is computed. We first run through the neural network as normal, during the forward pass. We need to keep all the intermediate features in memory, since we will need them later. Next, we’ll start from the loss function and move in reverse for the backward pass, which computes the gradients for each parameter. This is the backpropagation algorithm.
Automatic Differentiation#
Machine learning libraries today generally implement a version of automatic differentiation that lets us skip our gradient functions by hand. JAX gives us jax.grad
, which we used in the example to sanity check our hand-written gradient.
Under the hood, automatic differentiation libraries do exactly what we did above – calculate gradients by taking a forward pass, storing the intermediate values, then doing a backwards pass.
Do we need to store intermediate values?#
Storing intermediate values lets us avoid unneccssary computation. It’s possible to calculate the gradient of each parameter from scratch, but that would be computationally wasteful. If we don’t have enough memory on the GPU to store every intermediate state, it’s possible to re-compute certain portions during the backward pass. This trades off memory usage for computational complexity.
Does backpropagation work for any network?#
Backpropagation is a recursive algorithm, and relies on modules which are either composed of backprop-able sub-modules, or have a manually written forward and backward function. Most mathematical operations (e.g. matrix multiplication, element-wise operations) have an analytical gradient we can use. If we want to use a forward function that doesn’t have a gradient, such as a rounding operation, we can’t naively differentiate it with backpropagation. In some cases, we define a surrogate backward pass that approximates the gradient we want to use for training, as done in the straight-through estimator.