【kaggle代码】Plant Seedlings Classification (使用Resnet-18完成分类任务)

比赛地址:植物种子分类

注意的点:

  1. 使用datasets.ImageFolder读取数据,并且制作数据集。分类任务与图像分割任务不同。分类任务的数据是:【图片,标签(字符串类型)】,所以两者的数据读取方式不同。在分割任务中,常常需要重写Dataset便于图像预处理,而在该分类任务中,不需要重写Dataset,在datasets.ImageFolder中,可以接收transform参数对读入的图像进行处理,而不对标签(字符串)处理,且会将标签自动转为标签索引形式。关于datasets.ImageFolde

  2. torch的Dataloader接受的是(data, labels)的元组形式,在 PyTorch 的 DataLoader 中,元组列表中元素的数据类型要求相对较松。每个元组的第一个元素通常是输入数据,第二个元素是对应的标签。这两个元素可以是任何 PyTorch 支持的数据类型,例如张量(torch.Tensor)、NumPy 数组、PIL 图像等。

  3. 对于使用预训练好的Resnet-18,可以通过更改网络最后一层,来适应该分类任务。对于很多模型,model.fc 是最后一层的全连接层。
    在这里插入图片描述

  4. 在这个比赛中,最初得分总是很低。最后发现原因是:在提交submission中,图片名称是按照顺序读入的,但是在使用Dataloader读入测试集数据时,使用了shuffle=True,导致读入的顺序被打乱,从而使得图片名称和预测标签不对应,导致得分很低。改为shuffle=Flase问题解决。

代码,按照ipynb顺序排列:

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
import numpy as np
import os
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from torch import optim
from torch import nn
import cv2 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import random_split
from tqdm import tqdm
import imageio
from torchvision import datasets
from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。
work_dir = '/kaggle/input/plant-seedlings-classification'
os.listdir(work_dir)
`import glob
#读取数据,用于后续制作数据集
train_path = os.path.join(work_dir,'train')

# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(train_path, '*'))

print(f'总的类别数量:{len(folders)}')``

```python
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]

transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std),
])

dataset = datasets.ImageFolder(root=train_path,transform=transforms)

# self.classes:用一个 list 保存类别名称
# self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
# self.imgs:保存(img-path, class) tuple的 list
#查看有多少个样例和多少个类别
print('samples',len(dataset))
print('classes',len(dataset.classes))
print(dataset)
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs[0])
label_counts = []

# 遍历每个类别的文件夹
for d in glob.glob(os.path.join(train_path, '*')):
#glob.glob 返回一个包含匹配指定模式的所有文件或文件夹的列表。在这里,它返回了所有子文件夹的路径列表。    
    
    
    # 获取类别名称
    label = os.path.basename(d)
    
    # 计算该类别中图像的数量
    count = len(glob.glob(os.path.join(d, '*')))
    
    # 将类别名称和图像数量添加到列表中
    label_counts.append({'label': label, 'count': count})

# 创建一个 Pandas DataFrame
label_counts_df = pd.DataFrame(label_counts)

# 打印 DataFrame
print(label_counts_df)
## 划分训练集和验证集
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [3750, 1000])
print(len(train_dataset))
print(len(valid_dataset))
#DataLoader 返回的是一个迭代器(iterator),每次迭代都会产生一个包含小批次数据的元组。
#这个元组的内容取决于你在创建 DataLoader 时指定的数据集的格式。
#通常情况下,这个元组包含两个元素,分别是输入数据和对应的标签。
#例如,如果你的数据集是一个 TensorDataset,那么每个小批次的元组就是 (inputs, targets)。

# 这里是分类任务,和分割任务不同。
#dataset使用ImageFolder就对image已经进行了transform,而label使用的是索引(0,1...),所以不需要重写Dataloader

train_loader = DataLoader(train_dataset,batch_size=16,shuffle=True,num_workers=4)
valid_loader = DataLoader(valid_dataset,batch_size=16,shuffle=True,num_workers=4)

train_features_batch, train_labels_batch = next(iter(train_loader))

print(train_features_batch.shape, train_labels_batch.shape)
print(train_features_batch[0])a
print(train_labels_batch)
import torchvision.models as models

#修改最后一层
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, len(dataset.classes), device=device)
# 模型中的关键部分:
# model.features:

# 这是模型的特征提取部分,通常包含卷积层和池化层。
# model.avgpool:

# 模型中的平均池化层,用于对特征进行全局平均池化。
# model.classifier:

# 这是模型的分类部分,通常包含全连接层。

# model.fc:

# 对于很多模型,model.fc 是最后一层的全连接层。


model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=1e-4)


# 假设输入大小为 (batch_size, num_channels, height, width)
# 这里使用随机数据,你需要根据你的模型和数据进行适当的调整
batch_size = 1
num_channels = 3  # 通常是3,表示RGB图像
height, width = 224, 224  # 这可能需要根据你的数据集进行调整

# 创建随机输入数据
random_input = torch.randn(batch_size, num_channels, height, width)

# 将数据移动到设备(GPU或CPU)
random_input = random_input.to(device)

# 模型推断

with torch.no_grad():
    output = model(random_input)

# 打印模型输出
print("Model Output Shape:", output.shape)
print("Model Output Values:", output)
# 训练
epochs=10
train_loss_all = [] #定义一个列表用于保存总的训练集loss,方便后续打印
val_loss_all =[]   #定义一个列表用于保存总的验证集loss,方便后续打印
best_loss = 1e10   #记录最佳的loss
for epoch in range(epochs):
    train_loss =0
    val_loss =0
    train_num=0
    val_num =0
    correct = 0
    model.train()
    loop = tqdm(train_loader)
    for idx,(image,label) in enumerate(loop):
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
     
        pre_lab = torch.argmax(output,1)
        output = model(image)
        
        loss = loss_fn(output ,label)
        
        loss.backward()
        optimizer.step() #梯度更新
        
        train_loss +=loss.item()
        train_num +=1
    train_loss_all.append(train_loss / train_num)
    print('{} *****Train Loss:{:.4f}'.format(epoch,train_loss_all[-1]))
    
    
    with torch.no_grad():
        loop=tqdm(valid_loader)
        for idx,(image,label) in enumerate(loop):
            image = image.to(device)
            label = label.to(device)
            
            output = model(image)
            pre_lab = torch.argmax(output,1)
            
            loss = loss_fn(output,label)
            val_loss += loss.item()
            val_num +=1
            correct += (pre_lab == label.data).sum().item()
        correct /= len(valid_loader.dataset)
        val_loss_all.append(val_loss / val_num)
        print(f'{epoch} *****Valid Loss:{val_loss_all[-1]:.4f}  Accuracy={(100 * correct):>0.1f}%')
        
    ##保存模型
    if val_loss_all[-1] < best_loss :
        best_loss = val_loss_all[-1]
        check_points = model.state_dict()
        torch.save(check_points, '/kaggle/working/BestSave.pt')
#可视化模型训练过程中的loss曲线
epochs = list(range(1, 11))  # 或者任何你实际的 epochs 数量

plt.figure(figsize=(10,6))
plt.plot(epochs,train_loss_all,"ro-",label = "Train Loss")
plt.plot(epochs,val_loss_all,"bs-",label = "Valid Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()
#单张图像查看


from PIL import Image # Image模块是在Python PIL图像处理中常见的模块,对图像进行基础操作的功能基本都包含于此模块内。


#读取数据,用于后续制作数据集
tem_path = os.path.join(train_path,'Black-grass')


# 使用glob列出train文件夹下的所有文件夹
fold = glob.glob(os.path.join(tem_path, '*'))

print(fold[0])


input_image = Image.open(fold[0])

print(input_image.size)


input_tensor = transforms(input_image).unsqueeze(0).to(device)  # 添加 batch 维度
print(input_tensor.shape)

model = model.to(device)
model.load_state_dict(torch.load('/kaggle/working/BestSave.pt'))
input_tensor = input_tensor.to(device)
with torch.no_grad():
    output = model(input_tensor)
print(output)
print(torch.argmax(output,1))
#Dataloader默认是返回(输入数据,标签),但是测试集中没有标签,故重写一个Dataloader

class TestDataset(Dataset):
    def __init__(self,test_path,transform=None):
        self.test_path = test_path
        self.test_images = os.listdir(self.test_path)
        self.transform = transform
        
    def __len__(self):
        return len(self.test_images)
    
    def __getitem__(self,idx):
        self.image_path = os.path.join(self.test_path,os.listdir(self.test_path)[idx])
        img = Image.open(self.image_path)
        if self.transform is not None:
            img = self.transform(img)
            
        return img
        
#单张图像进行验证
import glob
#读取数据,用于后续制作数据集
test_path = os.path.join(work_dir,'test')

# 使用glob列出train文件夹下的所有文件夹
folders = glob.glob(os.path.join(test_path, '*'))
print(folders[:2])
print(folders[1])

from torchvision import transforms
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(transform_mean, transform_std)
])

test_dataset = TestDataset(test_path, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,num_workers=4)

print('Test:', len(test_dataset), 'samples')
from tqdm import tqdm
labels = []
model = model.to(device)
model.load_state_dict(torch.load('/kaggle/working/BestSave.pt'))
model.eval()

with torch.no_grad():
    loop = tqdm(test_loader)
    for idx ,(image)in enumerate(loop):
        image = image.to(device)
        
        output = model(image)

        preds = torch.argmax(output,1)
        labels.extend(preds.cpu().numpy().tolist())

species = [dataset.classes[label] for label in labels]

submission = pd.DataFrame({'file': os.listdir(test_path), 'species': species})
submission.to_csv('submission.csv', index=False)

最近更新

  1. TCP协议是安全的吗?

    2024-03-10 20:34:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-10 20:34:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-10 20:34:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-10 20:34:02       18 阅读

热门阅读

  1. 初识C语言—字符串、转义字符、注释

    2024-03-10 20:34:02       22 阅读
  2. vue3注册全局组件

    2024-03-10 20:34:02       18 阅读
  3. Docker Register 搭建私有镜像仓库

    2024-03-10 20:34:02       20 阅读
  4. Linux 系统上卸载 Docker

    2024-03-10 20:34:02       21 阅读
  5. 在 Docker 环境下安装 OpenWrt

    2024-03-10 20:34:02       25 阅读
  6. Docker修改网段

    2024-03-10 20:34:02       22 阅读
  7. Kotlin 中的数据类

    2024-03-10 20:34:02       21 阅读
  8. lvs集群

    lvs集群

    2024-03-10 20:34:02      21 阅读
  9. sklearn随机森林实现(备忘版)

    2024-03-10 20:34:02       20 阅读
  10. Docker

    2024-03-10 20:34:02       21 阅读
  11. Flink命令行提交时参数的传递

    2024-03-10 20:34:02       18 阅读
  12. Redis的HyperLogLog原理介绍

    2024-03-10 20:34:02       19 阅读