Transformer实战-系列教程18:DETR 源码解读5(BackboneBase类/Backbone类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

7、BackboneBase类

位置:models/backbone.py/BackboneBase类

7.1 构造函数

class BackboneBase(nn.Module):
    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {
   "layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {
   'layer4': "0"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.num_channels = num_channels
  1. 定义一个继承nn.Module的类
  2. 构造函数,传入4个参数:
    • backbone:一个nn.Module对象,代表用于特征提取的骨架网络
    • train_backbone:是否训练backbone
    • num_channels:backbone通道数
    • return_interm_layers:是否返回backbone的中间层输出
  3. 初始化
  4. 遍历backbone的所有参数,named_parameters()方法返回网络中所有参数的迭代器,包括参数的名称和值
  5. 如果train_backbone设置为False,且不训练layer2layer3layer4,也就是说如果train_backbone为False,backbone的所有层的所有参数都不需要训练,即所有层都被冻住
  6. 不需要训练的参数的requires_grad属性设置为False
  7. 根据return_interm_layers的值
  8. 选择性地设置return_layers字典
  9. 一个层对应一个值
  10. 这个字典定义了哪些层的输出将被返回
  11. 创建IntermediateLayerGetter实例,它封装了backbone,根据return_layers字典决定返回哪些层的输出,IntermediateLayerGetter来自torchvision
  12. num_channels

7.2 前向传播

    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {
   }
        for name, x in xs.items():
            m = tensor_list.mask
            assert m is not None
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out
  1. 前向传播函数,接收NestedTensor对象作为输入
  2. xs ,获取指定层的输出
  3. out,初始化一个字典,存储每个返回层的输出及其对应的新掩码
  4. 遍历xsitems
  5. 获取mask
  6. 确认mask存在
  7. 计算新的掩码
  8. 将输出和新掩码封装为NestedTensor对象
  9. 返回out字典

8、Backbone类

8.1 Backbone类

位置:models/backbone.py/Backbone类

class Backbone(BackboneBase):
    def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
        backbone = getattr(torchvision.models, name)(
            replace_stride_with_dilation=[False, False, dilation],
            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
  1. 定义一个继承BackboneBase的类
  2. 初始化方法,接受四个参数:
    • name:字符串,指定要使用的ResNet模型的名称(如resnet50resnet101等)
    • train_backbone:布尔值,指示是否训练backbone
    • return_interm_layers:布尔值,指示是否返回backbone的中间层输出
    • dilation:布尔值,指示在网络的最后几层是否应用空洞卷积(dilation)以增加感受野
  3. 通过torchvision.models动态获取指定名称的ResNet模型
  4. replace_stride_with_dilation,最后一个stage应用空洞卷积
  5. pretrained,根据is_main_process()的返回值决定是否加载预训练权重,norm_layer设置为FrozenBatchNorm2d,在backbone中使用冻结的批归一化
  6. 根据ResNet模型的不同,设置不同的输出通道数
  7. 调用基类BackboneBase的初始化方法,传递创建的backbone实例和其他参数

这个Backbone类通过提供对ResNet模型的封装,允许用户灵活地选择不同的配置,例如是否训练Backbone、是否返回中间层输出以及是否在网络后段应用空洞卷积。同时,通过使用冻结的批量归一化层,可以在不调整BN层参数的情况下,利用预训练的模型进行特征提取

8.2 build_backbone()函数

位置:models/backbone.py/build_backbone()函数

本项目的backbone,主要是调用resnet,用来提取图像特征,进而构建图像序列做Transformer的输入,backbone的构建主要通过这个函数来实现:

def build_backbone(args):
    position_embedding = build_position_encoding(args)
    train_backbone = args.lr_backbone > 0
    return_interm_layers = args.masks
    backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
    model = Joiner(backbone, position_embedding)
    model.num_channels = backbone.num_channels
    return model

这段代码定义了一个名为build_backbone的函数,用于根据提供的参数构建一个含有位置编码的骨架网络模型。以下是对这段代码的逐行解释:

  1. 函数build_backbone,接收命令行参数
  2. position_embedding ,调用build_position_encoding,函数构建位置编码
  3. 通过lr_backbone(backbone的学习率)是否大于0来决定是否训练backbone
  4. args.masks指示是否需要骨架网络返回中间层的输出
  5. 通过Backbone类构建backbone
  6. 通过Joiner类传入backbone和位置编码,建立backbone模型

DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)

最近更新

  1. TCP协议是安全的吗?

    2024-02-14 19:58:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-14 19:58:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-14 19:58:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-14 19:58:02       18 阅读

热门阅读

  1. LeetCode面试题54. 二叉搜索树的第k大节点

    2024-02-14 19:58:02       35 阅读
  2. mysql全国省市县三级联动创表sql(一)

    2024-02-14 19:58:02       39 阅读
  3. 机器视觉技术:提升安全与效率的关键

    2024-02-14 19:58:02       40 阅读
  4. Python爬虫:安全与会话管理

    2024-02-14 19:58:02       42 阅读
  5. Oracle数据库

    2024-02-14 19:58:02       30 阅读
  6. 深入解析MySQL 8:事务数据字典的变革

    2024-02-14 19:58:02       30 阅读