pytorch学习(八)Dataset加载分类数据集

我们之前用torchvision加载了pytorch的网络数据集,现在我们用Dataset加载自己的数据集,并且使用DataLoader做成训练数据集。

图像是从网上下载的,网址是 点这里,标签是图像文件夹名字。下载完成后作为自己的数据集。

1.加载自己的数据集的思路

    1)要完成继承自Dataset的类的构建

          由于Dataset是一个包含了虚函数的类,因此继承Dataset后,必须实现这些虚函数。

   2)第一个要完成的是__init__的构建,一般的方法是在__init__(self,root_dir, label_dir)中设置数据集的根目录root_dir,和类别数据集label_dir,然后用os.listdir得到label_dir中的图像名字

    3)第二个要完成的就是

__getitem__(self, item):

       item就是所要取数据的索引,这个函数主要是返回一个训练数据(比如一个图像),和一个结果数据,比如(该图像的分类结果是一个ant),因此用到刚os.listdir所列出的文件名字,用os.path.join加入路径,得到图像的绝对路径,用PIL导入图像,并给label赋值,返回图像和;abel即可。

   4)第三个要实现的就是数据集的长度

  __len__(self):

可以直接len(os.listdir所列出的文件名的数组),就可以得到数据集的长度。

2.需要注意的问题

   我在调试的时候发现

for imgs, labels  in train_loader:

一直报错,查找原因,发现是该数据集中的图像存在两个问题,第一个是大小不一,第二个貌似通道个数也不一致。

大小不一

因此使用transform做了处理

transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])

3.代码如下:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter("logs")
transform=transforms.Compose([ transforms.Resize((320,320),interpolation=Image.BILINEAR),
                                transforms.Grayscale(),
                                transforms.ToTensor()])


class MyDataLoader(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        img = transform(img)
        label = self.label_dir
        return img,label
    def __len__(self):
        return len(self.img_path)

root_dir = "E:/TOOLE/slam_evo/pythonProject/data/hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"

ants_dataset = MyDataLoader(root_dir,ants_label_dir)
bees_dataset = MyDataLoader(root_dir,bees_label_dir)
train_data = ants_dataset + bees_dataset

img0, label0 = train_data[12]
# img0.show()
img1, label1 = train_data[124]
# img1.show()
# 一次处理数据10个
BATCH_SIZE = 10
# 把数据集装载到DataLoader里
train_loader = DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)

A = len(train_loader)
num_iter = 0
for imgs, labels  in train_loader:

    print(imgs.shape)
    print(labels)
    # print(train_data.classes)
    writer.add_images("ant-bees",imgs,num_iter)
    num_iter = num_iter +1

writer.close()


用tensorboard显示,batch_size= 10,因此每次迭代有10张图像

标签为:

相关推荐

  1. 深度学习-Pytorch数据构造和分批

    2024-07-20 07:26:01       51 阅读
  2. 深度学习-4-PyTorch中的数据Dataset和DataLoader

    2024-07-20 07:26:01       14 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-20 07:26:01       52 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-20 07:26:01       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-20 07:26:01       45 阅读
  4. Python语言-面向对象

    2024-07-20 07:26:01       55 阅读

热门阅读

  1. [React]利用Webcomponent封装React组件

    2024-07-20 07:26:01       13 阅读
  2. CSS3 教程

    2024-07-20 07:26:01       14 阅读
  3. [python] 利用opencv显示对比试验效果

    2024-07-20 07:26:01       13 阅读
  4. vue中的some方法使用@1@

    2024-07-20 07:26:01       14 阅读
  5. RK3328 Debian安装OpenMediaVault

    2024-07-20 07:26:01       16 阅读
  6. list容器

    2024-07-20 07:26:01       14 阅读
  7. http 协议中GET如何传递参数(Query String)?

    2024-07-20 07:26:01       12 阅读
  8. 浏览器的缓存

    2024-07-20 07:26:01       17 阅读
  9. 记录贴-idea导入别人的项目

    2024-07-20 07:26:01       14 阅读
  10. 【SpringBoot】分页查询

    2024-07-20 07:26:01       17 阅读
  11. 第九十六周周报

    2024-07-20 07:26:01       15 阅读
  12. Webserver笔记

    2024-07-20 07:26:01       16 阅读