Optimal Transport

Domain adaptation via OT

A chest X-ray classifier trained on 10,000 scans from Mayo Clinic is deployed at a hospital in India. Accuracy drops from 91% to 67%. Different machines, different brightness calibrations. The same disease looks different. OT knows how to drag one point cloud onto another - not scan by scan, but distribution to distribution.

  • **Medical imaging cross-hospital transfer**: DeepJDOT on CT scans from Siemens vs GE scanners - +14% tumor segmentation accuracy without a single labeled target image
  • **Sim-to-real robotics**: OpenAI Rubik's (2019) and Boston Dynamics Spot - OT aligns feature space of simulator and real world. PhysX policy directly onto the physical robot
  • **NLP cross-domain**: BERT trained on Bloomberg (finance) adapts to clinical notes via OT alignment of sentence embeddings. F1 +18% over baseline
  • **Satellite imagery**: Sentinel-2 imagery Europe to Africa. Land-use classification without target labels. JDOT finds the correspondence 'wheat field' - 'sorghum field' through embedding geometry

Предварительные знания

  • Wasserstein distance and its interpretation as optimal transport cost
  • Sinkhorn algorithm: regularized OT and its matrix form
  • Wasserstein distances
  • Sinkhorn algorithm

Covariate shift: formalizing the problem

A chest X-ray classifier trained on 10,000 scans from Mayo Clinic is deployed at a hospital in India. Accuracy drops from 91% to 67%. Different machines, different brightness calibrations. The same disease looks different. The classifier is not broken - it simply never saw that point cloud.

Formally: the source domain has distribution $\mu_S$ over features $X$, the target domain has $\mu_T$. **Covariate shift** occurs when $\mu_S(X) \neq \mu_T(X)$ but the conditional $P(Y \mid X)$ is identical in both domains. The label depends on the object, not on the hospital. The task: train a classifier on source that works on target, where no labels exist.

**Scale of the problem in NLP**. BERT trained on Wikipedia loses F1 from 91% to 73% when transferred to clinical notes (Romanov & Shivade 2018). GPT-3 trained on Reddit shows a different token distribution than the same GPT-3 on legal documents. CLIP sidesteps covariate shift through text-image alignment - implicit domain adaptation via an OT-inspired shared space.

Classic approaches: importance weighting (reweight source samples that look like target), fine-tuning (minimal labeled target), adversarial DA (DANN, Ganin 2016 - learn domain-agnostic features). All of them either require target labels or provide no bounds on the distance between domains. OT gives both.

Covariate shift means $\mu_S(X) \neq \mu_T(X)$ but $P(Y \mid X)$ is identical. What does this imply for transferring a classifier?

OT map as a bridge between domains

OT offers a clean answer: find a map $T: \mathcal{X}_S \to \mathcal{X}_T$ that pushes $\mu_S$ to $\mu_T$ at minimum cost. Formally, $T_\# \mu_S = \mu_T$ under $\min_T \int c(x, T(x)) \, d\mu_S(x)$. The cost $c = \|x - y\|^2$ is quadratic. The map $T$ is the optimal-transport map.

The idea of Courty et al. (2017, NeurIPS): apply $T$ to source points to get an adapted cloud, then train a classifier on adapted data with the original labels. Three steps: (1) estimate $T$ via Sinkhorn, (2) compute $T(x_i^S)$ for each source point, (3) train $f$ on $\{T(x_i^S), y_i^S\}$.

**Why Wasserstein, not MMD or KL**. Maximum Mean Discrepancy (MMD) gives no constructive map $T$ - only a distance number. KL divergence requires density estimation, which is unreliable in high dimensions. Wasserstein via the OT plan $\gamma$ directly gives pairwise correspondences: $\gamma(i,j)$ tells how much mass of point $x_i^S$ flows to $x_j^T$. That is a map, not just a metric.

Sim-to-real in robotics is a direct application. OpenAI Rubik's (2019): policy trained in simulation (PhysX) - perfect friction, perfect fingers. Real robot - different domain. Domain randomization plus OT alignment in feature space transfers the policy without additional real-world training. Boston Dynamics Spot follows the same scheme for terrain adaptation.

OT map $T$ is computed via Sinkhorn. How is it applied for domain adaptation without target labels?

JDOT: joint optimization of transport and classifier

Basic OT-DA is two-stage: first $T$, then $f$. The problem - errors accumulate: a poor $T$ produces a poor training set for $f$. Courty et al. (2017, NeurIPS) proposed **Joint Distribution Optimal Transport (JDOT)**: optimize $T$ and $f$ jointly, folding the prediction loss directly into the transport cost.

The transport cost is now two-part: geometric distance $\|x_i^S - x_j^T\|^2$ plus prediction loss $\mathcal{L}(f(x_j^T), y_i^S)$. When classifier $f$ correctly predicts source label $y_i^S$ on target point $x_j^T$, the pair $(i,j)$ gets low cost and high weight in plan $\gamma$. Transport and classifier teach each other.

**Alternating optimization**. JDOT is solved by alternating: 1. fix $f$, update $\gamma$ via Sinkhorn with the new joint cost 2. fix $\gamma$, update $f$ by minimizing weighted prediction loss $\sum_{i,j} \gamma_{ij} \mathcal{L}(f(x_j^T), y_i^S)$. Converges to a local minimum. The scheme mirrors EM: E-step (plan $\gamma$) and M-step (classifier $f$).

NLP cross-domain transfer is a direct target. A model trained on Bloomberg financial news adapts to clinical notes. Features $X$ are sentence embeddings. JDOT cost: cosine distance plus NER prediction loss. JDOT finds the correspondence 'stock yield' - 'glucose level' through geometry of the embedding space, not through vocabulary.

In JDOT the transport cost includes $\lambda \cdot \mathcal{L}(f(x_j^T), y_i^S)$. What happens when $\lambda \to 0$?

DeepJDOT and transport in feature space

JDOT operates in input space. Take satellite imagery adaptation: the source is Sentinel-2 imagery from Europe, the target is Sentinel-2 imagery from Africa. The pixel-space distance is enormous - different terrain, different vegetation - and OT in pixel space loses semantic correspondence. The fix is to transport in the neural network's feature space rather than pixel space.

**DeepJDOT** (Damodaran et al. 2018, ECCV): map $T$ is parametrized by a neural network $g_\theta$. Instead of OT in $\mathcal{X}$, optimal transport is computed in feature space $g_\theta(\mathcal{X})$. JDOT losses now include both embedding alignment and prediction loss - all differentiable through $\theta$.

Three objects are optimized jointly: feature extractor $g_\theta$, plan $\gamma$ (via Sinkhorn), classifier $f$. Gradients flow through the whole computation graph. $g_\theta$ learns to build a feature space where both domains are easy to align - domain-invariant representations.

**Sliced Wasserstein for efficiency**. Sinkhorn on an $n \times n$ cost matrix costs $O(n^2)$ memory. For large batches this is a bottleneck. Fix: Sliced Wasserstein Distance (SWD) - project distributions onto random 1D lines and compute 1D OT analytically. SWD approximates $W_2$ and runs in $O(n \log n)$ via sorting. In DeepJDOT on medical datasets, SWD achieves the same accuracy at 8x lower memory.

LoRA is an implicit OT-inspired mechanism. When adapting a large language model to a new domain, LoRA updates $\Delta W = BA$ (low-rank update). Geometrically this is transport in weight space: from source output distribution to target output distribution. Courty 2017 for data, LoRA 2022 for weights - one idea at two scales.

Medical imaging cross-hospital transfer is the flagship DeepJDOT use case. CT scans from hospital A (Siemens scanner) and hospital B (GE scanner): different reconstruction kernels, different noise levels. DeepJDOT on ResNet-50: encoder aligns feature distributions via Sliced Wasserstein, classifier transfers directly. Tumor segmentation accuracy: +14% over the no-adaptation baseline.

OT-based domain adaptation is an academic toy; production uses simple fine-tuning on target

Fine-tuning requires target labels. OT-DA works with zero target labels and provides theoretical bounds via Wasserstein distance

In medical imaging, annotating target data costs USD 300-500 per scan (radiologist time). DeepJDOT delivers +14% over no-adaptation at zero target-label cost. In sim-to-real robotics, collecting real-world data is physically dangerous - the agent breaks during adaptation. OT-DA is not an option but a necessity in these scenarios.

DeepJDOT computes OT in feature space $g_\theta(\mathcal{X})$, not in input space. What is the main advantage?

Key ideas

  • **Covariate shift**: $\mu_S(X) \neq \mu_T(X)$, but $P(Y \mid X)$ is identical. OT-DA aligns feature clouds without target labels
  • **OT map $T$**: $T_\# \mu_S = \mu_T$ at minimum quadratic cost. Computed via Sinkhorn. $T(x_i^S)$ is the adapted source point in target space
  • **JDOT** (Courty 2017): joint optimization of transport and classifier. Cost = $\alpha \|x^S - x^T\|^2 + \lambda \mathcal{L}(f(x^T), y^S)$. EM-style alternating scheme
  • **DeepJDOT** (Damodaran 2018): OT in feature space of neural network $g_\theta$. Sliced Wasserstein for $O(n \log n)$ efficiency. Three objects ($g_\theta$, $\gamma$, $f$) optimized jointly
  • **Applications**: medicine (cross-hospital), robotics (sim-to-real), NLP (cross-domain NER), satellite imagery - wherever target labels are absent or expensive

Related topics

Domain adaptation via OT connects transport geometry with ML practice:

  • Wasserstein GAN — DeepJDOT optimizes Wasserstein in feature space; WGAN does so in output space via duality
  • Wasserstein barycenters — Multi-source DA is a barycenter of several source domains
  • Brenier's theorem — OT map $T$ between domains is the Brenier construction at $c = \|x-y\|^2$

Вопросы для размышления

  • JDOT requires running Sinkhorn at each iteration. With a batch of 1000 points that is a 1000x1000 matrix. How does Sliced Wasserstein solve the scalability problem and what accuracy is lost in the process?
  • Covariate shift assumes $P(Y \mid X)$ is identical in both domains. What happens to JDOT when this assumption is violated - for example, different hospitals have different treatment protocols and different outcomes?
  • DeepJDOT jointly optimizes $g_\theta$, $\gamma$, and $f$. Draw the computation graph: through which variables does the gradient flow when updating $\theta$, and why is $\gamma$ updated separately via Sinkhorn?

Связанные уроки

  • ot-03-wasserstein — Wasserstein distance measures domain incompatibility in JDOT
  • ot-04-sinkhorn — Sinkhorn computes the OT plan between source and target embeddings in $O(n^2)$
  • ot-07-wgan — WGAN and DeepJDOT both train a network to minimize Wasserstein via duality
  • ot-09-barycenters — Wasserstein barycenter of multiple domains generalizes DA to multi-source settings
  • ot-06-brenier — The Monge map $T$ between domains is the Brenier construction at $c = \|x-y\|^2$
  • ml-01
Domain adaptation via OT

0

1

Sign In