CBAM解析及代码(Pytorch)

CBAM,全称Convolutional Block Attention Module,是一种注意力机制模块,用于增强卷积神经网络(CNN)的特征表达能力。该模块由通道注意力模块和空间注意力模块两部分组成,能够分别关注输入特征图的通道信息和空间信息,进而提升模型对于重要特征的关注度。

在通道注意力模块中,CBAM通过全局平均池化和最大池化操作捕获通道间的依赖关系,生成两个通道描述子。这两个描述子随后通过共享的全连接层和ReLU激活函数进行变换,再经过Sigmoid函数得到通道注意力权重。这些权重与原始特征图相乘,实现通道维度的特征重标定。

空间注意力模块则关注特征图的空间位置信息。它首先对特征图进行通道维度的平均池化和最大池化操作,生成两个空间描述子。这两个描述子经过一个卷积层进行融合,再通过Sigmoid函数得到空间注意力权重。这些权重与原始特征图相乘,实现对空间位置的特征重标定。

CBAM模块可以轻松地嵌入到现有的卷积神经网络架构中,如ResNet、VGG等,通过增强模型的注意力能力,提升其在图像分类、目标检测等任务上的性能。同时,CBAM还具有良好的可解释性,有助于理解模型在决策过程中的关注点,为深度学习模型的可视化和解释提供了有力的工具

一、通道注意力

相同视角下,取不同的池化值,然后就是通道的压缩及扩展,最后通过sigmoid得到最终权重。其实这里模式很像SE,但也不能照搬SE,所以用了个多分支。

class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = channel // reduction
        # 使用自适应池化缩减map的大小,保持通道不变
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  #(1) 表示输出的高度和宽度都被设置为 1。
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=channel, out_features=mid_channel),
            nn.ReLU(),
            nn.Linear(in_features=mid_channel, out_features=channel)
        )
        self.sigmoid = nn.Sigmoid()
        # self.act=SiLU()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(avgout + maxout)

二、空间注意力

 

 刚才那个是通道上的压缩及放缩,那这里就是空间特征图上,依然采用两种池化方式。

# 空间注意力模块
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        # self.act=SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # map尺寸不变,缩减通道
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

该论文采用的创新,从两个不同的视角建立注意力,这个出发点还是不错的,但实际比如在小数据集上的效果怎么样,那么就需要你自己斟酌了。

三、CBAM_ResNet(Pytorch)

# ------------------------#
# CBAM模块的Pytorch实现
# ------------------------#

# 通道注意力模块
import torch.nn as nn
import torch
class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttentionModule, self).__init__()
        mid_channel = channel // reduction
        # 使用自适应池化缩减map的大小,保持通道不变
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  #(1) 表示输出的高度和宽度都被设置为 1。
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Linear(in_features=channel, out_features=mid_channel),
            nn.ReLU(),
            nn.Linear(in_features=mid_channel, out_features=channel)
        )
        self.sigmoid = nn.Sigmoid()
        # self.act=SiLU()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(avgout + maxout)


# 空间注意力模块
class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        # self.act=SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # map尺寸不变,缩减通道
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out


# CBAM模块
class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out
from CBAM import CBAM
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class CBAMBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(CBAMBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.cbam = CBAM(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.cbam(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CBAMBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(CBAMBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.cbam = CBAM(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.cbam(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


def cbam_resnet18(num_classes=1_000):
    model = ResNet(CBAMBasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet34(num_classes=1_000):
    model = ResNet(CBAMBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet50(num_classes=1_000, pretrained=False):
    model = ResNet(CBAMBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(
            "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))
    return model


def cbam_resnet101(num_classes=1_000):
    model = ResNet(CBAMBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def cbam_resnet152(num_classes=1_000):
    model = ResNet(CBAMBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


if __name__ == "__main__":
    inputs = torch.randn(2, 3, 224, 224)
    model = cbam_resnet50(pretrained=False)
    # outputs = model(inputs)
    # print(outputs.size())
    print(model)

相关推荐

  1. cbam+lstm代码预测

    2024-03-23 21:54:01       22 阅读
  2. PyTorch简单易懂的解析 Dropout Layers应用,代码示例

    2024-03-23 21:54:01       56 阅读
  3. pytorch框架下的逻辑回归代码解读

    2024-03-23 21:54:01       36 阅读
  4. Pytorch实现扩散模型【DDPM代码解读篇1】

    2024-03-23 21:54:01       31 阅读
  5. Pytorch实现扩散模型【DDPM代码解读篇2】

    2024-03-23 21:54:01       32 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-23 21:54:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-23 21:54:01       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-23 21:54:01       82 阅读
  4. Python语言-面向对象

    2024-03-23 21:54:01       91 阅读

热门阅读

  1. C#基础第一阶段

    2024-03-23 21:54:01       38 阅读
  2. string c++

    2024-03-23 21:54:01       36 阅读
  3. 剑指offer面试题42 翻转字符顺序 VS 左旋字符串

    2024-03-23 21:54:01       42 阅读
  4. 《桥接模式(极简c++)》

    2024-03-23 21:54:01       38 阅读
  5. dubbo RandomLoadBalance的一点优化

    2024-03-23 21:54:01       39 阅读
  6. 面试算法-43-最长递增子序列

    2024-03-23 21:54:01       39 阅读
  7. 使用VisualStudio集成开发nodejs的addon项目

    2024-03-23 21:54:01       39 阅读
  8. 【Docker】Docker官方发布26.0.0社区版

    2024-03-23 21:54:01       33 阅读
  9. Docker 安装 CentOS7 系统

    2024-03-23 21:54:01       37 阅读
  10. 实验7-2-10 简易连连看(PTA)

    2024-03-23 21:54:01       34 阅读
  11. RESTful架构

    2024-03-23 21:54:01       38 阅读
  12. Vue复习

    Vue复习

    2024-03-23 21:54:01      36 阅读