DEiT中如何处理mask数据的?与MAE的不同

在DeiT里面,是通过mask的方式,将mask+unmasked的patches输出进ViT中,但其实在下游任务输入的patches还是和训练时patches的数量N是一致的(encoder所有的patches)。

而MAE是在encoder中只encoder未被mask的patches

通过什么方式支持的?

  • 在处理文本时,可以根据最长的句子在批次中动态padding或截断长句子
  • 而在处理图像(如使用ViT)时,可以将图像划分为大小相等的patches,数量可以根据图像的大小动态变化。

在训练阶段,部分patches被mask为0,但是处理的所有patches加起来的总长度还是一样的。被mask的位置在模型内部仍然占位,保持了输入序列的“框架”。这样,即使实际参与计算的只是部分元素,模型也能够适应在推理时使用全部元素的情况。

具体的计算步骤如下:

  1. 确定mask哪些patches
  2. 将mask的patches位置设置为0
  3. 这些被mask和未被mask的所有patches一起被输入进attention模块
  4. 将被mask的patches的注意力分数手动设置为“无穷大负数”(-inf)
  5. 这些被mask的patches的softmax值就会变为0,也就意味着这些patches并未参与注意力的计算
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaskedSelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(MaskedSelfAttention, self).__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # 计算自注意力得分
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))

        # 将mask值为0的位置在attention_scores中设置为一个非常大的负数
        attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # 使得这些位置的softmax结果接近0
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # 算最终的注意力加权和
        output = torch.matmul(attention_weights, V)

        return output


# 假设嵌入大小为512
embed_size = 512
# 创建一个mask,假设我们有4个patches,我们想要mask掉第2个和第4个patches
mask = torch.tensor([[1, 0, 1, 0]])
# 扩展mask维度以适应attention_scores的形状(假设批大小为1,序列长度为4),mask需要与attention_scores形状匹配,即(batch_size, 1, 1, seq_length)
mask = mask.unsqueeze(1).unsqueeze(2)

# 初始化模型和数据
sa = MaskedSelfAttention(embed_size)
x = torch.randn(1, 4, embed_size)  # 假设有一个批大小为1,序列长度为4的输入

output = sa(x, mask)
print(output)

相关推荐

  1. webscoket mask 细节

    2024-03-18 08:04:03       41 阅读
  2. nlptransformermask

    2024-03-18 08:04:03       63 阅读
  3. 华纳云:ApacheBeam延迟数据处理如何处理

    2024-03-18 08:04:03       41 阅读
  4. 数据分析-Pandas如何处理表格文本数据

    2024-03-18 08:04:03       61 阅读
  5. 什么是图像掩膜(Mask),如何使用掩码

    2024-03-18 08:04:03       68 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-18 08:04:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-18 08:04:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-03-18 08:04:03       87 阅读
  4. Python语言-面向对象

    2024-03-18 08:04:03       96 阅读

热门阅读

  1. 云计算基础(一)

    2024-03-18 08:04:03       40 阅读
  2. python --阿里云(智能媒体管理/视频点播)

    2024-03-18 08:04:03       42 阅读
  3. p2p原理

    2024-03-18 08:04:03       43 阅读
  4. XSS基础知识

    2024-03-18 08:04:03       38 阅读
  5. 实验7-1-5 交换最小值和最大值(PTA)

    2024-03-18 08:04:03       38 阅读
  6. python--剑指offer--题目目录-学习计划

    2024-03-18 08:04:03       35 阅读
  7. mybatis mapper.xml获取insert后的自增ID

    2024-03-18 08:04:03       42 阅读
  8. 网络安全主题

    2024-03-18 08:04:03       40 阅读
  9. 利用适配器模式使用第三方库

    2024-03-18 08:04:03       40 阅读