3. Sequence Classification

This matters because many practical text tasks still reduce to classifying a sequence. Focus on how lengths, padding, packing, and the final hidden state interact.

[ ]:
import os
import re
from collections import Counter

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(11)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

positive_templates = [
    'training accuracy improved after fixing the batch',
    'the model checkpoint loaded correctly',
    'gpu utilization increased during training',
    'validation loss decreased for another epoch',
]
negative_templates = [
    'training loss became nan after the update',
    'the labels were shuffled by mistake',
    'the notebook failed before saving a checkpoint',
    'validation accuracy collapsed on the next run',
]
rows = [(text, 1) for text in positive_templates] + [(text, 0) for text in negative_templates]

PAD = '<pad>'
UNK = '<unk>'

def tokenize(text):
    return re.findall(r"[a-z0-9']+", text.lower())

counts = Counter(token for text, _ in rows for token in tokenize(text))
itos = [PAD, UNK] + sorted(counts)
stoi = {token: i for i, token in enumerate(itos)}

def encode(text):
    return torch.tensor([stoi.get(token, stoi[UNK]) for token in tokenize(text)], dtype=torch.long)

class SentenceDataset(Dataset):
    def __len__(self):
        return len(rows)

    def __getitem__(self, index):
        text, label = rows[index]
        return encode(text), torch.tensor(label, dtype=torch.long)

def collate(batch):
    sequences, labels = zip(*batch)
    lengths = torch.tensor([len(x) for x in sequences], dtype=torch.long)
    tokens = torch.full((len(sequences), lengths.max().item()), stoi[PAD], dtype=torch.long)
    for i, sequence in enumerate(sequences):
        tokens[i, :len(sequence)] = sequence
    return tokens, lengths, torch.stack(labels)

loader = DataLoader(SentenceDataset(), batch_size=4, shuffle=True, collate_fn=collate)

3.1. GRU classifier

pack_padded_sequence lets the GRU ignore pad tokens. The batch is sorted internally by passing enforce_sorted=False.

[ ]:
class GRUClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=stoi[PAD])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, tokens, lengths):
        embedded = self.embedding(tokens)
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        _, hidden = self.gru(packed)
        return self.classifier(hidden[-1])

model = GRUClassifier(len(itos), embedding_dim=16, hidden_dim=24, num_classes=2).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
criterion = nn.CrossEntropyLoss()

epochs = 4 if check_mode else 40
for epoch in range(epochs):
    correct = 0
    total = 0
    for tokens, lengths, labels in loader:
        tokens = tokens.to(device)
        lengths = lengths.to(device)
        labels = labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(tokens, lengths)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        correct += (logits.argmax(dim=1) == labels).sum().item()
        total += labels.numel()
    print(epoch, 'accuracy', round(correct / total, 3))

tokens, lengths, labels = next(iter(loader))
logits = model(tokens.to(device), lengths.to(device))
assert logits.shape == (tokens.shape[0], 2)