6. Early Stopping

This matters because training longer is not the same as training better. Early stopping gives you a simple rule for quitting when validation performance stops improving, which saves time and usually produces a model that generalizes better than the last checkpoint.

This notebook keeps the model and dataset small on purpose. The goal is to understand the stopping logic, not to chase a benchmark.

6.1. Setup

[ ]:
import copy
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split

torch.manual_seed(7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

6.2. Data

We build a small noisy binary classification problem so validation loss has room to flatten out and then get worse.

[ ]:
n_samples = 1200
x = torch.randn(n_samples, 8)
logits = 1.2 * x[:, 0] - 0.8 * x[:, 1] + 0.5 * x[:, 2] ** 2 - 0.3 * x[:, 3] * x[:, 4]
logits = logits + 0.8 * torch.randn(n_samples)
y = (logits > 0).float().unsqueeze(1)

dataset = TensorDataset(x, y)
train_size = 800
val_size = 200
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(7))

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128)
test_loader = DataLoader(test_ds, batch_size=128)
len(train_ds), len(val_ds), len(test_ds)

6.3. Model

[ ]:
model = nn.Sequential(
    nn.Linear(8, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

6.4. Training With Patience

We stop if validation loss fails to improve for a few epochs in a row and keep the best checkpoint instead of the last one.

[ ]:
def evaluate_loss(model, loader):
    model.eval()
    total = 0.0
    count = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            loss = criterion(model(xb), yb)
            total += loss.item() * len(xb)
            count += len(xb)
    return total / count

best_state = None
best_val = float('inf')
patience = 4
wait = 0
history = []

for epoch in range(40):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()

    train_loss = evaluate_loss(model, train_loader)
    val_loss = evaluate_loss(model, val_loader)
    history.append((epoch + 1, train_loss, val_loss))

    if val_loss < best_val - 1e-4:
        best_val = val_loss
        best_state = copy.deepcopy(model.state_dict())
        wait = 0
    else:
        wait += 1

    print(f"epoch={epoch+1:02d} train_loss={train_loss:.4f} val_loss={val_loss:.4f} wait={wait}")
    if wait >= patience:
        print(f"Stopping early at epoch {epoch+1}")
        break

model.load_state_dict(best_state)

6.5. Test Performance

[ ]:
def accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            preds = (torch.sigmoid(model(xb)).cpu() > 0.5).float()
            correct += (preds == yb).sum().item()
            total += yb.numel()
    return correct / total

print(f"best validation loss: {best_val:.4f}")
print(f"test accuracy: {accuracy(model, test_loader):.3f}")