4. Transformer Language Model
This matters because transformers are now the default sequence model family. Focus on the minimum moving parts: token embeddings, position information, causal masking, and next-token loss.
[ ]:
import os
import torch
from torch import nn
torch.manual_seed(17)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
corpus = (
'tensors move to cuda. '
'gradients flow through models. '
'checkpoints resume training. '
'small batches help debugging. '
) * 8
chars = sorted(set(corpus))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
data = torch.tensor([stoi[ch] for ch in corpus], dtype=torch.long)
block_size = 24
def sample_batch(batch_size):
starts = torch.randint(0, len(data) - block_size - 1, (batch_size,))
x = torch.stack([data[start:start + block_size] for start in starts])
y = torch.stack([data[start + 1:start + block_size + 1] for start in starts])
return x.to(device), y.to(device)
print('vocab size:', len(chars))
4.1. Causal transformer
The causal mask prevents each token from attending to future tokens.
[ ]:
class TinyCausalLM(nn.Module):
def __init__(self, vocab_size, block_size, embedding_dim=32, num_heads=4):
super().__init__()
self.block_size = block_size
self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
self.position_embedding = nn.Embedding(block_size, embedding_dim)
layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=64,
dropout=0.0,
batch_first=True,
activation='gelu',
)
self.encoder = nn.TransformerEncoder(layer, num_layers=1)
self.head = nn.Linear(embedding_dim, vocab_size)
def forward(self, tokens):
batch_size, sequence_length = tokens.shape
positions = torch.arange(sequence_length, device=tokens.device)
x = self.token_embedding(tokens) + self.position_embedding(positions)[None, :, :]
causal_mask = nn.Transformer.generate_square_subsequent_mask(sequence_length, device=tokens.device)
x = self.encoder(x, mask=causal_mask)
return self.head(x)
model = TinyCausalLM(len(chars), block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
steps = 12 if check_mode else 200
for step in range(steps):
x, y = sample_batch(batch_size=16)
optimizer.zero_grad(set_to_none=True)
logits = model(x)
loss = criterion(logits.reshape(-1, len(chars)), y.reshape(-1))
loss.backward()
optimizer.step()
if step in {0, steps - 1}:
print(step, round(loss.item(), 4))
x, _ = sample_batch(batch_size=2)
with torch.inference_mode():
logits = model(x)
print(logits.shape)
assert logits.shape == (2, block_size, len(chars))
4.2. Generate characters
Autoregressive generation feeds the current context back into the model, takes the last time step, samples the next token, and appends it.
[ ]:
def generate(prefix, new_tokens=30):
model.eval()
ids = torch.tensor([[stoi[ch] for ch in prefix]], dtype=torch.long, device=device)
with torch.inference_mode():
for _ in range(new_tokens):
context = ids[:, -block_size:]
logits = model(context)[:, -1, :]
next_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
ids = torch.cat([ids, next_id], dim=1)
return ''.join(itos[i] for i in ids[0].tolist())
print(generate('tensor', new_tokens=20))