self_attention python代码

self_attention面试code

from math import sqrt
import torch
import torch.nn as nn

class SA(nn.Module):
    def __init__(self, dimQ, dimK, dimV):
        super(SA, self).__init__()

        self.dimQ = dimQ
        self.dimK = dimK
        self.dimV = dimV

        # self.mid = 10
        
        self.linerQ = nn.Linear(self.dimQ, self.dimV, bias = False)
        self.linerK = nn.Linear(self.dimK, self.dimV, bias = False)
        self.linerV = nn.Linear(self.dimV, self.dimV, bias = False)

        self.sqrtD = 1 / sqrt(dimQ)
    
    def forward(self, x):
        batch, n, dim = x.shape

        assert(dim == self.dimQ)

        Q = self.linerQ(x)
        K = self.linerK(x)
        V = self.linerV(x)

        dist = torch.bmm(Q, K.transpose(1, 2)) * self.sqrtD
        W = torch.softmax(dist, dim = -1)

        Output = torch.bmm(W, V)
        return Output

if __name__ == "__main__":
    x = torch.tensor([[[1,2,3],[2,3,4],[3,4,5],[4,5,6]],
                     [[1,2,3],[2,3,4],[3,4,5],[4,5,6]]], dtype = torch.float)
    print(x.shape)

    saModel = SA(3, 3, 3)
    Output = saModel(x)
    print(Output)

https://zhuanlan.zhihu.com/p/338817680
https://blog.csdn.net/weixin_44750512/article/details/124244915
https://blog.csdn.net/qq_40178291/article/details/100302375

相关推荐

  1. 代码分享】

    2024-07-12 21:18:03       41 阅读
  2. 代码

    2024-07-12 21:18:03       57 阅读
  3. -代码分享-

    2024-07-12 21:18:03       40 阅读
  4. OpenFOAM代码

    2024-07-12 21:18:03       43 阅读
  5. matlab代码

    2024-07-12 21:18:03       31 阅读
  6. leetcode代码

    2024-07-12 21:18:03       28 阅读
  7. 通过代码代替注解方式注入BEAN

    2024-07-12 21:18:03       53 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-12 21:18:03       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-12 21:18:03       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-12 21:18:03       57 阅读
  4. Python语言-面向对象

    2024-07-12 21:18:03       68 阅读

热门阅读

  1. pytorch 指定GPU设备

    2024-07-12 21:18:03       21 阅读
  2. C#-反射

    C#-反射

    2024-07-12 21:18:03      14 阅读
  3. Codeforces Round #956 (Div. 2) and ByteRace 2024 A-C题解

    2024-07-12 21:18:03       23 阅读
  4. 科技与狠活

    2024-07-12 21:18:03       18 阅读
  5. 大语言模型系列-Transformer

    2024-07-12 21:18:03       21 阅读
  6. Git-Updates were rejected 解决

    2024-07-12 21:18:03       20 阅读
  7. 推荐系统中的冷启动问题及其解决方案

    2024-07-12 21:18:03       18 阅读
  8. vue在线预览excel、pdf、word文件

    2024-07-12 21:18:03       24 阅读