机器学习入门【经典的CIFAR10分类】

模型

神经网络采用下图
在这里插入图片描述

我使用之后发现迭代多了之后一直最高是正确率65%左右,然后我自己添加了一些Relu激活函数和正则化,现在正确率可以有80%左右。

模型代码

import torch
from torch import nn


class YmModel(nn.Module):
    def __init__(self):
        super(YmModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.model(x)


训练

有一点要说明的是,数据集中并没有验证集,你可以从训练集扣个1w张出来

import torch
import torchvision
from torchvision import transforms

from models.YMModel import YmModel
from torch.utils.data import DataLoader


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
print(len(train_loader), len(test_loader))

print(len(train_dataset), len(test_dataset))

model = YmModel()
#迭代次数
train_epochs = 300
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 损失函数
loss_fn = torch.nn.CrossEntropyLoss()

train_epochs_step = 0
best_accuracy = 0.

for epoch in range(train_epochs):
    model.train()
    print(f'Epoch is {epoch}')
    for images, labels in train_loader:
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if train_epochs_step % 100 == 0:
            print(f'Train_Epoch is {train_epochs_step}\t Loss is {loss.item()}')
        train_epochs_step += 1
    train_epochs_step = 0

    with torch.no_grad():
        loss_running_total = 0.
        acc_running_total = 0.
        for images, labels in test_loader:
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss_running_total += loss.item()
            acc_running_total += (outputs.argmax(1) == labels).sum().item()
    acc_running_total /= len(test_dataset)
    if acc_running_total > best_accuracy:
        best_accuracy = acc_running_total
        torch.save(model.state_dict(), './best_model.pth')
    print('accuracy is {}'.format(acc_running_total))
    print('total loss is {}'.format(loss_running_total))
    print('best accuracy is {}'.format(best_accuracy))


验证

import os

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from models.TestColor import TextColor
from models.YMModel import YmModel

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
model = YmModel()

model.load_state_dict(torch.load('best_model.pth'))


model.eval()
with torch.no_grad():
    correct = 0.
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
    print('Accuracy : {}'.format(100 * correct / len(test_dataset)))
folder_path = './images'
files_names = os.listdir(folder_path)
transform_test = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

for file_name in files_names:
    image_path = os.path.join(folder_path, file_name)
    image = Image.open(image_path)
    image = transform_test(image)
    image = np.reshape(image, [1, 3, 32, 32])
    output = model(image)
    _, predicted = torch.max(output, 1)
    source_name = os.path.splitext(file_name)[0]
    predicted_class = classes[predicted.item()]
    colors = TextColor.GREEN if predicted_class == source_name else TextColor.RED
    print(f"Source is {TextColor.BLUE}{source_name}{TextColor.RESET}, and predicted is {colors}{predicted_class}{TextColor.RESET}")

结果

TextColor是自定义字体颜色的类,image中就是自己的图片。
结果如下:测试集的正确率有82.7%

在这里插入图片描述

相关推荐

  1. 一起深度学习——CIFAR10

    2024-07-19 02:54:01       35 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-19 02:54:01       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-19 02:54:01       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-19 02:54:01       58 阅读
  4. Python语言-面向对象

    2024-07-19 02:54:01       69 阅读

热门阅读

  1. 代码随想录学习 54day 图论 Bellman_ford 算法精讲

    2024-07-19 02:54:01       20 阅读
  2. 锁升级过程中的两次自旋 面试重点

    2024-07-19 02:54:01       23 阅读
  3. electron 应用的生命周期

    2024-07-19 02:54:01       23 阅读
  4. SQL基础

    2024-07-19 02:54:01       22 阅读
  5. 【Unity C#优化】业务逻辑代码方面的优化

    2024-07-19 02:54:01       21 阅读
  6. 【Linux】微基准测试

    2024-07-19 02:54:01       21 阅读