4. Model Zoo
This matters because transfer learning starts with understanding what the library already gives you. Focus on discovering architectures, swapping heads, freezing the right layers, and understanding weights metadata.
4.1. Discover torchvision models
The Docker checker avoids network downloads, so this notebook builds architectures with weights=None. For real transfer learning, pass an official weights enum such as ResNet18_Weights.DEFAULT.
[ ]:
import torch
from torch import nn
from torchvision import models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
available = models.list_models(module=models)
print('number of torchvision model builders:', len(available))
print('sample:', available[:10])
assert 'resnet18' in available
4.2. Replace the classifier head
Most model-zoo classifiers end with a task-specific linear layer. Replace that layer when the number of classes changes.
[ ]:
model = models.resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 3)
model = model.to(device)
images = torch.randn(2, 3, 224, 224, device=device)
with torch.inference_mode():
logits = model(images)
print(logits.shape)
assert logits.shape == (2, 3)
4.3. Freeze the backbone
Freeze the feature extractor when you only want to train the new classifier head. Always print the trainable parameter count after changing requires_grad.
[ ]:
for parameter in model.parameters():
parameter.requires_grad = False
for parameter in model.fc.parameters():
parameter.requires_grad = True
trainable = sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
frozen = sum(parameter.numel() for parameter in model.parameters() if not parameter.requires_grad)
print('trainable parameters:', trainable)
print('frozen parameters:', frozen)
assert trainable == sum(parameter.numel() for parameter in model.fc.parameters())
assert frozen > trainable
4.4. Weight metadata
Weights enums document preprocessing, categories, metrics, and recipes. Accessing the enum does not download weights; passing it into the model builder may download them into TORCH_HOME if not cached.
[ ]:
weights = models.ResNet18_Weights.DEFAULT
print(weights)
print('categories:', weights.meta['categories'][:5])
print('recipe:', weights.meta.get('recipe'))
assert callable(weights.transforms)