008.googleNet-FashionMNIST-正确率90.510

一、GoogleNet简介

  • GoogleNet, 也称为Inception v1,是一种深度卷积神经网络(CNN)结构,最初由Google的研究人员在2014年设计并在ILSVRC竞赛中取得优异成绩。
  • GoogleNet在准确度和计算效率之间达到了较好的平衡,其独特的“Inception模块”是其核心创新。尤其在图像识别视频分析和计算机视觉等领域取得了突出的应用成果。

二、模块结构和变化

随着网络的发展,Inception模块也有了几个不同的版本,每个版本都在性能和效率上有所改进。

  • Inception v2 : 引入了批标准化和范围化卷积的概念,以加快训练速度,提高模型稳定性。
  • Inception v3 : 增加了分解卷积(factorized convolutions)和对对称性和非对称性网络结构的优化,进一步提高了效率。
  • Inception v4Inception-ResNet : 展示了Inception模块与残差网络结合的潜力,这种结合显著提升了训练速度和精确度。

三、FashionMNIST数据集简介

  • 之前的博客已经较为细致的介绍了FashionMNIST数据集:插眼传送

注意:了解数据集是机器学习的所有环节中最重要的一步,没有之一。

四、代码实现

1.导包
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
2.加载数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
generator = torch.Generator()

# 设置随机种子,确保实验可重复性
torch.random.manual_seed(420)
torch.cuda.manual_seed(420)
generator.manual_seed(420)
np.random.seed(420)
random.seed(420)

# 从"./dataset/"目录加载FashionMNIST数据集,如果没有则会自动下载。
train_data = FashionMNIST(root='./dataset/', train=True, transform=transforms.ToTensor(), download=True)
test_data = FashionMNIST(root='./dataset/', train=False, transform=transforms.ToTensor(), download=True)
train_batch = DataLoader(dataset=train_data, batch_size=256,  shuffle=True, num_workers=0, drop_last=False, generator=generator)
test_batch = DataLoader(dataset=test_data, batch_size=256,  shuffle=False, num_workers=0, drop_last=False, generator=generator)
3.定义模型
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个基本的卷积+批归一化+ReLU激活函数层
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

# 定义Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)
# 实现GoogleNet模型
class GoogleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(GoogleNet, self).__init__()
        self.conv1 = BasicConv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

4.定义损失函数、优化器
from torch.optim import Adam
from torch.nn import functional as F

# 初始化一个模型,输入图片通道数为1,输出特征为10
model = GoogleNet(num_classes=10).to(device)
# 使用负对数似然损失函数
criterion = torch.nn.CrossEntropyLoss()
# 初始化Adam优化器,设定学习率为0.005
opt = Adam(model.parameters(), lr=0.005)

5.开始训练
# 进行9次迭代
for _ in range(9):
    # 遍历数据批次
    for batch in train_batch:
        # 将输入数据X调整形状并输入到模型
        X = batch[0].to(device)
        # y为真实标签
        y = batch[1].to(device)

        # 前向传播,获取模型输出
        sigma = model.forward(X)
        # 计算损失
        loss = criterion(sigma, y)
        # 计算预测的标签
        y_hat = torch.max(sigma, dim=1)[1]
        # 计算预测正确的数量
        correct_count = torch.sum(y_hat == y)
        # 计算准确率
        accuracy = correct_count / len(y) * 100
        # 反向传播,计算梯度
        loss.backward()
        # 更新模型参数
        opt.step()
        # 清除之前的梯度
        model.zero_grad()
    # 打印当前批次的损失和准确率
    print('loss:', loss.item(), 'accuracy:', accuracy.item())

输出:

loss: 0.40873321890830994 accuracy: 86.45833587646484
loss: 0.32110244035720825 accuracy: 87.5
loss: 0.3983750343322754 accuracy: 84.375
loss: 0.2500341534614563 accuracy: 92.70833587646484
loss: 0.25288426876068115 accuracy: 90.625
loss: 0.23409247398376465 accuracy: 93.75
loss: 0.21931307017803192 accuracy: 90.625
loss: 0.19851894676685333 accuracy: 90.625
loss: 0.18098406493663788 accuracy: 94.79167175292969
6.验证测试集
correct_count = 0
for batch in test_batch:
    test_X = batch[0].to(device)
    test_y = batch[1].to(device)
    sigma = model.forward(torch.tensor(test_X, dtype=torch.float32))
    y_hat = torch.max(sigma, dim=1)[1]
    correct_count += torch.sum(y_hat == test_y)
    
accuracy = correct_count / 10000 * 100
print('accuracy:', accuracy.item())

输出:

accuracy: 90.51000213623047
  • 可以看出:googleNet相对于普通CNN,处理时间增加了很多,但精确度结果,并没有要强大很多。

相关推荐

  1. 008.googleNet-FashionMNIST-正确90.510

    2024-06-11 18:58:01       12 阅读
  2. 007.卷积网络-FashionMNIST-正确90.180

    2024-06-11 18:58:01       12 阅读
  3. <span style='color:red;'>GoogLeNet</span>

    GoogLeNet

    2024-06-11 18:58:01      10 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-11 18:58:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-11 18:58:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-11 18:58:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-11 18:58:01       20 阅读

热门阅读

  1. Android13 Settings 左上角箭头图标点击无效

    2024-06-11 18:58:01       8 阅读
  2. 【摘葡萄game】

    2024-06-11 18:58:01       9 阅读
  3. C++ Primer Chapter 8 The IO Library

    2024-06-11 18:58:01       6 阅读
  4. 深度学习的可微渲染

    2024-06-11 18:58:01       9 阅读
  5. Linux学习问题

    2024-06-11 18:58:01       11 阅读
  6. NAT

    NAT

    2024-06-11 18:58:01      9 阅读
  7. 享元模式

    2024-06-11 18:58:01       7 阅读
  8. css特效:对多个tag标签实现模拟地球仪特效

    2024-06-11 18:58:01       9 阅读
  9. 2024-6-11-SPECT和PET的区别是什么

    2024-06-11 18:58:01       9 阅读
  10. docker-ce 和 docker-ee介绍版本介绍

    2024-06-11 18:58:01       6 阅读
  11. C++中的命令模式

    2024-06-11 18:58:01       7 阅读
  12. 结构化表达,了解python的pep

    2024-06-11 18:58:01       7 阅读
  13. 关系模式R(U,F)【数据库-软件设计师】

    2024-06-11 18:58:01       9 阅读