【pytorch23】MNIST测试实战

理解

训练完之后也需要做测试
为什么要做test?
在这里插入图片描述
上图蓝色代表train的accuracy
下图蓝色代表train的loss
基本上符合预期,随着epoch增大,train的accuracy也会上升,loss也会一直下降,下降到一个较小的程度

但是如果只看train的情况的话,就会被欺骗,虽然accuracy高并且loss很低,你会以为这个算法就很好了,但是做其他的事情就不是特别好,这就是过拟合(overfitting)

deep learning表达能力非常强表达效果很好,在模型在训练数据上表现非常好,但在新的、未见过的数据上表现不佳的情况。这是因为模型学习到了训练数据中的特定噪声和细节,而不是更通用的特征。

如何缓解这种情况?
在train的时候做一个test,这个test使用validation set 验证集做的,在刚开始的阶段,蓝色的线在上升的时候,验证集的accuracy也会上升loss也与train基本一致,只不过是在训练集上面train在验证集上测试不一定完全符合,所以波动会有点大,很明显train会的更好,validation的表现(包括accuracy和loss)也会变的更好

说明在刚开始的阶段确实学到了一些通用的特征,随着时间的推移,就开始over fitting了,开始去记住一些噪声和细节,这样的话泛化能力会变差,所以在训练集上训练后,在验证集上测试的时候,accuracy会保持不变或者可能下降同样的loss也会巨幅的波动

深度学习所以并不是越训练越好,数据量和架构是核心问题,有一个好的结构再加上足够的数据才能取得一个好的结果

在这里插入图片描述
logits是一个是十个节点的向量,经过cross entropy loss(包含softmax和log和nll_loss)训练,得到loss和accuracy(经过softmax之后就变成了Y=i,i代表第i号节点的概率,只需要argmax之后就能得到概率最大所在的位置),这里对softmax之前和之后都做了一下argmax,其实是一样的效果,因为softmax不会改变单调性,即原来大的数据在softmax之后也会大

这是计算accuracy的基本流程

在这里插入图片描述
什么时候计算test的accuracy和loss

不能够每做一个batch就训练一次,这样就会花大量的时间做测试,不合理,尤其是对于大型数据集

一般情况:

  • 训练若干个batch做一次测试
  • 训练一个epoch做一次测试

如何做测试
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


def load_data(batch_size):
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_size, shuffle=True)
    return train_loader, test_loader


class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x


def training(train_loader, net, device):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()))


def testing(test_loader, net, device):
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


global net

if __name__ == '__main__':
    batch_size = 200
    learning_rate = 0.01
    epochs = 10

    train_loader, test_loader = load_data(batch_size)

    device = torch.device('cuda:0')
    net = MLP().to(device)
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)
    criteon = nn.CrossEntropyLoss().to(device)

    for epoch in range(epochs):
        training(train_loader, net, device)
        testing(test_loader, net, device)

相关推荐

  1. network_api_pytorch_mnist

    2024-07-10 14:08:07       18 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 14:08:07       4 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 14:08:07       5 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 14:08:07       4 阅读
  4. Python语言-面向对象

    2024-07-10 14:08:07       5 阅读

热门阅读

  1. Jupyter Notebook详尽安装教程

    2024-07-10 14:08:07       7 阅读
  2. 实现淘客返利系统中的用户登录与权限管理

    2024-07-10 14:08:07       6 阅读
  3. 【力扣】每日一题—第70题,爬楼梯

    2024-07-10 14:08:07       8 阅读
  4. mysql快速精通(一)DQL数据查询语言

    2024-07-10 14:08:07       10 阅读
  5. 408第二轮复习 数据结构 第七章查找

    2024-07-10 14:08:07       10 阅读
  6. Python中的迭代器与可迭代对象的概念及其关系

    2024-07-10 14:08:07       10 阅读
  7. 大数据面试题之Greenplum(2)

    2024-07-10 14:08:07       7 阅读
  8. 准备GPU H20机器k8s环境时用到的链接

    2024-07-10 14:08:07       8 阅读
  9. 数据库的优点和缺点分别是什么

    2024-07-10 14:08:07       10 阅读
  10. SQL语句分类

    2024-07-10 14:08:07       10 阅读