2. Memory Optimization

This matters because memory pressure is one of the first hard limits in deep learning. Focus on which tricks trade memory for compute, which reduce parameter footprint, and which simply change how work is scheduled.

[ ]:
import os

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

torch.manual_seed(47)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

2.1. Parameter footprint

[ ]:
small = nn.Sequential(nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 10))
large = nn.Sequential(nn.Linear(128, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10))
small_params = sum(p.numel() for p in small.parameters())
large_params = sum(p.numel() for p in large.parameters())
print('small params:', small_params)
print('large params:', large_params)
assert large_params > small_params

2.2. Gradient accumulation

Accumulate several micro-batches before stepping the optimizer. This reduces peak batch memory at the cost of more optimizer bookkeeping.

[ ]:
model = nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 2)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()
x = torch.randn(32, 64, device=device)
y = torch.randint(0, 2, (32,), device=device)
micro_batch_size = 8
accumulation_steps = x.shape[0] // micro_batch_size

optimizer.zero_grad(set_to_none=True)
for step in range(accumulation_steps):
    start = step * micro_batch_size
    end = start + micro_batch_size
    loss = criterion(model(x[start:end]), y[start:end]) / accumulation_steps
    loss.backward()
optimizer.step()
print('micro-batches:', accumulation_steps)
assert torch.isfinite(loss)

2.3. Activation checkpointing

Checkpointing trades compute for memory by recomputing part of the forward pass during backward.

[ ]:
block = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU()).to(device)
head = nn.Linear(64, 2).to(device)
input_tensor = torch.randn(16, 64, device=device, requires_grad=True)
target = torch.randint(0, 2, (16,), device=device)

def run_block(tensor):
    return block(tensor)

features = checkpoint(run_block, input_tensor, use_reentrant=False)
loss = criterion(head(features), target)
loss.backward()
print(round(loss.item(), 4))
assert input_tensor.grad is not None