PyTorch 提供了强大的工具来加载和处理数据集。PyTorch 的 torch.utils.data
模块提供了 Dataset
和 DataLoader
类,帮助我们简化数据处理和批次加载的工作。
1. Dataset 类
Dataset
是 PyTorch 数据加载的基础类,用于包装数据集。你可以通过继承 torch.utils.data.Dataset
类来定义自己的数据集。最常用的方法是 __len__
和 __getitem__
。
示例:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 假设有一些数据
data = torch.randn(100, 3, 32, 32) # 100张 3x32x32 的图片
labels = torch.randint(0, 10, (100,)) # 100个标签,范围是 0 到 9
# 创建数据集
dataset = MyDataset(data, labels)
print(len(dataset)) # 输出:100
print(dataset[0]) # 输出: (数据, 标签)
2. DataLoader 类
DataLoader
是一个迭代器,允许我们按批次加载数据。它可以自动分配批次大小、打乱数据和并行处理。
示例:
from torch.utils.data import DataLoader
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 迭代数据
for batch_data, batch_labels in dataloader:
print(batch_data.shape) # 输出:torch.Size([32, 3, 32, 32])
print(batch_labels.shape) # 输出:torch.Size([32])
break # 只查看第一批
3. 常用的数据集
PyTorch 提供了一些常用的数据集,如 torchvision.datasets
中的标准图像数据集。通过 torchvision
,你可以轻松加载 MNIST、CIFAR 等数据集。
示例(加载 CIFAR-10 数据集):
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])
# 下载并加载 CIFAR-10 数据集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# 迭代数据
for images, labels in trainloader:
print(images.shape) # 输出:torch.Size([64, 3, 32, 32])
print(labels.shape) # 输出:torch.Size([64])
break
4. 自定义数据预处理
有时你需要在加载数据时进行自定义的预处理操作,可以通过编写 transforms
来实现。例如,图像增强、裁剪、旋转等。
示例(自定义预处理):
from torchvision import transforms
# 自定义转换:裁剪、翻转、归一化
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转为 Tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])
# 使用自定义的 transform
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
总结
Dataset
用于定义数据集,重写__len__
和__getitem__
方法。DataLoader
用于处理批次数据、打乱数据和多线程加载。- 可以使用
torchvision
提供的标准数据集,也可以自定义数据集和预处理步骤。
你在构建和训练模型时,通常会利用这些工具高效地加载数据。
发表回复