007.卷积网络-FashionMNIST-正确率90.180

一、卷积网络的核心概念

卷积神经网络(CNN):
  • 深度学习最为成功和流行的模型之一,尤其在图像识别视频分析和计算机视觉等领域取得了突出的应用成果。
局部感受野:
  • 与传统的全连接神经网络不同,CNN中的每个神经元并不与前一层的所有神经元相连,而是仅仅连接到一个小区域,即局部感受野。
  • 这样可以大幅减少需要训练的参数数量,并且提高了网络对于空间层次结构的学习能力。
权重共享:
  • 在CNN中,同一层的多个神经元可以共享相同的权重,这意味着在进行卷积操作时使用的滤波器(卷积核)对整个输入数据执行相同的操作。
  • 这不仅进一步减少了模型的复杂性,而且也使得网络具有平移不变性(即不管物体在图像中的位置如何,网络都能够识别它)。
卷积层(Convolutional Layer):
  • 负责对输入图像执行卷积操作,该操作通过在图像上滑动滤波器,并计算滤波器与图像各个区域的内积来完成。输出结果称为特征图(Feature Map)。
激活层(Activation Layer):
  • 跟随卷积层的是激活函数,通常是非线性的激活函数,如ReLU(Rectified Linear Unit),它的作用是增加网络的非线性以处理复杂问题。
池化层(Pooling Layer):
  • 减少特征图的空间大小,同时使得网络对小的位置变化更棒。
全连接层(Fully Connected Layer):
  • 在多个卷积和池化层之后,全连接层用来将学到的“高级特征”映射到最终的分类结果,比如通过Softmax函数来实现多分类。

二、FashionMNIST数据集简介

  • 之前的博客已经较为细致的介绍了FashionMNIST数据集:插眼传送

注意:了解数据集是机器学习的所有环节中最重要的一步,没有之一。

三、代码实现

1.导包
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
2.加载数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 从"./dataset/"目录加载FashionMNIST数据集,如果没有则会自动下载。
train_data = FashionMNIST(root='./dataset/', train=True, transform=transforms.ToTensor(), download=True)
test_data = FashionMNIST(root='./dataset/', train=False, transform=transforms.ToTensor(), download=True)
train_batch = DataLoader(dataset=train_data, batch_size=256,  shuffle=True,  drop_last=False)
3.定义模型
class Model(torch.nn.Module):
    def __init__(self,in_features,out_features):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(1)
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3)
        self.pool1 = torch.nn.MaxPool2d(2) # 13
        self.conv2 = torch.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3) 
        self.pool2 = torch.nn.MaxPool2d(2) # 5
        self.linear1 = torch.nn.Linear(16*5*5, 128)
        self.output = torch.nn.Linear(128, out_features, bias=False)
        
    def forward(self,x):
        x = self.bn(x)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.relu(x.view(x.shape[0], -1))
        x = torch.relu(self.linear1(x))
        x = self.output(x)
        x = F.log_softmax(x,dim=1)
        return x

这个模型是使用PyTorch框架实现的一个简单的卷积神经网络(CNN)。它继承自torch.nn.Module,其中定义了网络的架构和前向传播的流程。下面我来逐行进行分析。

  1. __init__函数:

    • super().__init__()调用父类的构造函数,必须在开始执行子类构造函数逻辑之前完成。
    • self.bn定义了一个Batch Normalization层,用于正则化处理,这里针对的是2D数据,通常应用于卷积层之后,全连接层之前。
    • self.conv1定义了第一个卷积层,输入通道in_channels=1(假设是灰度图片),输出通道out_channels=8,卷积核大小kernel_size=3
    • self.pool1定义了第一个池化层,使用2x2的最大池化。
    • self.conv2定义第二个卷积层,输入通道为上一卷积层的输出通道8,输出通道16,核大小同为3
    • self.pool2定义了第二个池化层,同样是2x2的最大池化。
    • self.linear1定义了一个全连接层,输入特征是16通道的5x5的特征图展平后得到的一维向量,输出特征维度是128。
    • self.output定义了最后一个全连接层作为输出层,它将128维的向量映射到out_features维,bias=False表示此层不使用偏置项。
  2. forward函数:

    • forward函数中定义了如何计算从输入到输出的方式。
    • 首先通过先前定义的Batch Normalization层标准化输入x
    • 然后通过第一个卷积层,接着应用ReLU激活函数。
    • 经过第一个池化层降维。
    • 重复同样的操作(卷积、ReLU、池化)对经过第一层处理后的特征图进一步提取特征。
    • 将多维的数据展平成一维数据以便全连接层处理。
    • 数据通过全连接层,再次应用ReLU激活函数。
    • 最后通过输出层的全连接层,输出最终的预测。
    • 应用log_softmax函数处理输出,以获得最终的类别概率分布。

此网络适合用于处理单通道(例如灰度)图像的分类任务。每一层卷积和池化都设计来逐渐减小空间尺寸,增加特征复杂性。全连接层则用于根据提取的特征进行最终的分类判断。整体而言,这是一个基本的CNN架构。

4.定义损失函数、优化器
from torch.optim import Adam
from torch.nn import functional as F

# 设置随机种子,确保实验可重复性
torch.random.manual_seed(420)

# 初始化一个模型,输入图片通道数为1,输出特征为10
model = Model(in_features=1, out_features=10).to(device)
# 使用负对数似然损失函数
criterion = torch.nn.NLLLoss()
# 初始化Adam优化器,设定学习率为0.005
opt = Adam(model.parameters(), lr=0.005)

5.开始训练
# 进行9次迭代
for _ in range(9):
    # 遍历数据批次
    for batch in train_batch:
        # 将输入数据X调整形状并输入到模型
        X = batch[0].to(device)
        # y为真实标签
        y = batch[1].to(device)

        # 前向传播,获取模型输出
        sigma = model.forward(X)
        # 计算损失
        loss = criterion(sigma, y)
        # 计算预测的标签
        y_hat = torch.max(sigma, dim=1)[1]
        # 计算预测正确的数量
        correct_count = torch.sum(y_hat == y)
        # 计算准确率
        accuracy = correct_count / len(y) * 100
        # 反向传播,计算梯度
        loss.backward()
        # 更新模型参数
        opt.step()
        # 清除之前的梯度
        model.zero_grad()
    # 打印当前批次的损失和准确率
    print('loss:', loss.item(), 'accuracy:', accuracy.item())

输出:

loss: 0.42387986183166504 accuracy: 83.33333587646484
loss: 0.2386537343263626 accuracy: 86.45833587646484
loss: 0.25280213356018066 accuracy: 90.625
loss: 0.25218477845191956 accuracy: 91.66667175292969
loss: 0.17148371040821075 accuracy: 92.70833587646484
loss: 0.15541903674602509 accuracy: 94.79167175292969
loss: 0.188282772898674 accuracy: 90.625
loss: 0.14084762334823608 accuracy: 93.75
loss: 0.32792240381240845 accuracy: 89.58333587646484
6.验证测试集
test_X = test_data.data.unsqueeze(dim=1).to(device)
test_y = test_data.targets.to(device)
sigma = model.forward(torch.tensor(test_X,dtype=torch.float32))
y_hat = torch.max(sigma, dim=1)[1]
correct_count = torch.sum(y_hat == test_y)
accuracy = correct_count / 10000 * 100
print('accuracy:', accuracy.item())

输出:

accuracy: 90.18000030517578
  • 很明显:卷积网络对于图像的处理,对比全连接网络,要强大很多。

相关推荐

  1. 007.网络-FashionMNIST-正确90.180

    2024-06-08 03:44:04       12 阅读
  2. 008.googleNet-FashionMNIST-正确90.510

    2024-06-08 03:44:04       11 阅读
  3. 神经网络

    2024-06-08 03:44:04       23 阅读
  4. 神经网络

    2024-06-08 03:44:04       13 阅读
  5. 神经网络

    2024-06-08 03:44:04       7 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-08 03:44:04       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-08 03:44:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-08 03:44:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-08 03:44:04       18 阅读

热门阅读

  1. C#面:解释什么是扩展方法

    2024-06-08 03:44:04       8 阅读
  2. html 添加元素如何能提升速度

    2024-06-08 03:44:04       8 阅读
  3. 致远OA(A8) REST接口 python版

    2024-06-08 03:44:04       8 阅读
  4. Git 保留空文件夹结构

    2024-06-08 03:44:04       9 阅读
  5. Flink Rest Basic Auth - 安全认证

    2024-06-08 03:44:04       9 阅读
  6. 安卓手机APP开发___设备管理概述

    2024-06-08 03:44:04       10 阅读
  7. Gnu/Linux 系统编程 - 如何获取帮助及一个演示

    2024-06-08 03:44:04       9 阅读
  8. C#朗读语音

    2024-06-08 03:44:04       8 阅读
  9. 第3章 列表简介

    2024-06-08 03:44:04       11 阅读