10. Multi-Task Learning

This matters because many production models share one backbone across several objectives. Focus on how losses are combined and on making the weighting explicit rather than accidental.

[ ]:
import os

import torch
from torch import nn

torch.manual_seed(43)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

x = torch.randn(64, 6, device=device)
class_target = (x[:, 0] + x[:, 1] > 0).long()
regression_target = (x[:, 2] - x[:, 3]).unsqueeze(1)

[ ]:
class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(nn.Linear(6, 16), nn.ReLU(), nn.Linear(16, 16), nn.ReLU())
        self.class_head = nn.Linear(16, 2)
        self.regression_head = nn.Linear(16, 1)

    def forward(self, x):
        features = self.backbone(x)
        return self.class_head(features), self.regression_head(features)

model = MultiTaskModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.03)
classification_loss = nn.CrossEntropyLoss()
regression_loss = nn.MSELoss()

steps = 5 if check_mode else 60
for step in range(steps):
    optimizer.zero_grad(set_to_none=True)
    class_logits, regression_prediction = model(x)
    loss_a = classification_loss(class_logits, class_target)
    loss_b = regression_loss(regression_prediction, regression_target)
    loss = loss_a + 0.5 * loss_b
    loss.backward()
    optimizer.step()
    if step in {0, steps - 1}:
        print(step, round(loss.item(), 4), round(loss_a.item(), 4), round(loss_b.item(), 4))

assert class_logits.shape == (64, 2)
assert regression_prediction.shape == (64, 1)