迁移学习修改最后把一层类别数

  参考以下代码

def create_model(aux, num_classes, pretrain=True):
    model = deeplabv3_resnet50(aux=aux, num_classes=num_classes)

    if pretrain:
        weights_dict = torch.load("./deeplabv3_resnet50_coco.pth", map_location='cpu')

        if num_classes != 21:
            # 官方提供的预训练权重是21类(包括背景)
            # 如果训练自己的数据集,将和类别相关的权重删除,防止权重shape不一致报错
            for k in list(weights_dict.keys()):
                if "classifier.4" in k:
                    del weights_dict[k]

        missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
        if len(missing_keys) != 0 or len(unexpected_keys) != 0:
            print("missing_keys: ", missing_keys)
            print("unexpected_keys: ", unexpected_keys)

    return model

参考视频07:33

DeepLabV3源码讲解(Pytorch)_哔哩哔哩_bilibili

相关推荐

  1. 迁移学习修改最后类别

    2024-04-29 09:20:04       34 阅读
  2. 迁移学习最新进展和挑战

    2024-04-29 09:20:04       47 阅读
  3. 迁移强化学习论文笔记()(Successor Features)

    2024-04-29 09:20:04       185 阅读
  4. Pytorch学习-调整torchvision.models中模型输出类别

    2024-04-29 09:20:04       38 阅读
  5. 组的学习

    2024-04-29 09:20:04       57 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-29 09:20:04       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-29 09:20:04       106 阅读
  3. 在Django里面运行非项目文件

    2024-04-29 09:20:04       87 阅读
  4. Python语言-面向对象

    2024-04-29 09:20:04       96 阅读

热门阅读

  1. 智能家居如何融合人工智能技术

    2024-04-29 09:20:04       35 阅读
  2. Spring Cloud Gateway直接管理Vue.js的静态资源

    2024-04-29 09:20:04       32 阅读
  3. js之探索浏览器对象模型

    2024-04-29 09:20:04       27 阅读
  4. django运行配置

    2024-04-29 09:20:04       31 阅读
  5. centos常用命令

    2024-04-29 09:20:04       30 阅读
  6. 【面经】4月22日 腾讯云智/手图服务/一面/1h

    2024-04-29 09:20:04       28 阅读
  7. 自动化密码填充:使用Python提高日常工作效率

    2024-04-29 09:20:04       41 阅读