5. Testing PyTorch Code

This matters because tensor code breaks late and opaquely without tests. Focus on tiny checks that lock down shapes, gradients, checkpoint reloads, and numerical sanity.

5.1. Tiny forward and backward checks

A minimal test can assert output shape, finite loss, and the presence of gradients.

[ ]:
import torch
from torch import nn

torch.manual_seed(71)
model = nn.Sequential(nn.Linear(5, 8), nn.ReLU(), nn.Linear(8, 2))
x = torch.randn(6, 5)
y = torch.randint(0, 2, (6,))
criterion = nn.CrossEntropyLoss()

logits = model(x)
loss = criterion(logits, y)
loss.backward()
print(logits.shape, loss.item())
assert logits.shape == (6, 2)
assert torch.isfinite(loss)
assert all(parameter.grad is not None for parameter in model.parameters())

5.2. Checkpoint reload test

A reload test makes sure the state dict round-trip works before a long training run depends on it.

[ ]:
path = './output/testing-state-dict.pt'
torch.save(model.state_dict(), path)
reloaded = nn.Sequential(nn.Linear(5, 8), nn.ReLU(), nn.Linear(8, 2))
reloaded.load_state_dict(torch.load(path, map_location='cpu'))
with torch.inference_mode():
    original = model(x)
    copied = reloaded(x)
assert torch.allclose(original, copied)

5.3. What to test in real projects

Add unit tests for data collation, model heads, custom losses, metric calculations, checkpoint metadata, and any shape assumptions that would otherwise fail late.