5. Training Recipe

This matters because most good training runs follow the same small set of habits. A repeatable recipe reduces random experimentation and gives you a baseline process that is easier to debug, compare, and improve.

5.1. Setup

The example uses the small shape image dataset included with this book.

[ ]:
import os
from collections import defaultdict
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.manual_seed(37)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
batch_size = 8

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

datasets_by_phase = {
    'train': datasets.ImageFolder('./shapes/train', transform=transform),
    'valid': datasets.ImageFolder('./shapes/valid', transform=transform),
}

dataloaders = {
    phase: DataLoader(dataset, batch_size=batch_size, shuffle=(phase == 'train'), num_workers=2)
    for phase, dataset in datasets_by_phase.items()
}

idx_to_class = {v: k for k, v in datasets_by_phase['train'].class_to_idx.items()}
idx_to_class

5.2. Model

Use a small convolutional model first. A tiny model is faster to debug than a pretrained model with millions of parameters.

[ ]:
class SmallConvNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

model = SmallConvNet(num_classes=len(idx_to_class)).to(device)
sum(p.numel() for p in model.parameters())

5.3. One epoch functions

Separate train_one_epoch() from evaluate(). The difference is small but important: training enables gradients and optimizer updates; evaluation disables them and keeps the model in evaluation mode.

[ ]:
def batch_accuracy(logits, labels):
    predictions = logits.argmax(dim=1)
    return (predictions == labels).sum().item()

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    totals = defaultdict(float)

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        totals['loss'] += loss.item() * images.size(0)
        totals['correct'] += batch_accuracy(logits, labels)
        totals['count'] += images.size(0)

    return {'loss': totals['loss'] / totals['count'], 'accuracy': totals['correct'] / totals['count']}

@torch.inference_mode()
def evaluate(model, dataloader, criterion, device):
    model.eval()
    totals = defaultdict(float)

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)

        totals['loss'] += loss.item() * images.size(0)
        totals['correct'] += batch_accuracy(logits, labels)
        totals['count'] += images.size(0)

    return {'loss': totals['loss'] / totals['count'], 'accuracy': totals['correct'] / totals['count']}

5.4. Fit loop

A useful training loop returns history. Returning structured metrics is easier to plot, write to TensorBoard, serialize, and compare across experiments.

[ ]:
def fit(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs):
    history = []
    best_valid_accuracy = -1.0
    best_state = None

    for epoch in range(num_epochs):
        train_metrics = train_one_epoch(model, dataloaders['train'], criterion, optimizer, device)
        valid_metrics = evaluate(model, dataloaders['valid'], criterion, device)
        scheduler.step()

        if valid_metrics['accuracy'] > best_valid_accuracy:
            best_valid_accuracy = valid_metrics['accuracy']
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        record = {
            'epoch': epoch,
            'lr': scheduler.get_last_lr()[0],
            'train_loss': train_metrics['loss'],
            'train_accuracy': train_metrics['accuracy'],
            'valid_loss': valid_metrics['loss'],
            'valid_accuracy': valid_metrics['accuracy'],
        }
        history.append(record)
        print(record)

    model.load_state_dict(best_state)
    return history

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

history = fit(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs=1 if check_mode else 5)
history[-1]