3. Sequence Classification
This matters because many practical text tasks still reduce to classifying a sequence. Focus on how lengths, padding, packing, and the final hidden state interact.
[ ]:
import os
import re
from collections import Counter
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
torch.manual_seed(11)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
positive_templates = [
'training accuracy improved after fixing the batch',
'the model checkpoint loaded correctly',
'gpu utilization increased during training',
'validation loss decreased for another epoch',
]
negative_templates = [
'training loss became nan after the update',
'the labels were shuffled by mistake',
'the notebook failed before saving a checkpoint',
'validation accuracy collapsed on the next run',
]
rows = [(text, 1) for text in positive_templates] + [(text, 0) for text in negative_templates]
PAD = '<pad>'
UNK = '<unk>'
def tokenize(text):
return re.findall(r"[a-z0-9']+", text.lower())
counts = Counter(token for text, _ in rows for token in tokenize(text))
itos = [PAD, UNK] + sorted(counts)
stoi = {token: i for i, token in enumerate(itos)}
def encode(text):
return torch.tensor([stoi.get(token, stoi[UNK]) for token in tokenize(text)], dtype=torch.long)
class SentenceDataset(Dataset):
def __len__(self):
return len(rows)
def __getitem__(self, index):
text, label = rows[index]
return encode(text), torch.tensor(label, dtype=torch.long)
def collate(batch):
sequences, labels = zip(*batch)
lengths = torch.tensor([len(x) for x in sequences], dtype=torch.long)
tokens = torch.full((len(sequences), lengths.max().item()), stoi[PAD], dtype=torch.long)
for i, sequence in enumerate(sequences):
tokens[i, :len(sequence)] = sequence
return tokens, lengths, torch.stack(labels)
loader = DataLoader(SentenceDataset(), batch_size=4, shuffle=True, collate_fn=collate)
3.1. GRU classifier
pack_padded_sequence lets the GRU ignore pad tokens. The batch is sorted internally by passing enforce_sorted=False.
[ ]:
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=stoi[PAD])
self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, tokens, lengths):
embedded = self.embedding(tokens)
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
)
_, hidden = self.gru(packed)
return self.classifier(hidden[-1])
model = GRUClassifier(len(itos), embedding_dim=16, hidden_dim=24, num_classes=2).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
criterion = nn.CrossEntropyLoss()
epochs = 4 if check_mode else 40
for epoch in range(epochs):
correct = 0
total = 0
for tokens, lengths, labels in loader:
tokens = tokens.to(device)
lengths = lengths.to(device)
labels = labels.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(tokens, lengths)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
correct += (logits.argmax(dim=1) == labels).sum().item()
total += labels.numel()
print(epoch, 'accuracy', round(correct / total, 3))
tokens, lengths, labels = next(iter(loader))
logits = model(tokens.to(device), lengths.to(device))
assert logits.shape == (tokens.shape[0], 2)