Optimal Transport
Wasserstein Barycenters
Average two handwritten '2's pixel by pixel and the result is a blurry grey smear. Average them in Wasserstein space and the result is a third '2', sharp and intermediate. Agueh and Carlier (2011) gave this morphing operation a name: the Frechet mean in $(\mathcal{P}_2, W_2)$, a Wasserstein barycenter.
- **Shape interpolation**: geodesics in W2 between organ shapes (cardiac MRI) preserve topology. Shape atlases in medical imaging
- **Color transfer and grading**: Sinkhorn barycenter of RGB histograms across frames aligns color palettes. DaVinci Resolve uses a similar approach for color grading
- **Federated learning**: Wasserstein aggregation of client model weights is more robust than FedAvg under heterogeneous data (different class distributions across clients)
- **Diffusion equations**: JKO scheme (Jordan-Kinderlehrer-Otto 1998) - each Fokker-Planck step is a proximal step in W2. Molecular dynamics and Langevin optimization of neural networks
Предварительные знания
- Wasserstein distance $W_2$ and its interpretation as optimal transport cost
- Sinkhorn algorithm: entropy regularization and matrix iterations
Frechet Mean in Metric Space
Average two handwritten images of the digit '2' pixel by pixel and the result is a blurry grey smear. The pixels of the two images are not geometrically aligned, so arithmetic averaging blends them like mixed paint. Average those same images in Wasserstein space and something different happens: the mass of each digit flows smoothly into an intermediate shape, preserving the contour. That is not blur - it is morphing.
The formalization goes back to Frechet (1948): in $\mathbb{R}^n$ the mean of several points is $\arg\min_x \sum_i \|x - x_i\|^2$. In an arbitrary metric space $(M, d)$ the same definition applies: $\arg\min_{p \in M} \sum_i d^2(p, p_i)$. This is the **Frechet mean**. In the space of probability measures $(\mathcal{P}_2, W_2)$ it becomes the Wasserstein barycenter problem - Agueh & Carlier, 2011:
Weights $\lambda_i$ control attraction toward each measure: $\lambda_1 = 1$ recovers $\mu_1$, equal weights give the symmetric average. For measures on the real line $\mathbb{R}^1$ the barycenter has a closed form through quantile functions $Q_i$:
For Gaussians $\mu_i = \mathcal{N}(m_i, \sigma_i^2)$ on the line the barycenter is also Gaussian: $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{\sigma}^2)$ where $\bar{m} = \sum \lambda_i m_i$ and $\bar{\sigma} = \sum \lambda_i \sigma_i$. The average is over $\sigma$, not $\sigma^2$. This distinguishes Wasserstein geometry from Euclidean geometry.
**Application to generative models.** The latent space of a VAE is a space of measures. Interpolating two images via a Wasserstein barycenter ($\lambda_1 = t$, $\lambda_2 = 1-t$) produces smooth morphing over $t \in [0,1]$. Linear interpolation of latent vectors is an approximation; the correct interpolation is the Frechet mean in $(\mathcal{P}_2, W_2)$. FID (Frechet Inception Distance) implicitly uses the same geometry through $W_2$ between Gaussians of activation statistics.
For 1D Gaussians $\mathcal{N}(0, 1^2)$ and $\mathcal{N}(0, 3^2)$ with equal weights: what is the standard deviation of the $W_2$-barycenter?
Agueh-Carlier Theorem: Multi-Marginal OT
Agueh & Carlier (2011) proved a fundamental equivalence: the Wasserstein barycenter problem is equivalent to a **multi-marginal** transport plan. Instead of $n$ pairwise OT problems $W_2(\mu, \mu_i)$, there exists a single plan $\gamma$ on the product $\mu_1 \times \mu_2 \times \cdots \times \mu_n$.
The barycenter is recovered from the plan: apply the weighted mean $\sum_i \lambda_i x_i$ as a push-forward of $\gamma$. Each atom $(x_1, \ldots, x_n)$ of the plan contributes the point $\sum_i \lambda_i x_i$ to the barycenter with weight $\gamma(x_1, \ldots, x_n)$.
**Existence and uniqueness.** If at least one measure $\mu_i$ is absolutely continuous (has a density), the barycenter exists and is unique. The optimal maps $T_i: \bar{\mu} \to \mu_i$ are gradients of convex functions (a consequence of Brenier's theorem). Without absolute continuity uniqueness may fail - the discrete case is substantially harder.
For Gaussians $\mu_i = \mathcal{N}(m_i, \Sigma_i)$ the barycenter is also Gaussian $\mathcal{N}(\bar{m}, S)$ where $\bar{m} = \sum \lambda_i m_i$ and $S$ is the fixed point of the Bures equation:
This is a nonlinear matrix equation solved by simple iterations. Its connection to the Bures metric: $W_2^2(\mathcal{N}(m_1,\Sigma_1), \mathcal{N}(m_2,\Sigma_2)) = \|m_1 - m_2\|^2 + B^2(\Sigma_1, \Sigma_2)$ where $B^2(\Sigma_1, \Sigma_2) = \text{tr}(\Sigma_1 + \Sigma_2 - 2(\Sigma_1^{1/2} \Sigma_2 \Sigma_1^{1/2})^{1/2})$.
The matrix Bures barycenter is used in neuroimaging: each fMRI subject is represented by a covariance matrix of brain region activations. The group barycenter in the Bures metric is the geometrically correct group average, preserving the structure of the SPD manifold.
The W2 barycenter of Gaussians is also Gaussian. Why can the covariance barycenter not be computed as $\bar{\Sigma} = \sum \lambda_i \Sigma_i$ (Euclidean mean of matrices)?
Sinkhorn Barycenters (Cuturi-Doucet 2014)
Exact multi-marginal OT has complexity $O(m^n \cdot d)$ where $m$ is the number of atoms per measure and $n$ is the number of measures. With $n = 10$ distributions of $m = 64$ pixels each, that is $64^{10} \approx 10^{18}$ operations - intractable. Regularization is needed.
Cuturi & Doucet (2014) proposed the entropy-regularized barycenter: replace $W_2$ with $W_\varepsilon$ (Sinkhorn distance) in the Frechet mean problem. This turns the problem into an iterative algorithm where each step is a standard Sinkhorn computation.
The Sinkhorn barycenter algorithm alternates two steps. Step 1: for each pair $(\bar{\mu}, \mu_i)$, run Sinkhorn and obtain dual variables $(u_i, v_i)$. Step 2: update $\bar{\mu}$ as the geometric mean of the products $v_i$ (in log space - a weighted mean of $\log v_i$). Complexity per outer iteration: $O(m^2 \cdot n)$ instead of $O(m^n)$.
**Parallelism.** For discrete measures on a shared grid, all $n$ Sinkhorn problems share a single kernel matrix $K = e^{-C/\varepsilon}$. This allows batching all $n$ transport problems into one matrix operation. On GPU, the Sinkhorn barycenter of $n = 100$ histograms on a $64 \times 64$ grid runs in seconds.
Color transfer through the Sinkhorn barycenter: take R, G, B histograms from three frames shot under different lighting conditions, compute the barycenter, and obtain a compromise color palette. In filmmaking this is used for color grading: footage across multiple shooting days with different light is normalized to a single tone through the barycenter of pixel distributions.
In the Sinkhorn barycenter algorithm each outer iteration runs $n$ Sinkhorn sub-algorithms. What is the total complexity of one outer iteration?
Applications: Shape Analysis, Federated Learning, Gradient Flows
**Shape interpolation.** Represent a 2D shape as a measure $\mu$ (uniform on the boundary). The geodesic in $(\mathcal{P}_2, W_2)$ between $\mu_0$ and $\mu_1$ gives a sequence of intermediate shapes $\mu_t = \bar{\mu}(t, 1-t)$. Mass flows along the shortest path rather than blending pixel by pixel. Applications: medical imaging (cardiac shape atlases), animation, morphometric statistics in biology.
**Federated learning through barycenters.** Standard FedAvg aggregates neural network weights as an Euclidean mean: $\bar{\theta} = \sum \lambda_i \theta_i$. When clients have heterogeneous data (different class distributions), the Euclidean mean blurs the model. Wasserstein aggregation treats the weight distribution of each layer $p_i(\theta)$ and computes its barycenter:
**Sliced Wasserstein Barycenter.** Project measures onto random 1D directions and compute the 1D barycenter analytically through quantile functions. Complexity: $O(mn \log m \cdot L)$ where $L$ is the number of projections (typically 100-500). At $L = 200$ the error is below 3% of the exact barycenter, with 50-100x speedup over Sinkhorn. Used in fast color transfer and shape averaging.
**Gradient flows in $W_2$ (Jordan-Kinderlehrer-Otto 1998).** The diffusion equation $\partial_t \rho = \Delta \rho$ is the gradient flow of the entropy functional $H[\rho] = \int \rho \log \rho$ under the $W_2$ metric. The Fokker-Planck equation is the gradient flow of free energy $F[\rho] = \int V \rho + \int \rho \log \rho$. The JKO scheme discretizes these equations as a sequence of proximal steps in $(\mathcal{P}_2, W_2)$ - each step closely related to a Sinkhorn barycenter of two measures. Used for simulating diffusive processes in molecular dynamics and for Langevin dynamics in neural network optimization.
A Wasserstein barycenter is just a weighted mean of distribution parameters
The barycenter minimizes the sum of $W_2^2$ distances in measure space. For Gaussians, $\sigma$ averages linearly but covariance averages nonlinearly through the Bures equation
Naively averaging covariances $\bar{\Sigma} = \sum \lambda_i \Sigma_i$ minimizes $\sum \lambda_i \|\Sigma - \Sigma_i\|_F^2$ - the Euclidean mean of matrices. The Bures equation minimizes $\sum \lambda_i W_2^2(\mathcal{N}(0,\Sigma), \mathcal{N}(0,\Sigma_i))$. These are different objectives giving different points on the SPD manifold.
FedAvg: $\bar{\theta} = \sum \lambda_i \theta_i$. In which scenario does the Wasserstein barycenter of weight distributions give a principled improvement?
Key ideas
- **Frechet mean in $W_2$**: $\bar{\mu} = \arg\min_{\mu} \sum_i \lambda_i W_2^2(\mu, \mu_i)$. For 1D Gaussians: $\bar{\sigma} = \sum \lambda_i \sigma_i$ - average over $\sigma$, not $\sigma^2$
- **Agueh-Carlier 2011**: barycenter is equivalent to multi-marginal OT. With absolute continuity - exists and is unique. Optimal maps $T_i$ are gradients of convex functions (Brenier)
- **Bures equation**: barycenter of Gaussians $\mathcal{N}(m_i, \Sigma_i)$ - fixed point of $S = \sum \lambda_i (S^{1/2} \Sigma_i S^{1/2})^{1/2} S^{-1/2}$, solved iteratively
- **Sinkhorn barycenter** (Cuturi-Doucet 2014): $O(m^2 n)$ instead of $O(m^n)$. Alternate Sinkhorn iterations and geometric mean of dual variables
- **Gradient flow in $W_2$**: JKO scheme - diffusion as a sequence of proximal steps in $(\mathcal{P}_2, W_2)$. Bridges barycenters with partial differential equations
Related topics
Wasserstein barycenters connect the geometry of measures with practical ML algorithms:
- Brenier theorem — Optimal maps $T_i$ to the barycenter are gradients of convex functions, ensuring uniqueness
- Domain adaptation via OT — Multi-source DA - barycenter of several source domains in feature space
- Sinkhorn algorithm — Sinkhorn barycenter is built as a loop of standard Sinkhorn iterations
Вопросы для размышления
- The Sinkhorn barycenter converges to the exact W2 barycenter as $\varepsilon \to 0$, and to the Euclidean mean as $\varepsilon \to \infty$. Why? What happens to the transport plan under large regularization?
- The JKO scheme discretizes the diffusion equation as proximal steps in W2 with step size $\tau$. How does $\tau$ affect approximation accuracy? At what $\tau$ does the scheme lose stability?
- Wasserstein federated learning requires knowing the weight distribution $p_i(\theta)$ of each client. How can $p_i$ be approximated given only the point estimate $\theta_i$?
Связанные уроки
- ot-03-wasserstein — W2 distance is the metric being minimized in the barycenter problem
- ot-04-sinkhorn — Sinkhorn is the core building block of the iterative barycenter algorithm
- ot-06-brenier — Brenier theorem: optimal maps T_i are gradients of convex functions, ensuring barycenter uniqueness
- ot-08-domain-adaptation — Multi-source domain adaptation is a special case of barycenter over several source domains
- calc-01-sequences