6. Early Stopping
This matters because training longer is not the same as training better. Early stopping gives you a simple rule for quitting when validation performance stops improving, which saves time and usually produces a model that generalizes better than the last checkpoint.
This notebook keeps the model and dataset small on purpose. The goal is to understand the stopping logic, not to chase a benchmark.
6.1. Setup
[ ]:
import copy
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
torch.manual_seed(7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6.2. Data
We build a small noisy binary classification problem so validation loss has room to flatten out and then get worse.
[ ]:
n_samples = 1200
x = torch.randn(n_samples, 8)
logits = 1.2 * x[:, 0] - 0.8 * x[:, 1] + 0.5 * x[:, 2] ** 2 - 0.3 * x[:, 3] * x[:, 4]
logits = logits + 0.8 * torch.randn(n_samples)
y = (logits > 0).float().unsqueeze(1)
dataset = TensorDataset(x, y)
train_size = 800
val_size = 200
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(7))
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128)
test_loader = DataLoader(test_ds, batch_size=128)
len(train_ds), len(val_ds), len(test_ds)
6.3. Model
[ ]:
model = nn.Sequential(
nn.Linear(8, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 1),
).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
6.4. Training With Patience
We stop if validation loss fails to improve for a few epochs in a row and keep the best checkpoint instead of the last one.
[ ]:
def evaluate_loss(model, loader):
model.eval()
total = 0.0
count = 0
with torch.no_grad():
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
loss = criterion(model(xb), yb)
total += loss.item() * len(xb)
count += len(xb)
return total / count
best_state = None
best_val = float('inf')
patience = 4
wait = 0
history = []
for epoch in range(40):
model.train()
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = criterion(model(xb), yb)
loss.backward()
optimizer.step()
train_loss = evaluate_loss(model, train_loader)
val_loss = evaluate_loss(model, val_loader)
history.append((epoch + 1, train_loss, val_loss))
if val_loss < best_val - 1e-4:
best_val = val_loss
best_state = copy.deepcopy(model.state_dict())
wait = 0
else:
wait += 1
print(f"epoch={epoch+1:02d} train_loss={train_loss:.4f} val_loss={val_loss:.4f} wait={wait}")
if wait >= patience:
print(f"Stopping early at epoch {epoch+1}")
break
model.load_state_dict(best_state)
6.5. Test Performance
[ ]:
def accuracy(model, loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(device)
preds = (torch.sigmoid(model(xb)).cpu() > 0.5).float()
correct += (preds == yb).sum().item()
total += yb.numel()
return correct / total
print(f"best validation loss: {best_val:.4f}")
print(f"test accuracy: {accuracy(model, test_loader):.3f}")