【PyTorch][chapter 26][李宏毅深度学习][attention-1]

前言:

          attention 在自然语言处理,声音处理里面是一个很重要的技巧.

   attention 要解决的是输入的向量长度不定.

  

   根据输入输出的不同,分为三种场景:

        输入N个向量,输出N个向量,这是本章的重点

        输入N个向量,输出向量不定

        输入N个向量, 输出M个向量


目录:

  1.     相关方案
  2.    self-attention
  3.    code实现


一   相关方案

    1.1  全连接网络

      输入: N个向量

      模型: N个全连接网络,每个FC模型对应一个向量

       输出: N个向量

       缺点:

       是当前向量无法获得其他向量的信息

    

   1.2 问题

       输入: N个向量

       模型: N个全连接网络,每个全连接网络,输入N个向量.

       输出:  N个向量

       缺点:

                    向量的个数定义Windows窗口.如果窗口特别大,计算量特别大。

如果windows 窗口特别小,无法采集到整个Input sequence Labeling 

     

    需要开的窗口特别大


二 self-attention

     3.1   模型架构

      

      输入   N个向量

       输出:  N 个向量

      模型: Self-attention

  3.2  主要流程

    

   1.1 计算相关系数\alpha

        两个向量的相似度有很多表达方式,例如余弦

   attention 是通过self-attention 来计算,比如要计算a^1,a^2之间的相似度

   

   q^1=a^1W_Q

  k^2=a^wW_k

  \alpha_{1,2}=q^1 \odot k^2

其中:W_Q,W_k 是代表query,key 矩阵通过训练出来的

Query:查询向量,表示要关注或检索的目标

Key:    键向量,表示要与查询向量进行匹配或比较的源

      还有种Additive 结构

2.2   通过相关系数 \alpha,计算attention-score

            同理依次算出来跟其它向量之间的相似度

       

      对相似度矩阵,通过softmax 归一化后,得到attention-score.

     

    attention-score,本质上是代表权重系数

2.3  根据attention-score , 重新计算向量

      Value:值向量,表示要根据查询向量和键向量的匹配程度来加权求和的信息

      v^1=a^1W_v 

     通过attention-score 加权求和得到b^1


三  代码

  • Query:查询向量,表示要关注或检索的目标 W_{Q}
  • Key:键向量,表示要与查询向量进行匹配或比较的源W_K
  • Value:值向量,表示要根据查询向量和键向量的匹配程度来加权求和的信息W_V
# -*- coding: utf-8 -*-
"""
Created on Tue Jul  9 21:15:05 2024

@author: cxf
"""

# -*- coding: utf-8 -*-
"""
Created on Thu Jul  4 10:37:27 2024

@author: chengxf2
"""

import torch
import torch.nn.functional as F
import torch.nn as nn


class Attention(nn.Module):
    
    
    def __init__(self, in_features,query_features,out_features):
        
        super(Attention, self).__init__()
        
        self.QUERY = nn.Linear(in_features,  query_features)
        self.KEY  =  nn.Linear(in_features,  query_features)
        self.VALUE = nn.Linear(in_features,  out_features)
        
    
    
    def forward(self,inputs):


        Q = self.QUERY(inputs)
        K = self.KEY(inputs)     
        V = self.VALUE(inputs)
        
        #计算attention
        d_k= Q.shape[-1]
        
        alpha = torch.matmul(Q, K.T)/d_k**0.5
        attention_score =F.softmax(alpha,dim=1)
        print("\n attention_score:",attention_score)
        
     
        out = torch.matmul(attention_score, V)
        
        row_index =1
        row_sum = torch.sum(attention_score[row_index,:])
        print("\n row_sum ",row_sum)
    
        return out

seq_len =5
in_features = 7
query_features =4
out_features = 3

X = torch.randn((seq_len, in_features))
net =Attention(in_features, query_features, out_features)

out = net(X)


        
        

参考:

Transformer终于有拿得出手得教程了! 台大李宏毅自注意力机制和Transformer详解!通俗易懂,草履虫都学的会!_哔哩哔哩_bilibili

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-09 22:52:09       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-09 22:52:09       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-09 22:52:09       57 阅读
  4. Python语言-面向对象

    2024-07-09 22:52:09       68 阅读

热门阅读

  1. svn常用命令

    2024-07-09 22:52:09       24 阅读
  2. 面向对象——继承、封装、多态

    2024-07-09 22:52:09       20 阅读
  3. CoppeliaSim的简单教程

    2024-07-09 22:52:09       22 阅读
  4. cadence许可管理策略

    2024-07-09 22:52:09       19 阅读
  5. Android动态设置系统音量最大值

    2024-07-09 22:52:09       26 阅读
  6. Android Enable 和clickable

    2024-07-09 22:52:09       22 阅读
  7. 0. python面试常见问题

    2024-07-09 22:52:09       23 阅读
  8. 配置linux的yum镜像为阿里镜像源

    2024-07-09 22:52:09       19 阅读
  9. Docker一键部署PostGIS

    2024-07-09 22:52:09       20 阅读
  10. C语言编程2:常用的数据类型

    2024-07-09 22:52:09       22 阅读
  11. 秒验 iOS端授权页添加自定义按钮

    2024-07-09 22:52:09       21 阅读