Optimal Transport

Wasserstein Distances

WGAN (2017) was the first GAN to train stably without mode collapse , thanks to W₁. Sinkhorn made OT differentiable on GPU, opening the era of differentiable OT in deep learning.

  • **WGAN-GP (Gulrajani 2017):** gradient penalty for Lipschitz constraint. Standard for high-resolution GAN before the diffusion era.
  • **POT library:** Python Optimal Transport , scipy-style API for Sinkhorn, LP, sliced OT. 1M+ downloads/month.
  • **GeomLoss (Feydy 2019):** GPU-accelerated Sinkhorn divergences for point clouds. Used in 3D shape matching.

The Kantorovich Problem: W_p Distance Definition

**Wasserstein GAN (WGAN, 2017) uses W₁ distance instead of JS divergence , eliminates mode collapse, stabilizes training for generating 1024×1024 images.** The core question: how to formally measure distance between two probability measures?

W_p defines a metric on the space of probability measures with finite p-th moment. For p=2 it is a Riemannian metric on an infinite-dimensional space , the geometry underlying geodesics and gradient flows.

What is W₁(δ₀, δ₁) , the distance between two point masses at 0 and 1?

W₁(δ₀, δ₁) = |0−1| = 1: all unit mass must be moved a distance of 1.

Kantorovich Duality: 1-Lipschitz Functions

Solving the primal Kantorovich problem via LP is expensive for continuous measures. **Kantorovich-Rubinstein duality** reformulates W₁ as a supremum over functions , exactly what WGAN uses to train the critic.

WGAN-GP (Gulrajani 2017): penalize ‖∇f(x̂)‖₂ ≠ 1 on interpolations x̂ = αx_real + (1−α)x_fake. This soft Lipschitz constraint is more stable than weight clipping.

What constraint does WGAN impose on the critic?

By Kantorovich-Rubinstein duality the supremum is over 1-Lipschitz functions. WGAN enforces this via weight clipping or gradient penalty.

Sinkhorn Algorithm: Fast OT Computation

Solving the Kantorovich LP has O(n³) complexity. **Entropic regularization** (Cuturi 2013) adds −εH(γ) to the objective and yields a closed-form solution via Sinkhorn iterations , O(n²) per iteration, fully GPU-parallelizable.

As ε→0 the regularized solution converges to exact OT. In practice ε = 0.05 - 0.5 gives good approximation in 50 - 200 iterations. Sinkhorn is implemented in POT, GeomLoss, and OTT-JAX (GPU-native).

What does ε→0 give in the Sinkhorn algorithm?

As ε→0 the entropic penalty vanishes and the regularized solution converges to the Kantorovich optimal transport plan.

Key Ideas

  • **W_p distance:** infimum of transport cost ∫∫|x−y|^p dγ over all plans γ∈Π(μ,ν).
  • **K-R duality:** W₁ = sup_{‖f‖_L≤1} [∫f dμ − ∫f dν]. WGAN trains the critic as a 1-Lipschitz function.
  • **Sinkhorn:** add −εH(γ), optimum γ*=diag(u)Kdiag(v). Iterate u←μ/(Kv), v←ν/(Kᵀu). O(n²) per iter.
  • **ε→0:** regularized OT → exact OT. Trade-off: small ε = accuracy, large ε = speed.

Связанные уроки

  • ot-17-applications — Continues the OT applications thread
  • ot-07-wgan — WGAN uses W1 via duality
Wasserstein Distances

0

1

Sign In