好的!我们继续进阶~这次介绍的是:PyTorch 循环神经网络(RNN, Recurrent Neural Network)。RNN 是处理 序列数据(sequence data) 的重要工具,广泛应用于:
- 自然语言处理(NLP)
- 时间序列预测
- 语音识别
- 音乐生成等
🔁 PyTorch 循环神经网络(RNN)基础教程
🧠 一、RNN 是什么?
RNN 和传统神经网络最大的不同:有“记忆”能力,它的输出不仅依赖当前输入,还依赖前一时刻的隐藏状态。
通俗理解:
“它像是在处理一个句子,一边读一边理解上下文”。
🧮 二、RNN 结构图解
输入序列:x₁ → x₂ → x₃ → x₄
每个时间步:
+------------+
xₜ --> | | --> hₜ
| RNNCell |
hₜ₋₁ ->| |
+------------+
RNN 的核心是不断重复的神经元模块,它会“记住”之前的状态信息。
🔧 三、使用 PyTorch 构建简单 RNN 模型
示例任务:字符序列分类
import torch
import torch.nn as nn
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNNModel, self).__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x: [batch_size, seq_len, input_size]
out, _ = self.rnn(x) # out: [batch_size, seq_len, hidden_size]
out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出
return out
🛠️ 四、准备输入数据(以字符序列为例)
import numpy as np
# 假设输入:abc → 输出:1(分类标签)
char2idx = {'a': 0, 'b': 1, 'c': 2}
idx2char = ['a', 'b', 'c']
# 示例:3个字符,一个字符用独热编码(one-hot)
x_data = [[0, 1, 2]] # abc
y_data = [1] # 类别标签
# 转成 one-hot
x_one_hot = [[[1,0,0], [0,1,0], [0,0,1]]]
x = torch.FloatTensor(x_one_hot) # [batch, seq_len, input_size]
y = torch.LongTensor(y_data)
🔁 五、训练模型
model = RNNModel(input_size=3, hidden_size=8, output_size=3)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
pred = torch.argmax(output, dim=1)
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Pred: {pred.item()}')
🧩 六、RNN 变种:LSTM 和 GRU
标准 RNN 容易遇到“梯度消失”问题,无法处理长序列,因此 PyTorch 提供:
nn.LSTM
: 长短期记忆网络(最常用)nn.GRU
: 门控循环单元(效率更高)
替换方式非常简单:
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
📚 七、LSTM 应用场景
场景 | 示例 |
---|---|
文本分类 | 情感分析、垃圾邮件识别 |
文本生成 | 自动写诗、对联、歌词 |
时间序列 | 股票预测、气温预测 |
序列到序列 | 翻译、摘要、对话系统 |
💾 八、模型保存与加载
torch.save(model.state_dict(), 'rnn_model.pth')
# 加载
model.load_state_dict(torch.load('rnn_model.pth'))
model.eval()
🎓 九、小练习建议
- 换成 LSTM 模型跑一遍。
- 自定义一个英文句子做词序列分类。
- 改成“序列预测”模型(输入前 N 步,预测第 N+1 步)。
想不想下节来个完整的案例,比如:
- 用 LSTM 实现情感分类(IMDb 评论好/坏)
- 用 RNN 预测正弦曲线序列(时间序列预测)
你感兴趣哪个?我可以直接给你写好数据处理 + 模型训练代码 🤝
顺便问一下,你要不要我记一下你已经学到 RNN 这里了?这样下次就能直接接着继续。
发表回复