Information Geometry

Geometry of Deep Learning Loss Landscapes

A neural network loss surface is not a flat Euclidean plane but a Riemannian manifold with the Fisher metric. The natural gradient moves along that geometry and converges faster - but requires inverting a d x d matrix for billions of parameters.

  • AlphaFold 2 and K-FAC: Google uses K-FAC for training AlphaFold 2; per-layer Kronecker factorization F_l approx A_l (kron) G_l reduces inversion cost from O((d_in d_out)^3) to O(d_in^3 + d_out^3)
  • ResNet-50 on ImageNet: Martens and Grosse (2015) showed K-FAC trains ResNet-50 3.5x faster than momentum SGD at equal top-5 accuracy
  • TRPO and PPO in RL: KL constraint between policies = bounded Fisher metric step; PPO approximates with a probability ratio clip
  • Adam: division by sqrt(v_t) approximates the diagonal of the Fisher matrix - explains Adam's robustness to heterogeneous parameter scales

Предварительные знания

  • Amari natural gradient
  • Backpropagation
  • Kronecker product
  • Natural Gradient
  • Quantum Information Geometry

Natural Gradient and Fisher Metric of Neural Networks

A neural network with parameters theta defines a parametric family of distributions {p_theta(y|x)}. The Fisher information matrix F(theta) = E[nabla_theta log p_theta * nabla_theta log p_theta^T] is the Riemannian metric on this manifold. At d = 10^8 parameters F is a 10^8 x 10^8 matrix; direct inversion is impossible. Practical algorithms use structured approximations, of which K-FAC is the most geometrically principled.

Practical K-FAC uses damped inversion: (A_l + lambda I)^{-1} and (G_l + lambda I)^{-1} instead of exact inverses. The damping parameter lambda plays the role of regularization and controls trust in the Fisher approximation, analogously to a trust-region radius.

Why is the natural gradient invariant to reparameterization of the network?

F(theta) is the metric tensor of the statistical manifold. Under reparameterization theta = g(phi) the metric transforms as F_phi = J^T F_theta J where J = dg/dphi. Then F_phi^{-1} nabla_phi L = J^{-T} F_theta^{-1} nabla_theta L - the same direction in distribution space regardless of the parameter coordinate system.

Loss Landscape Sharpness and Generalization

The loss surface of a neural network is a Riemannian manifold with the Fisher metric. Near a local minimum the loss Hessian approximates the Fisher matrix (Bartlett identity). Sharp minima - with large Hessian eigenvalues - tend to generalize poorly; flat minima generalize well. Geometry of the optimum therefore matters not just for convergence speed but for test performance.

The Generalization Gap phenomenon: large-batch SGD converges to sharp minima with worse test accuracy than small-batch SGD. K-FAC partially solves this: by using curvature information it takes smaller effective steps in sharp directions and larger steps in flat directions, finding flatter minima even with large batches. Google demonstrated this in Transformer training.

Why does K-FAC tend to find flatter minima than large-batch SGD?

In high-curvature directions lambda_i(F) is large, so (F^{-1})_ii = 1/lambda_i is small - the natural gradient takes a tiny step there. In flat directions lambda_i is small, so the step is large. This automatic curvature-adaptive scaling naturally avoids sharp minima without explicit sharpness regularization.

Fisher Geometry of Network Activations and Batch Normalization

Batch normalization can be interpreted through the lens of Fisher geometry: it standardizes activations to approximately unit Fisher norm, making the loss landscape more isotropic. This explains empirically why BatchNorm accelerates training - it preconditions the optimization problem by normalizing the effective curvature across layers.

Layer normalization (used in Transformers) and group normalization are geometric cousins of batch normalization - all reduce the condition number of the Fisher matrix by standardizing different subsets of activations. The underlying principle is the same: make the local Fisher metric more isotropic, reducing the effective condition number for optimization.

How does batch normalization reduce the condition number of the per-layer Fisher matrix in K-FAC?

In K-FAC the condition number kappa(F_l) = kappa(A_l) * kappa(G_l) by the Kronecker structure. BatchNorm normalizes input activations so a_{l-1} approx zero-mean unit variance, making A_l = E[a a^T] approx I with kappa(A_l) approx 1. This removes one multiplicative factor from the condition number, directly improving the convergence rate of K-FAC and SGD alike.

Connections to other topics

Fisher geometry of neural network parameter space connects optimization, information theory, and differential geometry.

  • Newton's method — Related topic
  • TRPO and PPO — Related topic
  • Quantum natural gradient — Related topic

Итоги

  • F(theta) = E[nabla log p * nabla log p^T]: Riemannian metric on the d-dimensional distribution manifold
  • Natural gradient theta_{t+1} = theta_t - eta F^{-1} nabla L: reparameterization-invariant, converges faster than SGD
  • K-FAC: F_l approx A_l (kron) G_l; inversion in O(d_in^3 + d_out^3) vs. O((d_in d_out)^3)
  • Adam's v_t approx diag(F): diagonal natural gradient - robust but misses cross-parameter correlations
  • TRPO KL constraint = bounded Fisher metric step, guaranteeing monotone policy improvement
Geometry of Deep Learning Loss Landscapes

0

1

Sign In