Optimal Transport

OT in Generative Models and NLP

Stable Diffusion 3 and Flux generate images via Rectified Flow, which is OT in continuous time. WGAN stabilized GAN training via a Wasserstein loss. Model merging (Mistral, Llama derivatives) uses OT barycenters. Optimal Transport stopped being academic theory; it is now the tool behind image and text generation in production.

  • **Stable Diffusion 3 and FLUX.1:** Rectified Flow / Flow Matching - straight OT paths from noise to image. 25 steps instead of 1000. Open source, billions of downloads.
  • **WGAN-GP:** standard for high-quality generation before diffusion models. StyleGAN2 uses R1 regularization, a variant of the gradient penalty from WGAN-GP.
  • **MBR decoding in Google Translate and NLLB at Meta:** Minimum Bayes Risk with Wasserstein/chrF metrics improves translation quality without changing the model.

Flow Matching: Stable Diffusion via OT in Continuous Time

**Stable Diffusion 3 and Flux use Rectified Flow, which is optimal-transport flow matching.** The idea: connect noise (N(0, I)) with data (an image) by a straight line in space. The network learns to predict the direction of motion along this line. Simpler and faster than DDPM diffusion.

Flow matching trains a vector field v_theta(x, t) such that the ODE dx/dt = v_theta(x, t) transports p_0 (noise) into p_1 (data). The OT flow matching choice: v*(x, t) = x_1 - x_0, the straight line from noise to data. This is the minimal path under the Wasserstein metric.

**Flow Matching (Lipman 2022, Liu 2022):** Training: min_theta E over t, x_0, x_1 of ||v_theta(t * x_1 + (1 - t) * x_0, t) - (x_1 - x_0)||^2 where t ~ Uniform(0, 1), x_0 ~ p_0 (noise), x_1 ~ p_1 (data). **Rectified Flow (Liu 2022):** ODE dx/dt = v_theta(x, t). Start from x_0 ~ N(0, I), solve the ODE to t = 1, get x_1 ~ p_data. **Advantage vs DDPM:** - DDPM: 1000 Euler-solver steps - Rectified Flow: 10-25 steps (straight paths) - Stable Diffusion 3 and Flux use exactly this **OT connection:** OT coupling (x_0, x_1) gives straight paths, minimizing E[||x_0 - x_1||^2], the transport cost. The 'Reflow' iteration retrains on straight pairs to make paths even straighter.

Why does Rectified Flow need approximately 40x fewer steps than DDPM (25 vs 1000)?

DDPM has curved paths in data space, requiring many steps for accurate SDE approximation. RF learns a velocity field that transports along straight lines. Straight paths are approximated very well even by a single Euler step.

WGAN: Wasserstein as Adversarial Loss

The original GAN minimizes Jensen-Shannon divergence. When supports do not overlap (common early in training) the gradient is zero and training stalls. **WGAN** (Arjovsky 2017) uses Wasserstein-1 as the loss: it always gives an informative gradient and does not need exact support matching.

**Wasserstein-1 via Kantorovich-Rubinstein duality:** W_1(p_r, p_g) = max over ||f||_L <= 1 of E_{x ~ p_r}[f(x)] - E_{x ~ p_g}[f(x)] where f is a 1-Lipschitz function. **WGAN algorithm:** - Critic f_w (instead of a discriminator): max over w of E[f_w(real)] - E[f_w(fake)] - Generator G_theta: min over theta of -E[f_w(G_theta(z))] - Lipschitz constraint: weight clipping [-c, c] (original) or gradient penalty ||grad f|| = 1 (WGAN-GP) **WGAN-GP (Gulrajani 2017):** penalty on ||grad_xhat f(xhat)||_2 not equal to 1 along interpolations between real and fake. More stable than weight clipping.

Why does the WGAN critic not use sigmoid and binary cross-entropy?

Classical GAN minimizes JS divergence via BCE; when supports do not overlap, JS = log 2 and gradients vanish. The WGAN critic outputs f(x) in R without bounds, approximating W_1 = sup over ||f||_L <= 1 of E[f(real)] - E[f(fake)].

OT in LLMs: Beam Search, Decoding, and Distributed Training

OT appears in LLMs in three places: (1) **Minimum Bayes Risk decoding** picks the prediction minimizing expected Wasserstein distance to other hypotheses; (2) **Knowledge distillation** via Wasserstein loss between token distributions; (3) **OT Barycenters** for federated learning, averaging models as a Wasserstein barycenter.

**OT Barycenter for federated learning:** Federated learning: K clients, each trains its own model theta_k. Standard averaging: theta_avg = (1/K) sum theta_k (FedAvg). Problem: weight averaging is poor with heterogeneous data. Wasserstein barycenter of neural networks: theta* = argmin over theta of sum_k w_k * W_2^2(nu_theta, nu_theta_k) where nu_theta is the activation distribution of the model theta. Solution: align neuron permutations (activation matching) before averaging. Model Merging (2023): Git Re-basin, SLERP interpolation - concrete algorithms for LLM merging without retraining.

Model merging via Wasserstein barycenter is used in production: Mistral-7B-Instruct-v0.3 is a merger of several checkpoints via SLERP/TIES methods. The Hugging Face mergekit library implements Git Re-basin (permutation matching), SLERP interpolation, and TIES merging.

Why does FedAvg behave poorly with heterogeneous data, and how does a Wasserstein barycenter address this?

Neural networks are equivalent up to permutations of neurons. In FedAvg, neuron 1 of client A is averaged with neuron 1 of client B, even though they may do different things. WB finds the optimal permutation via OT before averaging activations.

Key ideas

  • **Flow Matching:** train v_theta(x, t) approximately x_1 - x_0 along straight paths. The ODE dx/dt = v_theta transports noise to data in 25 steps. Stable Diffusion 3, Flux.
  • **WGAN:** the critic f_w maximizes E[f(real)] - E[f(fake)] under ||f||_L <= 1. W_1 approximately equals the critic score. WGAN-GP uses gradient penalty instead of weight clipping.
  • **MBR decoding:** pick the hypothesis with min E_{h'}[dist(h, h')]. Wasserstein between token distributions gives a semantically meaningful metric.
  • **OT Barycenters:** theta* = argmin sum w_k * W_2(nu_theta, nu_theta_k). Federated learning and model merging without retraining.
  • **Rectified Flow vs DDPM:** straight OT paths -> less curvature -> fewer Euler-solver steps. Reflow iterations make paths even straighter.

Related topics

OT in modern ML:

  • Flow Matching (basics) — Theory of flow matching and CFM
  • Wasserstein Gradient Flows — Diffusion as gradient flow - the mathematical foundation of diffusion models

Связанные уроки

  • ot-11-flow-matching — Flow matching is OT in continuous time
  • ot-07-wgan — WGAN uses W1 as the adversarial loss
  • ot-14-gradient-flows — Gradient flows explain the dynamics of diffusion models
OT in Generative Models and NLP

0

1

Sign In