🧠 一、线性回归的基本原理

线性回归目标是学习下面的关系:y=wx+b

其中:

  • x 是输入特征
  • w 是权重(slope)
  • b 是偏置项(bias)
  • y 是模型预测值

我们用 PyTorch 来让模型学会如何拟合这个公式。


🧪 二、模拟一组数据

import torch
import matplotlib.pyplot as plt

# 构造数据 y = 2x + 1 + 噪声
torch.manual_seed(0)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # shape: (100, 1)
y = 2 * x + 1 + 0.2 * torch.rand(x.size())

# 可视化数据
plt.scatter(x.numpy(), y.numpy())
plt.title("训练数据分布")
plt.show()


🛠️ 三、构建线性回归模型

import torch.nn as nn

model = nn.Linear(in_features=1, out_features=1)
print(model)  # 查看结构:Linear(in_features=1, out_features=1, bias=True)


🎯 四、定义损失函数和优化器

import torch.optim as optim

criterion = nn.MSELoss()  # 均方误差
optimizer = optim.SGD(model.parameters(), lr=0.05)


🔁 五、训练模型

num_epochs = 100

for epoch in range(num_epochs):
    # 前向传播
    outputs = model(x)
    loss = criterion(outputs, y)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


📊 六、可视化拟合结果

predicted = model(x).detach()  # 不计算梯度
plt.scatter(x.numpy(), y.numpy(), label='真实数据')
plt.plot(x.numpy(), predicted.numpy(), 'r-', label='预测直线')
plt.legend()
plt.title("线性回归拟合效果")
plt.show()


💾 七、保存与加载模型

# 保存模型参数
torch.save(model.state_dict(), 'linear_model.pth')

# 加载模型参数
model.load_state_dict(torch.load('linear_model.pth'))
model.eval()


✅ 八、小结:线性回归完整流程

步骤操作
1. 模拟数据构造 x 和 y
2. 定义模型使用 nn.Linear(1, 1)
3. 定义损失函数nn.MSELoss()
4. 优化器optim.SGD()
5. 训练循环.backward() + .step()
6. 可视化matplotlib
7. 模型保存torch.save() 和 torch.load()

🎓 建议练习

  1. 修改数据生成公式为 y = -3x + 4 + 噪声 看拟合效果是否正常。
  2. 使用 nn.Sequential 来构建模型。
  3. 将输入维度扩展到二维、三维(多特征回归)。