Modern Optimizers#

This is a draft

The backbone of modern learning is gradient descent. We all know the pain of waiting for a model to train. So you can imagine a classic rite of passage is for researchers to think about ways to improve optimization. The current champion is Adam, however, a familiy of work has been building that claims to outperform Adam at the Pareto frontier of compute. In this post, we will explore the flavors of such optimizers, which we will refer to as whitening methods. Do such methods reliably outperform Adam? If so, in which ways do the various flavors have pros and cons?

Gradient Descent on Non-Euclidean Metrics#

When we calculate a gradient, we get a direction to adjust model parameters to reduce loss. But this gradient is only accurate in a local neighborhood. So we typicall take a small step in that direction, then re-calculate before moving again. This notion can be formalized by framing each step of gradient descent as solving the following distance-penalized problem:

\[ u =\text{argmin}_{\Delta\theta} \; \underbrace{\; g^T\Delta\theta \;}_{\text{Improvement}} + \underbrace{\alpha||\Delta\theta||^2}_{\text{Distance Penalty}} \quad = \quad \alpha g \]

where \(g = \nabla_\theta L(\theta,x)\). Traditionally, we assume a Euclidean distance over parameters, in which case the solution (as shown above) is simply the gradient scaled by a constant learning-rate factor \(\alpha\).

However, the Euclidean distance is an assumption, and is often suboptimal. Certain parameters may be more sensitive to higher-order changes than others, and thus should be assigned a larger penalty. We can generally represent second-order distances using a Riemannian metric \(M\), under which the distance of an update can be expressed as the matrix product:

\[ ||\Delta\theta||^2_M = \Delta\theta^T M \Delta\theta. \]

When we use \(M\) as the distance metric, the solution then becomes:

\[ u = \text{argmin}_{\Delta\theta} \; \underbrace{\; g^T\Delta\theta \;}_{\text{Improvement}} + \underbrace{(1/2)\Delta\theta^TM\Delta\theta}_{\text{Distance Penalty}} \quad = \quad M^{-1}g. \]

Why does it matter to use a good distance metric? The global learning rate of a neural network is typically bounded by the most sensitive parameter. Intuitively, if any parameter is updated with too high a learning rate, the system can oscillate or at worse, diverge. A good distance metric should even out the sensitivities, such that each parameter is updating at close to its optimal rate. For more intution, see the Optimization page, or at this beautiful blog post on momentum.

The Whitening Metric#

What metric should we use? Of course, there is not one answer. But modern optimizers have converged on embracing a certain metric in particular, which we will refer to here as the “whitening” metric following [Yang 2008]. Starting from the definition first:

\[ W = \mathbb{E}_x \left[ \nabla_\theta L(\theta,x) \nabla_\theta L(\theta,x)^T \right]^{1/2} = \mathbb{E}_x \left[ gg^T \right]^{1/2}. \]

Should we introduct the unifying concept as the whitening metric, or specifically the Kronecker-factored approximation of the whitening metric? This more directly relates to spectral descent methods, and a better name could be the “spectral-whitening” metric, which sets this metric apart from elementwise whitening metrics such as Adam.

Let’s look at three ways to interpret this object.

Relation to Newton’s Method#

The classical second-order optimization technique is Newton’s method, which performs steepest descent using the Hessian as the distance metric. The Hessian is a matrix of second-order derivatives:

\[ H = E_{x\sim D} \left[ \nabla_\theta^2 L(\theta,x) \right]. \]

Newton’s method can also be interpreted as using a quadratic approximation of the loss function rather than a linear one. We can estimate the penalty for taking large steps entirely from the second-order effects of the step on the loss.

However, Newton’s method has practical problems when training deep neural networks – first, calculating the Hessian itself is expensive, and second, the Hessian can potentially have negative eigenvalues and thus assign “negative” distance to certain directions. Both of these issues can be partially addressed by the Gauss-Newton approximation to the Hessian by considering only first-order terms:

\[ \nabla_\theta^2 L(\theta,x) = \underbrace{\nabla_\theta f(\theta,x) \nabla_\theta f(\theta,x)^T}_{\text{Gauss-Newton term}} + \underbrace{(f_\theta(\theta,x) - y) \nabla^2_\theta f(\theta,x)}_{\text{Dropped second-order term}}. \]

Unlike the Hessian, the Gauss-Newton matrix is always positive semi-definite, and is easy to compute as the inner product of two gradients.

\[ G = E_{x\sim D} \left[ \nabla_\theta f(\theta,x) \nabla_\theta f(\theta,x)^T \right]. \]

The whitening metric can now be understood as the square-root of the Gauss-Newton matrix.

What motivates the square-root? This is an excellent question, that I don’t have a great answer for. One viewpoint is that the square-root is a consersative estimate of \(G\), as \(G^{1/2}\) is the halfway point between \(I\) (Euclidean distance) and \(G\). In [Xi-Lin Li 2018], it is shown that the square-root of \(G\) is the optimal solution for a preconditioner when accounting for noisy gradient estimates. [Yang 2008] argues that using the square-root of \(G\) better approximates geodesic flows towards the solution.

Relation to Natural Gradient Descent#

Another way to motivate the whitening metric is through natural gradient descent, which aims to perform descent over the manifold of probability distributions. Natural gradient descent uses the Fisher information matrix as a distance metric, which is defined as:

\[ F = E_{x\sim D, y \sim p_\theta(\cdot|x)} \left[ \nabla_\theta \log p_\theta(y|x) \nabla_\theta \log p_\theta(y|x)^T \right]. \]

The Fisher does not take the loss function into account. It is only affected by the shape of the probability distribution itself, as defined by the current neural network. This is why natural gradient descent is often motivated as being parameterization-invariant, and (at small enough steps) follows the same optimization trajectory regardless of the underlying parameter structure.

Note the particular expectation – \(y\) must be sampled from the current distribution, not from the dataset labels. This means the true Fisher cannot be calculated simply from gradients over samples from the dataset. When we instead use dataset labels, we end up with the empirical Fisher:

\[ F = E_{x,y\sim D} \left[ \nabla_\theta \log p_\theta(y|x) \nabla_\theta \log p_\theta(y|x)^T \right]. \]

In most deep learning objectives, our loss is of the form \(\log p_\theta(y|x)\) – remember that mean-squared error is the log-probablity of a Gaussian distribution. So there’s an immediate relation between the empirical Fisher, the Gauss-Newton Hessian approximation, and the whitening metric.

Relation to Spectral Norm Descent#

Finally, examining the whitening metric under a particular parameterization reveals a third relation. First, let’s consider a single dense layer with \(\theta \in R^{m,n}\). A common trick to avoid inverting a giant \(gg^T \in R^{nm,nm}\) matrix is to factorize this matrix into its Kronecker factors:

\[ gg^T \quad \approx \quad (G^TG)^{1/2} \otimes (GG^T)^{1/2} \]

where \(g \in R^{nm}\) is the flattened gradient and \(G \in R^{m,n}\) is the gradient in matrix form.

Kronecker factorization is used in many modern optimizers, and brings two keys benefits. First, we can invert each factor independently, and the resulting Kronecker product is equivalent as if we had inverted the product directly. Second, we do not need to actually form the Kronecker product – instead, we can multiply the factors in sequence. These techniques form the following update, used in the KFAC [Martens 2015] and Shampoo [Gupta 2018] optimizer families:

\[ E[gg^T]^{-1/2}g \quad \approx \quad E[G^TG]^{-1/4} \; G \; E[GG^T]^{-1/4}. \]

When we remove the expectations, a new relation appears [Bernstein 2024]. Let’s write the matrix \(G\) in terms of its singular-value decomposition, \(G = U \Sigma V^T\):

\[\begin{split} \begin{align} u & = (G^TG)^{-1/4} \; G \; (GG^T)^{-1/4} \\ & = (U \Sigma^2 U^T)^{-1/4} \; U \Sigma V^T \; (V \Sigma^2 V^T)^{-1/4} \\ & = U \Sigma^{-1/2} U^T \; U \Sigma V^T \; V \Sigma^{-1/2} V^T \\ & = U \Sigma^{-1/2} \Sigma \Sigma^{-1/2} V^T \\ & = UV^T \end{align} \end{split}\]

The resulting update is the projection of the gradient update onto the closest orthogonal matrix. This is also the solution to steepest descent under the spectral norm, which measures distance as the maximum singular value of the matrix. Intuitively, the resulting \(UV^T\) matrix has singular values of entirely 1 or -1, and is thus updating maximally in every orthogonal direction.

Why do we want to descend on the spectral norm? This is another great question, which again I don’t have a great answer for. The spectral norm represents the maximum possible change in a layer’s output. In [Yang 2023], it is argued that in neural networks, inputs tend to align with a layer’s weights, so this bound is empirically tight.

Flavors of Optimizers#

Now, let’s get into the various practical implementations of optimizers that descend on an approximate whitening metric. In all methods, the raw gradients are calculated via backpropgation over a batch of input/output pairs. All methods also treat each dense layer independently, which can be seen as using a blockwise approximation of the true whitening metric. We’ll describe the optimizater code in terms of an incoming gradient, but in practice, a momentum is used instead. For nonstandard layers, such as layer norm vectors or embedding layers, Adam is used instead.

To keep things simple, we will focus on the core whitening behavior of the various optimizers. The original implementations of these methods have various details (e.g. learning rate grafting, Nesterov momentum, iterate averaging) that we will put aside.

Adam/RMSProp#

The most common approximation of the whitening metric is used in Adam, which takes the original inspiration from RMSProp. In these methods, we use an elementwise approximation of the whitening metric, treating each parameter as indpendent. This amounts to keeping track of a moving average of \(g^2\), then normalizing accordingly.

b2 = 0.9
def do_adam_rmsprop(grad, v):
    v = b2 * v + (1 - b2) * grad**2
    u = grad / jnp.sqrt(v + 1e-6)
    return v, u

Shampoo/SOAP/SPlus#

In the Shampoo [Gupta 2018] family, we explicitly keep track of the Kronecker factors of each dense layer’s gradient covariances. To save on computational costs, it is common to only perform the matrix inversion every \(N\) steps, and cache the results in between. Since the accumulated gradient covariances are always square and positive semi-definite, we can use a faster hermitian eigendecomposition to calculate the inverse.

def matrix_power(x, p):
    eigvals, eigvecs = jnp.linalg.eigh(x + 1e-30 * jnp.eye(x.shape[0]))
    eigvals = jnp.abs(eigvals)
    eigvals = 1 / (x**p + 1e-8)
    return eigvecs @ jnp.diag(eigvals) @ eigvecs.T

b2 = 0.9
def do_shampoo(grad, lg, rg, lg_inv, rg_inv, step):
    lg = b2 * lg + (1 - b2) * grad @ grad.T
    rg = b2 * rg + (1 - b2) * grad.T @ grad
    if step % 10 == 0:
        lg_inv = matrix_power(lg, -0.25)
        rg_inv = matrix_power(rg, -0.25)
    u = lg_inv @ grad @ rg_inv
    return lg, rg, lg_inv, rg_inv, u

SOAP [Vyas 2018] and SPlus [Frans 2025] are optimizers that build off Shampoo, with the intention of stabilizing training. The key idea is to view the Kronecker factors via their eigendecomposition:

\[\begin{split} u_\text{Shampoo} = \underbrace{\bar{U} \bar{\Sigma}^{-1/2} \bar{U}^T}_{E[G^TG]^{-1/4}} \; G \; \underbrace{\bar{V} \bar{\Sigma}^{-1/2} \bar{V}^T}_{E[GG^T]^{-1/4}} \\ \end{split}\]

and replace the inner rotation and division via a more expressive procedure. SOAP applies an inner Adam algorithm, keeping track of \(g^2\) in the rotated eigenbasis, while SPlus instead uses the sign function.

\[\begin{split} u_\text{SOAP} = \bar{U} \; \text{Adam} ( \bar{U}^T G \bar{V} ) \; \bar{V}^T \\ \end{split}\]
\[\begin{split} u_\text{SPlus} = \bar{U} \; \text{sign} ( \bar{U}^T G \bar{V} ) \; \bar{V}^T \\ \end{split}\]
def eigbasis(x, p):
    eigvals, eigvecs = jnp.linalg.eigh(x + 1e-30 * jnp.eye(x.shape[0]))
    return eigvecs

b2 = 0.9
def do_soap(grad, lg, rg, lg_eig, rg_eig, v_rot, step):
    lg = b2 * lg + (1 - b2) * grad @ grad.T
    rg = b2 * rg + (1 - b2) * grad.T @ grad
    if step % 10 == 0:
        lg_eig = eigbasis(lg)
        rg_eig = eigbasis(rg)
    g_rot = lg_eig.T @ grad @ rg_eig
    v_rot = b2 * v_rot + (1 - b2) * g_rot**2
    u_rot = g_rot / jnp.sqrt(v_rot + 1e-6)
    u = lg_eig @ u_rot @ rg_eig.T
    return lg, rg, lg_eig, rg_eig, v_rot, u

def do_splus(grad, lg, rg, lg_eig, rg_eig, step):
    lg = b2 * lg + (1 - b2) * grad @ grad.T
    rg = b2 * rg + (1 - b2) * grad.T @ grad
    if step % 10 == 0:
        lg_eig = eigbasis(lg)
        rg_eig = eigbasis(rg)
    u = lg_eig @ jnp.sign(lg_eig.T @ grad @ rg_eig) @ rg_eig.T
    return lg, rg, lg_eig, rg_eig, u

SOAP also introduces an alternate way to calculate the eigenbasis, using QR iteration instead of eigh. For simplicity, we will use the more direct version.

PSGD#

Preconditioned Stochastic Gradient Descent (PSGD) [Xi-Lin Li 2015] provides another method of whitening. We will focus on the Kron-Fisher version, which has been noted to be empirically effective. In the PSGD method, we aim to keep track of an explicit per-dimension preconditioner matrix, just as in Shampoo. However, instead of locating this preconditioner via a matrix inversion, we will do so in an iterative manner. The idea is that the optimal whitening procedure should minimize the following objective:

\[ cost(P) = E[g^T P g] + E[v^T P^{-1} v] \]

where \(g\) represents incoming gradients, and \(v\) represents independent random samples of Gaussian noise. The analytic positive-definite solution to the above equation is in fact exactly the whitening metric:

\[\begin{split} \begin{align} cost(P) & = trace(P E[gg^T] + P^{-1} E[vv^T]) \\ cost(P) & = trace(P E[gg^T] + P^{-1}) \\ \nabla cost(P) & = \nabla_P \; trace(P E[gg^T] + P^{-1}) = 0 \\ P & = E[gg^T]^{-1/2} \end{align} \end{split}\]

To discover \(P\) iteratively, we will start by enforcing symmetry and positive-definiteness by decomposing \(P = Q^TQ\), where \(Q\) is an upper diagonal matrix, with positive diagonals. We can now perform gradient descent over \(Q\) on the cost above:

\[\begin{split} cost(Q) = E[g^T Q^TQ g + v^T Q^{-T}Q^{-1} v] \\ \nabla_Q = 2 triu(Qg g^T Q^T - Q^{-1} v v^T Q^{-T}) \end{split}\]

where \(triu\) takes only the upper-triangular component.

The PSGD algorithm involves iteratively following the above “inner gradient” on Q. In the original paper, Q is updated using the relative gradient, along with a dynamic learning rate:

\[ Q \leftarrow Q - \alpha \nabla_Q * Q * (1/max(\nabla_Q)) \]

although in principle, a naive standard gradient works as well:

\[ Q \leftarrow Q - \alpha \nabla_Q \]
inner_lr = 1.0
def psgd_iter(g, Q):
    v = np.random.randn(g.shape[0])
    a = Q @ g
    b = jax.lax.linalg.triangular_solve(Q.T, v)
    delta = a @ a.T - b @ b.T
    Q = Q + delta * Q * inner_lr * (1/jnp.max(jnp.abs(delta)))
    return Q

def do_psgd(grad, Ql, Qr):
    lg = grad @ grad.T
    Ql = psgd_iter(lg, Ql)
    rg = grad.T @ grad
    Qr = psgd_iter(rg, Qr)
    u = (Ql @ Ql.T) @ grad @ (Qr @ Qr.T)
    return Ql, Qr, u

Explicit SVD#

Noting the relation between the whitening metric and the SVD, another flavor of optimizers aim to directly orthgonalize the gradients. This implicitly follows the whitening metric, without ever realizing an actual preconditioning matrix. The simplest way to achieve this is to explicitly take the singular vector decomposition at each iteration, as done in [Carlson 2015]. This directly descends along the spectral norm, but it is quite an expensive procedure.

def do_svd(grad):
    u, s, vt = jnp.linalg.svd(grad)
    return u @ vt

Muon#

The Muon optimizer [Jordan 2024] uses a neat trick to approximate the SVD at each iteration, saving on compute. The main idea is to use Newton-Schulz iteration, which is an iterative method for approximating \(G \rightarrow UV^T\). The intuition can be explained via polynomial iteration – for any number between [-1, 1], if you apply \(x <- x^3\) enough times, you will arrive that either \(-1\) or \(1\) in the limit. Newton-Shulz generalizes this notion to matrices, and can be understood as applying this procedure to the singular values of the matrix. For a detailed explanation, see this great page from the Modula project.

coeffs = (3, -16/5, 6/5) # Many choices work here.
def newton_schulz_iterator(x):
    a = x @ x.T
    b = coeffs[1] * a + coeffs[2] * a @ a
    return coeffs[0] * x + b @ x

def do_muon(grad):
    x /= jnp.linalg.norm(grad) + 1e-6  # Singular values must be in [-1, 1]
    for _ in range(5):  # Iterate a few times to converge.
        x = newton_schulz_iterator(x)
    return x

A Brief Benchmark#

So, do these modern optimizers actually outperform a well-tuned Adam, on settings we care about? Let’s give them a whirl.

We will largely build on the experimental settings we examined in the SPlus paper, with a few adjustments for more accurate comparison. We will consider the training of a standard GPT-2 style transformer on language modelling. The network will be trained for 10k gradient steps, with a batch size of 1024. All methods will use a warmup of 200 steps, and a cosine learning rate schedule decaying to zero. For all layernorm, input, and output parameters, we will use an Adam optimzer with a fixed tuned learning rate and weight decay. All methods use standard momentum. Training is done in bfloat16. We will measure performance on a fixed validation set.

The hardest part of optimizer comparisons is ensuring that all methods are properly tuned. For our experiments here, we will tune the learning rate, weight decay, momentum (b1), and variance accumulation (b2) where applicable. Learning rate is tuned within a resolution of \(\sqrt[4]{10}\), weight decay between a resolution of \(10\), and b1/b2 are chosen within 0.9, 0.95, 0.99, 0.995. For all methods, we ensure that our resulting best hyperparameters are within the local minimum of these range, i.e. independent changes to each hyperparameter does not improve performance.

For each optimizer, we will attempt to use the highest-performing variant. For Muon, we use the version with 15 unique coefficients. For SOAP/SPlus, we compute the matrix inverse every 10 steps. For PSGD, we update the preconditioner every step (instead of the default decaying probability).

It’s important to note that comparisons won’t ever be perfect, since we can’t feasibly try every combination of hyperparameters. In our specific setting, we found that varying the learning rate by a factor of \(\sqrt[4]{10}\) (the resolution we search over) result in a final validation of \(\pm 0.03\), so comparisons should be interpreted within this range of variation.

As raw validation losses are hard to interpret, we also calculate Steps-to-Adam, an estimate of ratio of gradient steps used in each optimizer vs. Adam, to reach the same validation loss. These are calculated by running Adam for (10000, 10500 ... 15000) steps.

Optimizer

LR

WD

b1

b2

Val Loss @ 10K

Steps-to-Adam

Adam

0.001

1.0

0.95

0.99

2.965 \(\pm\) 0.03

1.0

Shampoo

0.00132

1.0

0.95

0.99

2.975 \(\pm\) 0.03

> 1.1

SOAP

0.00132

1.0

0.95

0.99

2.929 \(\pm\) 0.03

0.75 to 0.8

SPlus

0.1

0.01

0.99

0.99

2.939 \(\pm\) 0.03

0.8 to 0.83

PSGD

0.000264

0.001

0.95

n/a

2.956 \(\pm\) 0.03

0.95 to 1

Muon

0.00578

0.1

0.95

n/a

2.952 \(\pm\) 0.03

0.9 to 0.95

Raw per-step performance is not the only thing that matters – in the end, wallclock time matters as well. We did not try to optimize runtime very hard in our experiments. Wallclock may vary greatly depending on hardware as well as distributed sharding, and many optimizers can trade off runtime for accuracy by updating the preconditining matrix less frequently. That said, an approximate ratio of runtimes for the above results are:

Adam

Shampoo

SOAP

SPlus

PSGD

Muon

1.0

3.54

3.66

3.79

3.20

1.05

Modded-NanoGPT Speedrun#

To sanity check these findings, let’s also see how these optimizers perform on the Modded-NanoGPT Medium optimization leaderboard. The main point of comparison here is Muon, which has been tuned by the community, so we can assume that hyperparameters have been tuned decently. The leaderboard already contains entries for Adam and PSGD (although they may not be heavily tuned), so we will focus on benchmarking SOAP and SPlus. In the speedrun guidelnes, the goal is to find the lowest number of steps to reach a validation loss of 2.92, with the architecture and data fixed. Since this involves sweeping over the training step length (as by default, a linear LR decay is used), we searched over a \(\pm 125\) resolution. Other hyperparameters are chosen in an ad-hoc manner.

Optimizer

LR

WD

b1

b2

Steps to 2.92

Description

Adam

0.0015

0.125

0.9

0.95

9500

Warmup=500

Muon

0.025

0.01

(0.85 - 0.95)\(^1\)

n/a

6125

Warmup=0

PSGD

0.0005

0.625

?

?

7875

Fisher-Kron

SOAP (new)

0.003

0.1

(0.85 - 0.95)\(^1\)

0.9

6000 \(\pm\) 125

Warmup=25. Inv every 5

SPlus (new)

1.5

0.3

(0.85 - 0.95)\(^1\)

0.9

6125 \(\pm\) 125

Warmup=25. Inv every 5

\(^1\)In the modded-nanoGPT base code, the b1 parameter is warmed up from 0.85 to 0.95 over the first 300 steps. We kept this behavior unchanged for SOAP/SPlus.

While these nanoGPT results are not as uniformly tuned as our above benchmark, the relationship between optimizers looks to be consistent.

Takeaways#

So in the end, what did we find? I believe a main takeaway is that spectral-whitening methods reliably outperform Adam, and they do so at roughly the same magnitude. [Zhao 2024] shows that elementwise-whitening optimizers (Adam, Signum, Lion, etc.) all perform similarly when hyperparameters are properly tuned. My current belief is that a similar relation holds for spectral-whitening methods. SOAP is the most effective optimizer per gradient-step, and this may be due to the role of the second moment in estimating signal-to-noise [Orvieto 2025]. In constrast, Muon and SPlus both approximate the instant sign of eigenvalues, which loses this property. Muon is particularly powerful due to its efficient computational properties. Likely, there is a method that can use the orthogonalization machinery of Muon while keeping the signal-to-noise estimation of Adam, and this optimizer will be great.

The corollary to the above conclusion is that we haven’t seen an optimization procedure that reliably goes beyond what the spectral-whitening optimizers can achieve. I would be very curious to see what such a method would look like.