SelfAttention和MultiHeadAttion实现demo

#encoding:utf-8

from math import sqrt

import torch

import torch.nn as nn

class Self_Attention(nn.Module):

    def __init__(self, input_dim, dim_k, dim_v):

        super(Self_Attention, self). __init__()

        self.q = nn.Linear(input_dim, dim_k)

        self.k = nn.Linear(input_dim, dim_k)

        self.v = nn.Linear(input_dim, dim_v)

        self.norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):

        print("x.shape:", x.shape)

        # print("q.shape:", self.q.shape)

        Q = self.q(x)

        print("Q.shape:", Q.shape)

        K = self.k(x)

        print("K.shape:", K.shape)

        V = self.v(x)

        print("V.shape:", V.shape)

        atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self.norm_fact

        output = torch.bmm(atten, V)

        return output

   

print("\n")

print("self attention:")

x = torch.randn(4,3,1024)

# print(x)

print("input size:", x.size())

self_attention = Self_Attention(1024,128,5)

res = self_attention(x)

# print("\n")

# print(res)

print("output size:", res.size())

print("\n")


 

class Self_Attention_Muti_Head(nn.Module):

    def __init__(self, input_dim, dim_k, dim_v, nums_head):

        super(Self_Attention_Muti_Head, self).__init__()

        assert dim_k % nums_head == 0

        assert dim_v % nums_head == 0

        self.q = nn.Linear(input_dim, dim_k)

        self.k = nn.Linear(input_dim, dim_k)

        self.v = nn.Linear(input_dim, dim_v)

        self.nums_head = nums_head

        self.dim_k = dim_k

        self.dim_v = dim_v

        self._norm_fact = 1 / sqrt(dim_k)


 

    def forward(self, x):

        Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)

        K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)

        V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v//self.nums_head)

        print("x.shape:", x.shape)

        print("Q.shape", Q.size())

        atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0,1,3,2)))

        output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1)

        return output

   

print("\n")

print("multi head attention:")

x = torch.randn(4,3,1024)

# print(x)

print(x.size())

self_attention = Self_Attention_Muti_Head(1024,128,6,2)

res = self_attention(x)

print("\n")

# print(res)

print(res.size())

-----------------------------------------------------------------

有个问题:

根据文献:https://arxiv.org/pdf/1911.02150.pdf,感觉这里说的Multi Head Attenion和 Group Query Attention意思是一样的:

这下面这张经典的图中的的Grouped-query意思是一样的:

哪里没理解到位?

相关推荐

  1. transformer中selfattention简单实现

    2023-12-30 05:16:07       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-30 05:16:07       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-30 05:16:07       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-30 05:16:07       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-30 05:16:07       18 阅读

热门阅读

  1. 举例说明自然语言处理(NLP)技术

    2023-12-30 05:16:07       32 阅读
  2. Serverless架构学习路线及平台对比

    2023-12-30 05:16:07       34 阅读
  3. 创建第一个electron项目

    2023-12-30 05:16:07       34 阅读
  4. test ui-01-UI 测试组件之 Selenium 入门介绍

    2023-12-30 05:16:07       35 阅读
  5. 项目记录:RabbitMq+Redis配置消息队列

    2023-12-30 05:16:07       30 阅读
  6. uniapp通过蓝牙传输数据 (ios)

    2023-12-30 05:16:07       28 阅读