昇思学习打卡-14-ResNet50迁移学习

  • 迁移学习:在一个很大的数据集上训练得到一个预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务中。
  • 本章学习使用的是前面学过的ResNet50,使用迁移学习的方法对ImageNet数据集中的狼和狗图像进行分类。

数据集可视化

在这里插入图片描述

预训练模型的使用

  • 搭建好模型框架后,通过将pretrained参数设置为True来下载ResNet50的预训练模型,并将权重参数加载到网络中。
  • 使用固定特征进行训练的时候,需要冻结除最后一层之外的所有网络层。通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度。

部分实现

import matplotlib.pyplot as plt
import os
import time
# 修改参数1pretrained=True
net_work = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool

# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
    # 修改参数2
        param.requires_grad = False

# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

推理

在这里插入图片描述
此章节学习到此结束,感谢昇思平台。

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-14 11:02:04       50 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-14 11:02:04       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-14 11:02:04       43 阅读
  4. Python语言-面向对象

    2024-07-14 11:02:04       54 阅读

热门阅读

  1. 中文科技核心论文

    2024-07-14 11:02:04       21 阅读
  2. 解决npm install 安装报错记录贴

    2024-07-14 11:02:04       23 阅读
  3. 山洪灾害研究

    2024-07-14 11:02:04       21 阅读
  4. 小白C语言基础详解:函数

    2024-07-14 11:02:04       20 阅读
  5. 【2024最新】C++扫描线算法介绍+实战例题

    2024-07-14 11:02:04       19 阅读
  6. 基于MacOS系统Sonoma 14.5的SSH服务禁止密码登录

    2024-07-14 11:02:04       22 阅读
  7. 【Druid 未授权访问漏洞】解决办法

    2024-07-14 11:02:04       22 阅读
  8. 电子版pdf格式标书怎么加盖公章?

    2024-07-14 11:02:04       24 阅读
  9. not enough information C#

    2024-07-14 11:02:04       20 阅读