8. Debugging Training
This matters because most failed training runs are simple bugs in shapes, labels, devices, or gradients. Focus on the shortest checks that catch those bugs before you scale up.
[ ]:
import os
from collections import Counter
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
torch.manual_seed(37)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
dataset = datasets.ImageFolder('./shapes/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=2)
class_counts = Counter(label for _, label in dataset.samples)
{dataset.classes[label]: count for label, count in sorted(class_counts.items())}
8.1. Check one batch
Before training for hours, inspect one batch: shape, dtype, label range, and device.
[ ]:
images, labels = next(iter(dataloader))
images, labels = images.to(device), labels.to(device)
print('images:', images.shape, images.dtype, images.device, images.min().item(), images.max().item())
print('labels:', labels.shape, labels.dtype, labels.device, labels.tolist())
assert images.ndim == 4
assert labels.min() >= 0 and labels.max() < len(dataset.classes)
assert images.device.type == device.type and labels.device.type == device.type
8.2. Overfit one batch
A model should be able to overfit one small batch. If the loss does not move down, check the labels, loss function, optimizer, learning rate, and gradient flow.
[ ]:
model = nn.Sequential(
nn.Flatten(),
nn.Linear(3 * 32 * 32, 64),
nn.ReLU(),
nn.Linear(64, len(dataset.classes)),
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
debug_images = images[:6]
debug_labels = labels[:6]
losses = []
steps = 120 if check_mode else 300
for step in range(steps):
optimizer.zero_grad(set_to_none=True)
logits = model(debug_images)
loss = criterion(logits, debug_labels)
loss.backward()
optimizer.step()
losses.append(loss.item())
print('first loss:', losses[0])
print('last loss:', losses[-1])
assert losses[-1] < losses[0]
8.3. Measure gradients
grad is None usually means the parameter was not used in the loss. A zero gradient norm can mean saturation, a disconnected graph, or a loss that does not depend on the parameter.
[ ]:
def gradient_report(model):
report = []
for name, parameter in model.named_parameters():
if parameter.grad is None:
report.append((name, None))
else:
report.append((name, parameter.grad.detach().norm().item()))
return report
optimizer.zero_grad(set_to_none=True)
loss = criterion(model(images), labels)
loss.backward()
report = gradient_report(model)
for name, norm in report:
print(name, norm)
assert all(norm is not None and norm >= 0 for _, norm in report)