【Pytorch】1.读取训练数据集

导入Dataset类

from torch.utils.data import Dataset
# 注意是Dataset(大写)的才是类

通过jupyter我们可以阅读一下Dataset类的具体使用方法

help(Dataset)
# 或者直接
Dataset??

在这里插入图片描述
我们可以看到具体对Dataset类的解释
从蓝色字体我们可以得出

  • 所有的代表map的数据集应该继承这个类
  • 所有继承的子类都重写__getitem__这个方法,这个方法支持获取数据样本中的指定键
  • 同时子类也要重写__len__这个方法返回数据集大小
  • 子类可以重写__getitem__,来加速样本生成
    也就是说我们要重写__getitem__方法与__len__方法

其他导入包

from PIL import Image  # 主要用于图像的操作
import os  # 文件操作

Image用于将目标路径的文件转化为可以打开的图片变量
os用于文件操作

  • listdir对目标文件夹中的文件名称列成列表
  • os.path.join用于将两个地址进行拼接

MyData类的定义

class MyData(Dataset):  # 创建一个MyData类,同时继承Dataset类
    def __init__(self, root_dir, label_dir):  # 类似于c++的构造函数
        # root_dir 一般设置为训练集文件夹的地址(train)
        # label_dir 一般设置为分类文件夹的地址(ants)
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)  # 这个函数的作用是将root_dir的地址与label_dir的地址拼接起来
        self.img_path = os.listdir(self.path)  # 将特定文件夹地址(path)中的所有文件列成一个list

    def __getitem__(self, index):  # 重写父类的方法
        img_name = self.img_path[index]  # 获取对应下标的图片名
        img_item_path = os.path.join(self.path, img_name)  # 获取图片路径
        img = Image.open(img_item_path)  # 根据图片路径打开图片
        # img.show()    展示图片
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

类的实例化

# root_dir 一般设置为训练集文件夹的地址(train)
# label_dir 一般设置为分类文件夹的地址(ants)
root_dir = "hymenoptera_data/train"
ant_label_dir = "ants"
bee_label_dir = "bees"
# 生成对应训练集的图片、标签列表
ants_dataset = MyData(root_dir, ant_label_dir)
bees_dataset = MyData(root_dir, bee_label_dir)

# 列表相加,前提是必须重载__len__方法
train_dataset = ants_dataset + bees_dataset

在这里插入图片描述

源码链接

github

参考资料

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐

  1. python/pytorch读取数据

    2024-05-09 21:14:04       56 阅读
  2. 大语言模型训练数据1

    2024-05-09 21:14:04       59 阅读
  3. 读书笔记】训练自己的数据yolov8

    2024-05-09 21:14:04       27 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-09 21:14:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-09 21:14:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-05-09 21:14:04       82 阅读
  4. Python语言-面向对象

    2024-05-09 21:14:04       91 阅读

热门阅读

  1. git stash技巧

    2024-05-09 21:14:04       32 阅读
  2. Docker技能

    2024-05-09 21:14:04       40 阅读
  3. 软件开发的必备步骤

    2024-05-09 21:14:04       29 阅读
  4. 哈夫曼编码python算法实现(代码版)

    2024-05-09 21:14:04       35 阅读
  5. 鸿蒙原生应用元服务开发-Web上传文件

    2024-05-09 21:14:04       31 阅读
  6. 2023-2024年电力行业报告合集(精选69份)

    2024-05-09 21:14:04       42 阅读
  7. linux不同引号的含义(随手记)

    2024-05-09 21:14:04       33 阅读
  8. 单选按钮选中后取消

    2024-05-09 21:14:04       37 阅读