🧩 一、核心组件简介
组件 | 作用 |
---|
Dataset | 表示一个数据集,每个样本可以被单独索引访问 |
DataLoader | 负责按批(batch)加载数据,支持多线程与打乱顺序 |
Transform | 数据预处理操作,如归一化、裁剪、旋转等(常用于图像) |
🛠️ 二、加载内置数据集(以 MNIST 为例)
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转为 Tensor
transforms.Normalize((0.5,), (0.5,)) # 标准化到 [-1, 1]
])
# 加载训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
# 访问一个 batch
images, labels = next(iter(train_loader))
print(images.shape) # 输出: [64, 1, 28, 28]
🧠 三、自定义 Dataset(如 CSV、图片路径、JSON)
如果你不是用官方数据集,可以自定义 Dataset
类:
from torch.utils.data import Dataset
import pandas as pd
class MyDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
features = torch.tensor(self.data.iloc[idx, :-1].values, dtype=torch.float32)
label = torch.tensor(self.data.iloc[idx, -1], dtype=torch.long)
return features, label
# 使用
my_data = MyDataset("mydata.csv")
my_loader = DataLoader(my_data, batch_size=32, shuffle=True)
🖼️ 四、图像数据增强与转换(transforms)
常见于图像任务,图像预处理通常使用 torchvision.transforms
:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((32, 32)), # 调整大小
transforms.RandomHorizontalFlip(), # 随机翻转
transforms.RandomRotation(15), # 随机旋转
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
🔄 五、DataLoader 进阶技巧
功能 | 参数 | 示例 |
---|
批量大小 | batch_size | batch_size=32 |
是否打乱数据 | shuffle | shuffle=True |
多线程并行加载 | num_workers | num_workers=4 (注意在 Windows 设置为 0) |
是否丢弃最后不足 | drop_last | drop_last=True |
自定义 batch 函数 | collate_fn | 自定义复杂数据处理逻辑时使用 |
🧪 六、完整小示例(分类任务)
from torchvision.datasets import CIFAR10
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 3通道
])
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
for images, labels in trainloader:
print(images.shape, labels.shape)
break
🧰 七、可视化加载后的图像数据(可选)
import matplotlib.pyplot as plt
import torchvision
# 取一个 batch 并可视化
dataiter = iter(trainloader)
images, labels = next(dataiter)
img = torchvision.utils.make_grid(images[:8])
img = img / 2 + 0.5 # 反归一化
plt.imshow(img.permute(1, 2, 0)) # 变换通道顺序为 (H, W, C)
plt.show()
✅ 八、小结与建议
要点 | 说明 |
---|
Dataset | 定义数据读取规则 |
DataLoader | 管理批量加载与并发 |
Transform | 提高模型鲁棒性的利器 |
自定义 Dataset | 可扩展至任何类型的数据 |
📘 推荐练习
- 读取本地图片文件夹并构建分类任务 Dataset。
- 使用
torchvision.transforms
构造不同的数据增强方案。
- 尝试在非图像数据(如文本、CSV)中使用
DataLoader
加载。
发表回复