3. Uncertainty
This matters because a confident wrong prediction is often more damaging than a plainly uncertain one. Focus on simple uncertainty signals you can compute without redesigning the whole model.
[ ]:
import torch
from torch import nn
torch.manual_seed(67)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Dropout(p=0.4), nn.Linear(8, 3)).to(device)
x = torch.randn(5, 4, device=device)
3.1. Predictive entropy
Entropy is larger when the class distribution is spread out and smaller when one class dominates.
[ ]:
model.eval()
with torch.inference_mode():
probabilities = torch.softmax(model(x), dim=-1)
entropy = -(probabilities * probabilities.clamp_min(1e-8).log()).sum(dim=1)
print(entropy)
assert entropy.shape == (5,)
3.2. Monte Carlo dropout
Keep dropout active at inference time, run several forward passes, and measure prediction spread.
[ ]:
model.train()
samples = []
with torch.inference_mode():
for _ in range(10):
samples.append(torch.softmax(model(x), dim=-1))
samples = torch.stack(samples)
mean_probability = samples.mean(dim=0)
variance = samples.var(dim=0)
print(mean_probability.shape, variance.shape)
assert mean_probability.shape == variance.shape == (5, 3)