🧠 一、线性回归的基本原理
线性回归目标是学习下面的关系:y=wx+b
其中:
x
是输入特征w
是权重(slope)b
是偏置项(bias)y
是模型预测值
我们用 PyTorch 来让模型学会如何拟合这个公式。
🧪 二、模拟一组数据
import torch
import matplotlib.pyplot as plt
# 构造数据 y = 2x + 1 + 噪声
torch.manual_seed(0)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # shape: (100, 1)
y = 2 * x + 1 + 0.2 * torch.rand(x.size())
# 可视化数据
plt.scatter(x.numpy(), y.numpy())
plt.title("训练数据分布")
plt.show()
🛠️ 三、构建线性回归模型
import torch.nn as nn
model = nn.Linear(in_features=1, out_features=1)
print(model) # 查看结构:Linear(in_features=1, out_features=1, bias=True)
🎯 四、定义损失函数和优化器
import torch.optim as optim
criterion = nn.MSELoss() # 均方误差
optimizer = optim.SGD(model.parameters(), lr=0.05)
🔁 五、训练模型
num_epochs = 100
for epoch in range(num_epochs):
# 前向传播
outputs = model(x)
loss = criterion(outputs, y)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
📊 六、可视化拟合结果
predicted = model(x).detach() # 不计算梯度
plt.scatter(x.numpy(), y.numpy(), label='真实数据')
plt.plot(x.numpy(), predicted.numpy(), 'r-', label='预测直线')
plt.legend()
plt.title("线性回归拟合效果")
plt.show()
💾 七、保存与加载模型
# 保存模型参数
torch.save(model.state_dict(), 'linear_model.pth')
# 加载模型参数
model.load_state_dict(torch.load('linear_model.pth'))
model.eval()
✅ 八、小结:线性回归完整流程
步骤 | 操作 |
---|---|
1. 模拟数据 | 构造 x 和 y |
2. 定义模型 | 使用 nn.Linear(1, 1) |
3. 定义损失函数 | nn.MSELoss() |
4. 优化器 | optim.SGD() |
5. 训练循环 | .backward() + .step() |
6. 可视化 | matplotlib |
7. 模型保存 | torch.save() 和 torch.load() |
🎓 建议练习
- 修改数据生成公式为
y = -3x + 4 + 噪声
看拟合效果是否正常。 - 使用
nn.Sequential
来构建模型。 - 将输入维度扩展到二维、三维(多特征回归)。
发表回复