9. Ragged Sequences
This matters because variable-length batches create subtle padding bugs in sequence models. Focus on the two standard strategies: packing for recurrent models and masking for attention models.
[ ]:
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lengths = torch.tensor([5, 3, 2], dtype=torch.long)
tokens = torch.tensor([
[3, 4, 5, 6, 7],
[8, 9, 10, 0, 0],
[11, 12, 0, 0, 0],
], dtype=torch.long, device=device)
mask = (tokens != 0)
print(mask.int())
9.1. Packed RNN input
[ ]:
embedding = nn.Embedding(20, 8, padding_idx=0).to(device)
gru = nn.GRU(8, 12, batch_first=True).to(device)
embedded = embedding(tokens)
packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_output, hidden = gru(packed)
padded_output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
print(padded_output.shape, hidden.shape)
assert hidden.shape == (1, 3, 12)
9.2. Attention mask
Transformers use masks instead of packed sequences.
[ ]:
attention_scores = torch.randn(tokens.shape[0], tokens.shape[1], tokens.shape[1], device=device)
attention_mask = mask.unsqueeze(1).expand(-1, tokens.shape[1], -1)
masked_scores = attention_scores.masked_fill(~attention_mask, float('-inf'))
weights = torch.softmax(masked_scores, dim=-1)
print(weights.shape)
assert weights.shape == attention_scores.shape