Attention机制解析

Attention

Attention机制解析

1. 引言

Attention机制在自然语言处理(NLP)和计算机视觉(CV)等领域取得了广泛的应用。其核心思想是通过对输入数据的不同部分赋予不同的权重,使模型能够更加关注重要的信息。本文将详细介绍Attention的原理,包括Self-Attention和Cross-Attention的机制、公式解析以及代码实现,并探讨其在实际中的应用。

2. Attention机制原理

2.1 基本概念

Attention机制的基本思想是通过计算输入序列中每个元素的重要性(即注意力权重),然后对这些权重进行加权求和,从而得到输出。其公式表示如下:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

其中:

  • Q Q Q(Query):查询矩阵
  • K K K(Key):键矩阵
  • V V V(Value):值矩阵
  • d k d_k dk:键矩阵的维度

2.2 Self-Attention

Self-Attention是Attention机制的一种特殊形式,其查询、键和值都来自同一个输入序列。Self-Attention的计算步骤如下:

  1. 将输入序列映射为查询、键和值矩阵。
  2. 计算查询和键的点积,并进行缩放。
  3. 对结果应用softmax函数,得到注意力权重。
  4. 使用这些权重对值进行加权求和,得到输出。
公式

设输入序列为 X X X,则:

Q = X W Q , K = X W K , V = X W V Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV

Self-Attention ( X ) = softmax ( Q K T d k ) V \text{Self-Attention}(X) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Self-Attention(X)=softmax(dk QKT)V

其中 W Q W_Q WQ W K W_K WK W V W_V WV 是可学习的权重矩阵。

2.3 Cross-Attention

Cross-Attention与Self-Attention类似,但其查询、键和值来自不同的输入序列。通常用于结合来自不同来源的信息。

公式

设查询序列为 X X X,键和值序列为 Y Y Y,则:

Q = X W Q , K = Y W K , V = Y W V Q = XW_Q, \quad K = YW_K, \quad V = YW_V Q=XWQ,K=YWK,V=YWV

Cross-Attention ( X , Y ) = softmax ( Q K T d k ) V \text{Cross-Attention}(X, Y) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V Cross-Attention(X,Y)=softmax(dk QKT)V

3. 代码实现

3.1 Self-Attention的实现

import torch
import torch.nn.functional as F

class SelfAttention(torch.nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

3.2 Cross-Attention的实现

import torch
import torch.nn.functional as F

class CrossAttention(torch.nn.Module):
    def __init__(self, embed_size, heads):
        super(CrossAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

3.3 使用示例

没错,你看的没错,这两段代码一模一样,代码完全可以复用,只是 Q Q Q K K K V V V来源的序列不同而已。

# 假设我们有以下输入张量
embed_size = 256
heads = 8
seq_length = 10
N = 32  # batch size

values = torch.randn(N, seq_length, embed_size)
keys = torch.randn(N, seq_length, embed_size)
queries = torch.randn(N, seq_length, embed_size)
mask = None  # 这里我们没有使用mask

# SelfAttention 示例
self_attention = SelfAttention(embed_size, heads)
self_attention_output = self_attention(values, keys, queries, mask)
print("SelfAttention Output Shape:", self_attention_output.shape)

# CrossAttention 示例
cross_attention = CrossAttention(embed_size, heads)
cross_attention_output = cross_attention(values, keys, queries, mask)
print("CrossAttention Output Shape:", cross_attention_output.shape)

4. 应用

4.1 自然语言处理

在NLP中,Attention机制被广泛应用于各种任务,如机器翻译、文本生成和问答系统等。例如,Transformer模型通过使用多头自注意力机制实现了高效的序列到序列转换,极大地提高了翻译质量。

4.2 计算机视觉

在CV中,Attention机制用于图像识别、目标检测和图像生成等任务。自注意力机制可以帮助模型关注图像中的重要区域,从而提高识别精度。

4.3 多模态任务

在多模态任务中,Cross-Attention用于结合不同模态的数据,例如图像和文本的匹配、视频字幕生成等。

5. 结论

Attention机制通过动态地调整输入数据的权重,使得模型能够更有效地关注重要信息。Self-Attention和Cross-Attention分别在单一序列和多模态任务中发挥重要作用。随着研究的不断深入,Attention机制在各种领域的应用前景广阔。

希望本文能帮助读者理解Attention机制的原理和实现,并能在实际应用中加以利用。如果有任何问题或建议,欢迎在评论区留言交流。

6. 参考文献

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  2. Lin, Z., Feng, M., Santos, C. N. D., Yu, M., Xiang, B., Zhou, B., & Bengio, Y. (2017). A structured self-attentive sentence embedding. arXiv preprint arXiv:1703.03130.

7. 附录

7.1 多头自注意力机制

多头自注意力机制(Multi-Head Self-Attention)通过在不同的子空间中并行执行多个自注意力操作,使得模型能够捕获不同方面的信息。其公式为:

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO

其中:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

7.2 代码实现示例

class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.keys = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.queries = torch.nn.Linear(self.head_dim, self.embed_size, bias=False)
        self.fc_out = torch.nn.Linear(heads * self.head_dim, self.embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

以上,Enjoy your learning!

相关推荐

  1. ArrayList的扩容机制

    2024-07-17 19:34:01       32 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-17 19:34:01       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-17 19:34:01       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-17 19:34:01       58 阅读
  4. Python语言-面向对象

    2024-07-17 19:34:01       69 阅读

热门阅读

  1. bignumber.js库,解决前端小数精度问题

    2024-07-17 19:34:01       19 阅读
  2. Developing Secure Software CMP7038B

    2024-07-17 19:34:01       20 阅读
  3. 递推算法及解题套路

    2024-07-17 19:34:01       23 阅读
  4. Next.js 和 React的区别

    2024-07-17 19:34:01       21 阅读
  5. cadence许可管理解决方案

    2024-07-17 19:34:01       24 阅读
  6. Qt Style Sheets-样式表语法

    2024-07-17 19:34:01       18 阅读
  7. vue3中常用组件封装及使用

    2024-07-17 19:34:01       21 阅读