【DeepLearning-6】实现倒置残差块(Inverted Residual Block)

倒置残差块(Inverted Residual Block),是MobileNetV2网络中提升效率的关键结构。

类定义和构造函数

class IRBlock(nn.Module): 
    def __init__(self, inp, oup, stride=1, expansion=4):
  • IRBlock 类继承自 nn.Module,是一个神经网络模块。
  • __init__ 方法是类的构造函数,用于初始化实例。
  • inp: 输入通道数。
  • oup: 输出通道数。
  • stride: 卷积的步长,决定了输出特征图的大小。
  • expansion: 扩展因子,用于控制内部隐藏层通道数的扩展。

内部变量和断言

self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
  • 存储步长信息。
  • 断言确保步长是1或2,这是典型的设计选择,用于控制特征图的下采样。
  • 计算隐藏层的通道数,通过输入通道数乘以扩展因子。
  • 判断是否使用残差连接,当步长为1且输入输出通道数相等时使用。

在编程中,断言(assert)是一种检查代码是否满足某些条件的方式。如果条件不成立,程序会抛出异常。在 IRBlock 类中,使用断言有几个目的:

assert stride in [1, 2]

这行代码确保传递给类的 stride 参数只能是1或2。步长(stride)在卷积神经网络中是一个关键参数,它决定了卷积层如何在输入数据上移动。步长为1意味着卷积核每次移动一个像素,步长为2则意味着卷积核每次移动两个像素。这对于输出特征图的尺寸有直接影响。

在这个特定的 MV2Block 中,只设计了处理步长为1或2的情况,因为这在大多数应用场景中是最常见和最有效的。步长为1通常用于保持特征图的尺寸,而步长为2用于减半特征图的宽度和高度,实现下采样。

通过使用断言,如果有人试图用不支持的步长值(比如3或更大的数)初始化 MV2Block,程序会立即抛出错误,这有助于及早发现问题,避免更深层次的错误和混乱。

卷积块的构建

根据扩展因子的值构建不同的卷积块结构:

根据扩展因子的值构建不同的卷积块结构是MobileNetV2架构中的一个关键设计决策,目的在于优化网络的性能和效率。让我们来详细解释这个设计选择的原因:

  1. 扩展因子(Expansion Factor): 扩展因子是MobileNetV2中引入的一个概念,用于控制倒置残差块(Inverted Residual Block)内部的中间扩展层的大小。具体来说,它决定了第一层逐点卷积(Pointwise Convolution)增加的通道数。

  2. 扩展层的作用:

    • 扩展因子 > 1: 当扩展因子大于1时,卷积块的第一部分是一个扩展层,它通过逐点卷积增加通道数。这样做的目的是在后续的深度可分离卷积(Depthwise Convolution)中提供更多的特征表示空间。这对于捕捉更复杂的特征是有益的,尤其是在网络的深层部分。
    • 扩展因子 = 1: 如果扩展因子等于1,就意味着不需要扩展层。这通常用于网络的输入层或者当输入和输出维度相同时。在这种情况下,直接使用深度可分离卷积就足够了,因为没有必要进一步增加通道数。
  3. 效率与性能的权衡:

    • 减少计算成本: 通过调整扩展因子,可以有效控制网络的计算复杂度。深度可分离卷积本身就是一种减少计算成本的设计,通过在其中加入适当的扩展层,可以进一步平衡网络的性能和计算成本。
    • 适应不同层的需求: 网络不同层对特征的需求不同。在某些层,可能需要更丰富的特征表示来捕捉复杂的模式,而在其他层,简单的特征提取可能就足够了。通过调整扩展因子,MobileNetV2可以灵活地适应这些不同层的需求。
if expansion == 1:
    # 构建没有扩展层的卷积块
else:
    # 构建包含扩展层的卷积块

代码实现:

        if expansion == 1: # 构建没有扩展层的卷积块
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:  # 构建包含扩展层的卷积块
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
  • 没有扩展层(expansion == 1: 这种情况下,块由深度可分离卷积(Depthwise Convolution)和逐点卷积(Pointwise Convolution)组成。
  • 包含扩展层(expansion != 1: 这种情况下,块先通过一个逐点卷积扩展特征维度,然后进行深度可分离卷积,最后再通过逐点卷积压缩维度到输出通道数。

每个卷积操作后都跟有批量归一化(Batch Normalization)和SiLU(也称为Swish)激活函数。

前向传播函数 forward

def forward(self, x):
    if self.use_res_connect:
        return x + self.conv(x)
    else:
        return self.conv(x)
  • forward 方法定义了数据通过块的方式。
  • self.use_res_connectTrue 时,将执行残差连接,即将块的输入加到块的输出上。否则,只返回卷积块的输出。

完整代码:

class IRBlock(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]
        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup
        if expansion == 1: # 构建没有扩展层的卷积块
            self.conv = nn.Sequential(
                # 深度可分离卷积(Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:  # 构建包含扩展层的卷积块
            self.conv = nn.Sequential(
                # 逐点卷积 (Pointwise Convolution)
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # 深度可分离卷积 (Depthwise Convolution)
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # “线性”逐点卷积 (Pointwise-Linear Convolution)
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

 

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-26 09:46:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-26 09:46:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-01-26 09:46:03       87 阅读
  4. Python语言-面向对象

    2024-01-26 09:46:03       96 阅读

热门阅读

  1. day31_CSS

    day31_CSS

    2024-01-26 09:46:03      43 阅读
  2. 面试 Vue 框架八股文十问十答第十一期

    2024-01-26 09:46:03       55 阅读
  3. 考研机试 阶乘的和

    2024-01-26 09:46:03       53 阅读
  4. 五种单例模式

    2024-01-26 09:46:03       47 阅读
  5. Midjourney 生成图片教程

    2024-01-26 09:46:03       90 阅读
  6. C++(1) 命名空间

    2024-01-26 09:46:03       52 阅读
  7. 牛刀小试 - C++ 推箱子小游戏

    2024-01-26 09:46:03       58 阅读
  8. Android - 持久化方案

    2024-01-26 09:46:03       45 阅读
  9. MOJO中导入python模块

    2024-01-26 09:46:03       50 阅读
  10. 人工智能相关的政策文件都有哪些?--九五小庞

    2024-01-26 09:46:03       54 阅读