8. Debugging Training

This matters because most failed training runs are simple bugs in shapes, labels, devices, or gradients. Focus on the shortest checks that catch those bugs before you scale up.

[ ]:
import os
from collections import Counter

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.manual_seed(37)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
dataset = datasets.ImageFolder('./shapes/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=2)

class_counts = Counter(label for _, label in dataset.samples)
{dataset.classes[label]: count for label, count in sorted(class_counts.items())}

8.1. Check one batch

Before training for hours, inspect one batch: shape, dtype, label range, and device.

[ ]:
images, labels = next(iter(dataloader))
images, labels = images.to(device), labels.to(device)

print('images:', images.shape, images.dtype, images.device, images.min().item(), images.max().item())
print('labels:', labels.shape, labels.dtype, labels.device, labels.tolist())

assert images.ndim == 4
assert labels.min() >= 0 and labels.max() < len(dataset.classes)
assert images.device.type == device.type and labels.device.type == device.type

8.2. Overfit one batch

A model should be able to overfit one small batch. If the loss does not move down, check the labels, loss function, optimizer, learning rate, and gradient flow.

[ ]:
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(3 * 32 * 32, 64),
    nn.ReLU(),
    nn.Linear(64, len(dataset.classes)),
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

debug_images = images[:6]
debug_labels = labels[:6]

losses = []
steps = 120 if check_mode else 300

for step in range(steps):
    optimizer.zero_grad(set_to_none=True)
    logits = model(debug_images)
    loss = criterion(logits, debug_labels)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

print('first loss:', losses[0])
print('last loss:', losses[-1])
assert losses[-1] < losses[0]

8.3. Measure gradients

grad is None usually means the parameter was not used in the loss. A zero gradient norm can mean saturation, a disconnected graph, or a loss that does not depend on the parameter.

[ ]:
def gradient_report(model):
    report = []
    for name, parameter in model.named_parameters():
        if parameter.grad is None:
            report.append((name, None))
        else:
            report.append((name, parameter.grad.detach().norm().item()))
    return report

optimizer.zero_grad(set_to_none=True)
loss = criterion(model(images), labels)
loss.backward()

report = gradient_report(model)
for name, norm in report:
    print(name, norm)

assert all(norm is not None and norm >= 0 for _, norm in report)