2. Embeddings
This matters because embeddings are the entry point for most NLP, recommendation, and ID-based models. Focus on what an embedding layer learns and how pooling turns token vectors into example-level predictions.
2.1. Mean-pooled text classifier
This baseline embeds each token, masks padding, averages the remaining token vectors, and trains a classifier.
[ ]:
import os
import re
from collections import Counter
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
torch.manual_seed(7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
PAD = '<pad>'
UNK = '<unk>'
rows = [
('gpu training is fast and useful', 1),
('checkpoint the model after validation improves', 1),
('the loss exploded and produced nan values', 0),
('the data cache prevents repeated downloads', 1),
('a shuffled label bug ruins accuracy', 0),
('tensorboard helps inspect training curves', 1),
('shape mismatch errors stop the run', 0),
('small batches make debugging easier', 1),
]
def tokenize(text):
return re.findall(r"[a-z0-9']+", text.lower())
counts = Counter(token for text, _ in rows for token in tokenize(text))
itos = [PAD, UNK] + sorted(counts)
stoi = {token: index for index, token in enumerate(itos)}
def encode(text):
return torch.tensor([stoi.get(token, stoi[UNK]) for token in tokenize(text)], dtype=torch.long)
class TinyText(Dataset):
def __len__(self):
return len(rows)
def __getitem__(self, index):
text, label = rows[index]
return encode(text), torch.tensor(label, dtype=torch.long)
def collate(batch):
sequences, labels = zip(*batch)
lengths = torch.tensor([len(x) for x in sequences], dtype=torch.long)
tokens = torch.full((len(sequences), lengths.max().item()), stoi[PAD], dtype=torch.long)
for i, sequence in enumerate(sequences):
tokens[i, :len(sequence)] = sequence
return tokens, lengths, torch.stack(labels)
loader = DataLoader(TinyText(), batch_size=4, shuffle=True, collate_fn=collate)
[ ]:
class MeanEmbeddingClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_classes, padding_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
self.classifier = nn.Linear(embedding_dim, num_classes)
def forward(self, tokens, lengths):
embedded = self.embedding(tokens)
mask = (tokens != stoi[PAD]).unsqueeze(-1)
summed = (embedded * mask).sum(dim=1)
pooled = summed / lengths.clamp_min(1).unsqueeze(1)
return self.classifier(pooled)
model = MeanEmbeddingClassifier(len(itos), embedding_dim=16, num_classes=2, padding_idx=stoi[PAD]).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.03)
criterion = nn.CrossEntropyLoss()
epochs = 3 if check_mode else 20
for epoch in range(epochs):
total_loss = 0.0
for tokens, lengths, labels in loader:
tokens = tokens.to(device)
lengths = lengths.to(device)
labels = labels.to(device)
optimizer.zero_grad(set_to_none=True)
loss = criterion(model(tokens, lengths), labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(epoch, round(total_loss / len(loader), 4))
tokens, lengths, labels = next(iter(loader))
with torch.inference_mode():
logits = model(tokens.to(device), lengths.to(device))
print(logits.shape)
assert logits.shape == (tokens.shape[0], 2)
2.2. EmbeddingBag
nn.EmbeddingBag pools variable-length token sequences without first building a padded matrix. It expects one flat token vector and an offsets vector that marks where each example begins.
[ ]:
embedding_bag = nn.EmbeddingBag(len(itos), 8, mode='mean').to(device)
encoded = [encode(text) for text, _ in rows[:3]]
flat_tokens = torch.cat(encoded).to(device)
offsets = torch.tensor([0] + [len(x) for x in encoded[:-1]], dtype=torch.long).cumsum(0).to(device)
bagged = embedding_bag(flat_tokens, offsets)
print(bagged.shape)
assert bagged.shape == (3, 8)