4. Model Zoo

This matters because transfer learning starts with understanding what the library already gives you. Focus on discovering architectures, swapping heads, freezing the right layers, and understanding weights metadata.

4.1. Discover torchvision models

The Docker checker avoids network downloads, so this notebook builds architectures with weights=None. For real transfer learning, pass an official weights enum such as ResNet18_Weights.DEFAULT.

[ ]:
import torch
from torch import nn
from torchvision import models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

available = models.list_models(module=models)
print('number of torchvision model builders:', len(available))
print('sample:', available[:10])
assert 'resnet18' in available

4.2. Replace the classifier head

Most model-zoo classifiers end with a task-specific linear layer. Replace that layer when the number of classes changes.

[ ]:
model = models.resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 3)
model = model.to(device)

images = torch.randn(2, 3, 224, 224, device=device)
with torch.inference_mode():
    logits = model(images)
print(logits.shape)
assert logits.shape == (2, 3)

4.3. Freeze the backbone

Freeze the feature extractor when you only want to train the new classifier head. Always print the trainable parameter count after changing requires_grad.

[ ]:
for parameter in model.parameters():
    parameter.requires_grad = False

for parameter in model.fc.parameters():
    parameter.requires_grad = True

trainable = sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
frozen = sum(parameter.numel() for parameter in model.parameters() if not parameter.requires_grad)
print('trainable parameters:', trainable)
print('frozen parameters:', frozen)
assert trainable == sum(parameter.numel() for parameter in model.fc.parameters())
assert frozen > trainable

4.4. Weight metadata

Weights enums document preprocessing, categories, metrics, and recipes. Accessing the enum does not download weights; passing it into the model builder may download them into TORCH_HOME if not cached.

[ ]:
weights = models.ResNet18_Weights.DEFAULT
print(weights)
print('categories:', weights.meta['categories'][:5])
print('recipe:', weights.meta.get('recipe'))
assert callable(weights.transforms)