Transformer详解:从放弃到入门(完结)

  前几篇文章中,我们已经拆开并讲解了Transformer中的各个组件。现在我们尝试使用这些方法实现Transformer的编码器。

相关文章:
Transformer详解:从放弃到入门(一)
Transformer详解:从放弃到入门(二)
Transformer详解:从放弃到入门(三)

在这里插入图片描述  如图所示,编码器(Encoder)由N个编码器块(Encoder Block)堆叠而成,我们依次实现。

class EncoderBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float,
        norm_first: bool = False,
    ) -> None:
        super().__init__()

        self.norm_first = norm_first

        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = LayerNorm(d_model)

        self.ff = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    # self attention sub layer
    def _sa_sub_layer(
        self, x: Tensor, attn_mask: Tensor, keep_attentions: bool
    ) -> Tensor:
        x = self.attention(x, x, x, attn_mask, keep_attentions)
        return self.dropout1(x)

    def _ff_sub_layer(self, x: Tensor) -> Tensor:
        x = self.ff(x)
        return self.dropout2(x)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tuple[Tensor, Tensor]:
        # pass througth multi-head attention
        # src (batch_size, seq_length, d_model)
        # attn_score (batch_size, n_heads, seq_length, k_length)
        x = src
        if self.norm_first:
            x = x + self._sa_sub_layer(self.norm1(x), src_mask, keep_attentions)
            x = x + self._ff_sub_layer(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_sub_layer(x, src_mask, keep_attentions))
            x = self.norm2(x + self._ff_sub_layer(x))

        return x

  需要注意的是,层归一化的位置通过参数norm_first控制,默认norm_first=False,这种实现方式称为Post-LN,是Transformer的默认做法。但这种方式很难从零开始训练,把层归一化放到残差块之间,接近输出层的参数的梯度往往较大。然后在那些梯度上使用较大的学习率会使得训练不稳定。通常需要用到学习率预热(warm-up)技巧,在训练开始时学习率需要设成一个极小的值,但是一旦训练好之后的效果要优于Pre-LN的方式。而如果采用norm_first=True的方式,被称为Pre-LN,它的区别在于对于子层(*_sub_layer)的输入先进行层归一化,再输入到子层中。最后进行残差连接。
在这里插入图片描述  即实际上由上图左变成了图右,注意最后在每个Encoder或Decoder的输出上再接了一个层归一化。
  有了编码器块,我们再来实现编码器。

class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layers: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        norm_first: bool = False,
    ) -> None:
        super().__init__()
        # stack n_layers encoder blocks
        self.layers = nn.ModuleList(
            [
                EncoderBlock(d_model, n_heads, d_ff, dropout, norm_first)
                for _ in range(n_layers)
            ]
        )

        self.norm = LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tensor:
        x = src
        # pass through each layer
        for layer in self.layers:
            x = layer(x, src_mask, keep_attentions)

        return self.norm(x)

  这里要注意的是,最后对编码器和输出进行一次层归一化。至此,我们的编码器完成了,在其forward()中src是词嵌入加上位置编码,那么src_mask是什么?它是用来指示非填充标记的。我们知道,对于文本序列批数据,一个批次内序列长短不一,因此需要以一个指定的最长序列进行填充,而我们的注意力不需要在这些填充标记上进行。
  创建src_mask很简单,假设输入是填充后的批数据:

def make_src_mask(src: Tensor, pad_idx: int = 0) -> Tensor:
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask

  输出维度变成(batch_size, 1, 1, seq_length)为了与缩放点积注意力分数适配维度。

相关推荐

  1. cka入门放弃

    2024-05-09 23:14:06       55 阅读
  2. Django入门放弃

    2024-05-09 23:14:06       52 阅读
  3. Docker入门放弃

    2024-05-09 23:14:06       37 阅读
  4. 入门放弃之「ClickHouse」

    2024-05-09 23:14:06       68 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-09 23:14:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-09 23:14:06       101 阅读
  3. 在Django里面运行非项目文件

    2024-05-09 23:14:06       82 阅读
  4. Python语言-面向对象

    2024-05-09 23:14:06       91 阅读

热门阅读

  1. Pytorch基础:torch.expand() 和 torch.repeat()

    2024-05-09 23:14:06       37 阅读
  2. C数据结构:链表高级篇 约瑟夫环(杀人游戏)

    2024-05-09 23:14:06       25 阅读
  3. 第21天 反射

    2024-05-09 23:14:06       35 阅读
  4. 学习笔记:IEEE 1003.13-2003【POSIX PSE51接口列表】

    2024-05-09 23:14:06       29 阅读
  5. 数据结构(三)算法

    2024-05-09 23:14:06       31 阅读
  6. 为什么 IP 地址通常以 192.168 开头?

    2024-05-09 23:14:06       34 阅读
  7. vue3引入vant完整步骤

    2024-05-09 23:14:06       34 阅读
  8. Mybatis Plus ActiveRecord 模式

    2024-05-09 23:14:06       27 阅读