6. Fine-Tuning Patterns
This matters because most transfer learning work is not “train everything from scratch”. The real decision is how much of a pretrained backbone to freeze, when to unfreeze it, and how aggressively to update the head versus the feature extractor.
This notebook compares three common patterns with a small synthetic setup so the mechanics stay clear.
6.1. Setup
[ ]:
import copy
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
torch.manual_seed(17)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6.2. Data and a Pretend Pretrained Backbone
[ ]:
x = torch.randn(1200, 10)
score = 1.5 * x[:, 0] - 0.7 * x[:, 1] + 0.4 * x[:, 2] * x[:, 3] + 0.4 * torch.randn(1200)
y = (score > 0).float().unsqueeze(1)
dataset = TensorDataset(x, y)
train_ds, test_ds = random_split(dataset, [900, 300], generator=torch.Generator().manual_seed(17))
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=128)
backbone = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU())
with torch.no_grad():
for p in backbone.parameters():
p.add_(0.1 * torch.randn_like(p))
6.3. Helpers
[ ]:
class Model(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.head = nn.Linear(16, 1)
def forward(self, x):
return self.head(self.backbone(x))
def evaluate(model):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for xb, yb in test_loader:
preds = (torch.sigmoid(model(xb.to(device))).cpu() > 0.5).float()
correct += (preds == yb).sum().item()
total += yb.numel()
return correct / total
def fit(model, optimizer, epochs):
loss_fn = nn.BCEWithLogitsLoss()
for _ in range(epochs):
model.train()
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
optimizer.step()
return evaluate(model)
6.4. Linear Probe
Freeze the backbone and only learn the final head.
[ ]:
linear_probe = Model(copy.deepcopy(backbone)).to(device)
for p in linear_probe.backbone.parameters():
p.requires_grad = False
probe_acc = fit(linear_probe, torch.optim.Adam(linear_probe.head.parameters(), lr=1e-2), epochs=10)
print(f'linear probe accuracy: {probe_acc:.3f}')
6.5. Partial Unfreeze
Train the head first, then unfreeze the last backbone layer with a smaller learning rate.
[ ]:
partial = Model(copy.deepcopy(backbone)).to(device)
for p in partial.backbone.parameters():
p.requires_grad = False
fit(partial, torch.optim.Adam(partial.head.parameters(), lr=1e-2), epochs=5)
for name, p in partial.backbone.named_parameters():
if '2' in name:
p.requires_grad = True
partial_acc = fit(
partial,
torch.optim.Adam([
{'params': partial.head.parameters(), 'lr': 1e-2},
{'params': [p for p in partial.backbone.parameters() if p.requires_grad], 'lr': 1e-3},
]),
epochs=8,
)
print(f'partial unfreeze accuracy: {partial_acc:.3f}')
6.6. Full Fine-Tune
Update everything, usually with more care and often a smaller learning rate for the backbone in real projects.
[ ]:
full = Model(copy.deepcopy(backbone)).to(device)
full_acc = fit(full, torch.optim.Adam(full.parameters(), lr=3e-3), epochs=10)
print(f'full fine-tune accuracy: {full_acc:.3f}')
6.7. Comparison
[ ]:
{
'linear_probe': round(probe_acc, 3),
'partial_unfreeze': round(partial_acc, 3),
'full_fine_tune': round(full_acc, 3),
}