5. Training Recipe
This matters because most good training runs follow the same small set of habits. A repeatable recipe reduces random experimentation and gives you a baseline process that is easier to debug, compare, and improve.
5.1. Setup
The example uses the small shape image dataset included with this book.
[ ]:
import os
from collections import defaultdict
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
torch.manual_seed(37)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
batch_size = 8
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
datasets_by_phase = {
'train': datasets.ImageFolder('./shapes/train', transform=transform),
'valid': datasets.ImageFolder('./shapes/valid', transform=transform),
}
dataloaders = {
phase: DataLoader(dataset, batch_size=batch_size, shuffle=(phase == 'train'), num_workers=2)
for phase, dataset in datasets_by_phase.items()
}
idx_to_class = {v: k for k, v in datasets_by_phase['train'].class_to_idx.items()}
idx_to_class
5.2. Model
Use a small convolutional model first. A tiny model is faster to debug than a pretrained model with millions of parameters.
[ ]:
class SmallConvNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(32 * 16 * 16, 64),
nn.ReLU(),
nn.Linear(64, num_classes),
)
def forward(self, x):
x = self.features(x)
return self.classifier(x)
model = SmallConvNet(num_classes=len(idx_to_class)).to(device)
sum(p.numel() for p in model.parameters())
5.3. One epoch functions
Separate train_one_epoch() from evaluate(). The difference is small but important: training enables gradients and optimizer updates; evaluation disables them and keeps the model in evaluation mode.
[ ]:
def batch_accuracy(logits, labels):
predictions = logits.argmax(dim=1)
return (predictions == labels).sum().item()
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
totals = defaultdict(float)
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
totals['loss'] += loss.item() * images.size(0)
totals['correct'] += batch_accuracy(logits, labels)
totals['count'] += images.size(0)
return {'loss': totals['loss'] / totals['count'], 'accuracy': totals['correct'] / totals['count']}
@torch.inference_mode()
def evaluate(model, dataloader, criterion, device):
model.eval()
totals = defaultdict(float)
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
totals['loss'] += loss.item() * images.size(0)
totals['correct'] += batch_accuracy(logits, labels)
totals['count'] += images.size(0)
return {'loss': totals['loss'] / totals['count'], 'accuracy': totals['correct'] / totals['count']}
5.4. Fit loop
A useful training loop returns history. Returning structured metrics is easier to plot, write to TensorBoard, serialize, and compare across experiments.
[ ]:
def fit(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs):
history = []
best_valid_accuracy = -1.0
best_state = None
for epoch in range(num_epochs):
train_metrics = train_one_epoch(model, dataloaders['train'], criterion, optimizer, device)
valid_metrics = evaluate(model, dataloaders['valid'], criterion, device)
scheduler.step()
if valid_metrics['accuracy'] > best_valid_accuracy:
best_valid_accuracy = valid_metrics['accuracy']
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
record = {
'epoch': epoch,
'lr': scheduler.get_last_lr()[0],
'train_loss': train_metrics['loss'],
'train_accuracy': train_metrics['accuracy'],
'valid_loss': valid_metrics['loss'],
'valid_accuracy': valid_metrics['accuracy'],
}
history.append(record)
print(record)
model.load_state_dict(best_state)
return history
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
history = fit(model, dataloaders, criterion, optimizer, scheduler, device, num_epochs=1 if check_mode else 5)
history[-1]