Optimal Transport
Wasserstein Distances
WGAN (2017) was the first GAN to train stably without mode collapse , thanks to W₁. Sinkhorn made OT differentiable on GPU, opening the era of differentiable OT in deep learning.
- **WGAN-GP (Gulrajani 2017):** gradient penalty for Lipschitz constraint. Standard for high-resolution GAN before the diffusion era.
- **POT library:** Python Optimal Transport , scipy-style API for Sinkhorn, LP, sliced OT. 1M+ downloads/month.
- **GeomLoss (Feydy 2019):** GPU-accelerated Sinkhorn divergences for point clouds. Used in 3D shape matching.
The Kantorovich Problem: W_p Distance Definition
**Wasserstein GAN (WGAN, 2017) uses W₁ distance instead of JS divergence , eliminates mode collapse, stabilizes training for generating 1024×1024 images.** The core question: how to formally measure distance between two probability measures?
W_p defines a metric on the space of probability measures with finite p-th moment. For p=2 it is a Riemannian metric on an infinite-dimensional space , the geometry underlying geodesics and gradient flows.
What is W₁(δ₀, δ₁) , the distance between two point masses at 0 and 1?
W₁(δ₀, δ₁) = |0−1| = 1: all unit mass must be moved a distance of 1.
Kantorovich Duality: 1-Lipschitz Functions
Solving the primal Kantorovich problem via LP is expensive for continuous measures. **Kantorovich-Rubinstein duality** reformulates W₁ as a supremum over functions , exactly what WGAN uses to train the critic.
WGAN-GP (Gulrajani 2017): penalize ‖∇f(x̂)‖₂ ≠ 1 on interpolations x̂ = αx_real + (1−α)x_fake. This soft Lipschitz constraint is more stable than weight clipping.
What constraint does WGAN impose on the critic?
By Kantorovich-Rubinstein duality the supremum is over 1-Lipschitz functions. WGAN enforces this via weight clipping or gradient penalty.
Sinkhorn Algorithm: Fast OT Computation
Solving the Kantorovich LP has O(n³) complexity. **Entropic regularization** (Cuturi 2013) adds −εH(γ) to the objective and yields a closed-form solution via Sinkhorn iterations , O(n²) per iteration, fully GPU-parallelizable.
As ε→0 the regularized solution converges to exact OT. In practice ε = 0.05 - 0.5 gives good approximation in 50 - 200 iterations. Sinkhorn is implemented in POT, GeomLoss, and OTT-JAX (GPU-native).
What does ε→0 give in the Sinkhorn algorithm?
As ε→0 the entropic penalty vanishes and the regularized solution converges to the Kantorovich optimal transport plan.
Key Ideas
- **W_p distance:** infimum of transport cost ∫∫|x−y|^p dγ over all plans γ∈Π(μ,ν).
- **K-R duality:** W₁ = sup_{‖f‖_L≤1} [∫f dμ − ∫f dν]. WGAN trains the critic as a 1-Lipschitz function.
- **Sinkhorn:** add −εH(γ), optimum γ*=diag(u)Kdiag(v). Iterate u←μ/(Kv), v←ν/(Kᵀu). O(n²) per iter.
- **ε→0:** regularized OT → exact OT. Trade-off: small ε = accuracy, large ε = speed.
Связанные уроки
- ot-17-applications — Continues the OT applications thread
- ot-07-wgan — WGAN uses W1 via duality