← Back to all questions
AI System DesignStaffDistributed TrainingGPU Clusters

Design Large-Scale Distributed LLM Training Infrastructure

A Staff-level walkthrough of the distributed LLM training infrastructure that Anthropic, OpenAI, and Google DeepMind operate at frontier scale. It frames training as a distributed-systems reliability problem wearing an ML hat: a hybrid nD parallelism plan that localizes communication, memory-frugal sharding plus mixed precision to beat the memory wall, a data pipeline that never starves the GPUs, and fault tolerance — fast async checkpoint/restart plus hot spares — that keeps MFU high despite constant failures.

Level
Staff
Category
AI Infrastructure · Distributed Training
Interview time
60 min
100% free · No login required
WHAT THIS QUESTION TESTS
·Hybrid nD parallelism: TP intra-node, PP inter-node, FSDP/ZeRO data-parallel, context parallelism
·Mixed precision (BF16 + FP32 master, FP8 GEMMs) to beat the GPU memory wall
·Async tiered checkpointing with an MTBF-driven (Young/Daly) interval
·Fault tolerance: hot spares and fast in-place gang restart, not full-job kill
★ STAFF-LEVEL SIGNALS
Sizes optimizer state at ~16–18 bytes/param and proves a 70B model's ~1.1TB state must be sharded
Keeps tensor parallelism inside an NVLink node and pipeline parallelism across the InfiniBand fabric
Quotes realistic ~35–45% MFU, not peak 989 TFLOP/s, and treats MFU as the goodput metric
Treats the data pipeline (dedup, tokenization, sharded loader) as first-class so GPUs never starve
0

Scope & ambiguity

Let me frame this up front: this is a distributed-systems reliability problem wearing an ML hat. The goal is to pretrain a 100B–700B parameter model on a fixed fleet of tens of thousands of GPUs over weeks, maximizing MFU (model FLOPs utilization) and minimizing time-to-converge. This is not an inference-serving system — there’s no per-request latency target, no autoscaling, no traffic spikes. It’s one long-lived, synchronous, gang-scheduled job whose enemy is the stop-the-world stall. I’ll lead with the failure model and the network fabric, treat MFU as goodput, and flag the two things I’d verify with an ML expert: the exact precision recipe and the convergence impact of any parallelism change.

This is the class of system Anthropic, OpenAI, Google DeepMind, Meta, and NVIDIA build internally and probe for in AI-infra interviews. I’m using it as industry context, not as a leaked question — the public reference points (Llama 3.1 405B, DeepSeek-V3, the Megatron/NeMo and FSDP stacks) are all documented, and I’ll anchor my numbers to them.

The job, precisely: Train a dense 100B–700B model (or a sparse MoE of similar active size) on a fixed cluster — say 16,384 H100s — for several weeks against a fixed token budget. Keep every GPU fed, keep MFU in the 35–45% band, and keep the run alive through a hardware failure roughly every few hours.

The three enemies

1. The memory wall. A model plus its Adam optimizer state and activations does not fit in one GPU’s 80GB HBM. We must shard state across devices.

2. Communication overhead. Every training step is a synchronous barrier — gradients must be aggregated across all ranks. Naive collectives stall the GPUs and crater MFU.

3. Hardware failure. At this scale something breaks every few hours. Without fast checkpoint/restart and hot spares, effective training time collapses.

Who asks this & what they probe

Role
Focus
What they probe
SDE
Reliability & systems
Gang scheduling, topology-aware placement (bin-packing), MTBF math, checkpoint/restart, hot spares, the network fabric (fat-tree / rail-optimized), collectives as structured all-to-all traffic, tiered checkpoint store as a write-back cache
MLE
Model fit & convergence
The TP/PP/DP/CP/EP arithmetic and its memory/comm tradeoffs, mixed-precision recipe (BF16 vs FP8, where FP32 stays), activation recomputation, optimizer-state sharding, global batch size, LR schedule, loss-spike recovery
Switcher (SDE to AI)
Mapping vocab to known patterns
Step = synchronous barrier; all-reduce = distributed aggregation; ZeRO/FSDP = sharded state with on-demand gather; checkpoint = write-ahead snapshot. The trap is importing request/response serving intuition

The switcher’s credibility move: name MFU as goodput, lead with the failure model and the network, and explicitly say “the precision recipe and the convergence impact of reshaping parallelism are the two things I’d confirm with an ML specialist.” That signals you know the boundary of your lane.

1

Requirements

Functional requirements

  • Launch and resume a distributed training run from a declarative job spec.
  • Partition the model so its state fits in aggregate HBM (nD parallelism).
  • Keep every GPU fed — a data pipeline that never starves 10k+ workers.
  • Checkpoint periodically and recover deterministically after failure.
  • Track metrics (loss, grad-norm, throughput, per-rank MFU) in real time.
  • Support elastic resize — drop a failed gang, optionally re-plan parallelism on a smaller/larger fleet.

Non-functional targets

Dimension
Target
Why
MFU (goodput)
35–45%
Realized FLOPs vs peak; the headline efficiency metric
Failure tolerance
Survive 1 failure / few hours
Lose at most a few % of work per interruption
Resume
Deterministic, idempotent
Bit-reproducible continuation, no double-counted tokens
Checkpoint overhead
Single-digit % of wall-clock
Async write off the critical path
Multi-tenancy
Fair-share scheduling
Cluster shared across runs/teams
Dataloader SLA
Never starve the GPUs
Token throughput must exceed compute throughput

The convergence constraint (the callout the MLE wants)

Throughput is necessary but not sufficient — the run has to converge well. The knob that couples infra to convergence is global batch size:

global_batch = micro_batch × grad_accum_steps × DP_width

If you change the data-parallel width (e.g., after losing nodes), you change the effective global batch unless you compensate with gradient accumulation. Global batch size interacts with the learning-rate schedule and gradient clipping; getting it wrong wastes the entire multi-week run. So a hard requirement: any parallelism reshape must preserve effective global batch size, or be paired with an LR adjustment vetted by ML. The token-budget target (e.g., ~15T tokens) is fixed; we trade DP width and grad-accum to hit batch size at whatever fleet size survives.

2

Back-of-envelope estimation

Compute budget

Standard rule of thumb: training costs ~6 FLOPs per token per parameter (2 forward + 4 backward).

Llama 3.1 405B reference:
tokens = 15.6e12
params = 405e9
FLOPs = 6 × 15.6e12 × 405e9 ≈ 3.8e25 FLOPs
 
H100 BF16 dense peak ≈ 1e15 FLOP/s (~989 TFLOP/s)
At 40% MFU → 4e14 effective FLOP/s per GPU
 
GPU-seconds = 3.8e25 / 4e14 ≈ 9.5e10 s
GPU-hours ≈ 2.6e7 → ~26–30M GPU-hours
On 16,384 H100s → ~1600 GPU-hrs/GPU ≈ 70 days wall-clock

This matches the public record: Llama 3.1 405B trained on 16,384 H100s over roughly two months. The estimate is internally consistent, which is the point of doing it.

Memory wall

Adam mixed-precision state per parameter:
FP16/BF16 weights ........ 2 B
FP16/BF16 gradients ...... 2 B
FP32 master weights ...... 4 B
FP32 Adam m + v .......... 8 B
----------------------------------
≈ 16 B/param (18 with extras)
 
70B model → ~1.1–1.3 TB of state.
One H100 has 80 GB HBM.
→ State alone needs ~16 GPUs' worth of memory before activations.

Activations are extra and scale with batch×seqlen×layers — which is why activation recomputation exists (trade compute to shrink the activation footprint). Conclusion: state must be sharded; the only question is along which axes.

Network

Link
Bandwidth
Role
NVLink (intra-node)
~900 GB/s per GPU
Latency-sensitive collectives
8× 400G InfiniBand rails
3.2 Tbps per node
Inter-node, rail-optimized
NVLink vs IB ratio
roughly 18–20×
Drives placement

The ~18–20× gap between NVLink and a single IB rail is the single most important number for the parallelism plan: tensor parallelism, which all-reduces every layer, must stay inside a node on NVLink. Anything chattier than the link can carry becomes the bottleneck.

Checkpoint interval (Young/Daly)

Optimal interval ≈ sqrt(2 × C × MTBF)
C = checkpoint write cost
MTBF = mean time between failures
 
Llama 3.1 405B: 466 interruptions over 54 days
→ MTBF ≈ 54×24 / 466 ≈ one failure every ~2.8 h

With an MTBF of ~3h and an async checkpoint that costs only seconds of stall, Young/Daly lands the interval in the 15–30 minute range. That bounds worst-case lost work to a few percent of an interval — which is why we checkpoint that often, not hourly.

3

API design

These are control-plane and framework interfaces, not request/response APIs. There is no client calling in — there’s a job spec, a scheduler, and inter-rank collectives.

Job specification

job:
model:
arch: transformer-dense # or moe
params: 405e9
layers: 126
hidden: 16384
seqlen: 8192
parallelism:
tp: 8 # tensor (intra-node, NVLink)
pp: 16 # pipeline (inter-node)
dp: 128 # data (FSDP / ZeRO outer)
cp: 1 # context (long-seq sharding)
ep: 1 # expert (MoE all-to-all)
precision:
compute: bf16
master: fp32
gemm: fp8 # optional, fine-grained scaling
batch:
micro_batch: 1
grad_accum: 8
global_batch: 1024
checkpoint:
interval_steps: 250
tiers: [local_nvme, peer, lustre, object_store]

tp × pp × dp × cp × ep must equal the total GPU count (here 8×16×128 = 16,384).

Scheduler request (gang + topology-aware)

schedule(
gang_size = 16384,
placement = TOPOLOGY_AWARE, # rack + rail affinity
tp_group = SAME_NODE, # NVLink domain
pp_group = SAME_RACK, # low-hop IB
spares = 256, # hot standby
) -> { ranks: [...], rail_map: {...} }

Framework contracts

# Reconfiguration / resume
resume(checkpoint_id, new_world_size) -> TrainingState
reshape(state, new_parallelism_plan) -> TrainingState # ML-vetted
 
# Dataloader shard assignment (deterministic by rank)
assign_shards(rank, dp_rank, epoch, seed) -> [shard_ids]
 
# Checkpoint I/O
save_checkpoint(state, tier=LOCAL_NVME, async=True) -> handle
load_checkpoint(checkpoint_id, rank_map) -> state

The real inter-rank API: collectives

The actual “API” between GPUs is the NCCL collective set. Everything in the step reduces to these:

Primitive
Used by
Traffic shape
all-reduce
TP (per layer), DP grads
Symmetric, every rank
all-gather
FSDP param gather
Sharded → full
reduce-scatter
FSDP grad shard
Full → sharded
all-to-all
MoE expert routing
Structured permutation
send/recv
PP stage boundaries
Point-to-point
4

Data model

There are two distinct state planes, and conflating them is a classic mistake.

Plane 1 — Training state (the thing you checkpoint)

Per-rank shard of:
params (BF16) — sharded by TP×PP×FSDP
gradients (BF16)
optimizer (FP32 m,v + master weights)
--------------------------------------------
Global scalars:
step_count
RNG state (per-rank; for dropout/init determinism)
dataloader_position (which shards/offsets consumed)
LR scheduler state

A consistent checkpoint must capture all four: model + optimizer + RNG + dataloader position. Miss the dataloader position and you re-feed or skip tokens on resume, silently corrupting the token budget and convergence. Miss RNG and you lose bit-reproducibility. This is the “write-ahead snapshot” the switcher already understands — it just spans the whole nD-sharded layout.

Plane 2 — Data plane (the corpus pipeline)

raw web/code/books
→ dedup (MinHash / LSH near-dup removal)
→ quality filter (classifier + heuristics)
→ tokenize (BPE/SentencePiece)
→ shard (sharded mmap token files, fixed-size)
→ stream (per-rank streaming dataloader, prefetch)

Tokenized output is stored as sharded memory-mapped token files so the dataloader does sequential, prefetchable reads with near-zero CPU cost. Shards are assigned to DP ranks deterministically by (epoch, seed, dp_rank) so resume is exact.

Storage tiering

Tier
Medium
Holds
Latency
L0
Local NVMe
Latest local shard write
seconds
L1
Peer GPU RAM/NVMe
Redundant copy for fast restore
seconds
L2
Parallel FS (Lustre/GPFS)
Durable recent checkpoints
tens of s
L3
Object store (S3/GCS)
Cold history, audit
minutes

This is a write-back cache hierarchy: the training loop writes to L0 instantly, async tiers down to L2/L3. Recovery reads from the closest tier that has a good copy.

5

High-level architecture

Four planes. Control plane decides what runs where; compute plane runs the step; data plane feeds it; storage plane remembers it.

┌──────────────────────── CONTROL PLANE ────────────────────────┐
│ Scheduler (gang, topology-aware) │
│ Job Controller (lifecycle, resume, reshape) │
│ Fault-Tolerance Controller (detect → evict → restart) │
└──────────────┬───────────────────────────────────────────────┘
│ placement, rank map, spares
┌──────────────▼──────────── COMPUTE PLANE ─────────────────────┐
│ GPU nodes running the training framework (nD parallelism) │
│ TP groups ── NVLink ── within node │
│ PP / DP ── 8×400G IB rails ── across nodes │
│ + hot-spare nodes idling, warm │
└───────┬──────────────────────────────────────┬───────────────┘
│ token shards │ checkpoints
┌───────▼──────── DATA PLANE ─────────┐ ┌───────▼─ STORAGE PLANE ─┐
│ dedup→filter→tokenize→shard │ │ L0 NVMe → L1 peer │
│ parallel FS + streaming dataloader │ │ → L2 Lustre → L3 object │
└─────────────────────────────────────┘ └─────────────────────────┘

Per-step flow

1. dataloader → next micro-batch (prefetched, no stall)
2. forward → activations (recompute checkpointed layers)
3. backward → local gradients
4. collective → TP all-reduce (per layer, NVLink)
→ DP/FSDP reduce-scatter + all-gather (IB)
→ PP send/recv at stage boundaries
5. optimizer → FP32 master update, LR step
6. checkpoint → every N steps, async to L0 then tier down

Steps 1–5 are a synchronous barrier across all 16,384 ranks: the slowest rank sets the step time, which is why straggler screening (Step 7) matters so much.

6

Deep dives

WHERE STAFF IS WON

This is where Staff is won. I’ll go deep on four: (1) the nD parallelism plan and comm-overlap schedule, (2) mixed precision, (3) fault tolerance, and (4) MoE expert parallelism.

Deep dive 1 — The nD parallelism plan

The art is assigning each parallelism axis to the network tier whose bandwidth matches that axis’s communication intensity. Localize the chattiest traffic.

Axis
What it shards
Comm pattern
Placement
Why
TP (tensor)
Within-layer matmuls
all-reduce every layer
Intra-node, NVLink
Latency-critical; needs ~900 GB/s
PP (pipeline)
Layers into stages
send/recv at boundaries
Inter-node, IB
Tolerates slower links; bubble is the cost
DP / FSDP
Batch + optimizer state
reduce-scatter/all-gather/step
Inter-node, IB
Once-per-step; overlappable
CP (context)
Sequence length
all-gather of K/V
Intra-rack
For long context
EP (expert)
MoE experts
all-to-all routing
Inter-node, IB
Sparse, load-balance-sensitive

The placement rules that prove you understand the fabric:

  • TP stays inside the node. It all-reduces activations after every layer — dozens of times per step. Cross-node TP over a single 400G rail (~20× slower than NVLink) would dominate step time. Keep TP ≤ node size (8 for an H100 box).
  • PP crosses nodes. Pipeline only exchanges activations at stage boundaries — sparse, point-to-point traffic that tolerates IB latency. The cost of PP is the bubble (idle time while the pipeline fills/drains), mitigated by 1F1B scheduling or DeepSeek's DualPipe bidirectional schedule that overlaps the two directions.
  • FSDP/ZeRO is the outer wrapper for optimizer-state sharding — it gathers params just-in-time per layer (all-gather), then reduce-scatters gradients. This is the "sharded state with on-demand gather" the switcher knows.

Communication overlap is what turns the plan into MFU. While the GPU computes layer N’s backward, NCCL is already all-gathering layer N−1’s params and reduce-scattering its gradients on a separate CUDA stream. Done right, nearly all DP/FSDP communication hides behind compute.

The trap to name out loud: “FSDP scales linearly to 16K GPUs” is false — the all-gather/reduce-scatter cost grows with world size and eventually the IB fabric saturates, so beyond a few thousand ranks you compose FSDP with TP+PP rather than scaling DP alone. The other trap is putting TP across nodes — instant MFU death. A 16,384-GPU plan looks like TP=8 (intra-node) × PP=16 × DP=128.

Deep dive 2 — Mixed precision recipe

This is the MLE’s lane, and I’d verify the exact thresholds with an ML specialist, but the structure is well established:

Quantity
Precision
Rationale
Forward/backward compute
BF16
Wide exponent → no loss-scaling fragility
Master weights
FP32
Accumulation accuracy; the source of truth
Optimizer moments (m,v)
FP32
Small-magnitude updates need range
GEMM inputs (optional)
FP8 (E4M3/E5M2)
~2× throughput on Hopper tensor cores
LayerNorm, softmax, residuals
BF16/FP32
Numerically sensitive; keep higher

Why BF16 over FP16: FP16’s narrow exponent forces dynamic loss scaling to avoid gradient underflow — a fragile feedback loop that can blow up a run. BF16 has the same exponent range as FP32, so we skip loss scaling entirely and keep an FP32 master copy for accurate updates.

FP8 (the frontier move): DeepSeek-V3 demonstrated FP8 GEMMs in production for roughly 2× matmul throughput, using fine-grained (per-tile/per-block) scaling to keep the limited FP8 dynamic range from clipping, while keeping the sensitive ops (normalization, attention softmax, the master optimizer state) at higher precision. The judgment call — which ops can go FP8 without hurting convergence — is exactly what I’d confirm with ML, because it’s a quality risk, not a systems risk.

Deep dive 3 — Fault tolerance (the heart of the system)

At one failure every ~3 hours across 16K GPUs, the run is defined by how fast it heals.

Detection — multiple signals, fast:

Signal
Catches
NCCL collective timeout
Hung/dead rank, network partition
ECC / HBM uncorrectable errors
Failing GPU memory
Watchdog heartbeat loss
Crashed node
Grad-norm / loss anomaly
Loss spike, divergence
Silent-data-corruption checks
Bad GEMM output, flaky GPU

Silent data corruption is the scariest because the run keeps going while quietly poisoning the weights — periodic checksum/recompute spot-checks and per-rank loss outlier detection are the defense.

Recovery — in-place gang restart:

1. FT controller detects failed rank R (NCCL timeout / ECC)
2. Quarantine R's node; drain in-flight collectives
3. Pull a HOT SPARE from the warm standby pool
4. Re-map R's rank onto the spare (topology-preserving)
5. All ranks roll back to last good checkpoint (load from
nearest tier that holds it: peer NVMe before Lustre)
6. Restore params + optimizer + RNG + dataloader position
7. Resume; idempotent — no tokens double-counted or skipped

Hot spares (a few hundred warm nodes) make this minutes, not hours — no waiting on the scheduler to find and boot fresh hardware. The restart is in-place: only the affected gang’s mapping changes; the parallelism plan is preserved so the effective batch size is unchanged.

Async tiered checkpointing is what keeps the overhead invisible:

checkpoint at step N:
t0: snapshot state to pinned host RAM (~ms, blocks briefly)
t1: GPUs resume training immediately
t2: background: RAM → local NVMe (L0)
t3: background: NVMe → peer replica (L1) [fast-restore copy]
t4: background: → Lustre (L2) → object store (L3) [durable]

The GPU only stalls for the in-memory snapshot (sub-second); everything durable happens off the critical path. The peer-replica (L1) tier means a single-node failure usually restores from a neighbor’s NVMe rather than the parallel FS — seconds instead of minutes. This is precisely a write-back cache hierarchy with the GPUs as the write source.

Deep dive 4 — MoE expert parallelism

Frontier models are increasingly sparse. DeepSeek-V3 is 671B total parameters but only ~37B active per token — you get a huge parameter count at a fraction of the FLOPs.

The systems cost of MoE is the routing all-to-all: every token is dispatched to its top-k experts (which live on different GPUs under expert parallelism), then results are combined back. That’s two all-to-all collectives per MoE layer — structured permutation traffic the switcher can map onto a distributed shuffle.

MoE challenge
Mitigation
all-to-all dominates step time
Overlap dispatch/combine with compute (DualPipe)
Expert load imbalance
Auxiliary load-balancing loss / bias; capacity factor
Stragglers from hot experts
Token dropping or routing bias to even out load
Memory: all experts resident
Expert parallelism shards experts across GPUs

The trap here: treating MoE as “free” capacity. The routing is dynamic and data-dependent, so load can skew badly within a step — a single overloaded expert stalls the synchronous barrier for everyone. Load balancing is a convergence-and-systems coupling: too aggressive a balancing loss hurts quality, too loose and you get stragglers. Another thing to co-design with ML.

7

Multi-team rollout

You don’t “deploy” a multi-week run — you nurse it. The operational discipline is half the battle.

Pre-launch burn-in

  • NCCL bandwidth tests across every rail and node pair — confirm each link hits ~400G; a single degraded cable silently halves a collective.
  • Straggler screening — run a synthetic step and evict any GPU more than a few % slower; in a synchronous barrier, the slowest rank taxes all 16K.
  • Numerical smoke test — a few hundred steps at small scale to confirm loss decreases and grad-norm is sane before committing the fleet.

Observability

Dashboard
Watch for
Per-rank MFU / throughput
Stragglers, fabric degradation
Loss + grad-norm curves
Spikes, divergence, NaN/Inf
NCCL flight recorder
Where a hung collective stalled
Checkpoint timing
Tiering falling behind
GPU temp / ECC counts
Imminent hardware failure

The NCCL flight recorder is the single best debugging tool: when a collective hangs, it shows exactly which rank was waiting on which op, turning a cluster-wide freeze into a pinpointed node.

Divergence playbooks

  • Loss spike: if grad-norm exceeds threshold, gradient clipping should absorb it; if loss diverges, roll back to a pre-spike checkpoint, optionally skip the offending data batch, and resume — sometimes with a briefly lowered LR. Loss spikes are often a single bad data shard or a numerical edge case, not a fundamental problem.
  • Silent corruption suspected: checkpoint, run a deterministic recompute on a held-out batch across ranks, diff outputs to find the flaky GPU, evict it.
  • Slow creep in step time: suspect a degrading link or thermal throttle; cross-reference per-rank MFU.

Reproducibility & the human loop

Deterministic resume (fixed seeds + captured RNG + exact dataloader position) means a resumed run is bit-identical to one that never failed — essential for debugging and for trusting the token budget. A human on-call rotation watches the dashboards around the clock; for a run worth millions of dollars, an engineer is always one page away from the loss curve.

8

Bottlenecks & evolution

Where it breaks next

Scale
Binding constraint
Response
~16K GPUs
Failure every ~3h
Async checkpoint + hot spares (today's design)
~100K GPUs
MTBF falls to tens of min
Partial/elastic recovery, no full stop-the-world
~100K GPUs
Network becomes binding
In-network reduction (SHARP), better overlap
Beyond
Single-DC power/space limit
Multi-datacenter + async training

At 100K+ GPUs the MTBF drops to tens of minutes, and without strong fault tolerance the effective training time falls below 50% — you’d spend more than half the cluster’s life recovering. The fixes: elastic/partial recovery (heal the affected gang without a global rollback), in-network reduction like NVIDIA SHARP (the switch fabric performs part of the all-reduce, halving collective traffic), and ever-tighter compute/comm overlap.

Evolution

1. Lower precision: FP8 today on Hopper; FP4 on Blackwell for another throughput step — gated by convergence validation.

2. Bigger sparsity: larger MoE with more experts, longer context (more context parallelism).

3. Multi-datacenter / async training: when one building can’t power 100K+ GPUs, train across DCs with relaxed synchronization.

4. Post-training reuse: the same fleet, checkpoint store, and fault tolerance run RLHF and other post-training — the infra amortizes across the model lifecycle.

The cost/MFU frontier

DeepSeek-V3 is the reference for algorithm–framework–hardware co-design: roughly 2.79M H800-hours (~$5.5M) for a frontier-class model, achieved via FP8 training, DualPipe overlap, and MoE — proof that MFU and cost are won by co-designing the model, the framework, and the hardware together, not by throwing more GPUs at a fixed recipe.

Summary

1. This is a distributed-systems reliability problem wearing an ML hat. Lead with the failure model and the network fabric, not the math of attention.

2. Hybrid nD parallelism that localizes communication is the central win. TP intra-node on NVLink (all-reduce every layer), PP/DP/FSDP inter-node on IB, with comm overlapped behind compute. Never put TP across nodes; never assume FSDP scales linearly to 16K GPUs.

3. Beat the memory wall with sharding plus mixed precision. ~16 B/param of Adam state must be sharded; BF16 compute + FP32 master avoids loss-scaling fragility; FP8 GEMMs buy ~2× throughput when applied carefully.

4. Fault tolerance keeps goodput high. Detection (NCCL timeouts, ECC, SDC checks) → hot-spare swap → in-place gang restart from an async, tiered, write-back checkpoint store, with deterministic idempotent resume. Checkpoint every 15–30 min per Young/Daly.

5. A data pipeline that never starves the GPUs — dedup, filter, tokenize into sharded mmap files, streamed deterministically per rank.

6. MFU is the metric that ties it all together — frame it as goodput: realized FLOPs over peak, eroded by stalls, stragglers, bubbles, and recovery.

7. The two things to verify with an ML expert: the exact precision recipe (which ops can drop to FP8/FP4) and the convergence implications of any parallelism reshape (because it changes effective global batch size and interacts with the LR schedule). Naming these boundaries is the credibility move, especially for an SDE moving into AI.

Rubric — Senior vs Staff

Dimension
Senior signal
Staff signal
Parallelism plan
Names data and model parallelism.
Designs a hybrid nD plan — TP intra-node, PP inter-node, FSDP/ZeRO outer, CP for long sequences — chosen to fit HBM and sustain DP width.
Memory wall
Knows the model must be split across GPUs.
Sizes optimizer state (~16–18 B/param), shards it (ZeRO-1/2/3), and adds activation recomputation to trade compute for memory.
Communication & network
Mentions all-reduce gradient sync.
Keeps TP on NVLink, PP across InfiniBand, overlaps comm with compute (DualPipe/1F1B), and knows FSDP roughly doubles comm vs DDP.
Mixed precision
Uses FP16/BF16 to save memory.
BF16 compute + FP32 master, FP8 GEMMs with per-tensor scaling, high-precision sensitive ops; never naive FP8 everywhere.
Fault tolerance
Checkpoints periodically.
Async tiered checkpoint (GPU→CPU→NVMe→object store), Young/Daly interval, hot spares, fast in-place gang restart on the affected ranks.
Scheduling
Allocates GPUs for the job.
Gang (all-or-nothing) topology-aware scheduling so TP groups land in an NVLink domain and DP groups respect the fat-tree.
MFU & data pipeline
Tracks throughput.
Targets ~35–45% MFU as goodput, never starves GPUs (dedup/tokenize/sharded streaming), and quotes sustained not peak FLOPs.
★ MORE WALKTHROUGHS

Want more breakdowns like this?

Join free early access for upcoming RAG, LLM eval, agents, and AI infrastructure walkthroughs.

Join Free Early Access →