基于BP网络识别MNIST数据集

1.对于MNIST数据的下载和预处理

 1.1torchvison.transforms

此方法为常见的处理图像数据的工具本次使用其中的两个方法

1.2transforms.Compose

是pytouch图像预处理包。一般把多个步骤整合到一起

transform = transforms.Compose([接收transfroms的操作])

1.3transforms.ToTensor()

shape(H, W, C)nump.ndarrayimg转为shape(C, H, W)tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255。

1.4transfrom.Normalize

标准化操作用于将数据映射到-1到1或0-1区间方便计算

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])

1.5torchvision.datasets

获取数据集

Torchvision在 torchvision.datasets 模块中提供了许多内置的数据集,以及用于构建自己的数据集的实用程序类。

主要数据集

•    1.1 CIFAR10
•    1.2 Fashion-MNIST
•    1.3 ImageNet

•    1.4 MNIST

数据集选项

root(string):数据集的根目录,其中目录 cifar-10-batches-py 存在或将保存到(如果下载设置为True)。
train(bool,可选):如果为True,则从训练集创建数据集,否则从测试集创建。
transform(可调用,可选):接受PIL图像并返回转换版本的函数/转换。例如, transforms.RandomCrop
target_transform(可调用,可选):接收目标并对其进行转换的函数/transform。
download(bool,可选):如果为true,则从Internet下载数据集并将其放在根目录中。如果数据集已下载,则不会再次下载。

traindata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=True, download=True,transform=transform)
testdata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=False, download=True,transform=transform)

 1.6DataLoader

是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。

主要参数

dataset(数据集):需要提取数据的数据集,Dataset对象 通常为图片和分类的组合
batch_size(批大小):每一次装载样本的个数,int型
 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

trainload = DataLoader(dataset=traindata, shuffle=True, batch_size=64)
testLoad = DataLoader(dataset=testdata, shuffle=False, batch_size=64)

2网络的构建

我们采用BP网络进行图片的分类训练

由于迭代器中一次给的tensor是(64,1,28,28)的向量,所以先将nn.Flatten()将后三为展成一维

(64,784)后将数据经过多层感知基和激活函数转化为(64,10)的输出,最后再将数据经过logsoftmax得到输出向量(64,10)

net=nn.Sequential(nn.Flatten(),nn.Linear(784,128),nn.ReLU(),nn.Linear(128,64),nn.ReLU(),nn.Linear(64,10),nn.LogSoftmax(dim=1))

注:dim参数为1对于行做softmax dim为0对列做softmax

损失函数使用交叉熵损失

loss=nn.NLLLoss()

NLLoss与CrossEntropyLoss的区别:

CrossEntropyLoss不需要对数据做logsoftmax操作,NLLoss需要先对数据进行softmax操作

本质都是交叉熵求导

优化算法采用平均梯度下降

updater=torch.optim.SGD(net.parameters(),lr=0.003,momentum=0.88)

lr为学习率 

mementum为动量 可以更快下降至损失最低点附近 有关动量怎么来的以及什么作用

 net.parameters为网络参数

请转至使用动量(Momentum)的SGD、使用Nesterov动量的SGD_sgd动量-CSDN博客

训练函数 

        for X,y in trainload:
            shuchu=net(X)
            l=loss(shuchu,y)
            loss_data=loss_data+l.item()
            updater.zero_grad()
            l.backward()
            updater.step()

基本操作没啥讲的

测试模型正确率

def test(tai=0):
    print("测试单元")
    with torch.no_grad():
        for X, y in testLoad:
            shuchu = net(X)
            index = shuchu.argmax(axis=1).tolist()
            y = y.tolist()
            ji = 0
            for f in range(64):
                if index[f] == y[f]:
                    ji += 1
            print(f"正确率: {ji * 100 / 64}%")

注argmax可以获取指定维度最大值的标号,获取后使用.tolist方法将tensor转化为列表

图像显示 

            if tai==1:
                ji=0
                plt.figure(figsize=(9, 7))
                for x in  range(4):
                    for p in range(4):
                        plt.subplot2grid((4, 4), (x,p))
                        plt.title(f'预测值: {index[ji]}, 实际值: {y[ji]}')
                        plt.imshow(X[ji].squeeze(), cmap='gray')
                        plt.axis('off')
                        ji+=1
                plt.show()

 imshow为plt内置的显示图片函数,接受图片的tensor数据 

cmap:颜色设置。常用的值有’viridis’、‘gray’、'hot’等。可以通过plt.colormaps()查看可用的颜色映射。

aspect:调整坐标轴。这将根据图像数据自动调整坐标轴的比例。常用的值有’auto’、'equal’等。设置为’auto’时会根据图像数据自动调整纵横比,而设置为’equal’时则会强制保持纵横比相等。

interpolation:插值方法。它定义了图像在放大或缩小时的插值方式。常用的值有’nearest’、‘bilinear’、'bicubic’等。较高的插值方法可以使图像看起来更平滑,但计算成本更高。

alpha:透明度。它允许您设置图像的透明度,取值范围为0(完全透明)到1(完全不透明)之间。

vmin和vmax:用于设置显示的数据值范围。当指定了这两个参数时,imshow()将会根据给定的范围显示图像,超出范围的值会被截断显示

3.运行结果及总代码

3.1运行结果

 

3.2 代码


import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch import nn
import pylab as pl
pl.rcParams['font.sans-serif'] = ['SimHei']
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
traindata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=True, download=True,transform=transform)
testdata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=False, download=True,transform=transform)
# 利用DataLoader加载数据集
trainload = DataLoader(dataset=traindata, shuffle=True, batch_size=64)
testLoad = DataLoader(dataset=testdata, shuffle=False, batch_size=64)
net=nn.Sequential(nn.Flatten(),nn.Linear(784,128),nn.ReLU(),nn.Linear(128,64),nn.ReLU(),nn.Linear(64,10),nn.LogSoftmax(dim=1))
loss=nn.NLLLoss()
updater=torch.optim.SGD(net.parameters(),lr=0.003,momentum=0.88)
ls=list()
ci=10
def train():
    net.train()
    for x in range(ci):
        print(f"第{x+1}次训练")
        loss_data = 0
        for X,y in trainload:
            shuchu=net(X)
            l=loss(shuchu,y)
            loss_data=loss_data+l.item()
            updater.zero_grad()
            l.backward()
            updater.step()
        print(f"总损失{loss_data}")
        ls.append(loss_data)
    test()
    x=[x for x in range(1,ci+1)]
    print(ls)
    plt.plot(x,ls)
    plt.show()
import matplotlib.pyplot as plt

def test(tai=0):
    print("测试单元")
    with torch.no_grad():
        for X, y in testLoad:
            shuchu = net(X)
            index = shuchu.argmax(axis=1).tolist()
            y = y.tolist()
            ji = 0
            for f in range(64):
                if index[f] == y[f]:
                    ji += 1
            print(f"正确率: {ji * 100 / 64}%")
            if tai==1:
                ji=0
                plt.figure(figsize=(9, 7))
                for x in  range(4):
                    for p in range(4):
                        plt.subplot2grid((4, 4), (x,p))
                        plt.title(f'预测值: {index[ji]}, 实际值: {y[ji]}')
                        plt.imshow(X[ji].squeeze(), cmap='gray')
                        plt.axis('off')
                        ji+=1
                plt.show()
            break

train()
test(1)


如有侵权必删

参考

使用动量(Momentum)的SGD、使用Nesterov动量的SGD_sgd动量-CSDN博客

pytorch初学笔记(六):DataLoader的使用_pytorch dataloader-CSDN博客

Pytorch学习:常见数据集torchvision.datasets—CIFAR10、Fashion-MNIST和ImageNet,以及数据集的使用DataLoader_torchvision.datasets.imagenet-CSDN博客

深度学习之使用BP神经网络识别MNIST数据集_bp神经网络完成minist数据集识别-CSDN博客

plt.imshow()的用法和参数介绍-CSDN博客 

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 23:02:02       5 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 23:02:02       5 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 23:02:02       4 阅读
  4. Python语言-面向对象

    2024-07-10 23:02:02       8 阅读

热门阅读

  1. 深入理解Spring Cloud中的服务注册

    2024-07-10 23:02:02       9 阅读
  2. SIFT代码,MATLAB

    2024-07-10 23:02:02       9 阅读
  3. Scala 数据类型

    2024-07-10 23:02:02       11 阅读
  4. DP学习——简单工厂模式

    2024-07-10 23:02:02       9 阅读
  5. 从 Spark 离线数仓到 Flink 实时数仓:实战指南

    2024-07-10 23:02:02       9 阅读
  6. 浅析DDoS高防数据中心网络

    2024-07-10 23:02:02       10 阅读
  7. 奇幻的Python

    2024-07-10 23:02:02       10 阅读
  8. 记录一些简单的linux运维命令

    2024-07-10 23:02:02       9 阅读
  9. python--del

    2024-07-10 23:02:02       9 阅读
  10. BiLSTM模型实现

    2024-07-10 23:02:02       10 阅读