【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

最近更新

  1. TCP协议是安全的吗?

    2024-03-30 13:50:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-30 13:50:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-30 13:50:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-30 13:50:04       20 阅读

热门阅读

  1. 2024年道路运输安全员考试真题题库

    2024-03-30 13:50:04       13 阅读
  2. 怎么在循环List的时候删除List的元素

    2024-03-30 13:50:04       13 阅读
  3. 类模板分文件编写

    2024-03-30 13:50:04       20 阅读
  4. [C++提高编程](三):STL-string容器

    2024-03-30 13:50:04       20 阅读
  5. 高等代数复习:矩阵秩的基本公式

    2024-03-30 13:50:04       16 阅读
  6. Manticore Search 中文分词搜索入门

    2024-03-30 13:50:04       17 阅读