2. Memory Optimization
This matters because memory pressure is one of the first hard limits in deep learning. Focus on which tricks trade memory for compute, which reduce parameter footprint, and which simply change how work is scheduled.
[ ]:
import os
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
torch.manual_seed(47)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
2.1. Parameter footprint
[ ]:
small = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 10))
large = nn.Sequential(nn.Linear(128, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10))
small_params = sum(p.numel() for p in small.parameters())
large_params = sum(p.numel() for p in large.parameters())
print('small params:', small_params)
print('large params:', large_params)
assert large_params > small_params
2.2. Gradient accumulation
Accumulate several micro-batches before stepping the optimizer. This reduces peak batch memory at the cost of more optimizer bookkeeping.
[ ]:
model = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 2)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()
x = torch.randn(32, 64, device=device)
y = torch.randint(0, 2, (32,), device=device)
micro_batch_size = 8
accumulation_steps = x.shape[0] // micro_batch_size
optimizer.zero_grad(set_to_none=True)
for step in range(accumulation_steps):
start = step * micro_batch_size
end = start + micro_batch_size
loss = criterion(model(x[start:end]), y[start:end]) / accumulation_steps
loss.backward()
optimizer.step()
print('micro-batches:', accumulation_steps)
assert torch.isfinite(loss)
2.3. Activation checkpointing
Checkpointing trades compute for memory by recomputing part of the forward pass during backward.
[ ]:
block = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU()).to(device)
head = nn.Linear(64, 2).to(device)
input_tensor = torch.randn(16, 64, device=device, requires_grad=True)
target = torch.randint(0, 2, (16,), device=device)
def run_block(tensor):
return block(tensor)
features = checkpoint(run_block, input_tensor, use_reentrant=False)
loss = criterion(head(features), target)
loss.backward()
print(round(loss.item(), 4))
assert input_tensor.grad is not None