PyTorch搭建LeNet测试集实现

搭建神经网络请看PyTorch搭建LeNet神经网络-CSDN博客

实现训练集请看PyTorch搭建LeNet训练集详细实现-CSDN博客

测试集比较简单,直接上代码。

代码实现

# 导包 不必多说
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

# 详细解释见下面
transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


# 与训练集一样的分类
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 模型实例化
net = LeNet()
net.load_state_dict(torch.load('LeNet.pth'))  # 载入在训练时保存的权重文件

im = Image.open('3.jpg')
im = transform(im)  # 预处理数据
im = torch.unsqueeze(im, dim=0)  # 将数据中增加一个batch维度

with torch.no_grad():
    outputs = net(im)
    # 寻找最大值所在的index索引值
    predict = torch.max(outputs, dim=1)[1].data.numpy()
# 最后打印预测结果
print(classes[int(predict)])

 预处理数据函数

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

 这一段是将数据预处理,相比于训练集中的数据预处理多了transforms.Resize((32, 32)),因为导入的图片尺寸不一定正确,需要先将图片的尺寸重新定义。

运行结果

我测试了飞机、汽车、鸟,飞机、汽车都可以识别出来。但鸟不行,可能是图片的像素太小,训练不到位。

把鸟给预测成猫了

总结

三天!从0开始,实现了LeNet。跟着b站上的视频,反复观看并记笔记,再自己手敲代码,再写出笔记。代码都能跑通实现,中间遇到的问题也靠自己独立解决了。对于自己来说还是比较有成就感的。但是我知道这点知识对于想要学好深度学习是远远远远不够的。还是要继续不断地学习。这样一篇一篇笔记也是我努力学习的见证!要努力成为很厉害的人!希望大家也是!

相关推荐

  1. PyTorchInformer实现长序列时间序列预测

    2024-03-11 13:38:06       16 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-11 13:38:06       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-11 13:38:06       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-11 13:38:06       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-11 13:38:06       20 阅读

热门阅读

  1. 举例说明计算机视觉(CV)技术的优势和挑战。

    2024-03-11 13:38:06       25 阅读
  2. Ubuntu系统开发环境搭建和常用软件

    2024-03-11 13:38:06       16 阅读
  3. Unity3D 基于AStar地图的摇杆控制角色详解

    2024-03-11 13:38:06       20 阅读
  4. Debian系APT源通用镜像加速配置

    2024-03-11 13:38:06       21 阅读
  5. NLP技术

    2024-03-11 13:38:06       22 阅读
  6. Go语言聊天室demo

    2024-03-11 13:38:06       19 阅读
  7. 【golang】二叉树的遍历

    2024-03-11 13:38:06       22 阅读
  8. Go语法之函数 defer使用

    2024-03-11 13:38:06       20 阅读
  9. 大数据开发(Hadoop面试真题-卷六)

    2024-03-11 13:38:06       24 阅读
  10. Node.js概述与安装和运行

    2024-03-11 13:38:06       20 阅读
  11. springboot文件上传

    2024-03-11 13:38:06       20 阅读