2. Datasets and Caching

This matters because flaky downloads and repeated dataset preparation waste more time than model code. Focus on making data locations explicit, caching deterministic, and sampling choices visible.

[ ]:
import os
from collections import Counter
from pathlib import Path

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms

DATA_DIR = Path(os.getenv('PYTORCH_INTRO_DATA_DIR', './output/data'))
TORCH_HOME = Path(os.getenv('TORCH_HOME', './output/torch'))

DATA_DIR.mkdir(parents=True, exist_ok=True)
TORCH_HOME.mkdir(parents=True, exist_ok=True)

print('DATA_DIR =', DATA_DIR.resolve())
print('TORCH_HOME =', TORCH_HOME.resolve())

2.1. Local datasets

ImageFolder is useful when the directory name is the class label.

[ ]:
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
train_dataset = datasets.ImageFolder('./shapes/train', transform=transform)
valid_dataset = datasets.ImageFolder('./shapes/valid', transform=transform)

print('classes:', train_dataset.classes)
print('train size:', len(train_dataset))
print('valid size:', len(valid_dataset))

targets = [label for _, label in train_dataset.samples]
class_counts = Counter(targets)
{train_dataset.classes[index]: count for index, count in sorted(class_counts.items())}

2.2. Balanced sampling

When a dataset is imbalanced, use a sampler rather than copying files into duplicate folders.

[ ]:
weights_by_class = {label: 1.0 / count for label, count in class_counts.items()}
sample_weights = [weights_by_class[label] for label in targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

balanced_loader = DataLoader(train_dataset, batch_size=12, sampler=sampler, num_workers=2)
images, labels = next(iter(balanced_loader))

print(images.shape)
print('sampled labels:', labels.tolist())
assert images.ndim == 4

2.3. Predownload remote datasets

The checker image includes a helper that preloads selected torchvision datasets into the same cache that is mounted during notebook execution.

docker/pytorch-intro-check/download-datasets.sh mnist cifar10 usps
PYTORCH_INTRO_RUN_REMOTE_DATASETS=1 docker/pytorch-intro-check/execute-notebooks.sh data.ipynb

By default, the book examples skip remote dataset downloads so a notebook pass can run without network I/O.

[ ]:
remote_examples_enabled = os.getenv('PYTORCH_INTRO_RUN_REMOTE_DATASETS') == '1'
print('remote dataset examples enabled:', remote_examples_enabled)
print('remote dataset root would be:', DATA_DIR / 'mnist')