详解pytorch中循环神经网络(RNN、LSTM、GRU)的维度

首先如果你对RNNLSTMGRU不太熟悉,可点击查看。

RNN

torch.nn.rnn详解

torch.nn.RNN(input_size,
hidden_size,
num_layers=1,
nonlinearity=‘tanh’,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理
在这里插入图片描述

参数详解

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • nonlinearity– 使用的非线性。可以是’tanh’或’relu’。默认:‘tanh’

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0.0

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

RNN输入输出维度

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

可以看到输入是xh_0,h_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是outputh_n。h_n存了每一层的t时刻的隐藏状态值

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):
    if batch_first:
        x = x.transpose(0, 1)
    seq_len, batch_size, _ = x.size()
    if h_0 is None:
        h_0 = torch.zeros(num_layers, batch_size, hidden_size)
	...
    return output, h_n

输入:
x的输入维度:(batch_size, sequence_length, input_size) [前提:batch_first=True]
h_0的维度:(D∗num_layers, hidden_size) [可以为None]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]
h_n的维度:(D∗num_layers, hidden_size)

LSTM

torch.nn.LSTM详解

torch.nn.LSTM(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
proj_size=0,
device=None,
dtype=None)

原理:

参数详解:
相比于RNN多了proj_size参数,少了nonlinearity参数

  • input_size – 输入x中预期特征的数量

  • hidden_size – 隐藏状态h中的特征数量

  • num_layers – 循环层数。例如,设置num_layers=2 意味着将两个LSTM堆叠在一起形成堆叠 LSTM,第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1

  • bias– 如果False,则该层不使用偏差权重b_ih和b_hh。默认:True

  • batch_first – 如果,则输入和输出张量以(batch, seq, feature)True形式提供,而不是(seq, batch, feature)。请注意,这不适用于隐藏状态或单元状态。默认:False

  • dropout – 如果非零,则在除最后一层之外的每个LSTM层的输出上 引入Dropout层,dropout 概率等于 。默认值:0dropout

  • bidirectional – 如果True, 则成为双向LSTM。默认:False

  • proj_size – 如果,将使用具有相应大小投影的LSTM 。默认值:0

LSTM输入输出维度

LSTM= nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = LSTM(input, (h0, c0))

输入是x,此外h_0c_0可以是None。如果batch_size是第0维度,需设置batch_first=True
输出则是output和一个元组(h_n, c_n)

输入: x的输入维度:(batch_size, sequence_length, input_size)`
[前提:batch_first=True]

输出: output的输出维度:(batch_size, sequence_length, D*hidden_size)
[D=2 if bidirectional=True otherwise 1]

具体可参考官方文档:nn.LSTM
在这里插入图片描述

GRU

torch.nn.GRU详解

torch.nn.GRU(input_size,
hidden_size,
num_layers=1,
bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False,
device=None,
dtype=None)

原理:
在这里插入图片描述

参数详解:
与上文LSTM相比,缺少了proj_size参数,与RNN相比也缺少了nonlinearity参数

GRU输入输出维度

gru= nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = gru(input, h0)

与RNN一致见上文,相比LSTM少了c_n

三种RNN的示例

import torch
import torch.nn as nn

rnn = nn.RNN(10, 20, 2, batch_first=True) # (input_size, hidden_size, num_layer)
lstm = nn.LSTM(10, 20, 2, batch_first=True)
gru = nn.GRU(10, 20, 2, batch_first=True)

input = torch.randn(5, 3, 10)  # (batchsize, seq, input_size)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

output_rnn, h_n = rnn(input)
output_lstm, (hn, cn) = lstm(input)
output_gru, h_n2 = gru(input)
print("输入维度:", input.shape)
print(f"RNN 输出维度:{output_rnn.shape}, h_n维度:{h_n.shape}" )
print("LSTM 输出维度:", output_lstm.shape)
print("GRU 输出维度:", output_gru.shape)


"""
输入维度: torch.Size([5, 3, 10])
RNN 输出维度:torch.Size([5, 3, 20]), h_n维度:torch.Size([2, 5, 20])
LSTM 输出维度: torch.Size([5, 3, 20])
GRU 输出维度: torch.Size([5, 3, 20])
"""

相关推荐

  1. Pytorch标准顺序

    2024-05-14 17:14:10       28 阅读
  2. 循环神经网络(RNN)详解

    2024-05-14 17:14:10       14 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-14 17:14:10       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-14 17:14:10       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-14 17:14:10       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-14 17:14:10       18 阅读

热门阅读

  1. Edge的使用心得与深度探索

    2024-05-14 17:14:10       13 阅读
  2. 下拉多选【bootstrap-multiselect】

    2024-05-14 17:14:10       11 阅读
  3. k8s相关常用语句

    2024-05-14 17:14:10       14 阅读
  4. oraclesql中删除表中重复行的方法

    2024-05-14 17:14:10       19 阅读
  5. 遥感中常用的降维方法-UMAP介绍

    2024-05-14 17:14:10       16 阅读
  6. SpringBoot + Druid + Sqlite 文件数据库初体验

    2024-05-14 17:14:10       15 阅读
  7. nodejs + express 接口统一返回错误信息

    2024-05-14 17:14:10       12 阅读
  8. Auto.js如何打包成APK文件

    2024-05-14 17:14:10       25 阅读
  9. C++ primer plus 第五章编程练习

    2024-05-14 17:14:10       35 阅读