Design Large-Scale Distributed LLM Training Infrastructure
A Staff-level walkthrough of the distributed LLM training infrastructure that Anthropic, OpenAI, and Google DeepMind operate at frontier scale. It frames training as a distributed-systems reliability problem wearing an ML hat: a hybrid nD parallelism plan that localizes communication, memory-frugal sharding plus mixed precision to beat the memory wall, a data pipeline that never starves the GPUs, and fault tolerance — fast async checkpoint/restart plus hot spares — that keeps MFU high despite constant failures.
Scope & ambiguity
Let me frame this up front: this is a distributed-systems reliability problem wearing an ML hat. The goal is to pretrain a 100B–700B parameter model on a fixed fleet of tens of thousands of GPUs over weeks, maximizing MFU (model FLOPs utilization) and minimizing time-to-converge. This is not an inference-serving system — there’s no per-request latency target, no autoscaling, no traffic spikes. It’s one long-lived, synchronous, gang-scheduled job whose enemy is the stop-the-world stall. I’ll lead with the failure model and the network fabric, treat MFU as goodput, and flag the two things I’d verify with an ML expert: the exact precision recipe and the convergence impact of any parallelism change.
This is the class of system Anthropic, OpenAI, Google DeepMind, Meta, and NVIDIA build internally and probe for in AI-infra interviews. I’m using it as industry context, not as a leaked question — the public reference points (Llama 3.1 405B, DeepSeek-V3, the Megatron/NeMo and FSDP stacks) are all documented, and I’ll anchor my numbers to them.
The job, precisely: Train a dense 100B–700B model (or a sparse MoE of similar active size) on a fixed cluster — say 16,384 H100s — for several weeks against a fixed token budget. Keep every GPU fed, keep MFU in the 35–45% band, and keep the run alive through a hardware failure roughly every few hours.
The three enemies
1. The memory wall. A model plus its Adam optimizer state and activations does not fit in one GPU’s 80GB HBM. We must shard state across devices.
2. Communication overhead. Every training step is a synchronous barrier — gradients must be aggregated across all ranks. Naive collectives stall the GPUs and crater MFU.
3. Hardware failure. At this scale something breaks every few hours. Without fast checkpoint/restart and hot spares, effective training time collapses.
Who asks this & what they probe
The switcher’s credibility move: name MFU as goodput, lead with the failure model and the network, and explicitly say “the precision recipe and the convergence impact of reshaping parallelism are the two things I’d confirm with an ML specialist.” That signals you know the boundary of your lane.
Requirements
Functional requirements
- Launch and resume a distributed training run from a declarative job spec.
- Partition the model so its state fits in aggregate HBM (nD parallelism).
- Keep every GPU fed — a data pipeline that never starves 10k+ workers.
- Checkpoint periodically and recover deterministically after failure.
- Track metrics (loss, grad-norm, throughput, per-rank MFU) in real time.
- Support elastic resize — drop a failed gang, optionally re-plan parallelism on a smaller/larger fleet.
Non-functional targets
The convergence constraint (the callout the MLE wants)
Throughput is necessary but not sufficient — the run has to converge well. The knob that couples infra to convergence is global batch size:
global_batch = micro_batch × grad_accum_steps × DP_width
If you change the data-parallel width (e.g., after losing nodes), you change the effective global batch unless you compensate with gradient accumulation. Global batch size interacts with the learning-rate schedule and gradient clipping; getting it wrong wastes the entire multi-week run. So a hard requirement: any parallelism reshape must preserve effective global batch size, or be paired with an LR adjustment vetted by ML. The token-budget target (e.g., ~15T tokens) is fixed; we trade DP width and grad-accum to hit batch size at whatever fleet size survives.
Back-of-envelope estimation
Compute budget
Standard rule of thumb: training costs ~6 FLOPs per token per parameter (2 forward + 4 backward).
This matches the public record: Llama 3.1 405B trained on 16,384 H100s over roughly two months. The estimate is internally consistent, which is the point of doing it.
Memory wall
Activations are extra and scale with batch×seqlen×layers — which is why activation recomputation exists (trade compute to shrink the activation footprint). Conclusion: state must be sharded; the only question is along which axes.
Network
The ~18–20× gap between NVLink and a single IB rail is the single most important number for the parallelism plan: tensor parallelism, which all-reduces every layer, must stay inside a node on NVLink. Anything chattier than the link can carry becomes the bottleneck.
Checkpoint interval (Young/Daly)
With an MTBF of ~3h and an async checkpoint that costs only seconds of stall, Young/Daly lands the interval in the 15–30 minute range. That bounds worst-case lost work to a few percent of an interval — which is why we checkpoint that often, not hourly.
API design
These are control-plane and framework interfaces, not request/response APIs. There is no client calling in — there’s a job spec, a scheduler, and inter-rank collectives.
Job specification
tp × pp × dp × cp × ep must equal the total GPU count (here 8×16×128 = 16,384).
Scheduler request (gang + topology-aware)
Framework contracts
The real inter-rank API: collectives
The actual “API” between GPUs is the NCCL collective set. Everything in the step reduces to these:
Data model
There are two distinct state planes, and conflating them is a classic mistake.
Plane 1 — Training state (the thing you checkpoint)
A consistent checkpoint must capture all four: model + optimizer + RNG + dataloader position. Miss the dataloader position and you re-feed or skip tokens on resume, silently corrupting the token budget and convergence. Miss RNG and you lose bit-reproducibility. This is the “write-ahead snapshot” the switcher already understands — it just spans the whole nD-sharded layout.
Plane 2 — Data plane (the corpus pipeline)
Tokenized output is stored as sharded memory-mapped token files so the dataloader does sequential, prefetchable reads with near-zero CPU cost. Shards are assigned to DP ranks deterministically by (epoch, seed, dp_rank) so resume is exact.
Storage tiering
This is a write-back cache hierarchy: the training loop writes to L0 instantly, async tiers down to L2/L3. Recovery reads from the closest tier that has a good copy.
High-level architecture
Four planes. Control plane decides what runs where; compute plane runs the step; data plane feeds it; storage plane remembers it.
Per-step flow
Steps 1–5 are a synchronous barrier across all 16,384 ranks: the slowest rank sets the step time, which is why straggler screening (Step 7) matters so much.
Deep dives
WHERE STAFF IS WONThis is where Staff is won. I’ll go deep on four: (1) the nD parallelism plan and comm-overlap schedule, (2) mixed precision, (3) fault tolerance, and (4) MoE expert parallelism.
Deep dive 1 — The nD parallelism plan
The art is assigning each parallelism axis to the network tier whose bandwidth matches that axis’s communication intensity. Localize the chattiest traffic.
The placement rules that prove you understand the fabric:
- TP stays inside the node. It all-reduces activations after every layer — dozens of times per step. Cross-node TP over a single 400G rail (~20× slower than NVLink) would dominate step time. Keep TP ≤ node size (8 for an H100 box).
- PP crosses nodes. Pipeline only exchanges activations at stage boundaries — sparse, point-to-point traffic that tolerates IB latency. The cost of PP is the bubble (idle time while the pipeline fills/drains), mitigated by 1F1B scheduling or DeepSeek's DualPipe bidirectional schedule that overlaps the two directions.
- FSDP/ZeRO is the outer wrapper for optimizer-state sharding — it gathers params just-in-time per layer (all-gather), then reduce-scatters gradients. This is the "sharded state with on-demand gather" the switcher knows.
Communication overlap is what turns the plan into MFU. While the GPU computes layer N’s backward, NCCL is already all-gathering layer N−1’s params and reduce-scattering its gradients on a separate CUDA stream. Done right, nearly all DP/FSDP communication hides behind compute.
The trap to name out loud: “FSDP scales linearly to 16K GPUs” is false — the all-gather/reduce-scatter cost grows with world size and eventually the IB fabric saturates, so beyond a few thousand ranks you compose FSDP with TP+PP rather than scaling DP alone. The other trap is putting TP across nodes — instant MFU death. A 16,384-GPU plan looks like TP=8 (intra-node) × PP=16 × DP=128.
Deep dive 2 — Mixed precision recipe
This is the MLE’s lane, and I’d verify the exact thresholds with an ML specialist, but the structure is well established:
Why BF16 over FP16: FP16’s narrow exponent forces dynamic loss scaling to avoid gradient underflow — a fragile feedback loop that can blow up a run. BF16 has the same exponent range as FP32, so we skip loss scaling entirely and keep an FP32 master copy for accurate updates.
FP8 (the frontier move): DeepSeek-V3 demonstrated FP8 GEMMs in production for roughly 2× matmul throughput, using fine-grained (per-tile/per-block) scaling to keep the limited FP8 dynamic range from clipping, while keeping the sensitive ops (normalization, attention softmax, the master optimizer state) at higher precision. The judgment call — which ops can go FP8 without hurting convergence — is exactly what I’d confirm with ML, because it’s a quality risk, not a systems risk.
Deep dive 3 — Fault tolerance (the heart of the system)
At one failure every ~3 hours across 16K GPUs, the run is defined by how fast it heals.
Detection — multiple signals, fast:
Silent data corruption is the scariest because the run keeps going while quietly poisoning the weights — periodic checksum/recompute spot-checks and per-rank loss outlier detection are the defense.
Recovery — in-place gang restart:
Hot spares (a few hundred warm nodes) make this minutes, not hours — no waiting on the scheduler to find and boot fresh hardware. The restart is in-place: only the affected gang’s mapping changes; the parallelism plan is preserved so the effective batch size is unchanged.
Async tiered checkpointing is what keeps the overhead invisible:
The GPU only stalls for the in-memory snapshot (sub-second); everything durable happens off the critical path. The peer-replica (L1) tier means a single-node failure usually restores from a neighbor’s NVMe rather than the parallel FS — seconds instead of minutes. This is precisely a write-back cache hierarchy with the GPUs as the write source.
Deep dive 4 — MoE expert parallelism
Frontier models are increasingly sparse. DeepSeek-V3 is 671B total parameters but only ~37B active per token — you get a huge parameter count at a fraction of the FLOPs.
The systems cost of MoE is the routing all-to-all: every token is dispatched to its top-k experts (which live on different GPUs under expert parallelism), then results are combined back. That’s two all-to-all collectives per MoE layer — structured permutation traffic the switcher can map onto a distributed shuffle.
The trap here: treating MoE as “free” capacity. The routing is dynamic and data-dependent, so load can skew badly within a step — a single overloaded expert stalls the synchronous barrier for everyone. Load balancing is a convergence-and-systems coupling: too aggressive a balancing loss hurts quality, too loose and you get stragglers. Another thing to co-design with ML.
Multi-team rollout
You don’t “deploy” a multi-week run — you nurse it. The operational discipline is half the battle.
Pre-launch burn-in
- NCCL bandwidth tests across every rail and node pair — confirm each link hits ~400G; a single degraded cable silently halves a collective.
- Straggler screening — run a synthetic step and evict any GPU more than a few % slower; in a synchronous barrier, the slowest rank taxes all 16K.
- Numerical smoke test — a few hundred steps at small scale to confirm loss decreases and grad-norm is sane before committing the fleet.
Observability
The NCCL flight recorder is the single best debugging tool: when a collective hangs, it shows exactly which rank was waiting on which op, turning a cluster-wide freeze into a pinpointed node.
Divergence playbooks
- Loss spike: if grad-norm exceeds threshold, gradient clipping should absorb it; if loss diverges, roll back to a pre-spike checkpoint, optionally skip the offending data batch, and resume — sometimes with a briefly lowered LR. Loss spikes are often a single bad data shard or a numerical edge case, not a fundamental problem.
- Silent corruption suspected: checkpoint, run a deterministic recompute on a held-out batch across ranks, diff outputs to find the flaky GPU, evict it.
- Slow creep in step time: suspect a degrading link or thermal throttle; cross-reference per-rank MFU.
Reproducibility & the human loop
Deterministic resume (fixed seeds + captured RNG + exact dataloader position) means a resumed run is bit-identical to one that never failed — essential for debugging and for trusting the token budget. A human on-call rotation watches the dashboards around the clock; for a run worth millions of dollars, an engineer is always one page away from the loss curve.
Bottlenecks & evolution
Where it breaks next
At 100K+ GPUs the MTBF drops to tens of minutes, and without strong fault tolerance the effective training time falls below 50% — you’d spend more than half the cluster’s life recovering. The fixes: elastic/partial recovery (heal the affected gang without a global rollback), in-network reduction like NVIDIA SHARP (the switch fabric performs part of the all-reduce, halving collective traffic), and ever-tighter compute/comm overlap.
Evolution
1. Lower precision: FP8 today on Hopper; FP4 on Blackwell for another throughput step — gated by convergence validation.
2. Bigger sparsity: larger MoE with more experts, longer context (more context parallelism).
3. Multi-datacenter / async training: when one building can’t power 100K+ GPUs, train across DCs with relaxed synchronization.
4. Post-training reuse: the same fleet, checkpoint store, and fault tolerance run RLHF and other post-training — the infra amortizes across the model lifecycle.
The cost/MFU frontier
DeepSeek-V3 is the reference for algorithm–framework–hardware co-design: roughly 2.79M H800-hours (~$5.5M) for a frontier-class model, achieved via FP8 training, DualPipe overlap, and MoE — proof that MFU and cost are won by co-designing the model, the framework, and the hardware together, not by throwing more GPUs at a fixed recipe.
Summary
1. This is a distributed-systems reliability problem wearing an ML hat. Lead with the failure model and the network fabric, not the math of attention.
2. Hybrid nD parallelism that localizes communication is the central win. TP intra-node on NVLink (all-reduce every layer), PP/DP/FSDP inter-node on IB, with comm overlapped behind compute. Never put TP across nodes; never assume FSDP scales linearly to 16K GPUs.
3. Beat the memory wall with sharding plus mixed precision. ~16 B/param of Adam state must be sharded; BF16 compute + FP32 master avoids loss-scaling fragility; FP8 GEMMs buy ~2× throughput when applied carefully.
4. Fault tolerance keeps goodput high. Detection (NCCL timeouts, ECC, SDC checks) → hot-spare swap → in-place gang restart from an async, tiered, write-back checkpoint store, with deterministic idempotent resume. Checkpoint every 15–30 min per Young/Daly.
5. A data pipeline that never starves the GPUs — dedup, filter, tokenize into sharded mmap files, streamed deterministically per rank.
6. MFU is the metric that ties it all together — frame it as goodput: realized FLOPs over peak, eroded by stalls, stragglers, bubbles, and recovery.
7. The two things to verify with an ML expert: the exact precision recipe (which ops can drop to FP8/FP4) and the convergence implications of any parallelism reshape (because it changes effective global batch size and interacts with the LR schedule). Naming these boundaries is the credibility move, especially for an SDE moving into AI.
Rubric — Senior vs Staff
Want more breakdowns like this?
Join free early access for upcoming RAG, LLM eval, agents, and AI infrastructure walkthroughs.