Ishaan was training an IndicBERT sentiment classifier on 2M Hindi+Telugu+Tamil tweets. On a single A100 GPU, one epoch took 4.7 hours. With 3 epochs needed, the experiment took 14 hours โ and a single typo cost him a day.
His college lab had 8ร A100s in one node. Using PyTorch DDP (DistributedDataParallel), he scaled to all 8 GPUs in 12 lines of code change. Training dropped to 2 hours. Then he learned DeepSpeed ZeRO for fitting larger models that don't even fit on one GPU.
Data Parallel (DDP)
Each GPU has a full model copy + a different data shard. Gradients averaged via all-reduce. Use when model fits on 1 GPU.
Model Parallel
Model split across GPUs. Each GPU runs different layers. Use when model is too big for 1 GPU.
Tensor Parallel
Each layer's weight matrix split across GPUs. Used inside huge LLMs (Megatron-LM).
Pipeline Parallel
Layers in stages, micro-batches flow through. Reduces idle time vs naive model parallel.
ZeRO (DeepSpeed)
Shards optimiser state, gradients, and parameters across GPUs. Enables huge models with little extra code.
FSDP (PyTorch)
PyTorch's native ZeRO-equivalent. Recommended for new projects after PyTorch 2.0.
Ishaan's IndicBERT model is 280M parameters โ fits easily on one A100 (40GB). So DDP is the right choice: 8ร speedup with no engineering complexity.
Single-GPU training looks like this (simplified):
model = IndicBERTClassifier().cuda()
loader = DataLoader(dataset, batch_size=64, shuffle=True)
opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in loader:
loss = model(batch).loss
loss.backward(); opt.step(); opt.zero_grad()
The DDP version โ only 5 lines change:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def train(local_rank: int, world_size: int):
dist.init_process_group("nccl", rank=local_rank, world_size=world_size)
torch.cuda.set_device(local_rank)
model = IndicBERTClassifier().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
loader = DataLoader(dataset, batch_size=64, sampler=sampler)
opt = torch.optim.AdamW(model.parameters(), lr=2e-5)
for epoch in range(3):
sampler.set_epoch(epoch) # so shuffling differs per epoch
for batch in loader:
batch = {k: v.cuda(local_rank) for k, v in batch.items()}
loss = model(**batch).loss
loss.backward()
opt.step(); opt.zero_grad()
dist.destroy_process_group()
Launch with torchrun:
# 1 node, 8 GPUs
torchrun --nproc_per_node=8 train_ddp.py
- Replicate: Each GPU starts with an identical copy of the model.
- Shard data:
DistributedSamplerensures each GPU sees a different subset of the batch each step. - Forward + backward locally: Each GPU computes its own gradients on its own data shard.
- All-reduce gradients: Before
opt.step(), NCCL averages every gradient across all 8 GPUs. After this, every GPU has identical gradients. - Step optimiser: Each GPU applies the same optimiser update with the same gradients โ so all model copies stay in sync without explicit broadcast.
The genius is that all-reduce happens during the backward pass (overlapped with compute), so the only extra wall time is the slowest GPU's bandwidth.
If Ishaan wanted to fine-tune Llama-3-70B (140GB in fp16) on 8ร A100s, no single GPU could hold the full model. ZeRO-3 / FSDP shards the model itself:
| Stage | Shards | VRAM Saving | Comm Overhead |
|---|---|---|---|
| ZeRO-1 | Optimiser state | ~4ร (Adam) | None vs DDP |
| ZeRO-2 | + Gradients | ~8ร | Slight |
| ZeRO-3 | + Parameters | ~Nร (linear in GPUs) | ~50% more |
FSDP API is similar โ PyTorch native, recommended for new code:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)