7. Class Imbalance

This matters because a model can look accurate while failing on the class you actually care about. Imbalanced data breaks naive metrics and often needs a different loss or sampling strategy before training results mean anything.

This notebook uses a small synthetic example so the failure mode is obvious and fast to reproduce.

7.1. Setup

[ ]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(11)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

7.2. Create an Imbalanced Dataset

[ ]:
n_major = 950
n_minor = 50
major = torch.randn(n_major, 2) + torch.tensor([0.0, 0.0])
minor = torch.randn(n_minor, 2) + torch.tensor([2.0, 2.0])

x = torch.cat([major, minor], dim=0)
y = torch.cat([torch.zeros(n_major), torch.ones(n_minor)]).unsqueeze(1)
perm = torch.randperm(len(x))
x, y = x[perm], y[perm]

split = 800
train_ds = TensorDataset(x[:split], y[:split])
test_ds = TensorDataset(x[split:], y[split:])
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256)

positive_rate = y[:split].mean().item()
positive_rate

7.3. Helper Functions

[ ]:
def make_model():
    return nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 1)).to(device)

def train_model(loss_fn, epochs=20):
    model = make_model()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    for _ in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            loss = loss_fn(model(xb), yb)
            loss.backward()
            optimizer.step()
    return model

def confusion_counts(model, loader):
    model.eval()
    tp = tn = fp = fn = 0
    with torch.no_grad():
        for xb, yb in loader:
            preds = (torch.sigmoid(model(xb.to(device))).cpu() > 0.5).float()
            tp += ((preds == 1) & (yb == 1)).sum().item()
            tn += ((preds == 0) & (yb == 0)).sum().item()
            fp += ((preds == 1) & (yb == 0)).sum().item()
            fn += ((preds == 0) & (yb == 1)).sum().item()
    return tp, tn, fp, fn

def summarize(name, model):
    tp, tn, fp, fn = confusion_counts(model, test_loader)
    acc = (tp + tn) / max(tp + tn + fp + fn, 1)
    recall = tp / max(tp + fn, 1)
    precision = tp / max(tp + fp, 1)
    print(name)
    print({'accuracy': round(acc, 3), 'precision': round(precision, 3), 'recall': round(recall, 3), 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn})

7.4. Naive Training

[ ]:
baseline = train_model(nn.BCEWithLogitsLoss())
summarize('unweighted loss', baseline)

7.5. Weighted Loss

pos_weight tells the loss that missing rare positives should cost more than missing common negatives.

[ ]:
neg = (y[:split] == 0).sum().item()
pos = (y[:split] == 1).sum().item()
pos_weight = torch.tensor([neg / max(pos, 1)], device=device)
weighted = train_model(nn.BCEWithLogitsLoss(pos_weight=pos_weight))
summarize('weighted loss', weighted)
print(f'pos_weight={pos_weight.item():.2f}')