大核注意力 LKA | Visual Attention Network

在这里插入图片描述

论文名称:《Visual Attention Network》

论文地址:2202.09741 (arxiv.org)


尽管最初是为自然语言处理任务而设计的,但自注意力机制最近在各个计算机视觉领域迅速崭露头角。然而,图像的二维特性给计算机视觉中的自注意力应用带来了三大挑战。(1) 将图像视为一维序列忽视了它们的二维结构。(2) 二次复杂性对于高分辨率图像来说成本太高。(3) 它只能捕捉空间适应性而忽略了通道适应性。在本文中,我们提出了一种新的线性注意力机制,称为大核注意力 (LKA),以实现自适应和长距离关联,同时避免自注意力的缺点。此外,我们还基于 LKA 提出了一种神经网络,即视觉注意力网络 (VAN)。虽然极其简单,VAN 在各种任务中超越了类似尺寸的视觉转换器 (ViTs) 和卷积神经网络 (CNNs),包括图像分类、目标检测、语义分割、全景分割、姿势估计等。例如,VAN-B6ImageNet 基准测试中取得了 87.8% 的准确率,并在全景分割方面创造了新的最先进性能(58.2 PQ)。此外,VAN-B2ADE20K 基准上的语义分割上比 Swin-T 高出 4%mIoU50.146.1),在 COCO 数据集上的目标检测中高出 2.6%AP48.846.2)。这为社区提供了一种新方法和简单但强大的基准。


问题背景

自注意力机制虽然在自然语言处理(NLP)领域取得了巨大成功,但在应用于计算机视觉时存在几个主要挑战。首先,它将二维图像视为一维序列,忽略了图像的二维结构。其次,二次复杂度的计算在高分辨率图像处理时非常昂贵。此外,自注意力仅捕捉空间适应性,而忽略了通道适应性。为了克服这些缺点,LKA提出了一种新的线性注意力机制,旨在提供长距离关联,同时避免传统自注意力机制的不足。


核心概念

LKA的核心概念在于融合卷积和自注意力的优势,包括局部结构信息、长距离依赖和适应性。它通过将大核卷积分解为三个部分来实现这一点:一个深度卷积(Depthwise Convolution)、一个深度扩张卷积(Depthwise Dilation Convolution)、以及一个点卷积(Pointwise Convolution)。这种分解使得模型可以在保持较低计算复杂度的同时,捕获长距离关系和通道维度的适应性。


模块的操作步骤

在这里插入图片描述

不同模块的结构:(a) 提出的大核注意力 (LKA);(b) 非注意力模块;© 用加法替代 LKA 中的乘法;(d) 自注意力。值得注意的是,(d) 是为一维序列设计的。

在这里插入图片描述

大核卷积的分解示意图。标准卷积可以分解为三部分:深度卷积 (DW-Conv)、深度扩张卷积 (DW-D-Conv) 和逐点卷积 (1×1 Conv)。彩色网格表示卷积核的位置,黄色网格表示中心点。示意图显示,一个 13×13 卷积被分解为一个 5×5 的深度卷积、一个膨胀率为 3 5×5 深度扩张卷积和一个逐点卷积。注意:上图中省略了零填充。


LKA模块的操作步骤包括以下几个关键部分:

  1. 深度卷积:处理局部结构信息。
  2. 深度扩张卷积:捕获长距离依赖。
  3. 点卷积:进行通道间的交互。

这一过程通过三部分来增强自注意力的长距离关联,同时保持局部信息。通过融合不同类型的卷积,LKA为特征的自适应调整提供了灵活性。


文章贡献

本文的主要贡献在于提出了一种新的线性注意力机制LKA,并基于此构建了视觉骨干网络VANLKA吸收了卷积和自注意力的优点,并避免了它们的缺点。此外,VAN在多种视觉任务中表现出色,包括图像分类、目标检测、语义分割、全景分割、姿态估计等。


实验结果与应用

实验结果表明,VANImageNet-1KCOCO等多个基准测试中取得了优异的表现。例如,VAN-B6ImageNet基准测试中取得了87.8%的准确率,并在COCO的全景分割任务中创造了新的记录。它在多个视觉任务上超过了许多传统的CNN和自注意力模型。


对未来工作的启示

LKA的成功展示了新的线性注意力机制在计算机视觉中的潜力。未来的研究可以进一步探索这种机制在其他领域的应用,或将其与其他注意力机制相结合。此外,LKA的低计算复杂度特点使其在移动设备和资源受限环境中的应用前景广阔。


代码

import torch
from torch import nn


class LKA(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(
            dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
        )
        self.conv1 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return u * attn


if __name__ == "__main__":
    input = torch.rand(3, 64, 64, 64)
    model = LKA(64)
    output = model(input)
    print(output.size())

相关推荐

  1. CUDA | 函数编写的注意事项

    2024-04-29 01:24:03       20 阅读
  2. mysql表ddl注意

    2024-04-29 01:24:03       6 阅读
  3. LSKNet:选择网络在遥感目标检测中的应用

    2024-04-29 01:24:03       41 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-29 01:24:03       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-29 01:24:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-29 01:24:03       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-29 01:24:03       20 阅读

热门阅读

  1. springboot实现同时批量新增和批量修改数据

    2024-04-29 01:24:03       16 阅读
  2. it运维管理平台:设备管理与网络监控方案

    2024-04-29 01:24:03       12 阅读
  3. kotlin语法快速入门-接口与接口实现(8)

    2024-04-29 01:24:03       12 阅读
  4. C++每日一练——只出现一次的数字

    2024-04-29 01:24:03       11 阅读
  5. 小程序中的生命周期函数

    2024-04-29 01:24:03       17 阅读
  6. 二次封装搜索组件

    2024-04-29 01:24:03       10 阅读
  7. Ollama+Open WebUI部署大模型在linux平台

    2024-04-29 01:24:03       11 阅读
  8. Vue 3 组合式API深度剖析:工具函数详解

    2024-04-29 01:24:03       18 阅读
  9. 06 华三防火墙的如何进入web页面?

    2024-04-29 01:24:03       12 阅读
  10. milvus datacoord启动源码分析

    2024-04-29 01:24:03       12 阅读