7. Class Imbalance
This matters because a model can look accurate while failing on the class you actually care about. Imbalanced data breaks naive metrics and often needs a different loss or sampling strategy before training results mean anything.
This notebook uses a small synthetic example so the failure mode is obvious and fast to reproduce.
7.1. Setup
[ ]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(11)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7.2. Create an Imbalanced Dataset
[ ]:
n_major = 950
n_minor = 50
major = torch.randn(n_major, 2) + torch.tensor([0.0, 0.0])
minor = torch.randn(n_minor, 2) + torch.tensor([2.0, 2.0])
x = torch.cat([major, minor], dim=0)
y = torch.cat([torch.zeros(n_major), torch.ones(n_minor)]).unsqueeze(1)
perm = torch.randperm(len(x))
x, y = x[perm], y[perm]
split = 800
train_ds = TensorDataset(x[:split], y[:split])
test_ds = TensorDataset(x[split:], y[split:])
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)
positive_rate = y[:split].mean().item()
positive_rate
7.3. Helper Functions
[ ]:
def make_model():
return nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 1)).to(device)
def train_model(loss_fn, epochs=20):
model = make_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(epochs):
model.train()
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
optimizer.step()
return model
def confusion_counts(model, loader):
model.eval()
tp = tn = fp = fn = 0
with torch.no_grad():
for xb, yb in loader:
preds = (torch.sigmoid(model(xb.to(device))).cpu() > 0.5).float()
tp += ((preds == 1) & (yb == 1)).sum().item()
tn += ((preds == 0) & (yb == 0)).sum().item()
fp += ((preds == 1) & (yb == 0)).sum().item()
fn += ((preds == 0) & (yb == 1)).sum().item()
return tp, tn, fp, fn
def summarize(name, model):
tp, tn, fp, fn = confusion_counts(model, test_loader)
acc = (tp + tn) / max(tp + tn + fp + fn, 1)
recall = tp / max(tp + fn, 1)
precision = tp / max(tp + fp, 1)
print(name)
print({'accuracy': round(acc, 3), 'precision': round(precision, 3), 'recall': round(recall, 3), 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn})
7.4. Naive Training
[ ]:
baseline = train_model(nn.BCEWithLogitsLoss())
summarize('unweighted loss', baseline)
7.5. Weighted Loss
pos_weight tells the loss that missing rare positives should cost more than missing common negatives.
[ ]:
neg = (y[:split] == 0).sum().item()
pos = (y[:split] == 1).sum().item()
pos_weight = torch.tensor([neg / max(pos, 1)], device=device)
weighted = train_model(nn.BCEWithLogitsLoss(pos_weight=pos_weight))
summarize('weighted loss', weighted)
print(f'pos_weight={pos_weight.item():.2f}')