昇思25天学习打卡营第23天|ResNet50图像分类

课程打卡凭证

ResNet网络

ResNet(Residual Networks,残差网络)是一种深度神经网络结构,它的核心思想是引入了“残差学习”来解决深度网络中的退化问题。在深度神经网络中,当网络层数增加到一定程度时,网络的性能可能会出现退化,即准确率不再提升甚至下降,这并不是由于过拟合引起的,如下图所示,20层网络比56层网络的训练误差和测试误差更大。

ResNet通过引入残差连接,允许网络在训练过程中跳过某些层,直接将输入传递到输出,从而保留了原始特征,并使得网络在深度增加时仍然能够保持较好的性能,如下图所示。

数据集准备与加载

下载数据集。

导入必要的库和模块,设置参数的初始值。

根据不同的用途(训练或测试)加载 CIFAR-10 数据集,并设置并行线程和随机打乱数据。如果是训练集,添加随机裁剪和随机水平翻转。再对所有数据都进行统一的操作,包括图像大小调整、归一化、标准化和维度转换(从 HWC 转为 CHW)。最后将数据集按批量大小进行批量处理。

从训练数据集中获取前六张图片及其对应的标签,并将其可视化,结果如图所示。

构建网络

ResNet的网络结构主要由多个残差块(Residual Building Block)组成,每个残差块包含多个级联的卷积层和一个残差连接。残差连接是跳过一层或多层的连接,它将输入直接加到残差块的输出上,形成残差学习的基本单元。它主要由以下几个部分组成:

输入层:接收输入图像,并进行初步处理。

卷积层:使用较大的卷积核进行卷积操作,以提取图像的基本特征。

卷积组:每个卷积组包含多个残差块,每个残差块由多个级联的卷积层和一个残差连接组成。随着网络深度的增加,卷积核的数量和大小也会相应调整。

输出层:对卷积组的输出进行全局平均池化,并连接全连接层进行分类或其他任务。

导入必要的模块,初始化重要参数。

定义标准的残差块,它带有两个 3x3 的卷积层和一个可选的下采样层,用以构建ResNet网络。

通过ResidualBlock类实现了Bottleneck残差块,通过引入 1x1 卷积层,显著减少了计算量,同时保持了较高的特征提取能力。

该函数用于构建由多个残差块(Residual Block)堆叠而成的网络层,它通过在第一层中添加下采样层,可以在输入维度和输出维度不匹配时调整输入维度,从而确保残差连接的正确性。

ResNet类实现了一个典型的深度卷积神经网络结构,采用了残差块来缓解深层网络中的梯度消失问题。通过在每个残差块中引入shortcut connection,可以直接将输入信息传递到输出,从而更有效地训练深层神经网络。

_resnet函数可以初始化一个ResNet模型,它根据指定的残差块类型和层数构建模型,并且可以选择加载预训练模型参数。resnet50函数通过调用_resnet函数,使用特定的参数配置,构建并返回一个ResNet50模型实例。

模型训练与评估

定义ResNet50网络并加载预训练模型,获取全连接层输入层的大小并重新定义全连接层,再替换网络中的全连接层。

设置一个动态的学习率策略,使用动量优化器和交叉熵损失函数进行训练,定义前向计算和梯度计算函数,执行训练步骤,更新模型参数并返回损失值。

为模型训练和验证准备数据加载器,并设置保存最佳模型的路径。

定义训练函数,用于训练模型,每个批次计算损失并更新模型参数,同时打印训练进度,并返回每个epoch的平均损失值。评估函数用于验证模型,在验证集上计算准确率,返回模型的预测准确率。

开始训练模型,结果如下图所示。

可视化模型预测

结果如下图所示。

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-22 00:42:01       51 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-22 00:42:01       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-22 00:42:01       44 阅读
  4. Python语言-面向对象

    2024-07-22 00:42:01       55 阅读

热门阅读

  1. 数组指针跟指针数组的区别

    2024-07-22 00:42:01       16 阅读
  2. OpenWRT/iStoreOS 安装 qemu-guest-agent

    2024-07-22 00:42:01       15 阅读
  3. 计算机学院——秋招的总结

    2024-07-22 00:42:01       17 阅读
  4. go中map

    go中map

    2024-07-22 00:42:01      16 阅读
  5. 计算并输出杨辉三角形的前10行

    2024-07-22 00:42:01       20 阅读
  6. 线程局部变量共享 -- 使用ThreadLocal解决该需求

    2024-07-22 00:42:01       15 阅读
  7. 内联汇编清楚变量指定位

    2024-07-22 00:42:01       18 阅读
  8. 信竞2024年csp-j模拟赛第二场赛后总结

    2024-07-22 00:42:01       20 阅读
  9. 《C++并发编程实战》笔记(三)

    2024-07-22 00:42:01       19 阅读