【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(模型预测)

介绍

这篇文章是《【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(如何使用自定义数据集)》的最后一部分内容:模型预测。

在本文中,我们将介绍如何测试模型的预测效果——让已训练好模型对一张新的图片进行分类;最后将整个流程打包,写成一个可以被直接调用的函数。

整个预测流程包括:

  • 图片下载
  • 图像转张量、图像数据变换
  • 使用训练好的模型进行预测
  • 预测结果输出

通过这些步骤,读者将能够进一步了解如何对已经训练好模型进行测试,以及了解模型是如何完成对图像的分类工作。


其他相关文章:


图像处理和预测分步骤详解

1. 图片下载&路径设置

import requests

# 设置文件路径
custom_image_path = data_path / "04-pizza-dad.jpeg"

# 文件下载
if not custom_image_path.is_file():
    with open(custom_image_path, "wb") as f:
        request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/04-pizza-dad.jpeg")
        print(f"Downloading {custom_image_path}...")
        f.write(request.content)
else:
    print(f"{custom_image_path} already exists, skipping download.")

2. 将图像转换成张量

import torchvision

# 将图像转换成张量(未指定格式)
custom_image_uint8 = torchvision.io.read_image(str(custom_image_path))

# 打印结果
print(f"Custom image tensor:\n{custom_image_uint8}\n")
print(f"Custom image shape: {custom_image_uint8.shape}\n")
print(f"Custom image dtype: {custom_image_uint8.dtype}")

![[04.1 8. 模型预测-20240605181741722.webp]]

图像数据的格式是torch.uint8, 表示范围在(0,255),通常用于表示图像的像素值。

而在深度学习模型中,通常使用 torch.float32 格式的输入,因为模型训练和推理时需要更高的数值精度和更广泛的表示范围。

因此,需要把格式转成精度更高的float32,这是模型所需要的格式。

# 载入图像,并将张量值转换为float32
custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)

# 将torch.uint8张量转换为torch.float32,并归一化到[0, 1]
custom_image = custom_image / 255. 

# 检查转换后的张量的数据
print(f"Custom image tensor:\n{custom_image}\n")
print(f"Custom image shape: {custom_image.shape}\n")
print(f"Custom image dtype: {custom_image.dtype}")

![[04.1 8. 模型预测-20240606124914487.webp]]在这里插入图片描述

# 图片展示
plt.imshow(custom_image.permute(1, 2, 0))
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);

在这里插入图片描述
![[04.1 8. 模型预测-20240605181832312.webp]]
数据形状现在是[3,4032,4032], 我们还需要对它进行进一步的处理,使其能够匹配模型训练时使用的数据形状。

3. 图像变换

# 设置图像变换过程
custom_image_transform = transforms.Compose([
    transforms.Resize((64, 64)),
])

# 图片转换
custom_image_transformed = custom_image_transform(custom_image)

# 打印图片形状
print(f"Original shape: {custom_image.shape}")
print(f"New shape: {custom_image_transformed.shape}")

![[04.1 8. 模型预测-20240605181913674.webp]]

经过Transform过程,图片形状变成[3,64,64]。

原始形状为torch.Size([3, 4032, 3024]),这表示图像的高度为4032像素,宽度为3024像素,并且有3个通道(通常表示RGB通道)。

新的形状为torch.Size([3, 64, 64]),这表示经过调整后,图像的高度和宽度都变成了64像素,依然保持3个通道。

transforms.ToTensor():

  • 输入格式:对于彩色图像(RGB),输入通常是形状为 (H, W, 3) 的 numpy 数组或 PIL 图像,其中 H 是高度,W 是宽度,3 表示颜色通道(红、绿、蓝)。
  • 输出格式:形状为 (C, H, W)的PyTorch 张量,其中 C 是颜色通道数(通常为 3),H 是高度,W 是宽度。
  • 示例:假设有一张 RGB 图像,原始大小为 256x256,转换后为形状为 (3, 256, 256) 的张量,其中 3 表示 RGB 通道。如果是灰度图像,转换为 (1, H, W),因为灰度图像只有一个通道。

4. 模型预测

model_0.eval()

with torch.inference_mode():
    # 给图像增加一个维度:batch size
    custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)  

    # 打印结果
    print(f"Custom image transformed shape: {custom_image_transformed.shape}")
    print(f"Unsqueezed custom image shape: {custom_image_transformed_with_batch_size.shape}")

    # 使用模型对图像进行分类预测
    custom_image_pred = model_0(custom_image_transformed.unsqueeze(dim=0).to(device))

![[04.1 8. 模型预测-20240605182651951.webp]]

  • custom_image_transformed.unsqueeze(dim=0) 因为在模型训练过程中,图像张量数据是按照批次导入模型训练的,模型适应的维度/形状是(N, C, H, W), 这里的N是批次的意思。因此, torch.unsqueeze(dim=0)给图像价

5. 预测结果输出

# 打印原始预测值logits
print(f"Prediction logits: {custom_image_pred}")

# 将logits转换为预测概率-->模型预测的概率
custom_image_pred_probs = torch.softmax(custom_image_pred, dim=1)
print(f"Prediction probabilities: {custom_image_pred_probs}")

# 将预测概率转换为预测标签
custom_image_pred_label = torch.argmax(custom_image_pred_probs, dim=1)
print(f"Prediction label: {custom_image_pred_label}")

![[04.1 8. 模型预测-20240605182220141.webp]]

  • torch.softmax(custom_image_pred, dim=1): 使用Softmax函数将logits转换为概率。Softmax函数将logits转换为0到1之间的概率值,并且所有概率值的总和为1。dim=1表示在类别维度上进行计算。
  • torch.argmax(custom_image_pred_probs, dim=1): 在概率最大的类别索引上取最大值,这个索引对应于模型预测的类别标签。
# 找出预测标签
custom_image_pred_class = class_names[custom_image_pred_label.cpu()] # put pred label to CPU, otherwise will error

custom_image_pred_class

![[04.1 8. 模型预测-20240606123453519.webp]]

  • .cpu() 这里代码是在GPU上运行的,所以需要把预测标签移回CPU上。

创建预测函数(打包整个预测过程)

我们复习一下上面的步骤:

  1. 设置目标图像路径,并将其转换为适合我们模型的数据类型(torch.float32)。
  2. 确保目标图像的像素值在范围 [0, 1] 之内。
  3. 如有必要,对目标图像进行变换。
  4. 确保模型在指定的设备上。
  5. 使用训练好的模型对目标图像进行预测(确保图像尺寸正确,并与模型在同一设备上)。
  6. 将模型的输出logits转换为预测概率。
  7. 将预测概率转换为预测标签。
  8. 绘制目标图像,并显示模型的预测结果和预测概率。

接下来我们需要把这些步骤都打包到一个函数中,这样就能通过函数实现模型的预测功能。

def pred_and_plot_image(model: torch.nn.Module,
                        image_path: str,
                        class_names: List[str] = None,
                        transform=None,
                        device: torch.device = device):

    # 1. 载入图像,并将张量值转换为float32
    target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)

    # 2. 将图像像素值除以255,使其在[0, 1]之间
    target_image = target_image / 255.

    # 3. 如有必要,进行图像变换
    if transform:
        target_image = transform(target_image)

    # 4. 确保模型在指定设备上
    model.to(device)

    # 5. 启用模型评估模式和推理模式
    model.eval()
    with torch.inference_mode():
        # 为图像添加一个维度
        target_image = target_image.unsqueeze(dim=0)

        # 对图像进行预测,并将其发送到指定设备
        target_image_pred = model(target_image.to(device))

    # 6. 将logits转换为预测概率(使用softmax进行多分类)
    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

    # 7. 将预测概率转换为预测标签
    target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

    # 8. 绘制图像,并显示预测结果和预测概率
    plt.imshow(target_image.squeeze().permute(1, 2, 0))  # 调整图像以适应matplotlib
    if class_names:
        title = f"预测: {class_names[target_image_pred_label.cpu()]} | 概率: {target_image_pred_probs.max().cpu():.3f}"
    else:
        title = f"预测: {target_image_pred_label} | 概率: {target_image_pred_probs.max().cpu():.3f}"
    plt.title(title)
    plt.axis(False);
pred_and_plot_image(model=model_0,
                    image_path=custom_image_path,
                    class_names=class_names,
                    transform=custom_image_transform,
                    device=device)

![[04.1 8. 模型预测-20240606123753105.webp]]

最后结果展示了分类的标签,概率,以及经过处理后的图片。


最近更新

  1. TCP协议是安全的吗?

    2024-06-07 10:36:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-07 10:36:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-07 10:36:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-07 10:36:02       18 阅读

热门阅读

  1. 什么是shell脚本?

    2024-06-07 10:36:02       7 阅读
  2. MySQL和Redis的区别

    2024-06-07 10:36:02       9 阅读
  3. shell脚本对编码和行尾符敏感吗

    2024-06-07 10:36:02       9 阅读
  4. 2-链表-61-相交节点-LeetCode160

    2024-06-07 10:36:02       7 阅读
  5. GaussDB 数据库的事务管理

    2024-06-07 10:36:02       8 阅读
  6. Python语言回归:深入探索与实战应用

    2024-06-07 10:36:02       9 阅读
  7. 8086 汇编笔记(十一):内中断

    2024-06-07 10:36:02       9 阅读
  8. OC和Swift的区别,发送消息和执行方法的区别

    2024-06-07 10:36:02       6 阅读
  9. AWS Load Balancer Controller 实践

    2024-06-07 10:36:02       7 阅读