3. Data Pipelines at Scale

This matters because data loading becomes a systems problem once the model is fast enough. Focus on worker sharding, deterministic worker seeding, and the difference between map-style and iterable datasets.

[ ]:
import itertools
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

3.1. Shard an iterable dataset by worker

[ ]:
class CountingDataset(IterableDataset):
    def __iter__(self):
        info = get_worker_info()
        if info is None:
            start, step = 0, 1
        else:
            start, step = info.id, info.num_workers
        for value in range(start, 12, step):
            yield torch.tensor([value], dtype=torch.float32)

loader = DataLoader(CountingDataset(), batch_size=2, num_workers=0)
batches = list(itertools.islice(loader, 3))
print([batch.squeeze(1).tolist() for batch in batches])
assert sum(batch.numel() for batch in batches) == 6

3.2. Worker seeding

Use a worker-init function so random transforms differ across workers while staying reproducible for a fixed global seed.

[ ]:
def seed_worker(worker_id):
    seed = 1234 + worker_id
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_worker(0)
sample_a = (random.random(), float(np.random.rand()), float(torch.rand(1)))
seed_worker(0)
sample_b = (random.random(), float(np.random.rand()), float(torch.rand(1)))
print(sample_a)
assert sample_a == sample_b