【Transformer】single self-attention的Pytorch实现

在这里插入图片描述

import torch.nn as nn
import torch
import matplotlib.pyplot as plt


class Self_Attention(nn.Module):
    def __init__(self, dim, dk, dv):
        super(Self_Attention, self).__init__()
        self.scale = dk ** -0.5
        self.q = nn.Linear(dim, dk)
        self.k = nn.Linear(dim, dk)
        self.v = nn.Linear(dim, dv)


    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = attn @ v
        return x


att = Self_Attention(dim=2, dk=2, dv=3)
x = torch.rand((1, 4, 2))
output = att(x)



代码中首先创建了一个self-attention的类,然后在随机出输入x,(1,4,2) 1是batch_size 4是4个token,2是每个token的长度

然后把x传入对象,在self-attention的初始化函数中定义了 1 d k \frac{1}{\sqrt{d_k}} dk 1,并且从输入中提出q,k,v,其中q,k的维度是一定要保持一致的,

这里面 X的输入是(batchsize,num,dim_in)num是一维序列中token的个数,这里a1到a4就4个,dim_in是每个token的特征维数,这里每一个a都是1*2的向量,特征维度为2,dim_in就为2

对于Q、K、V的维度,W1 W2 W3分别是(dim_in,dq) (dim_in,dk) (dim_in,dv) 只不过dq肯定等于dk
这样X(batchsize,num,dim_in)才能分别与W1 、W2、W3相乘得到Q、K、V
Q(batchsize,num,dq) K(batchsize,num,dk) V(batchsize,num,dv)
这样 Q K T QK^T QKT (batchsize,num,dq)*(batchsize,dk,num) = (batchsize,num,num)

K T K^T KT是依靠k.transpose(-2,-1)实现的

self.scale = dk ** -0.5这一步是计算 1 d k \frac{1}{\sqrt{d_k}} dk 1

self.q = nn.Linear(dim, dk)
self.k = nn.Linear(dim, dk)
self.v = nn.Linear(dim, dv)

torch.nn.Linear的用法可以参考官方的文档

torch.nn.Linear

在这里插入图片描述
这里的三个操作实际上就构建了W1,W2,W3矩阵,这三个矩阵分别与X相乘就得到了Q K V

def forward(self, x):
    q = self.q(x)
    k = self.k(x)
    v = self.v(x)

这里就完成了Q K V的计算

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)

    x = attn @ v

这三步就分别完成了相似度分数的计算、相似度分数的归一化和最终计算

相关推荐

  1. PytorchAttention理解和代码实现

    2024-03-10 01:36:04       32 阅读
  2. bottom-up-attention.pytorch

    2024-03-10 01:36:04       33 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-10 01:36:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-10 01:36:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-10 01:36:04       82 阅读
  4. Python语言-面向对象

    2024-03-10 01:36:04       91 阅读

热门阅读

  1. 用 reduce 实现 map 的功能

    2024-03-10 01:36:04       48 阅读
  2. 【C#语言入门】13. 表达式、语句详解(3)

    2024-03-10 01:36:04       50 阅读
  3. 基于单片机的输液监测系统设计与实现

    2024-03-10 01:36:04       41 阅读
  4. 鸿蒙崛起:能否颠覆安卓霸主地位?

    2024-03-10 01:36:04       47 阅读
  5. mongodb的备份与恢复

    2024-03-10 01:36:04       43 阅读
  6. python中的模块和包

    2024-03-10 01:36:04       50 阅读
  7. el-aside中添加el-menu设置collapse宽度自适应

    2024-03-10 01:36:04       42 阅读
  8. 2021年CCCC天梯赛

    2024-03-10 01:36:04       40 阅读