4. Data

4.1. Dataset

[1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.transforms import *
from PIL import Image
import pathlib

class ShapeDataset(Dataset):

    def __init__(self, root_dir, transform=transforms.Compose([Resize(256), RandomCrop(224), ToTensor()])):
        self.root_dir = root_dir
        self.transform = transform
        self.__init()

    def __init(self):
        self.jpg_files = [f for f in pathlib.Path(self.root_dir).glob('**/*.jpg')
                          if '.ipynb_checkpoints' not in f.parts]
        self.class_to_idx = {clazz: i for i, clazz in enumerate(['circle', 'poly', 'rect'])}

    def __len__(self):
        return len(self.jpg_files)

    def __getitem__(self, idx):
        img_path = self.jpg_files[idx]

        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)

        clazz = self.class_to_idx[img_path.parts[2]]

        return image, clazz
[2]:
dataset = ShapeDataset('./shapes/train')
[3]:
len(dataset)
[3]:
30
[4]:
dataset.jpg_files
[4]:
[PosixPath('shapes/train/rect/0026.jpg'),
 PosixPath('shapes/train/rect/0029.jpg'),
 PosixPath('shapes/train/rect/0024.jpg'),
 PosixPath('shapes/train/rect/0023.jpg'),
 PosixPath('shapes/train/rect/0021.jpg'),
 PosixPath('shapes/train/rect/0028.jpg'),
 PosixPath('shapes/train/rect/0025.jpg'),
 PosixPath('shapes/train/rect/0022.jpg'),
 PosixPath('shapes/train/rect/0027.jpg'),
 PosixPath('shapes/train/rect/0020.jpg'),
 PosixPath('shapes/train/poly/0012.jpg'),
 PosixPath('shapes/train/poly/0016.jpg'),
 PosixPath('shapes/train/poly/0011.jpg'),
 PosixPath('shapes/train/poly/0018.jpg'),
 PosixPath('shapes/train/poly/0015.jpg'),
 PosixPath('shapes/train/poly/0017.jpg'),
 PosixPath('shapes/train/poly/0010.jpg'),
 PosixPath('shapes/train/poly/0014.jpg'),
 PosixPath('shapes/train/poly/0013.jpg'),
 PosixPath('shapes/train/poly/0019.jpg'),
 PosixPath('shapes/train/circle/0008.jpg'),
 PosixPath('shapes/train/circle/0007.jpg'),
 PosixPath('shapes/train/circle/0006.jpg'),
 PosixPath('shapes/train/circle/0005.jpg'),
 PosixPath('shapes/train/circle/0009.jpg'),
 PosixPath('shapes/train/circle/0001.jpg'),
 PosixPath('shapes/train/circle/0000.jpg'),
 PosixPath('shapes/train/circle/0002.jpg'),
 PosixPath('shapes/train/circle/0004.jpg'),
 PosixPath('shapes/train/circle/0003.jpg')]
[5]:
for i in range(len(dataset)):
    print(i, type(dataset[i]),
          type(dataset[i][0]),
          type(dataset[i][1]),
          dataset[i][1],
          dataset[i][0].shape)
0 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
1 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
2 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
3 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
4 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
5 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
6 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
7 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
8 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
9 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
10 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
11 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
12 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
13 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
14 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
15 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
16 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
17 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
18 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
19 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
20 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
21 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
22 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
23 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
24 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
25 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
26 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
27 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
28 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
29 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])

4.2. ImageFolder

[6]:
from torchvision import datasets

transform = transforms.Compose([Resize(256), RandomCrop(224), ToTensor()])
image_folder = datasets.ImageFolder('./shapes/train', transform=transform)
[7]:
image_folder.classes
[7]:
['circle', 'poly', 'rect']
[8]:
for clazz in image_folder.classes:
    print(clazz, image_folder.class_to_idx[clazz])
circle 0
poly 1
rect 2
[9]:
len(image_folder)
[9]:
30
[10]:
for i in range(len(image_folder)):
    print(i, type(image_folder[i]),
          type(image_folder[i][0]),
          type(image_folder[i][1]),
          image_folder[i][1],
          image_folder[i][0].shape)
0 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
1 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
2 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
3 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
4 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
5 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
6 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
7 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
8 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
9 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 0 torch.Size([3, 224, 224])
10 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
11 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
12 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
13 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
14 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
15 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
16 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
17 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
18 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
19 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 1 torch.Size([3, 224, 224])
20 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
21 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
22 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
23 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
24 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
25 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
26 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
27 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
28 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])
29 <class 'tuple'> <class 'torch.Tensor'> <class 'int'> 2 torch.Size([3, 224, 224])

4.3. DataLoader

4.3.1. Using a custom dataset

[11]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=4)
[12]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch,
          type(sample_batched), len(sample_batched),
          type(sample_batched[0]))
0 <class 'list'> 2 <class 'torch.Tensor'>
1 <class 'list'> 2 <class 'torch.Tensor'>
2 <class 'list'> 2 <class 'torch.Tensor'>
3 <class 'list'> 2 <class 'torch.Tensor'>
4 <class 'list'> 2 <class 'torch.Tensor'>
5 <class 'list'> 2 <class 'torch.Tensor'>
6 <class 'list'> 2 <class 'torch.Tensor'>
7 <class 'list'> 2 <class 'torch.Tensor'>
8 <class 'list'> 2 <class 'torch.Tensor'>
9 <class 'list'> 2 <class 'torch.Tensor'>
10 <class 'list'> 2 <class 'torch.Tensor'>
11 <class 'list'> 2 <class 'torch.Tensor'>
12 <class 'list'> 2 <class 'torch.Tensor'>
13 <class 'list'> 2 <class 'torch.Tensor'>
14 <class 'list'> 2 <class 'torch.Tensor'>
[13]:
for inputs, labels in dataloader:
    print(type(inputs), type(labels), ':', inputs.shape, labels.shape)
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
[14]:
for inputs, labels in dataloader:
    print(labels)
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])

4.3.2. Using ImageFolder

[15]:
dataloader = DataLoader(image_folder, batch_size=2, shuffle=False, num_workers=4)
[16]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch,
          type(sample_batched), len(sample_batched),
          type(sample_batched[0]))
0 <class 'list'> 2 <class 'torch.Tensor'>
1 <class 'list'> 2 <class 'torch.Tensor'>
2 <class 'list'> 2 <class 'torch.Tensor'>
3 <class 'list'> 2 <class 'torch.Tensor'>
4 <class 'list'> 2 <class 'torch.Tensor'>
5 <class 'list'> 2 <class 'torch.Tensor'>
6 <class 'list'> 2 <class 'torch.Tensor'>
7 <class 'list'> 2 <class 'torch.Tensor'>
8 <class 'list'> 2 <class 'torch.Tensor'>
9 <class 'list'> 2 <class 'torch.Tensor'>
10 <class 'list'> 2 <class 'torch.Tensor'>
11 <class 'list'> 2 <class 'torch.Tensor'>
12 <class 'list'> 2 <class 'torch.Tensor'>
13 <class 'list'> 2 <class 'torch.Tensor'>
14 <class 'list'> 2 <class 'torch.Tensor'>
[17]:
for inputs, labels in dataloader:
    print(type(inputs), type(labels), ':', inputs.shape, labels.shape)
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
<class 'torch.Tensor'> <class 'torch.Tensor'> : torch.Size([2, 3, 224, 224]) torch.Size([2])
[18]:
for inputs, labels in dataloader:
    print(labels)
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])
tensor([0, 0])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([1, 1])
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])
tensor([2, 2])

4.3.3. Loading tensors to device

[19]:
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'device = {device}')

for inputs, labels in dataloader:
    inputs, labels = inputs.to(device), labels.to(device)
    print(labels)
device = cuda
tensor([0, 0], device='cuda:0')
tensor([0, 0], device='cuda:0')
tensor([0, 0], device='cuda:0')
tensor([0, 0], device='cuda:0')
tensor([0, 0], device='cuda:0')
tensor([1, 1], device='cuda:0')
tensor([1, 1], device='cuda:0')
tensor([1, 1], device='cuda:0')
tensor([1, 1], device='cuda:0')
tensor([1, 1], device='cuda:0')
tensor([2, 2], device='cuda:0')
tensor([2, 2], device='cuda:0')
tensor([2, 2], device='cuda:0')
tensor([2, 2], device='cuda:0')
tensor([2, 2], device='cuda:0')

4.4. Datasets

Here are a few datasets accessible through the PyTorch API that you may load.

4.4.1. MNIST

A handwritten digit database.

[ ]:
from torchvision import datasets

data_root = './output/mnist'
num_workers = 4
batch_size = 64

dataset = datasets.MNIST(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.2. Fashion MNIST

An article (clothing) database.

[ ]:
data_root = './output/fasionmnist'
dataset = datasets.FashionMNIST(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.3. KMNIST

Kuzushiji database of classic Japanese characters.

[ ]:
data_root = './output/kmnist'
dataset = datasets.KMNIST(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.4. EMNIST

A handwritten character digit database.

[ ]:
data_root = './output/emnist'
dataset = datasets.EMNIST(root=data_root, split='byclass', download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.5. QMNIST

A handwritten character digit database based on EMNIST.

[ ]:
data_root = './output/qmnist'
dataset = datasets.QMNIST(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.6. FakeData

A fake dataset that returns randomly generated images.

[ ]:
dataset = datasets.FakeData(size=1000,
                            image_size=(3, 224, 224),
                            num_classes=10, transform=None,
                            target_transform=None, random_offset=0)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.7. LSUN

A dataset of 10 scene categories and 20 object categories.

[ ]:
data_root = './output/lsun'
dataset = datasets.LSUN(root=data_root, classes='train', transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.8. ImageNet

An image database organized according to the WordNet hierarchy (nouns).

[ ]:
data_root = './output/imagenet'
dataset = datasets.ImageNet(root=data_root, split='train', download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.9. CIFAR10

A dataset of 60,000 32x32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 testing images.

[ ]:
data_root = './output/cifar10'
dataset = datasets.CIFAR10(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.10. STL10

An image recognition dataset.

[ ]:
data_root = './output/stl10'
dataset = datasets.STL10(root=data_root, split='train', download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.11. SVHN

An image dataset for object recognition.

[ ]:
data_root = './output/svhn'
dataset = datasets.SVHN(root=data_root, split='train', download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.12. PhotoTour

A dataset consisting of 1024x1024 bitmap images, each with a 16x16 array of image patches. Available datasets.

  • notredame_harris

  • yosemite_harris

  • liberty_harris

  • notredame

  • yosemite

  • liberty

[ ]:
data_root = './output/phototour'
dataset = datasets.PhotoTour(root=data_root, name='notredame_harris', download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.13. SBU

A captioned photo dataset.

[ ]:
data_root = './output/sbu'
dataset = datasets.SBU(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.14. VOC

A dataset for segmentation and object recognition.

[ ]:
data_root = './output/voc'
dataset = datasets.VOCSegmentation(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.15. SBD

A semantic boundary dataset.

[ ]:
data_root = './output/sbd'
dataset = datasets.SBDataset(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.16. USPS

An image classification dataset.

[ ]:
data_root = './output/usps'
dataset = datasets.USPS(root=data_root, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

4.4.17. Kinetics-400

An action recognition video dataset.

[ ]:
data_root = './output/kinetics400'
dataset = datasets.Kinetics400(root=data_root,
                               frames_per_clip=10, download=True, transform=None)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)