Optimal Transport
WGAN and Kantorovich-Rubinstein Duality
Цели урока
- Derive Kantorovich-Rubinstein duality and understand why it converts W1 into a supremum over 1-Lipschitz functions
- Implement WGAN-GP and spectral normalization critics and compare their Lipschitz enforcement strategies
- Apply W1 and KR duality to domain adaptation, distributional robustness, and fairness
Предварительные знания
- Wasserstein distances: primal coupling formulation and basic properties
- GAN training loop: generator-discriminator alternation
- Fenchel-Moreau duality for convex functions
- Gradient computation in PyTorch (autograd)
Classic GANs collapse, oscillate, and produce artifacts because JS divergence vanishes the moment real and fake distributions stop overlapping. The Wasserstein distance sees through disjoint supports and hands the generator a gradient that says: move this way, by this much.
- BigGAN (2018) used WGAN-GP to achieve FID 7.4 on ImageNet - a 3x improvement over JS-based GANs
- DALL-E 2 and Stable Diffusion use W1-inspired critics in their discriminator-free distillation steps
- Domain adaptation in medical imaging: aligning feature distributions across hospital scanners reduces false-negative rate by up to 30%
- Distributionally robust optimization based on W1 balls is standard practice in high-stakes ML deployments at tech companies
From Rubinstein's Norm to Wasserstein GANs
The Kantorovich-Rubinstein theorem dates to 1958, but its use in machine learning began with Arjovsky et al.'s 2017 WGAN paper. They noticed that JS divergence saturates during GAN training when distributions lack overlapping support - a near-universal condition early in training. Replacing JS with W1 via KR duality, with weight clipping enforcing Lipschitz, produced stable training curves for the first time. Gulrajani et al. quickly improved this with gradient penalty (WGAN-GP, NeurIPS 2017), resolving the capacity-killing effect of aggressive weight clipping. Today the KR dual formulation underlies score matching, flow matching, and modern alignment methods.
Kantorovich-Rubinstein Duality: W1 as a Supremum
The Wasserstein-1 distance between probability measures admits a dual formulation that transforms a hard optimization over couplings into a supremum over 1-Lipschitz functions. This duality is the mathematical engine behind Wasserstein GANs.
KR duality holds for general cost c(x,y) = d(x,y) when d is a metric. For squared-Euclidean cost one gets the c-transform duality used in Sinkhorn.
The dual objective - mean critic score on reals minus mean critic score on fakes - is exactly what the WGAN discriminator maximizes. The critic need not be a probability; it only needs to be 1-Lipschitz.
What constraint must the WGAN critic satisfy for the dual objective to equal W1?
KR duality requires the test function f to be 1-Lipschitz: |f(x)-f(y)| <= |x-y|. The gradient penalty enforces this by penalizing gradients deviating from unit norm.
WGAN Training: Gradient Penalty and Spectral Normalization
Enforcing the Lipschitz constraint on a neural network requires explicit regularization. Two approaches dominate: weight clipping (original WGAN), gradient penalty (WGAN-GP), and spectral normalization (SN-GAN). Each trades computational cost against constraint tightness.
WGAN-GP uses 5 critic updates per generator update. Spectral normalization allows 1:1 ratio and is cheaper per step, making SN-GAN faster in practice.
The WGAN critic score converges to the W1 distance between the data distribution and the generator distribution. This provides a meaningful training signal even when the two distributions have disjoint supports - a regime where JS divergence saturates at log 2.
Why does JS divergence fail when real and generated distributions have disjoint support?
When supports are disjoint, JS divergence equals log 2 everywhere and the gradient of the generator loss vanishes - training stalls. W1 remains finite and informative because it measures geometric distance even across disjoint supports.
Beyond Image Generation: W1 in Modern ML
The WGAN framework extends far beyond face generation. W1 and KR duality appear in domain adaptation, fairness constraints, distributional robustness, and large language model alignment.
Google Brain used WGAN-GP to train BigGAN (2018), achieving FID 7.4 on ImageNet 128x128 - a 3x improvement over prior state-of-the-art. The key was scaling critic capacity alongside generator capacity.
For small-to-medium dimensional problems, the POT library's `ot.emd` computes exact W1 in milliseconds without neural approximation. Neural critics are necessary only for high-dimensional data like images.
In distributional robustness optimization (DRO), what does the W1 ball radius epsilon control?
The W1 ball {Q : W1(Q,P) <= epsilon} captures all distributions that differ from training distribution P by at most epsilon in transport cost. Optimizing worst-case loss over this ball yields models robust to covariate shift.
KR Duality as a Universal Distance Tool
Kantorovich-Rubinstein duality transforms an intractable optimization over joint distributions into a tractable optimization over functions. This pattern - primal geometric problem dualized to a functional supremum - recurs throughout optimal transport: the c-transform in Sinkhorn, the Schrodinger bridge dual, and the multi-marginal dual. Mastering KR duality is mastering the algebraic language of Wasserstein geometry.
- Optimal Transport — Related topic
Итоги
- W1(mu, nu) = sup over 1-Lipschitz f of E_mu[f] - E_nu[f] - this is Kantorovich-Rubinstein duality
- WGAN trains a neural critic to approximate the KR dual, providing stable generator gradients across disjoint supports
- WGAN-GP enforces Lipschitz via gradient penalty on interpolated samples; spectral normalization is a cheaper per-step alternative
- KR duality applies beyond generation: domain adaptation, distributional robustness (DRO), and fairness all use W1 balls as constraint sets
Вопросы для размышления
- Weight clipping (original WGAN) constrains all weights to [-c, c]. Why does this kill model capacity, and how does gradient penalty solve this without the same pathology?
- Spectral normalization makes each layer individually 1-Lipschitz. Does composing n such layers give a 1-Lipschitz network overall? What is the actual Lipschitz constant?
- The W1 ball in distributional robustness has radius epsilon chosen by the practitioner. How would one calibrate epsilon to the expected covariate shift in a production deployment?
Связанные уроки
- ot-07-wgan — WGAN is the direct application of W1 distance in generative modeling
- ot-01-monge — Primal formulation underpins KR duality
- ot-04-sinkhorn — Sinkhorn provides an alternative to gradient penalty for W1 approximation
- ot-21 — Entropic regularization connects to W1 via epsilon -> 0 limit
- ot-03-wasserstein
- ot-25-flow-matching