🧩 一、核心组件简介

组件作用
Dataset表示一个数据集,每个样本可以被单独索引访问
DataLoader负责按批(batch)加载数据,支持多线程与打乱顺序
Transform数据预处理操作,如归一化、裁剪、旋转等(常用于图像)

🛠️ 二、加载内置数据集(以 MNIST 为例)

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),                    # 转为 Tensor
    transforms.Normalize((0.5,), (0.5,))      # 标准化到 [-1, 1]
])

# 加载训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 访问一个 batch
images, labels = next(iter(train_loader))
print(images.shape)  # 输出: [64, 1, 28, 28]


🧠 三、自定义 Dataset(如 CSV、图片路径、JSON)

如果你不是用官方数据集,可以自定义 Dataset 类:

from torch.utils.data import Dataset
import pandas as pd

class MyDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = torch.tensor(self.data.iloc[idx, :-1].values, dtype=torch.float32)
        label = torch.tensor(self.data.iloc[idx, -1], dtype=torch.long)
        return features, label

# 使用
my_data = MyDataset("mydata.csv")
my_loader = DataLoader(my_data, batch_size=32, shuffle=True)


🖼️ 四、图像数据增强与转换(transforms)

常见于图像任务,图像预处理通常使用 torchvision.transforms

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((32, 32)),              # 调整大小
    transforms.RandomHorizontalFlip(),        # 随机翻转
    transforms.RandomRotation(15),            # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


🔄 五、DataLoader 进阶技巧

功能参数示例
批量大小batch_sizebatch_size=32
是否打乱数据shuffleshuffle=True
多线程并行加载num_workersnum_workers=4(注意在 Windows 设置为 0)
是否丢弃最后不足drop_lastdrop_last=True
自定义 batch 函数collate_fn自定义复杂数据处理逻辑时使用

🧪 六、完整小示例(分类任务)

from torchvision.datasets import CIFAR10

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 3通道
])

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

for images, labels in trainloader:
    print(images.shape, labels.shape)
    break


🧰 七、可视化加载后的图像数据(可选)

import matplotlib.pyplot as plt
import torchvision

# 取一个 batch 并可视化
dataiter = iter(trainloader)
images, labels = next(dataiter)

img = torchvision.utils.make_grid(images[:8])
img = img / 2 + 0.5  # 反归一化
plt.imshow(img.permute(1, 2, 0))  # 变换通道顺序为 (H, W, C)
plt.show()


✅ 八、小结与建议

要点说明
Dataset定义数据读取规则
DataLoader管理批量加载与并发
Transform提高模型鲁棒性的利器
自定义 Dataset可扩展至任何类型的数据

📘 推荐练习

  1. 读取本地图片文件夹并构建分类任务 Dataset。
  2. 使用 torchvision.transforms 构造不同的数据增强方案。
  3. 尝试在非图像数据(如文本、CSV)中使用 DataLoader 加载。