【自学记录7】【Pytorch2.0深度学习从零开始学 王晓华】第七章 实战ResNet

7.1.4 ResNet网络的实现

遇到的问题:
1、nn.ReLU(inplace=True)
在PyTorch中,nn.ReLU(inplace=True)中的inplace=True参数表示该ReLU激活函数将直接修改输入张量,而不是创建一个新的输出张量。这意味着它会在原地(in-place)执行操作,不会占用额外的内存空间来存储输出。
具体地说,当你使用inplace=True时,输入张量x在经过ReLU激活函数后,其值会直接被ReLU的结果所替换。这样做可以节省内存,但需要注意,由于输入张量被修改了,因此在后续的计算中,如果你还需要原始张量的值,就可能会遇到问题。
2、ResNet的四个主要阶段
ResNet(残差神经网络)的四个主要阶段是指其网络结构中的四个部分,每个部分都包含多个残差块(Residual Block)。这些阶段的设计旨在逐步提取输入图像的特征,并随着网络的深入,逐渐加深和提升这些特征。以下是关于这四个阶段的详细解释:

第一阶段(Stage 1):

通常包含几个卷积层(例如3个)和池化层(例如2个)。
这一阶段的主要任务是对输入的图像进行初步的特征提取。卷积层通过卷积操作捕捉图像的局部特征,而池化层则用于减小特征图的空间尺寸,降低计算量,并增加特征的鲁棒性。
第二阶段(Stage 2):

包含更多的卷积层(例如4个)和池化层(例如2个)。
这一阶段进一步加深和提升了第一阶段提取的特征。通过增加卷积层的数量,网络能够学习到更复杂的特征表示。
第三阶段(Stage 3):

包含更多的卷积层(例如6个)和池化层(例如2个)。
在这一阶段,网络继续对特征进行加深和提升,进一步抽象和提取更高层次的特征。
第四阶段(Stage 4):
通常包含较少的卷积层(例如3个)和一个池化层。
这一阶段主要是对第三阶段提取的特征进行最终的加深和提升,为后续的分类或回归任务做准备。
除了这四个主要阶段外,ResNet还包含一个全连接层,用于将提取的特征映射到输出类别上。在每个阶段中,残差块的使用是关键。残差块通过引入恒等映射(identity mapping),使得网络在加深时能够更容易地优化,从而解决了深度神经网络训练中的梯度消失或爆炸问题。
总的来说,ResNet的四个主要阶段构成了一个深度卷积神经网络,通过逐步提取和加深特征,实现了对输入图像的高效处理。这种结构使得ResNet在图像识别、目标检测、人脸识别等领域取得了显著的成果。

源码\第七章\ resnet.py

import torch
import torch.nn as nn

#BasicBlock 是 ResNet 中使用的基本残差块类型,它是构成整个 ResNet 网络的基本组件。
#BasicBlock包含两部分:a、学习残差。b、捷径连接
class BasicBlock(nn.Module):

    expansion = 1
    #in_channels 和 out_channels 分别代表输入和输出通道数。stride 是卷积操作的步长。
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #a、residual function 执行【卷积】操作以学习残差。
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),#批量正则化
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        #b、shortcut:捷径连接,用于将输入直接连接到输出,以进行残差学习。
        self.shortcut = nn.Sequential()#里面啥也没有
       
        
        #判定输出的维度是否和输入相一致
        #shortcut 是一个序列。如果输入和输出的维度不匹配(比如由于步长不为1或通道数不同),则shortcut会包含一个1x1的卷积来确保维度匹配。
        #BasicBlock.expansion 是一个类属性,它表示基础块(BasicBlock)的输出通道数与输入通道数之间的比例
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):#将residual_function(x) 残差和self.shortcut(x) 捷径连接 两者相加,并通过ReLU激活函数得到输出。
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):
    #block 是用于构建网络的块类型(这里是BasicBlock)。
    #num_block 是一个列表,指定了每个阶段(conv2_x, conv3_x, conv4_x, conv5_x)中block的数量。
    #num_classes 是分类任务的类别数。
    def __init__(self, block, num_block, num_classes=100):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),#卷积层,输入通道数为3(对应于RGB图像),输出通道数为64,卷积核大小为3x3,使用padding=1以保持空间维度不变,不使用偏置(bias=False)。
            nn.BatchNorm2d(64),#批量归一化层,用于加速训练并增加模型稳定性。
            nn.ReLU(inplace=True)#原地(in-place)直接修改输入张量,不会占用额外的内存空间来存储输出。
        )

        #3、定义ResNet的四个主要阶段:
        #self.conv2_x、self.conv3_x、self.conv4_x、self.conv5_x:每个阶段都是由一系列的block(这里指的是BasicBlock)组成。
        #每个阶段的空间维度(宽度和高度)可能会通过stride来减小。

        #block:要使用的残差块类型(例如 BasicBlock 或 Bottleneck)。
        #64:该层中每个残差块的输出通道数。
        #num_block[0]:该层中要创建的残差块的数量。
        #1:第一个残差块中卷积层的步长。
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        
        #4、自适应平均池化层,用于将特征图转化为1*1的特征向量。
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        #5、全连接层,用于将提取的特征映射到输出类别上
        #全连接层,输入特征数为512(最后一个阶段的输出通道数)乘以block.expansion(对于BasicBlock,这个值为1),输出特征数为分类任务的类别数num_classes。
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    #用于创建一系列block(这里是BasicBlock)。它接收输出通道数、block的数量以及stride作为参数。
    def _make_layer(self, block, out_channels, num_blocks, stride):

        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)#x通过自适应平均池化层avg_pool,将空间维度减小到1x1。
        output = output.view(output.size(0), -1)#使用view方法将x展平为一维张量,以便它可以作为全连接层fc的输入。
        output = self.fc(output)#x通过全连接层fc,输出分类结果。

        return output

#18层的resnet
def resnet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

#34层的resnet
def resnet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

if __name__ == '__main__':
    #创建一个随机的图像张量 image,这表示有 5 张图像,每张图像有 3 个颜色通道(RGB),且每张图像的高度和宽度都是 32 像素。
    image = torch.randn(size=(5,3,32,32))
    #例化一个 18 层的 ResNet 模型 resnet(通过调用 ResNet(BasicBlock, [2, 2, 2, 2]))。
    resnet = ResNet(BasicBlock, [2, 2, 2, 2])

    img_out = resnet(image)
    print(img_out.shape)#torch.Size([5, 100])

7.2 实战ResNet

跟之前的训练代码一样,不贴了。

最近更新

  1. TCP协议是安全的吗?

    2024-04-08 06:14:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-08 06:14:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-08 06:14:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-08 06:14:01       20 阅读

热门阅读

  1. 机器视觉系统-什么是颜色/波长

    2024-04-08 06:14:01       14 阅读
  2. C#-实现软删除

    2024-04-08 06:14:01       15 阅读
  3. mybatis知识点

    2024-04-08 06:14:01       12 阅读
  4. 2022-04-24_数组的定义和初始化等_作业

    2024-04-08 06:14:01       12 阅读
  5. Kubernetes(K8s)运维实战:案例解析与代码实践

    2024-04-08 06:14:01       15 阅读
  6. 【flutter和android原生的异步】

    2024-04-08 06:14:01       14 阅读
  7. 绘图工具 draw.io / diagrams.net 免费在线图表编辑器

    2024-04-08 06:14:01       14 阅读
  8. Linux k8s相关命令

    2024-04-08 06:14:01       12 阅读
  9. 高通项目-TCP/UDP 心跳 Offload 如何实现省电

    2024-04-08 06:14:01       14 阅读
  10. npm 命令及其详细解释

    2024-04-08 06:14:01       14 阅读