6. Fine-Tuning Patterns

This matters because most transfer learning work is not “train everything from scratch”. The real decision is how much of a pretrained backbone to freeze, when to unfreeze it, and how aggressively to update the head versus the feature extractor.

This notebook compares three common patterns with a small synthetic setup so the mechanics stay clear.

6.1. Setup

[ ]:
import copy
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split

torch.manual_seed(17)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

6.2. Data and a Pretend Pretrained Backbone

[ ]:
x = torch.randn(1200, 10)
score = 1.5 * x[:, 0] - 0.7 * x[:, 1] + 0.4 * x[:, 2] * x[:, 3] + 0.4 * torch.randn(1200)
y = (score > 0).float().unsqueeze(1)
dataset = TensorDataset(x, y)
train_ds, test_ds = random_split(dataset, [900, 300], generator=torch.Generator().manual_seed(17))
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=128)

backbone = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU())
with torch.no_grad():
    for p in backbone.parameters():
        p.add_(0.1 * torch.randn_like(p))

6.3. Helpers

[ ]:
class Model(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.head = nn.Linear(16, 1)
    def forward(self, x):
        return self.head(self.backbone(x))

def evaluate(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in test_loader:
            preds = (torch.sigmoid(model(xb.to(device))).cpu() > 0.5).float()
            correct += (preds == yb).sum().item()
            total += yb.numel()
    return correct / total

def fit(model, optimizer, epochs):
    loss_fn = nn.BCEWithLogitsLoss()
    for _ in range(epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            loss = loss_fn(model(xb), yb)
            loss.backward()
            optimizer.step()
    return evaluate(model)

6.4. Linear Probe

Freeze the backbone and only learn the final head.

[ ]:
linear_probe = Model(copy.deepcopy(backbone)).to(device)
for p in linear_probe.backbone.parameters():
    p.requires_grad = False
probe_acc = fit(linear_probe, torch.optim.Adam(linear_probe.head.parameters(), lr=1e-2), epochs=10)
print(f'linear probe accuracy: {probe_acc:.3f}')

6.5. Partial Unfreeze

Train the head first, then unfreeze the last backbone layer with a smaller learning rate.

[ ]:
partial = Model(copy.deepcopy(backbone)).to(device)
for p in partial.backbone.parameters():
    p.requires_grad = False
fit(partial, torch.optim.Adam(partial.head.parameters(), lr=1e-2), epochs=5)
for name, p in partial.backbone.named_parameters():
    if '2' in name:
        p.requires_grad = True
partial_acc = fit(
    partial,
    torch.optim.Adam([
        {'params': partial.head.parameters(), 'lr': 1e-2},
        {'params': [p for p in partial.backbone.parameters() if p.requires_grad], 'lr': 1e-3},
    ]),
    epochs=8,
)
print(f'partial unfreeze accuracy: {partial_acc:.3f}')

6.6. Full Fine-Tune

Update everything, usually with more care and often a smaller learning rate for the backbone in real projects.

[ ]:
full = Model(copy.deepcopy(backbone)).to(device)
full_acc = fit(full, torch.optim.Adam(full.parameters(), lr=3e-3), epochs=10)
print(f'full fine-tune accuracy: {full_acc:.3f}')

6.7. Comparison

[ ]:
{
    'linear_probe': round(probe_acc, 3),
    'partial_unfreeze': round(partial_acc, 3),
    'full_fine_tune': round(full_acc, 3),
}