pytorch-DATASETS & DATALOADERS
# DATASETS & DATALOADERS
PyTorch 提供了两个数据原语:torch.utils.data.DataLoader
和torch.utils.data.Dataset
,允许您使用预加载的数据集以及您自己的数据。Dataset
存储样本及其对应的标签,DataLoader
在 Dataset
周围包装了一个可迭代对象,以便轻松访问样本。
PyTorch 域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集是 torch.utils.data.Dataset
的子类,并实现了特定于特定数据的功能。 它们可用于对您的模型进行原型设计和基准测试。 您可以在此处找到它们:Image Datasets (opens new window), Text Datasets (opens new window), and Audio Datasets (opens new window)
# Loading a Dataset
以下是如何从 TorchVision 加载 Fashion-MNIST (opens new window) 数据集的示例。 Fashion-MNIST 是 Zalando 文章图像的数据集,由 60,000 个训练示例和 10,000 个测试示例组成。 每个示例都包含 28×28 灰度图像和来自 10 个类别之一的相关标签。
我们使用以下参数加载 FashionMNIST 数据集 (opens new window):
root
是存储 train/test data 的路径。train
指定是训练数据集还是测试数据集,true是训练数据集,false是测试数据集。download=True
如果在root
中不可用,则从 Internet 下载数据。transform
andtarget_transform
specify the feature and label transformations,如转换图片为tensor,转换target为softmax target。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Out:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Iterating and Visualizing the Dataset
We can index Datasets
manually like a list: training_data[index]
. We use matplotlib
to visualize some samples in our training data.
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Creating a Custom Dataset for your files
A custom Dataset class must implement three functions: __init__
,__len__
, and__getitem__
. Take a look at this implementation; the FashionMNIST images are stored in a directory img_dir
, and their labels are stored separately in a CSV file annotations_file
.
In the next sections, we’ll break down what’s happening in each of these functions.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# __init__
The __init__
function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).
The labels.csv file looks like:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
2
3
4
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
2
3
4
5
# __len__
The __len__
function returns the number of samples in our dataset.
Example:
def __len__(self):
return len(self.img_labels)
2
# __getitem__
The __getitem__
function loads and returns a sample from the dataset at the given index idx
. Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image
, retrieves the corresponding label from the csv data in self.img_labels
, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a tuple.
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
2
3
4
5
6
7
8
9
# Preparing your data for training with DataLoaders
Dataset
检索我们数据集的特征并一次标记一个样本。 在训练模型时,我们通常希望以“小批量”的形式传递样本,在每个 epoch 重新洗牌(shuffle)以减少模型过拟合,并使用 Python 的“多处理”来加速数据检索。
DataLoader
是一个迭代器,它通过一个简单的 API 为我们抽象了这种复杂性。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
2
3
4
# Iterate through the DataLoader
We have loaded that dataset into the DataLoader
and can iterate through the dataset as needed. Each iteration below returns a batch of train_features
and train_labels
(containing batch_size=64
features and labels respectively). Because we specified shuffle=True
, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers (opens new window)).
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
2
3
4
5
6
7
8
9
Out:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 4
2
3
- 01
- Linux系统移植(五)--- 制作、烧录镜像并启动Linux02-05
- 03
- Linux系统移植(三)--- Linux kernel移植02-05