PyTorch 提供了强大的工具来加载和处理数据集。PyTorch 的 torch.utils.data 模块提供了 Dataset 和 DataLoader 类,帮助我们简化数据处理和批次加载的工作。

1. Dataset 类

Dataset 是 PyTorch 数据加载的基础类,用于包装数据集。你可以通过继承 torch.utils.data.Dataset 类来定义自己的数据集。最常用的方法是 __len__ 和 __getitem__

示例:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 假设有一些数据
data = torch.randn(100, 3, 32, 32)  # 100张 3x32x32 的图片
labels = torch.randint(0, 10, (100,))  # 100个标签,范围是 0 到 9

# 创建数据集
dataset = MyDataset(data, labels)
print(len(dataset))  # 输出:100
print(dataset[0])  # 输出: (数据, 标签)

2. DataLoader 类

DataLoader 是一个迭代器,允许我们按批次加载数据。它可以自动分配批次大小、打乱数据和并行处理。

示例:

from torch.utils.data import DataLoader

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 迭代数据
for batch_data, batch_labels in dataloader:
    print(batch_data.shape)  # 输出:torch.Size([32, 3, 32, 32])
    print(batch_labels.shape)  # 输出:torch.Size([32])
    break  # 只查看第一批

3. 常用的数据集

PyTorch 提供了一些常用的数据集,如 torchvision.datasets 中的标准图像数据集。通过 torchvision,你可以轻松加载 MNIST、CIFAR 等数据集。

示例(加载 CIFAR-10 数据集):

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化
])

# 下载并加载 CIFAR-10 数据集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# 迭代数据
for images, labels in trainloader:
    print(images.shape)  # 输出:torch.Size([64, 3, 32, 32])
    print(labels.shape)  # 输出:torch.Size([64])
    break

4. 自定义数据预处理

有时你需要在加载数据时进行自定义的预处理操作,可以通过编写 transforms 来实现。例如,图像增强、裁剪、旋转等。

示例(自定义预处理):

from torchvision import transforms

# 自定义转换:裁剪、翻转、归一化
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ToTensor(),  # 转为 Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化
])

# 使用自定义的 transform
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

总结

  • Dataset 用于定义数据集,重写 __len__ 和 __getitem__ 方法。
  • DataLoader 用于处理批次数据、打乱数据和多线程加载。
  • 可以使用 torchvision 提供的标准数据集,也可以自定义数据集和预处理步骤。

你在构建和训练模型时,通常会利用这些工具高效地加载数据。