2.ResNet——使用resnet101实例化一个具有101层的卷积神经网络

  我们将使用resnet101来实例化一个具有101层的卷积神经网络。下面所有的代码和数据我都会进行上传,并且也会整个可以运行的代码,前期是前面的环境配置好了。(相关文件代码正在上传中...)

 

一、加载预训练模型

书上所用指示函数下载resnet101在ImageNet数据集上训练好的权重的代码为,但是运行之后会报下面的错误。

resnet = models.resnet101(pretrained=True)

  报错原因为这些警告表明 torchvision 中使用 pretrained 参数来加载预训练模型的方式已经过时,并将在未来版本中移除。现在应该使用 weights 参数来代替 pretrained 参数。 所有解决方法为导入相关库函数并修改原来的代码。

from torchvision import models, transforms
from torchvision.models import ResNet101_Weights

# 加载预训练模型
resnet = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

  其中我们可以打印一下resnet的信息,可以看到这个resnet的网络结构非常的庞大,下面只展示部分

二、预处理 

1.预处理输入图像

  要想可以正常调用resnet,在此之前需要对输入图像进行预处理,使其大小正确,使其值(颜色)大致处于相同的数值范围。在这里使用torchvision模块提供的转换功能进行转换。

from torchvision import models, transforms

# 定义预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),   # 调整短边为256像素
    transforms.CenterCrop(224),   # 从中心裁剪224x224图像
    transforms.ToTensor(),   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 归一化
])

 2.进行图像的导入

  现在我们需要选择一张我们喜欢的图片,对其进行预处理,然后查看ResNet对它的识别结果。我们可以使用一个Python的图像操作模块Pillow从本地加载一张图片。

from PIL import Image

# 打开和预处理图像
image = Image.open("./pic/bobby.jpg")
image_t = preprocess(image)

 3.转化为张量并进行处理

  然后我们可以按照网络期望的方式对输入的张量(这里不懂没关系,后面对说,想象为线性代数里面的n维矩阵就行)进行重塑、裁剪和归一化处理。在这里就要开始使用Pytorch模块了。

import torch

# 批处理维度
batch_t = torch.unsqueeze(image_t, dim=0)

三、推理模型  

  在深度学习中,在新数据上运行训练过的模型的过程被称为推理(inference)。为了进行推理,我们需要将网络置于eval模式。

# 评估模式
resnet.eval()

  现在eval设置好了,我们准备进行推理(也就是逆推,向前传播)

# 前向传播
out = resnet(batch_t)

四、读取标签并得到结果  

  通过上面的步骤可以得出1000个分数向量,也就相当于1000个标签,我们现在需要找出得分高的类的标签,也就是符合预期的标签。要查看预测标签的列表,我们需要加载一个文本文件,按照训练中呈现给网络的顺序列出标签。

   让我们为ImageNet数据集类加载一个包含1000个标签的文件。相关的文件到时会引用在文章中。

# 读取标签
with open('./imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

  此时,我们需要确定不同分高对应的索引。可以使用pytorch的max()函数来做到这一点,它可以输出一个张量中的最大值以及最大值所在的索引。

_, index = torch.max(out, 1)

  现在我们可以使用索引来访问标签。在这里并非普通的python索引,而是张量索引,如torch([20])。因此我们需要使用index[0]获得实际的数字作为标签列表的索引,还可以使用 torch.nn.functional.softmax()将输出归一化带[0,1]之间然后除以总和,就可以得到大致的预测置信度。

# 获取预测结果
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(labels[index[0]], percentage[index[0]].item())

    最后我们获取前五个预测结果来看一下。

# 获取前5个预测结果
_, indices = torch.sort(out, descending=True)
result = [(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
print(result)

五、整体代码  

整体代码如下:

import torch
from torchvision import models, transforms
from torchvision.models import ResNet101_Weights
from PIL import Image

# 加载预训练模型
resnet = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)

# 定义预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),   # 调整短边为256像素
    transforms.CenterCrop(224),   # 从中心裁剪224x224图像
    transforms.ToTensor(),   # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 归一化
])

# 打开和预处理图像
image = Image.open("./pic/bobby.jpg")
image_t = preprocess(image)

# 批处理维度
batch_t = torch.unsqueeze(image_t, dim=0)

# 评估模式
resnet.eval()

# 前向传播
out = resnet(batch_t)

# 读取标签
with open('./imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

# 获取预测结果
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(labels[index[0]], percentage[index[0]].item())

# 获取前5个预测结果
_, indices = torch.sort(out, descending=True)
result = [(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
print(result)


相关推荐

  1. 利用pytorch实现形式ResNet

    2024-06-11 20:34:01       58 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-11 20:34:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-11 20:34:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-11 20:34:01       87 阅读
  4. Python语言-面向对象

    2024-06-11 20:34:01       96 阅读

热门阅读

  1. 02. fastLed 基本用法

    2024-06-11 20:34:01       22 阅读
  2. angular2网页前端执行流程

    2024-06-11 20:34:01       31 阅读
  3. 制作手机IOS苹果ipa应用的重签名工具

    2024-06-11 20:34:01       30 阅读
  4. golang生成根证书,服务端证书,用于 tls

    2024-06-11 20:34:01       31 阅读