Optimization
Distributed Optimization
GPT-4 does not fit on a single GPU. An MRI dataset cannot leave the hospital. Billions of smartphones with user data want to improve a shared keyboard model. Each of these situations demands distributed optimization. And each has its own constraints: privacy, bandwidth, fault tolerance. Over the last decade three algorithms cover most of the industry: FedAvg for privacy-preserving training, ADMM for cluster computation with guarantees, and Ring All-Reduce for scaling deep learning across thousands of GPUs. This lesson covers how they work and when to choose which.
- **Google Gboard** - federated learning for a keyboard model across billions of phones; data never leaves the device
- **PyTorch DDP + NCCL** - industry standard for multi-GPU training; uses Ring All-Reduce under the hood
- **DeepSpeed ZeRO-3** - optimization for 100B+ parameter models; sharded all-reduce plus gradient compression
Federated Learning: FedAvg
In 2017 Google solved what had seemed impossible: training a shared Gboard keyboard model across billions of phones without collecting user messages into a data center. The solution is **federated learning**: the model is shipped to the device, trained locally on the user's data, and returned to the server, where it is averaged with updates from other devices. Data never leaves the phone. This is the FedAvg algorithm (McMahan et al., 2017), and it became the foundation of every modern privacy-preserving ML system.
Formalization: K clients (devices), each holding a dataset D_k of size n_k. The global objective minimizes F(w) = sum_k (n_k/n) * F_k(w), where F_k is the local loss and n = sum n_k. In each round the server selects a subset of clients S_t, broadcasts the current weights w_{t-1}, every client trains E local epochs, returns updated weights, and the server averages: w_t = sum (n_k/n) * w_k. E=1 yields FedSGD (gradient averaging), E>1 saves communication rounds at the cost of client drift.
In FedAvg, what does the server do after every communication round?
ADMM for Distributed Problems
FedAvg is a simple heuristic, and its convergence on non-IID data is not proven. ADMM (Alternating Direction Method of Multipliers) is a more principled approach with guaranteed convergence for convex problems. The idea: the problem decomposes into local sub-problems plus a global consensus constraint. Formally: minimize sum_k f_k(x_k) subject to x_k = z (all local variables equal a common z). ADMM alternates three steps: x-step (locally, in parallel), z-step (on the server), u-step (update dual multipliers).
FedAvg vs ADMM comparison: FedAvg is a heuristic, ADMM has convergence guarantees. FedAvg transmits only weights, ADMM also transmits dual variables. FedAvg tolerates client drop-outs (skip them), ADMM is sensitive to failures. FedAvg is simple, ADMM requires tuning rho (the augmented-Lagrangian parameter). In practice FedAvg dominates edge scenarios (smartphones), ADMM dominates HPC clusters with reliable networking.
Where does the z-step run in distributed ADMM?
Gradient Compression and Ring All-Reduce
Training GPT-3 (175B parameters): a single gradient step needs to transmit 175B * 4 bytes = 700GB between nodes. On InfiniBand 400Gb/s that is 14 seconds per step - unacceptable. Two classes of solutions: **gradient compression** (quantization to 8/4/1 bits, Top-K sparsification, error feedback) and **communication architecture** (Ring All-Reduce instead of Parameter Server). Ring All-Reduce is the industry standard: every node transmits exactly 2N * (K-1)/K bytes, which barely grows with the number of nodes K.
Parameter Server (PS) is the simpler architecture: every worker sends gradients to a central server, which averages and broadcasts the updated weights. Downside: PS becomes the bottleneck at large K. Ring All-Reduce: nodes form a ring, scatter-reduce (K-1 steps) sums partial gradients, all-gather (K-1 steps) broadcasts the result. Total: every node transmits ~2N bytes independently of K - perfect scaling. PyTorch DDP, NCCL, and Horovod all use Ring All-Reduce under the hood.
All-Reduce scales poorly with the number of nodes
Ring All-Reduce transmits exactly 2*(K-1)/K*N bytes per node, which is asymptotically constant in K. This is the optimal lower bound for synchronous consensus
The misconception arises because K-1 ring steps look long. But each step transmits only N/K bytes, totaling 2N. The real bottleneck is latency and the straggler problem, not bandwidth
Why is Ring All-Reduce preferred over Parameter Server for synchronous training across many nodes?
Key Ideas
- **FedAvg** - clients train locally for E epochs, the server averages weights with weight n_k/n; the basis of privacy-preserving ML
- **Client drift** under non-IID data is mitigated by FedProx (proximal term) or FedNova (step-count normalization)
- **ADMM** - consensus formulation with guaranteed convergence; the x-step is parallel, the z-step is centralized, the u-step updates duals
- **Ring All-Reduce** - optimal communication architecture: ~2N bytes per node independent of K; the standard in PyTorch DDP, Horovod, NCCL
- **Gradient compression** - quantization (8/4/1 bits) + Top-K sparsification + error feedback; essential at trillion-parameter scale
Related Topics
Back to motivation: distributed optimization is not a separate discipline but an evolution of classical methods under network and privacy constraints. Links to earlier lessons:
- Multi-Objective Optimization — ADMM arose as a decomposition method for multi-objective problems; distributed ADMM is its natural continuation
- Stochastic Gradient Descent — FedSGD (E=1) is just SGD with gradient aggregation; FedAvg generalizes the idea to E local epochs
- Subgradient Methods — Many practical distributed problems are non-differentiable (L1 regularization); ADMM handles them elegantly via soft-thresholding
Вопросы для размышления
- If FedAvg is so simple and effective, why does ADMM still exist? Which practical scenarios require the convergence guarantees that justify the extra complexity?
- Ring All-Reduce is bandwidth-optimal, but the straggler problem (a slow node holds up everyone) limits real throughput. Which strategies exist to handle stragglers?
- Back to motivation: imagine training a model on medical data from three hospitals without sharing raw data. Which three or four algorithmic decisions must be made, and in what order?
Связанные уроки
- opt-13 — Distributed optimization generalizes single-node methods
- par-06 — MPI as infrastructure for distributed gradient descent
- dl-12 — Distributed training is the key application of distributed optimization
- ml-09-gradient-descent — SGD is the foundation of AllReduce and parallel training
- ds-01-intro — CAP theorem tradeoffs also appear in distributed optimization
- calc-01-sequences