4. Transformer Language Model

This matters because transformers are now the default sequence model family. Focus on the minimum moving parts: token embeddings, position information, causal masking, and next-token loss.

[ ]:
import os

import torch
from torch import nn

torch.manual_seed(17)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

corpus = (
    'tensors move to cuda. '
    'gradients flow through models. '
    'checkpoints resume training. '
    'small batches help debugging. '
) * 8

chars = sorted(set(corpus))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
data = torch.tensor([stoi[ch] for ch in corpus], dtype=torch.long)
block_size = 24

def sample_batch(batch_size):
    starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[start:start + block_size] for start in starts])
    y = torch.stack([data[start + 1:start + block_size + 1] for start in starts])
    return x.to(device), y.to(device)

print('vocab size:', len(chars))

4.1. Causal transformer

The causal mask prevents each token from attending to future tokens.

[ ]:
class TinyCausalLM(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim=32, num_heads=4):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = nn.Embedding(block_size, embedding_dim)
        layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=64,
            dropout=0.0,
            batch_first=True,
            activation='gelu',
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)
        self.head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, tokens):
        batch_size, sequence_length = tokens.shape
        positions = torch.arange(sequence_length, device=tokens.device)
        x = self.token_embedding(tokens) + self.position_embedding(positions)[None, :, :]
        causal_mask = nn.Transformer.generate_square_subsequent_mask(sequence_length, device=tokens.device)
        x = self.encoder(x, mask=causal_mask)
        return self.head(x)

model = TinyCausalLM(len(chars), block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

steps = 12 if check_mode else 200
for step in range(steps):
    x, y = sample_batch(batch_size=16)
    optimizer.zero_grad(set_to_none=True)
    logits = model(x)
    loss = criterion(logits.reshape(-1, len(chars)), y.reshape(-1))
    loss.backward()
    optimizer.step()
    if step in {0, steps - 1}:
        print(step, round(loss.item(), 4))

x, _ = sample_batch(batch_size=2)
with torch.inference_mode():
    logits = model(x)
print(logits.shape)
assert logits.shape == (2, block_size, len(chars))

4.2. Generate characters

Autoregressive generation feeds the current context back into the model, takes the last time step, samples the next token, and appends it.

[ ]:
def generate(prefix, new_tokens=30):
    model.eval()
    ids = torch.tensor([[stoi[ch] for ch in prefix]], dtype=torch.long, device=device)
    with torch.inference_mode():
        for _ in range(new_tokens):
            context = ids[:, -block_size:]
            logits = model(context)[:, -1, :]
            next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
            ids = torch.cat([ids, next_id], dim=1)
    return ''.join(itos[i] for i in ids[0].tolist())

print(generate('tensor', new_tokens=20))