5. Distributed Training

This matters because serious training jobs eventually outgrow one process or one GPU. Focus on the invariants that stay the same from a local one-process example to a multi-GPU node: samplers, rank-aware logging, and rank-zero checkpointing.

5.1. Initialize a local process group

This notebook uses world_size=1 and a file-based init method. That keeps the DDP code path real without needing a launcher.

[ ]:
import os
import tempfile

import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

torch.manual_seed(31)

backend = 'gloo'
rank = 0
world_size = 1
init_file = tempfile.NamedTemporaryFile(delete=False)
init_method = f'file://{init_file.name}'
init_file.close()

dist.init_process_group(backend=backend, init_method=init_method, rank=rank, world_size=world_size)
print('initialized:', dist.is_initialized(), 'backend:', dist.get_backend())

5.2. Wrap the model and use a distributed sampler

With world_size=1, the sampler still works and keeps the code path identical to multi-process training.

[ ]:
x = torch.randn(32, 6)
y = (x[:, 0] + x[:, 1] > 0).long()
dataset = TensorDataset(x, y)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(dataset, batch_size=8, sampler=sampler)

model = nn.Sequential(nn.Linear(6, 16), nn.ReLU(), nn.Linear(16, 2))
ddp_model = DDP(model)
optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=0.05)
criterion = nn.CrossEntropyLoss()

for epoch in range(2):
    sampler.set_epoch(epoch)
    for batch_x, batch_y in loader:
        optimizer.zero_grad(set_to_none=True)
        logits = ddp_model(batch_x)
        loss = criterion(logits, batch_y)
        loss.backward()
        optimizer.step()
    if rank == 0:
        print(epoch, round(loss.item(), 4))

with torch.inference_mode():
    logits = ddp_model(x)
assert logits.shape == (32, 2)

5.3. Rank-zero checkpointing

In real distributed jobs, save checkpoints only on rank 0 to avoid file corruption and duplicated work.

[ ]:
checkpoint_path = './output/ddp-checkpoint.pt'
if rank == 0:
    torch.save({'model': ddp_model.module.state_dict(), 'epoch': 1}, checkpoint_path)

dist.barrier()
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print(checkpoint['epoch'])
assert checkpoint['epoch'] == 1

dist.destroy_process_group()
os.unlink(init_file.name)