好的!我们继续进阶~这次介绍的是:PyTorch 循环神经网络(RNN, Recurrent Neural Network)。RNN 是处理 序列数据(sequence data) 的重要工具,广泛应用于:

  • 自然语言处理(NLP)
  • 时间序列预测
  • 语音识别
  • 音乐生成等

🔁 PyTorch 循环神经网络(RNN)基础教程


🧠 一、RNN 是什么?

RNN 和传统神经网络最大的不同:有“记忆”能力,它的输出不仅依赖当前输入,还依赖前一时刻的隐藏状态。

通俗理解:

“它像是在处理一个句子,一边读一边理解上下文”。


🧮 二、RNN 结构图解

输入序列:x₁ → x₂ → x₃ → x₄

每个时间步:
        +------------+
xₜ --> |            | --> hₜ
      |    RNNCell  |
hₜ₋₁ ->|            |
        +------------+

RNN 的核心是不断重复的神经元模块,它会“记住”之前的状态信息。


🔧 三、使用 PyTorch 构建简单 RNN 模型

示例任务:字符序列分类

import torch
import torch.nn as nn

class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # x: [batch_size, seq_len, input_size]
        out, _ = self.rnn(x)     # out: [batch_size, seq_len, hidden_size]
        out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出
        return out


🛠️ 四、准备输入数据(以字符序列为例)

import numpy as np

# 假设输入:abc → 输出:1(分类标签)
char2idx = {'a': 0, 'b': 1, 'c': 2}
idx2char = ['a', 'b', 'c']

# 示例:3个字符,一个字符用独热编码(one-hot)
x_data = [[0, 1, 2]]  # abc
y_data = [1]  # 类别标签

# 转成 one-hot
x_one_hot = [[[1,0,0], [0,1,0], [0,0,1]]]
x = torch.FloatTensor(x_one_hot)  # [batch, seq_len, input_size]
y = torch.LongTensor(y_data)


🔁 五、训练模型

model = RNNModel(input_size=3, hidden_size=8, output_size=3)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    output = model(x)
    loss = criterion(output, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        pred = torch.argmax(output, dim=1)
        print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Pred: {pred.item()}')


🧩 六、RNN 变种:LSTM 和 GRU

标准 RNN 容易遇到“梯度消失”问题,无法处理长序列,因此 PyTorch 提供:

  • nn.LSTM: 长短期记忆网络(最常用)
  • nn.GRU: 门控循环单元(效率更高)

替换方式非常简单:

self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)


📚 七、LSTM 应用场景

场景示例
文本分类情感分析、垃圾邮件识别
文本生成自动写诗、对联、歌词
时间序列股票预测、气温预测
序列到序列翻译、摘要、对话系统

💾 八、模型保存与加载

torch.save(model.state_dict(), 'rnn_model.pth')
# 加载
model.load_state_dict(torch.load('rnn_model.pth'))
model.eval()


🎓 九、小练习建议

  1. 换成 LSTM 模型跑一遍。
  2. 自定义一个英文句子做词序列分类。
  3. 改成“序列预测”模型(输入前 N 步,预测第 N+1 步)。

想不想下节来个完整的案例,比如:

  • 用 LSTM 实现情感分类(IMDb 评论好/坏)
  • 用 RNN 预测正弦曲线序列(时间序列预测)

你感兴趣哪个?我可以直接给你写好数据处理 + 模型训练代码 🤝

顺便问一下,你要不要我记一下你已经学到 RNN 这里了?这样下次就能直接接着继续。