Distributed Training (PyTorch DDP) โšก

Class 12Age 16-17Lesson 03 of 12๐Ÿ†“ Free
Class 12 Lesson 03 hero โ€” Ishaan, Bengaluru
Watch first - 2-3 minutes

Class 12 Lesson 3 - Distributed Training (PyTorch DDP)

No sign-in needed - English narration - Safe for all school ages

Story
Ishaan's 14-Hour to 2-Hour Win
๐Ÿ‘จโ€๐Ÿ’ป Ishaan ยท Bengaluru ยท Age 17

Ishaan was training an IndicBERT sentiment classifier on 2M Hindi+Telugu+Tamil tweets. On a single A100 GPU, one epoch took 4.7 hours. With 3 epochs needed, the experiment took 14 hours โ€” and a single typo cost him a day.

His college lab had 8ร— A100s in one node. Using PyTorch DDP (DistributedDataParallel), he scaled to all 8 GPUs in 12 lines of code change. Training dropped to 2 hours. Then he learned DeepSpeed ZeRO for fitting larger models that don't even fit on one GPU.

Why DDP
Data Parallel vs Model Parallel

Data Parallel (DDP)

Each GPU has a full model copy + a different data shard. Gradients averaged via all-reduce. Use when model fits on 1 GPU.

Model Parallel

Model split across GPUs. Each GPU runs different layers. Use when model is too big for 1 GPU.

Tensor Parallel

Each layer's weight matrix split across GPUs. Used inside huge LLMs (Megatron-LM).

Pipeline Parallel

Layers in stages, micro-batches flow through. Reduces idle time vs naive model parallel.

ZeRO (DeepSpeed)

Shards optimiser state, gradients, and parameters across GPUs. Enables huge models with little extra code.

FSDP (PyTorch)

PyTorch's native ZeRO-equivalent. Recommended for new projects after PyTorch 2.0.

Ishaan's IndicBERT model is 280M parameters โ€” fits easily on one A100 (40GB). So DDP is the right choice: 8ร— speedup with no engineering complexity.

Code
DDP in 12 Lines

Single-GPU training looks like this (simplified):

model = IndicBERTClassifier().cuda()
loader = DataLoader(dataset, batch_size=64, shuffle=True)
opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in loader:
    loss = model(batch).loss
    loss.backward(); opt.step(); opt.zero_grad()

The DDP version โ€” only 5 lines change:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

def train(local_rank: int, world_size: int):
    dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
    torch.cuda.set_device(local_rank)

    model = IndicBERTClassifier().cuda(local_rank)
    model = DDP(model, device_ids=[local_rank])

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
    loader = DataLoader(dataset, batch_size=64, sampler=sampler)
    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)

    for epoch in range(3):
        sampler.set_epoch(epoch)  # so shuffling differs per epoch
        for batch in loader:
            batch = {k: v.cuda(local_rank) for k, v in batch.items()}
            loss = model(**batch).loss
            loss.backward()
            opt.step(); opt.zero_grad()

    dist.destroy_process_group()

Launch with torchrun:

# 1 node, 8 GPUs
torchrun --nproc_per_node=8 train_ddp.py
Theory
What Happens Inside DDP
  1. Replicate: Each GPU starts with an identical copy of the model.
  2. Shard data: DistributedSampler ensures each GPU sees a different subset of the batch each step.
  3. Forward + backward locally: Each GPU computes its own gradients on its own data shard.
  4. All-reduce gradients: Before opt.step(), NCCL averages every gradient across all 8 GPUs. After this, every GPU has identical gradients.
  5. Step optimiser: Each GPU applies the same optimiser update with the same gradients โ€” so all model copies stay in sync without explicit broadcast.

The genius is that all-reduce happens during the backward pass (overlapped with compute), so the only extra wall time is the slowest GPU's bandwidth.

Effective batch size: With 8 GPUs ร— batch_size=64, your effective batch is 512. That changes the optimal learning rate. Linear scaling rule: multiply LR by world_size (so 2e-5 โ†’ 1.6e-4) โ€” but warm up over the first epoch to avoid divergence.
Beyond DDP
When You Need DeepSpeed ZeRO or FSDP

If Ishaan wanted to fine-tune Llama-3-70B (140GB in fp16) on 8ร— A100s, no single GPU could hold the full model. ZeRO-3 / FSDP shards the model itself:

StageShardsVRAM SavingComm Overhead
ZeRO-1Optimiser state~4ร— (Adam)None vs DDP
ZeRO-2+ Gradients~8ร—Slight
ZeRO-3+ Parameters~Nร— (linear in GPUs)~50% more

FSDP API is similar โ€” PyTorch native, recommended for new code:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
Indian student tip: Most college labs have one node with 4โ€“8 GPUs. DDP is enough for almost everything you'll do as a student. Save FSDP/ZeRO for when you're training a 70B+ model โ€” usually after you join an industry team.

๐Ÿ“ Check Your Understanding (8 Questions)

1. Why does Ishaan choose DistributedDataParallel (DDP) over model parallelism for his IndicBERT classifier?
a) Model parallelism is deprecated in PyTorch 2.0
b) His 280M-parameter model fits comfortably on one A100 (40GB), so the simpler and faster approach is to replicate it on each GPU and shard the data โ€” which is exactly what DDP does
c) DDP supports BERT-style models while model parallel does not
d) Model parallelism requires Kubernetes which is unavailable in his lab
2. What does the all-reduce step in DDP accomplish?
a) It synchronises model weights across GPUs by broadcasting from rank 0
b) It averages gradients across all GPUs so that, after the optimiser step, every model replica applies the same update and stays identical without explicit weight broadcast
c) It reduces the precision of weights from fp32 to fp16 to save memory
d) It removes duplicate samples from each GPU's data shard
3. Why must Ishaan call sampler.set_epoch(epoch) every epoch?
a) It is required to log the current epoch to the distributed logger
b) DistributedSampler uses the epoch number as a shuffle seed; without it, every epoch sees the same shard order, hurting generalisation
c) It tells NCCL to switch communication backends between epochs
d) It resets gradient accumulation buffers between epochs
4. What is the linear scaling rule for learning rate when going from 1 GPU to 8 GPUs?
a) Divide the learning rate by the number of GPUs
b) Multiply the learning rate by the number of GPUs (world size), because effective batch size grew by that factor โ€” but combine with a warmup to prevent early divergence
c) Keep the learning rate the same โ€” DDP handles scaling internally
d) Square the learning rate to compensate for parallel noise reduction
5. When would Ishaan need to use ZeRO-3 (or FSDP) instead of plain DDP?
a) When he wants to train faster on the same model
b) When the model itself is too large to fit on a single GPU โ€” ZeRO-3 shards parameters across GPUs, enabling models like Llama-3-70B on 8ร— A100s
c) When training data exceeds 1 TB
d) When the model uses attention layers
6. Why is the all-reduce communication overlapped with the backward pass?
a) It is mandated by the NCCL protocol
b) To hide communication latency behind useful compute โ€” the moment one layer's gradients are ready, NCCL begins reducing them while later layers continue to compute, so the only wall time penalty is the slowest GPU's bandwidth
c) It is a side effect of how PyTorch implements autograd
d) It compresses gradients into smaller tensors during the backward
7. What does torchrun --nproc_per_node=8 do?
a) It runs the script 8 times sequentially with different random seeds
b) It launches 8 worker processes on the current node, one per GPU, each receiving its own LOCAL_RANK environment variable so the training script can call dist.init_process_group correctly
c) It splits the script into 8 sub-processes that share a single GPU via time-slicing
d) It spawns 8 Docker containers, one per GPU
8. What is the most important practical lesson from Ishaan's DDP work?
a) Always use the maximum number of GPUs available regardless of model size
b) DDP gives near-linear speedup with minimal code change when the model fits on one GPU; choose the simplest distributed strategy that solves your actual bottleneck โ€” don't over-engineer with FSDP/ZeRO until you need them
c) DDP requires rewriting the entire training loop in JAX
d) DDP is only useful for models with more than 1 billion parameters
โ† Lesson 2: Vector DBs & RAG Lesson 4: Graph Neural Networks โ†’