LSTM和GRU的介绍以及Pytorch源码解析

介绍一下LSTM模型的结构以及源码,用作自己复习的材料。 

LSTM模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

上次上一篇文章介绍了RNN序列模型,但是RNN模型存在比较严重的梯度爆炸和梯度消失问题。

本文介绍的LSTM模型解决的RNN的大部分缺陷。

首先展示LSTM的模型框架:

下面是LSTM模型的数学推导公式:

\begin{array}{ll} \\ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ h_t = o_t \odot \tanh(c_t) \\ \end{array}

h_t表示t时刻的隐藏状态,c_t表示t时刻的记忆细胞状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。

i_t,f_t,g_t,o_t 分别是输入门、遗忘门、单元门和输出门。

这张图片比较好的介绍了各个门之间的交互关系以及输入输出,大家可以放大看一下。

接下来展示GRU的框架模型:

下面是GRU的数学推导公式:

r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}

h_t表示t时刻的隐藏状态,x_t表示t时刻的输入,h_{t-1}表示在时间t-1的隐藏状态或在时间0的初始隐藏状态。r_t,n_t,z_t分别表示重置门更新门和新建门

上面的图片可以更直观的看到GRU中是如何迭代的。

接下来我们看一下源码中LSTM和GRU类的初始化(只介绍几个重要的参数):

torch.nn.LSTM(self, input_size, hidden_size, num_layers=1,
              bias=True, batch_first=False, dropout=0.0, 
              bidirectional=False, proj_size=0, device=None,
              dtype=None)
torch.nn.GRU(self, input_size, hidden_size, num_layers=1,
             bias=True, batch_first=False, dropout=0.0, 
             bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向LSTM(GRU),默认关闭。

LSTM与GRU都是特殊的RNN,因此输入输出可以参考的上一篇介绍RNN的文章,在这里直接进行代码举例。

lstm1 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
lstm2 = nn.LSTM(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)

gru1 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=True)
gru2 = nn.GRU(input_size=20,hidden_size=25,num_layers=4,bidirectional=False)

tensor1 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)
tensor2 = torch.randn(5,10,20)  # (batch_size * seq_len * emb_dim)

out_lstm1,(hn, cn) = lstm1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_lstm2,(hn, cn) = lstm2(tensor2)  # (batch_size * seq_len * (hidden_size * bidirectional))

out_gru1,h_n = gru1(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))
out_gru2,h_n = gru2(tensor1)  # (batch_size * seq_len * (hidden_size * bidirectional))

print(out_lstm1.shape)  # torch.Size([5, 10, 80])
print(out_lstm2.shape)  # torch.Size([5, 10, 40])

print(out_gru1.shape)  # torch.Size([5, 10, 50])
print(out_gru2.shape)  # torch.Size([5, 10, 25])

维度已经在注释中给大家标注上了!

相关推荐

  1. LSTMGRU区别

    2023-12-14 19:28:04       38 阅读
  2. SpringBoot

    2023-12-14 19:28:04       42 阅读
  3. ConcurrentHashMap

    2023-12-14 19:28:04       43 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-14 19:28:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-14 19:28:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-14 19:28:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-14 19:28:04       20 阅读

热门阅读

  1. CDN加速:社会服务的必备利器

    2023-12-14 19:28:04       35 阅读
  2. LeetCode 2697. 字典序最小回文串

    2023-12-14 19:28:04       44 阅读
  3. leetcode 最大和的连续子数组 C语言

    2023-12-14 19:28:04       31 阅读
  4. 敏捷开发项目管理流程及scrum工具

    2023-12-14 19:28:04       37 阅读
  5. K8S(七)—污点、容忍

    2023-12-14 19:28:04       41 阅读
  6. k8s-Pod

    k8s-Pod

    2023-12-14 19:28:04      32 阅读
  7. hive客户机执行sql脚本无法显示表头

    2023-12-14 19:28:04       37 阅读
  8. 客户端注册账号-服务器-存入数据库..

    2023-12-14 19:28:04       30 阅读
  9. 【算法】【动规】单词拆分

    2023-12-14 19:28:04       37 阅读
  10. RESTful API

    2023-12-14 19:28:04       37 阅读
  11. 线程上下文设计模式

    2023-12-14 19:28:04       31 阅读
  12. Shiro框架权限控制

    2023-12-14 19:28:04       35 阅读
  13. 在ubuntu上rmp打包:准备工作

    2023-12-14 19:28:04       40 阅读