0 简单的图像分类

本文主要针对交通标识图片进行分类,包含62类,这个就是当前科大讯飞比赛,目前准确率在0.94左右,难点如下:

1 类别不均衡,有得种类图片2百多,有个只有10个不到;

2 像素大小不同,导致有的图片很清晰,有的很模糊;

直接上代码:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split

from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import pandas as pd
from torch.utils.data import random_split

warnings.filterwarnings("ignore")

# 检测能否使用GPU
print(#labels
torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
)

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 62  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 10  # 训练的轮次
traindataset = datasets.ImageFolder(root='../all/data/train_set/', transform=transforms.Compose([
    transforms.Resize((224,224)),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
]))


# 分割比例:比如80%的数据用于训练,20%用于验证
train_val_ratio = 0.8
train_size = int(len(traindataset) * train_val_ratio)
val_size = len(traindataset) - train_size
train_dataset, val_dataset = random_split(traindataset, [train_size, val_size])


classes = traindataset.classes
print(classes)
 
model = models.resnext50_32x4d(pretrained=preteain)
#model = models.resnet34(pretrained=preteain)

if preteain == True:
    for param in model.parameters():
        param.requires_grad = False
        
model.fc = nn.Linear(in_features=2048, out_features=n_classes, bias=True)
model = model.to(device)
 
 
def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc
 
 
def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))
 
        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        print("轮次:%4d|测试集损失:%.5f|测试集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy
 
 
loss_fn = nn.CrossEntropyLoss().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

模型预测:

sub = pd.read_csv("../all/data/example.csv")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model.eval()
for path in os.listdir("../all/data/test_set/"):
    try:
        img = Image.open("../all/data/test_set/"+path)
        img_p = transform(img).unsqueeze(0).to(device)
        output = model(img_p)
        pred = output.argmax(dim=1).item()
        if img.size[0] * img.size[1]<2000:
            plt.imshow(img)
            plt.show()
        p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
        sub.loc[sub['ImageID'] == path,'label'] = classes[pred]
        print(f'{path} size = {img.size}, 该图像预测类别为:', classes[pred])
    except:
        print(f'error {path}')
sub.loc[sub['ImageID']=='e57471de-6527-4b9b-90a8-4f1d93909216.png','label'] = 'Under Construction'
sub.loc[sub['ImageID']=='ff38d59e-9a11-41e4-901b-67097bb0e960.png','label'] = 'Keep Left'
sub.columns = ['ImageID','Sign Name']
label_map = pd.read_excel("../all/data/label_map.xlsx")
sub_all = pd.merge(left=sub,right=label_map,on='Sign Name',how='left')
#sub_all[['ImageID','label']].to_csv('./sub_resnet34_add_img_ratio_drop_dire.csv',index=False)

个人的心得:

1 如何进行图片增强,图片增强应该注意什么(方向问题);

2 模型大小如何进行选择;

相关推荐

  1. 使用OpenCV进行简单图像分割3个步骤

    2024-06-18 07:10:02       28 阅读
  2. 图像分类实战案例

    2024-06-18 07:10:02       39 阅读
  3. 一个Pytorch 简单分类本地图片训练AI例子

    2024-06-18 07:10:02       64 阅读
  4. C# 调整图像亮度简单示例

    2024-06-18 07:10:02       38 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-18 07:10:02       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-18 07:10:02       101 阅读
  3. 在Django里面运行非项目文件

    2024-06-18 07:10:02       82 阅读
  4. Python语言-面向对象

    2024-06-18 07:10:02       91 阅读

热门阅读

  1. oracle SCHEDULER

    2024-06-18 07:10:02       26 阅读
  2. mysql面试题 Day2

    2024-06-18 07:10:02       27 阅读
  3. 【Homebrew】包管理器清理软件包,释放mac空间

    2024-06-18 07:10:02       44 阅读
  4. redis击穿问题使用锁实现方案

    2024-06-18 07:10:02       33 阅读
  5. ZSH 配置

    2024-06-18 07:10:02       32 阅读
  6. supersocket的socket层源码解析

    2024-06-18 07:10:02       33 阅读
  7. codeArts Snap:辅助你的编程神器

    2024-06-18 07:10:02       32 阅读
  8. 机器学习中的神经网络重难点!纯干货(上篇)

    2024-06-18 07:10:02       34 阅读
  9. LeetCode 619, 58, 24

    2024-06-18 07:10:02       43 阅读