将图像增广应用于Mnist数据集

将图像增广应用于Mnist数据集

不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用

首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d2l.set_figsize()
img = Image.open('cat.png')
d2l.plt.imshow(img)

在这里插入图片描述
之后呢,我们先定义几个函数,以后方便调用,第一个函数show_images,他是用来展示多张图片的

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize = figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

然后将图像展示函数和图像增广函数结合起来展示,也用一个函数来集成

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

接下来,就可以开始我们的图像增广之路啦

左右翻转

torchvision.transforms.RandomHorizontalFlip()这个函数有百分之五十的概率实现左右翻转

apply(img, torchvision.transforms.RandomHorizontalFlip()) # torchvision.transforms.RandomHorizontalFlip() 百分之50的概率左右翻转

在这里插入图片描述

上下翻转

torchvision.transforms.RandomVerticalFlip() 百分之50的概率上下翻转
在这里插入图片描述

随机裁剪

随机裁剪出一块面积为原面积10%100%的区域,且该区域的宽和高之比随机取自0.52,然后将该区域的宽高缩放到200像素

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述
自然,我们也可以变换颜色,有亮度(brightness),对比度(contrast),饱和度(saturation),色调(hue)
我就直接一起写了,也可以只变单个
0.5的意思是比如对于亮度来说,他会在50%的范围内随机选择,即亮度为原来的0.5~1.5

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, saturation=0.5, contrast=0.5) 
apply(img, color_aug)

在这里插入图片描述
那么当然,我们也可以把上述的那些进行叠加
用到torchvision.transforms.Compose

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])
apply(img, augs)

在这里插入图片描述
之后呢,就可以用增广后的图像进行训练啦,这里给大家一个例子用Resnet18进行训练Mnist数据集,Resnet18就不带着大家写了,直接调用别人写好的函数,写网络并不是本节的重点,如果以后有时间或者大家有需要我可以再来写~
(为什么是Mnist数据集,其实他在Mnist数据集上的效果并没有很明显,比较比较简单,最好是在cifar上,但是cifar要下太久了,懒,大家可以在cifar上测一下)
先写两个augs,训练集我就将他随机翻转,测试集就不动了

flip_aug = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()           # 记得转换成tensor 以便训练
])
no_aug = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

然后是load_mnist,就加一个transform就好啦,集成到一个函数里

def load_mnist(is_train, augs, batch_size, root="~/Datasets/MNIST"):
    dataset = torchvision.datasets.MNIST(train=is_train, root=root, transform=augs, download=True)
    return DataLoader(dataset, batch_size = batch_size, shuffle=is_train)

再之后就是模型的训练了,这个大家应该都写腻了,我也不多说什么了,反正就是模型前向传播+反向传播,然后再记录点值

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = d2l.evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

再最后,就定义一个函数,把前面的都用上啦!

def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(output=10, in_channels=1)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_mnist(True, train_augs, batch_size)
    test_iter = load_mnist(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

值得注意的是,这边调用别人的d2l.resnet18,要注意in_channels=1记得写,他默认是3通道的,改成1通道对于我们的mnist,如果你要是cifar-10就不用变了,把in_channel=1给删掉就好~,至此,调用我们的函数就行
在这里插入图片描述

训练还是很快滴

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2023-12-08 09:26:04       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-08 09:26:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-08 09:26:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-08 09:26:04       18 阅读

热门阅读

  1. vue项目中如何引入zip压缩包之解决方案

    2023-12-08 09:26:04       42 阅读
  2. Installing GDS

    2023-12-08 09:26:04       40 阅读
  3. 【1day】金和OA某接口存在未授权访问漏洞

    2023-12-08 09:26:04       31 阅读
  4. ARM虚拟化与车联网安全应用

    2023-12-08 09:26:04       39 阅读
  5. 【RabbitMQ高级功能详解以及常用插件实战】

    2023-12-08 09:26:04       37 阅读