Differential Geometry
Differential Geometry in Machine Learning
Why does Adam work so well across completely different architectures? Why do graph neural networks fail on citation networks that hyperbolic GCNs handle effortlessly? The answer is differential geometry-the math of curved spaces that modern ML is quietly built on.
- **Riemannian optimization**: training covariance matrices in brain-computer interfaces, rotation-equivariant networks in robotics, federated learning on the manifold of model weights
- **Hyperbolic embeddings**: Poincaré embeddings for WordNet achieve lower distortion than 200-dimensional word2vec using just 5 dimensions; HGCN outperforms flat GCN on biological networks
- **Information geometry**: KFAC (natural gradient with Kronecker-factored Fisher) trains ImageNet 3x faster than SGD; understanding why Adam works explains how to tune it for new architectures
Предварительные знания
Riemannian Optimization
**Riemannian optimization** generalizes gradient descent to curved manifolds. Standard SGD assumes flat Euclidean space-on a manifold M, the update must stay on M. The fix: use the **exponential map** to move along the geodesic in the gradient direction.
**Riemannian gradient**: project the Euclidean gradient onto the tangent space TₓM at the current point x. Then move via expₓ(-α · grad_M f(x)). The result lies on M exactly.
Why does this matter? Rotation matrices SO(n), covariance matrices Sym⁺(n), unit spheres Sⁿ-these are common constraint sets in ML. Euclidean SGD requires projection back onto the constraint after each step, which is slow and loses geometric structure. Riemannian SGD moves intrinsically along the manifold.
| Manifold | Application | geoopt class |
|---|---|---|
| SO(n) | Rotation learning, PCA | geoopt.Stiefel |
| Sym⁺(n) | Covariance, attention | geoopt.SymmetricPositiveDefinite |
| Sⁿ | Word embeddings, unit vectors | geoopt.Sphere |
| H^n | Hierarchical data | geoopt.PoincareBall |
Why is Riemannian gradient descent preferred over projected gradient descent on manifolds?
Hyperbolic Embeddings
Hyperbolic space H^n has constant negative curvature K < 0. Its key property: volume grows **exponentially** with radius, matching the structure of trees and hierarchical graphs. In Euclidean space, embedding a tree of depth d requires dimension O(2^d)-in H², dimension 2 suffices.
**Poincaré ball model**: H^n ≅ {x ∈ ℝⁿ : ||x|| < 1} with metric g = 4/(1-||x||²)² · g_Euclidean. Distance formula: d(x,y) = arcosh(1 + 2||x-y||²/((1-||x||²)(1-||y||²))). Points near the boundary represent nodes deep in the hierarchy.
HGCN (Hyperbolic Graph Convolutional Network) achieves state-of-the-art on citation networks with 10x fewer parameters than Euclidean GCN. The intuition: academic citation graphs are trees with occasional cross-links-exactly the structure that hyperbolic space encodes efficiently.
Numerical stability: ||x|| must stay strictly less than 1. Use `ball.projx(x)` after each update. The curvature c is a learnable parameter-use `geoopt.PoincareBall(c=c)` where c > 0.
Why does hyperbolic space embed trees more efficiently than Euclidean space?
Information Geometry
**Information geometry** treats the space of probability distributions as a Riemannian manifold. The metric is the **Fisher information matrix**: F_ij(θ) = E_{p(x|θ)}[∂log p(x|θ)/∂θᵢ · ∂log p(x|θ)/∂θⱼ]. This measures how fast the distribution changes with parameters.
**Natural gradient**: instead of θ ← θ - α·∇f(θ) (steepest descent in parameter space), use θ ← θ - α·F(θ)⁻¹·∇f(θ) (steepest descent in distribution space, measured by KL divergence). This is parameter-reparametrization invariant.
The **statistical manifold** of Gaussians N(μ, σ²) is the hyperbolic half-plane {(μ, σ) : σ > 0} with hyperbolic metric. Distance between two Gaussians equals their geodesic distance in H². This is why hyperbolic geometry appears naturally in probabilistic models.
Connection to transformers: the attention softmax maps queries to a simplex (probability simplex = statistical manifold). Information-geometric analysis of attention explains why scaled dot-product (1/√d scaling) stabilizes training-it controls curvature of the distribution manifold.
Adam optimizer is just a heuristic trick with momentum and adaptive learning rates
Adam approximates the natural gradient: the running variance v_t estimates the diagonal of the Fisher matrix F, so m_t/√v_t ≈ F⁻¹∇f
This information-geometric interpretation explains why Adam generalizes across architectures and why its hyperparameters (β₁, β₂) have robust default values
What does the natural gradient θ ← θ - α·F(θ)⁻¹·∇f amount to geometrically?
Key Ideas
- **Riemannian optimization**: use expₓ(-α · grad_M f) to update parameters on a manifold-stays on M exactly, no projection needed; geoopt implements this for SO(n), Sym⁺(n), H^n
- **Hyperbolic embeddings**: H^n has exponential volume growth matching tree structure; Poincaré ball model in 2D embeds deep hierarchies that need O(2^d) Euclidean dimensions
- **Information geometry**: probability distributions form a Riemannian manifold with Fisher metric; natural gradient = steepest descent in KL-divergence; Adam ≈ diagonal natural gradient
Related Topics
Differential geometry in ML connects to the entire DG curriculum:
- Geodesics and Exponential Map — Foundation of Riemannian optimization-the exp map is the update step
- Curvature (Riemann Tensor) — Constant negative curvature K<0 defines hyperbolic space; Fisher metric curvature explains optimization landscape
- Connections and Parallel Transport — Covariant derivative underlies the Riemannian gradient computation in geoopt
Вопросы для размышления
- Transformer attention maps tokens to a probability simplex. What does information geometry say about the geometry of that simplex, and how might it guide attention design?
- Riemannian optimization requires computing the exponential map, which can be expensive. When is Euclidean projected gradient descent the better choice instead, despite its theoretical inferiority?
- The Fisher information matrix F is often intractable to compute exactly. KFAC, diagonal approximations, and empirical Fisher are common workarounds. What geometric properties do they sacrifice?