Pytorch:Attention理解和代码实现


知乎:Attention原理理解
CSDN:参考代码

一、Attention原理核心点

1、Self-Attention

a.核心原始形态

A t t e n t i o n ( X , X , X ) = s o f t m a x ( X X T ) X Attention(X,X,X)=softmax(XX^T)X Attention(X,X,X)=softmax(XXT)X

  • A = X X T A=XX^T A=XXT,它表示的矩阵是 每一个行向量 与 所有行向量的内积值。 A ( i , j ) A(i,j) A(i,j)表示向量 i i i与向量 j j j的内积值。实际上这里就是求的相关度,相关度在这里本质是由向量的内积度量的。
  • s o f t m a x ( X X T ) softmax(XX^T) softmax(XXT):归一化,使得每一行的内积值 之和为1。
  • s o f t m a x ( X X T ) X softmax(XX^T)X softmax(XXT)X:根据归一化后的相关度,进行加权求和,新的行向量。新的行向量都是对所有向量每一维分别加权求和后的结果。

b.self-Attention

自注意力机制指的是输入都是同一个矩阵 X X X,并不代表 Q , K , V Q,K,V Q,K,V是相同的。

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

  • Q , K , V Q,K,V Q,K,V其来源是 X X X 与矩阵的乘积,本质上都是 X X X 的线性变换
    为什么不直接使用 X X X 而要对其进行线性变换?
    当然是为了提升模型的拟合能力,矩阵 W W W都是可以训练的,起到一个缓冲的效果。
  • 在自注意力和多头注意力机制中,使用 d k \sqrt{d_k} dk 作为缩放因子进行缩放操作是为了防止在计算点积时由于维度较高导致的数值稳定性问题。这里的 d k d_k dk 是键向量的维度。解释如下:
    • 缩放的必要性

      • 如果不进行缩放,当 d k d_k dk 较大时,点积的结果可能会变得非常大,这会导致在应用 s o f t m a x softmax softmax 函数时产生的梯度非常小。因为 s o f t m a x softmax softmax 函数是通过指数函数计算的,大的输入值会使得部分输出接近于1,而其他接近于0,从而导致梯度消失,这会在反向传播过程中造成梯度非常小,使得学习变得非常缓慢。
    • 使用 d k \sqrt{d_k} dk 的效果

      • 通过点积结果除以 d k \sqrt{d_k} dk ,我们可以调整这些值的范围,使得它们不会太大。这样, s o f t m a x softmax softmax 的输入在一个合适的范围内,有助于避免极端的指数运算结果,从而保持数值稳定性和更有效的梯度流。这个操作确保了即使在 d k d_k dk 很大的情况下,注意力机制也能稳定并有效地学习。

在注意力机制中,尤其是在自注意力和Transformer模型中,输入通常是一个统一的输入矩阵 X X X,而这个矩阵后续会通过乘以不同的权重矩阵来转换成三个不同的向量集合:查询向量 Q Q Q、键向量 K K K和值向量 V V V。这三组向量是通过如下方式生成的:

  1. 查询向量 (Q) Q = X W Q Q = XW^Q Q=XWQ
  2. 键向量 (K) K = X W K K = XW^K K=XWK
  3. 值向量 (V) V = X W V V = XW^V V=XWV

这里的 W Q W^Q WQ, W K W^K WK, 和 W V W^V WV 是可学习的权重矩阵,分别对应于查询、键和值。这些矩阵的维度取决于模型的设计,通常它们的输出维度(列数)是预先定义的,以满足特定的模型架构要求。

在Transformer模型中,使用不同的权重矩阵 W Q W^Q WQ, W K W^K WK, 和 W V W^V WV来分别生成查询向量 Q Q Q、键向量 K K K 和值向量 V V V 的目的是为了允许模型在不同的表示空间中学习和抽取特征。这样做增加了模型的灵活性和表达能力,允许模型分别优化用于匹配(Q 和 K)和用于输出信息合成(V)的表示。

使用相同的矩阵是否可行?

理论上,可以使用同一个权重矩阵来生成 Q Q Q K K K V V V,但这样做会限制模型的能力。如果使用相同的矩阵,那么 Q 、 K 、 V Q、K、V QKV三者将无法在不同的表示空间中进行优化,从而可能导致模型无法充分捕捉到输入数据中的复杂关系。实际上,这种设计减少了模型的灵活性,可能导致性能下降。

2、常见的注意力机制

注意力机制并不全是自注意力机制,因此输入并不全是相同的。

注意力机制多种多样,各有特点,适用于不同的应用和任务。下面列举几种常见的注意力机制,并说明它们的输入及相应的权重矩阵情况:

1. 自注意力机制(Self-Attention)

  • 输入:序列或一组特征向量 X X X
  • 权重矩阵
    • W Q W^Q WQ:生成查询向量 Q Q Q
    • W K W^K WK:生成键向量 K K K
    • W V W^V WV:生成值向量 V V V
  • 应用:常见于Transformer模型,用于处理文本、图像等序列数据。

2. 多头注意力(Multi-Head Attention)

  • 输入:与自注意力相同,一组特征向量 X X X
  • 权重矩阵:逻辑上:每个“头”分别有其 W i Q W^Q_i WiQ, W i K W^K_i WiK, W i V W^V_i WiV,多头并行处理。
  • 应用:增强模型的表示能力,用于复杂任务如机器翻译、文本摘要等。在Transformer中,多头输出的不同 V i V_i Vi,通过拼接操作拼在一起,随后通过一个线性层,得到最终输出。

3. 序列到序列的注意力(Seq2Seq Attention)

  • 输入:一个编码器输出的序列 H H H(来自输入序列 X X X的变换)和一个解码器的当前状态。
  • 权重矩阵
    • W Q W^Q WQ:通常从解码器状态生成 Q Q Q
    • W K W^K WK W V W^V WV:从编码器输出 H H H 生成 K K K V V V
  • 应用:机器翻译、语音识别等任务。

4. 点积注意力(Dot-Product Attention)

  • 输入:查询 Q Q Q,键 K K K 和值 V V V(可以是相同的输入或不同的输入)。
  • 权重矩阵:可能不使用显式的权重矩阵,直接使用输入。
  • 应用:简化版的自注意力,用于需要较快运算的场景。

5. 加性注意力(Additive Attention)或串联注意力(Concat Attention)

  • 输入:通常是两个不同的输入序列,例如编码器输出和解码器状态。
  • 权重矩阵:使用一个额外的参数 W W W 来计算未标准化的注意力分数,通常通过一个小型的前馈网络实现。
  • 应用:在处理不同长度输入时比点积注意力更灵活。

6. 卷积注意力(Convolutional Attention)

  • 输入:图像或视频数据。
  • 权重矩阵:通常与卷积层结合,使用卷积核作为权重参数。
  • 应用:视觉任务中,如图像分类、目标检测。

7. 跨模态注意力(Cross-Modal Attention)

  • 输入:来自不同模态的数据,如文本和图像。
  • 权重矩阵
    • W Q W^Q WQ:从一种模态生成 Q Q Q
    • W K W^K WK W V W^V WV :从另一种模态生成 K K K V V V
  • 应用:图像字幕生成、视频问答等。

二、手撕Self-Attenion代码

理解了上述 X X X Q , K , V Q,K,V Q,K,V的转换,实际上就很容易实现自注意力机制了,转换实际上是一个线性变换,因此只需要使用三个Linear层分别表示三个权重矩阵即可。Linear层使用解释。
根据原理,我们可以得出以下结论:

  • W Q W^Q WQ W K W^K WK的形状必然是 [ d m o d e l , d k ] [d_{model},d_k] [dmodel,dk](其中 d m o d e l d_{model} dmodel是输入向量的长度, d k d_k dk并不固定)
  • Q = X W Q Q=XW^Q Q=XWQ的形状是 [ s e q _ l e n g t h , d k ] [seq\_length,d_k] [seq_length,dk]
  • Q K T QK^T QKT的形状必然是 [ b a t c h _ s i z e , s e q _ l e n g t h , s e q _ l e n g t h ] [batch\_size,seq\_length,seq\_length] [batch_size,seq_length,seq_length]
  • W V W^V WV的形状必然是 [ s e q _ l e n g t h , o u t p u t _ s i z e ] [seq\_length,output\_size] [seq_length,output_size]

理论上,由于 Q Q Q K K K求的是不同向量的相关度(点积),因此 W Q , W K W^Q,W^K WQ,WK的形状必须相同。而 W V W^V WV是用来定义输出的形状的,可以不同。但是我们这里将输出的形状定义为与输入 X X X形状相同,因此线性层输入输出相同。

以下代码是Self-Attention示例,可以发现它和公式高度匹配。

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self,seq_length):
        super(SelfAttention,self).__init__()
        self.input_size = seq_length
		# 定义三个权重矩阵
        self.Wq=nn.Linear(seq_length,seq_length)#不改变形状的线性变换
        self.Wk=nn.Linear(seq_length,seq_length)
        self.Wv=nn.Linear(seq_length,seq_length)
        
    def forward(self,input):
    	# 计算Q,K,V 三个矩阵
        q = self.Wq(input)
        k = self.Wk(input)
        v = self.Wv(input)
        
        # 计算QK^T,即向量之间的相关度 ; 这里可以理解dk了:torch.tensor(float(self.input_size)),是Wk的维度。
        attention_scores = torch.matmul(q, k.transpose(-1,-2))/torch.sqrt(torch.tensor(float(self.input_size)))
        # 计算向量权重,softmax归一化
        attention_weight = F.softmax(attention_scores, dim=-1)
        # 计算输出
        output = torch.matmul(attention_weight, v)
        return output
    
x = torch.randn(2,2,3)
Self_Attention = SelfAttention(3) # 这里的3表示输入向量的维度。
output = Self_Attention(x)
print(output.shape) #[2,2,3]

通过以上代码可以发现,自注意力机制在这里体现为,输入一个x,得到一个output,这里输入和输出是相同形状的。换句话说,自注意力机制通过将 输入向量与整个输入向量加权求和,得到的每个输出向量是包含所有加权后输入向量信息的向量。

相关推荐

  1. Pytorch:Attention理解代码实现

    2024-04-28 03:36:01       33 阅读
  2. 理解实现 LRU 缓存置换算法

    2024-04-28 03:36:01       28 阅读
  3. 数据结构算法:链表构造相关代码理解

    2024-04-28 03:36:01       44 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-28 03:36:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-28 03:36:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-04-28 03:36:01       87 阅读
  4. Python语言-面向对象

    2024-04-28 03:36:01       96 阅读

热门阅读

  1. Linux 内核深入理解 - 绪论

    2024-04-28 03:36:01       34 阅读
  2. day04--react中批量传递props

    2024-04-28 03:36:01       37 阅读
  3. 随手记:vue2 filters this指向undefined

    2024-04-28 03:36:01       30 阅读
  4. Qt——代码崩溃 free() invalid pointer

    2024-04-28 03:36:01       41 阅读
  5. Nacos

    Nacos

    2024-04-28 03:36:01      35 阅读
  6. ruoyi-cloud-plus的bom

    2024-04-28 03:36:01       31 阅读
  7. 【软考】面向对象设计

    2024-04-28 03:36:01       37 阅读
  8. 创建PAM配置文件

    2024-04-28 03:36:01       27 阅读
  9. C语言例题27:打印99乘法口诀表

    2024-04-28 03:36:01       34 阅读