Information Geometry
K-FAC: Kronecker-factored approximation of the Fisher matrix
The Fisher matrix for GPT-3 with 175 billion parameters has 3*10^22 entries - physically impossible to store. K-FAC (2015) addresses this with a Kronecker factorization that drops memory from O(n^4) to O(n^2). Google Brain used K-FAC to speed up training of AlphaGo Zero.
- DeepMind's distributed K-FAC (2021) trained ResNet-50 on ImageNet in 55 epochs instead of 90 with SGD at the same accuracy. Meta uses K-FAC-inspired methods in optimizers for LLaMA.
Kronecker factorization of Fisher blocks
The Fisher matrix for a 100M-parameter neural network would have 10^16 entries - storing it is infeasible. K-FAC (Martens, Grosse 2015) approximates each block F_l for layer l as a Kronecker product A_l x G_l, where A_l is the input covariance and G_l is the covariance of output gradients. This drops memory from O(n^4) to O(n^2) per layer. Google Brain applied K-FAC to train AlphaGo Zero and transformers.
What is the cost of inverting a K-FAC block for a layer with n inputs and m outputs?
Convergence and practical settings of K-FAC
K-FAC requires periodic refresh of the A and G statistics (every 10-100 steps) and inversion (every 100-500 steps). Martens & Grosse showed that with the right damping K-FAC converges in 10-20x fewer iterations than SGD on recurrent network training tasks. On ResNet-50 (ImageNet) K-FAC reaches 75% top-1 in 55 epochs versus 90 for SGD.
Why does K-FAC not refresh the A and G statistics on every step?
Distributed K-FAC and extensions
Distributed K-FAC (Osawa et al., 2019; Pauloski et al., 2021) scales the algorithm to hundreds of GPUs: the A_l and G_l statistics are computed in parallel across layers, and inversion is sharded across processes. On ImageNet with 512 GPUs, K-FAC matches SGD accuracy in 35 epochs instead of 90.
Extensions of K-FAC: EKFAC (George et al., 2018) additionally rescales by eigenvalues for a better fit; KFRA accounts for nonlinear activations; Shampoo (Gupta et al.) generalizes K-FAC to tensor-shaped parameters.
What is the advantage of eigendecomposing the factors A_l and G_l in K-FAC?
In the eigenbasis the Kronecker product is diagonal: (Λ_A ⊗ Λ_G)_{ii} = λ_i^A · λ_i^G. Inversion is then just division of each entry by λ_i^A · λ_i^G + λ_damp, which costs O(nm) instead of O((nm)^3).
Key takeaways
- K-FAC approximates a layer's Fisher block as A_l x G_l.
- Inversion costs O(n^3 + m^3) instead of O((nm)^3).
- Statistics refresh with momentum every T steps, inversion runs every T' steps.
- In practice K-FAC delivers a 2-10x reduction in iterations at comparable per-step cost.
- The eigendecomposition lets us invert (A⊗G+λI) in O(nm) within the eigenbasis.