3. Data Pipelines at Scale
This matters because data loading becomes a systems problem once the model is fast enough. Focus on worker sharding, deterministic worker seeding, and the difference between map-style and iterable datasets.
[ ]:
import itertools
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
3.1. Shard an iterable dataset by worker
[ ]:
class CountingDataset(IterableDataset):
def __iter__(self):
info = get_worker_info()
if info is None:
start, step = 0, 1
else:
start, step = info.id, info.num_workers
for value in range(start, 12, step):
yield torch.tensor([value], dtype=torch.float32)
loader = DataLoader(CountingDataset(), batch_size=2, num_workers=0)
batches = list(itertools.islice(loader, 3))
print([batch.squeeze(1).tolist() for batch in batches])
assert sum(batch.numel() for batch in batches) == 6
3.2. Worker seeding
Use a worker-init function so random transforms differ across workers while staying reproducible for a fixed global seed.
[ ]:
def seed_worker(worker_id):
seed = 1234 + worker_id
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
seed_worker(0)
sample_a = (random.random(), float(np.random.rand()), float(torch.rand(1)))
seed_worker(0)
sample_b = (random.random(), float(np.random.rand()), float(torch.rand(1)))
print(sample_a)
assert sample_a == sample_b