3. Uncertainty

This matters because a confident wrong prediction is often more damaging than a plainly uncertain one. Focus on simple uncertainty signals you can compute without redesigning the whole model.

[ ]:
import torch
from torch import nn

torch.manual_seed(67)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Dropout(p=0.4), nn.Linear(8, 3)).to(device)
x = torch.randn(5, 4, device=device)

3.1. Predictive entropy

Entropy is larger when the class distribution is spread out and smaller when one class dominates.

[ ]:
model.eval()
with torch.inference_mode():
    probabilities = torch.softmax(model(x), dim=-1)
entropy = -(probabilities * probabilities.clamp_min(1e-8).log()).sum(dim=1)
print(entropy)
assert entropy.shape == (5,)

3.2. Monte Carlo dropout

Keep dropout active at inference time, run several forward passes, and measure prediction spread.

[ ]:
model.train()
samples = []
with torch.inference_mode():
    for _ in range(10):
        samples.append(torch.softmax(model(x), dim=-1))
samples = torch.stack(samples)
mean_probability = samples.mean(dim=0)
variance = samples.var(dim=0)
print(mean_probability.shape, variance.shape)
assert mean_probability.shape == variance.shape == (5, 3)