PyTorch中的 Dataset、DataLoader 和 enumerate()

PyTorch:关于Dataset,DataLoader 和 enumerate()

本博文主要参考了 Pytorch中DataLoader的使用方法详解pytorch:关于enumerate,Dataset和Dataloader 两篇文章进行总结和归纳。

DataLoader 隶属 PyTorch 中 torch.utils.data 下的一个类,任何继承 torch.utils.data.Data 类的子类均需要重载__getitem__()及__len__()两个函数,且子类在__init__()函数产生的数据路径,将作为 DataLoader 参数 DataSets 的实参。该类将自定义的 Dataset 根据 batch size 大小、是否 shuffle 等封装成一个 Batch Size 大小的 Tensor,用于后面的训练。

Dataset 类构建

在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。这里的 Dateset 可以指整个数据集,也可以是训练集,测试集等。

class Dataset:
    def __init__(self,...):
        ...
    def __len__(self,...):
        return n
    def __getitem__(self,item):
        return data[item]

正常情况下,该数据集是要继承 Pytorch 中 Dataset 类的,但实际操作中,即使不继承,数据集类构建后仍可以用 Dataloader() 加载的。

在dataset类中,len(self)返回数据集中数据的总个数,getitem(self,item)表示每次返回第 item 条(个)数据。
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个 item
③__getitem__:返回一条(个)训练样本的数据,并将其转换成 tensor

在 dataset 实例化时一般要传入数据集的路径,一般在__init__() 函数中指定数据集路径等相关信息(可以通过相关路径读取包含图像名称、标签等相关信息的 json 或者 csv 等类型的文件);通过__getitem__(self,item) 得到对应的图像并将进行 transform 转换(缩放、裁剪、转换成 tensor 等操作),最终以 tensor 的形式返回。

DataLoader 使用

在构建 Dataset 类后,即可使用 DataLoader 加载。DataLoader 中常用参数如下:

  1. dataset:需要载入的数据集,如前面构造的 dataset 类。
  2. batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个 batch 进行训练。
  3. shuffle:是否在打乱数据集样本顺序。True 为打乱,False 反之。
  4. num_workers:这个参数决定了有几个进程来处理 data loading。0 意味着所有的数据都会被 load 进主进程。(默认num_workers=0,在 Windows 系统下需要设置为 0
  5. drop_last:是否舍去最后一个batch的数据(很多情况下数据总数 N 与 batch size 不整除,导致最后一个 batch 不为 batch size)。True 为舍去,False 反之。

注意:使用 DataLoader 读取数据时,为了加快效率,所以使用了多个线程,即 num_workers 不为0,在 windows 系统下报如下的错误。
RuntimeError: Couldn’t open shared file mapping: <torch_16716_3565374679>, error code: <1455>

DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 

参照 DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support() 教程中提到,在 https://github.com/pytorch/pytorch/pull/5585 中给出了一些官方解释,应该是 Windows下的一些线程文件读写的问题。
在 Windows 上,FileMapping 对象应必须在所有相关进程都关闭后,才能释放。启用多线程处理时,子进程将创建 FileMapping,然后主进程将打开它。 之后当子进程将尝试释放它的时候,因为父进程还在引用,所以它的引用计数不为零,无法释放。 但是当前代码没有提供在可能的情况下再次关闭它的机会。这个版本官方说 num_workers=1 是可以用的,更多的线程还在解决,不过现在即便是用 2 个子进程也已经可以了。

加载数据的过程

pytorch 中加载数据的顺序是:

  1. 创建一个 dataset 对象
  2. 创建一个 dataloader 对象
  3. 循环 dataloader 对象,将 data, label 拿到模型中去训练

enumerate() 函数

在对 Dataloader 进行读取时,通常使用 enumerate() 函数,enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。调用 enumerate(dataloader) 时每次都会读出一个 batch_size 大小的数据。例如,数据集中总共包含 245 张图像,train_loader = dataloader(dataset, batch_size=32, drop_last=True) 被实例化时,经过以下代码后输出的 count 为 224(正好等于32*7),而多出来的 245-224=21 张图像不够一个 batch 因此被 drop 掉了。下面展示了如何从 dataloader 中通过 enumerate() 返回一个batch_size的数据。

for k, images, target in enumerate(dataloader):

其中,k代表下标值,images, target 代表可遍历的数据对象。因为 enumerate(dataloader) 一次会返回一个 batch 的数据,所以返回的 images 为 batch_size 长度的list,target 也为 batch_size 长度的 list。

通常,dataloader 里包含很多个数据对象,那么我们应该怎么保证 batch 就是我们所需要的数据呢?通过 Dataset 的定义可以实现我们需要的数据。Dataset 是用来定义数据从哪里读取,以及如何读取的问题,通过重写 Dataset 抽象类的__getitem__()函数。enumerate(dataloader) 得到的数据就是 __getitem__() 函数返回的数据,只不过 enumerate(dataloader) 一次会得到 batch_size 个不同 item 的数据组成的 list。

def __getitem__(self, item):
	images = self.data[item]
	target = self.label[item]
	return images, target

返回 item 对应的数据,就是 enumerate(dataloader) 得到的数据的一部分。

def __len__(self):
	return len(self.data)

返回 dataset 中总的数据个数,用于控制返回多少个 batch 的数据,enumerate(dataloader) 一次会返回 batch_size 大小的 list。

Reference

Pytorch中DataLoader的使用方法详解
pytorch:关于enumerate,Dataset和Dataloader
DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support()

相关推荐

  1. PyTorch Dataset、DataLoader enumerate()

    2023-12-30 10:42:02       39 阅读
  2. Pythonenumerate函数详解

    2023-12-30 10:42:02       11 阅读
  3. [AIGC] 深入浅出 Python`enumerate`函数

    2023-12-30 10:42:02       9 阅读
  4. 在 Swift , enumerated() 有哪些常用使用方式 ?

    2023-12-30 10:42:02       21 阅读
  5. Pytorchresizereshape

    2023-12-30 10:42:02       25 阅读
  6. PyTorchAOTAutograd、PrimTorchTorchInductor

    2023-12-30 10:42:02       26 阅读
  7. pytorchdatasetdataloader

    2023-12-30 10:42:02       17 阅读
  8. c# Enumerable<T>GroupJoin方法Join用法区别

    2023-12-30 10:42:02       12 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-30 10:42:02       14 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-30 10:42:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-30 10:42:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-30 10:42:02       18 阅读

热门阅读

  1. Linux添加一个指令代替指定指令

    2023-12-30 10:42:02       37 阅读
  2. 79. Word Search

    2023-12-30 10:42:02       38 阅读
  3. 蓝桥杯python比赛历届真题99道经典练习题 (8-12)

    2023-12-30 10:42:02       30 阅读
  4. 结构体--高考数组

    2023-12-30 10:42:02       35 阅读
  5. STM32传输FPGA业务

    2023-12-30 10:42:02       33 阅读
  6. Windows下Qt使用MSVC编译出现需要转为unicode的提示

    2023-12-30 10:42:02       33 阅读