3. Checkpoint and Resume

This matters because long runs fail, experiments get interrupted, and good training code must resume without guesswork. Focus on what state belongs in the checkpoint and how to verify that the resumed run is actually continuing the same optimization process.

[ ]:
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

torch.manual_seed(37)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint_dir = Path('./output/checkpoints')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
train_dataset = datasets.ImageFolder('./shapes/train', transform=transform)
valid_dataset = datasets.ImageFolder('./shapes/valid', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False, num_workers=2)

def build_model():
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(3 * 32 * 32, 64),
        nn.ReLU(),
        nn.Linear(64, len(train_dataset.classes)),
    ).to(device)
[ ]:
def run_epoch(model, loader, criterion, optimizer=None):
    is_training = optimizer is not None
    model.train(is_training)
    total_loss, total_correct, total_count = 0.0, 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        with torch.set_grad_enabled(is_training):
            logits = model(images)
            loss = criterion(logits, labels)
            if is_training:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * images.size(0)
        total_correct += (logits.argmax(dim=1) == labels).sum().item()
        total_count += images.size(0)

    return {'loss': total_loss / total_count, 'accuracy': total_correct / total_count}

def save_checkpoint(path, model, optimizer, scheduler, epoch, best_accuracy):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'epoch': epoch,
        'best_accuracy': best_accuracy,
    }, path)

def load_checkpoint(path, model, optimizer, scheduler, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    return checkpoint['epoch'], checkpoint['best_accuracy']

3.1. Train, save, reload, resume

[ ]:
criterion = nn.CrossEntropyLoss()

model = build_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

best_accuracy = -1.0
for epoch in range(1):
    train_metrics = run_epoch(model, train_loader, criterion, optimizer)
    valid_metrics = run_epoch(model, valid_loader, criterion)
    scheduler.step()
    best_accuracy = max(best_accuracy, valid_metrics['accuracy'])
    save_checkpoint(checkpoint_dir / 'latest.pt', model, optimizer, scheduler, epoch, best_accuracy)
    if valid_metrics['accuracy'] >= best_accuracy:
        save_checkpoint(checkpoint_dir / 'best.pt', model, optimizer, scheduler, epoch, best_accuracy)

reloaded_model = build_model()
reloaded_optimizer = torch.optim.AdamW(reloaded_model.parameters(), lr=1e-3)
reloaded_scheduler = torch.optim.lr_scheduler.StepLR(reloaded_optimizer, step_size=1, gamma=0.9)

last_epoch, best_accuracy = load_checkpoint(
    checkpoint_dir / 'latest.pt',
    reloaded_model,
    reloaded_optimizer,
    reloaded_scheduler,
    device,
)

resume_metrics = run_epoch(reloaded_model, train_loader, criterion, reloaded_optimizer)
print('resumed from epoch:', last_epoch)
print('best accuracy:', best_accuracy)
print('resume metrics:', resume_metrics)

assert (checkpoint_dir / 'latest.pt').exists()
assert (checkpoint_dir / 'best.pt').exists()