Optimal Transport
Gromov-Wasserstein and Applications
Gromov-Wasserstein, Sliced Wasserstein, and OT domain adaptation are three different answers to one question: how to scale OT to real-world problems. Novartis uses GW for drug discovery; OT domain adaptation is a standard ML pipeline component.
- **Drug discovery (Novartis, 2023):** GW for aligning molecular graphs without known atom correspondence. Accelerates search for drug analogs.
- **Point cloud generation (PointFlow):** SW₂ as a differentiable loss for training 3D shape generative models. Used in autonomous driving.
- **Domain adaptation:** POT SinkhornTransport , standard baseline for unsupervised DA on Office-31, VisDA benchmarks.
Gromov-Wasserstein Distance: Comparing Metric Measure Spaces
**Gromov-Wasserstein distance compares molecular graphs in drug discovery (Novartis, 2023) , aligning 3D protein structures without known atom correspondence.** Classical OT requires a shared metric space. GW compares the internal geometries of two spaces , without embedding them into a common one.
Fused Gromov-Wasserstein (Vayer 2019) combines GW with classical OT for spaces with node features: (1−α)·OT_cost + α·GW_cost. Used for aligning attributed graphs.
What is the key advantage of Gromov-Wasserstein over classical Wasserstein distance?
Classical W requires μ and ν to live in the same metric space. GW compares the internal geometries of two arbitrary spaces , graphs, molecules, point clouds , without embedding them into a common ambient space.
Sliced Wasserstein: Scalable OT via 1D Projections
Exact W₂ in d dimensions requires O(n³) , prohibitive for generative models with millions of points. **Sliced Wasserstein (SW)** projects both measures onto random 1D lines, where W₂ costs O(n log n) via sorting, then averages the results.
SW₂ is used in point cloud generation (PointFlow, ShapeGF), domain adaptation, and as a generative model loss instead of GAN objectives. Differentiable via torch.sort().
Why is Sliced Wasserstein computationally more efficient than exact W₂?
In 1D, W₂ is computed in O(n log n) via sorting (the quantile formula). SW averages L such computations: total complexity O(L·n log n) instead of O(n³).
Applications: Domain Adaptation and Shape Matching
OT solves a fundamental ML problem: **transferring knowledge from one data distribution to another**. Domain adaptation via OT literally finds the optimal transport plan between source and target distributions.
POT (Python Optimal Transport) implements OT-based DA: ot.da.SinkhornTransport, ot.da.EMDTransport, label propagation. Standard baseline on Office-31 and VisDA benchmarks.
What is barycentric projection in OT domain adaptation?
Barycentric projection: T̂(x_s) = Σ_j γ*_{sj}/μ_s · x_{t,j}. Each source point is mapped to the weighted average of target points according to its row in γ*.
Key Ideas
- **Gromov-Wasserstein:** GW²(μ,ν) = min_γ ∫∫∫∫(d_X(x,x')−d_Y(y,y'))² dγdγ. Compares spaces without a shared embedding.
- **Sliced Wasserstein:** average W₂ over random 1D projections. O(L·n log n) vs O(n³). Differentiable via sorting.
- **OT domain adaptation:** find γ* between source and target, apply barycentric projection T̂(x_s) = Σ_j γ*_{sj} x_{t,j}.
- **Fused GW:** (1−α)·W + α·GW for spaces with features. Aligns attributed graphs.