PyTorch计算机视觉入门:测试模型与评估,对单帧图片进行推理

在完成模型的训练之后,对模型进行测试与评估是至关重要的一步,它能帮助我们理解模型在未知数据上的泛化能力。本篇指南将带您了解如何使用PyTorch进行模型测试,并对测试结果进行分析。我们将基于之前训练好的模型,演示如何加载数据、进行预测、计算指标以及可视化结果。

准备工作

假设您已经有一个训练好的模型,保存在.pth文件中,以及一个用于测试的自定义数据集。我们将继续使用前文提到的自定义数据集CustomDataset类,并引入一些新的概念和代码。

加载测试数据集

与训练过程类似,首先需要加载测试数据集,并对其进行适当的预处理。确保您的测试集遵循与训练集相同的数据结构和预处理步骤。

test_dataset = CustomImageDataset(data_path="./data/", model= "test", transform = transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

测试模型

训练完成后,使用测试数据集来评估模型的性能

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target)
            # max 函数返回两个值,一个是是数值,一个是index
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标 

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
Test set: Average loss: 0.0018, Accuracy: 9671/10000 (97%)

单帧图片进行测试

# test single image 
img = Image.open("./data/data_test/1.jpg")
img_t = transform(img)
img_t = img_t.unsqueeze(0)  # 变为[1, 1, 28, 28]
img_t = img_t.to(device)
model.eval()
output = model(img_t)
_, predicted_class = torch.max(output, 1)
print(predicted_class)
tensor([2], device='cuda:0')

通过以上步骤,我们可以全面地评估和分析PyTorch模型在计算机视觉任务中的表现,从而确保模型在实际应用中的有效性和可靠性。

关注我的公众号Ai fighting, 第一时间获取更新内容。

相关推荐

  1. 利用C++进行图像处理计算机视觉

    2024-06-17 03:56:01       34 阅读
  2. 基于C++和OpenCv视频进行

    2024-06-17 03:56:01       12 阅读
  3. python视频进行处理以及裁减部分区域

    2024-06-17 03:56:01       8 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-17 03:56:01       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-17 03:56:01       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-17 03:56:01       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-17 03:56:01       18 阅读

热门阅读

  1. 2024前端面试准备6-TS基础

    2024-06-17 03:56:01       8 阅读
  2. vue3 如何给表单添加表单效验+正则表达式

    2024-06-17 03:56:01       6 阅读
  3. LeetCode热题1. 两数之和

    2024-06-17 03:56:01       6 阅读
  4. git diff

    2024-06-17 03:56:01       8 阅读
  5. windows用脚本编译qt的项目

    2024-06-17 03:56:01       7 阅读
  6. Window上ubuntu子系统编译Android

    2024-06-17 03:56:01       6 阅读
  7. react捡起来了

    2024-06-17 03:56:01       7 阅读
  8. python判断一个数是不是偶数

    2024-06-17 03:56:01       9 阅读
  9. 编程机器人的参数表怎么看

    2024-06-17 03:56:01       6 阅读
  10. AI芯片战场的迁徙:从训练到推理的深度剖析

    2024-06-17 03:56:01       6 阅读
  11. Linux部署FTP服务

    2024-06-17 03:56:01       5 阅读