[pytorch入门] 5. DataLoader的使用

简介

dataset:数据集,提供数据
dataloader:数据加载器,对数据进行加载,可以讲数据加载到神经网络当中
从dataset中取数据时,通过在dataloader中设置参数来确定取数据的方式

用法

from torch.utils.data import DataLoader
参数:大部分参数都有默认值

 Args:
        dataset (Dataset): 指定数据集
        batch_size (int, optional): how many samples per batch to load (default: ``1``). 每次加载多少数据
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).  每次训练后是否进行无序操作(默认为False,一般设置为True)
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
            returns a batch of indices at a time. Mutually exclusive with
            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
            and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)    每次加载时是否采用多进程加载(但是在windows下可能会出现错误)
        collate_fn (Callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.
        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
            see the example below.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)   除不尽时是否舍去
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)
        multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
            ``None``, the default `multiprocessing context`_ of your operating system will
            be used. (default: ``None``)
        generator (torch.Generator, optional): If not ``None``, this RNG will be used
            by RandomSampler to generate random indexes and multiprocessing to generate
            ``base_seed`` for workers. (default: ``None``)
        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
            in advance by each worker. ``2`` means there will be a total of
            2 * num_workers batches prefetched across all workers. (default value depends
            on the set value for num_workers. If value of num_workers=0 default is ``None``.
            Otherwise, if value of ``num_workers > 0`` default is ``2``).
        persistent_workers (bool, optional): If ``True``, the data loader will not shut down
            the worker processes after a dataset has been consumed once. This allows to
            maintain the workers `Dataset` instances alive. (default: ``False``)
        pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
            ``True``.

实践

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试数据集
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor())

# 加载测试集
test_loader = DataLoader(test_data,shuffle=True, batch_size = 64, num_workers=0, drop_last=False)

img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter('logs')
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("epoch: {}".format(epoch), imgs, step)
        step = step+1

writer.close

shuffle为True时,可以发现每轮训练数据的顺序是不一致的
在这里插入图片描述
drop_last=False时,则没有整除的数据保留,即最后一步会不足所设置的batch大小
在这里插入图片描述

相关推荐

  1. pytorch学习(四):Dataloader使用

    2024-01-24 01:28:01       32 阅读
  2. pytorch学习3-torchvisin和Dataloader使用

    2024-01-24 01:28:01       57 阅读
  3. [pytorch] 定义自己dataloader

    2024-01-24 01:28:01       56 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-24 01:28:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-24 01:28:01       101 阅读
  3. 在Django里面运行非项目文件

    2024-01-24 01:28:01       82 阅读
  4. Python语言-面向对象

    2024-01-24 01:28:01       91 阅读

热门阅读

  1. 【使用vue-cli构建项目详细介绍】

    2024-01-24 01:28:01       64 阅读
  2. 前端Vue开发规范

    2024-01-24 01:28:01       47 阅读
  3. EXCEL VBA两列判断重复

    2024-01-24 01:28:01       60 阅读
  4. ZZULIOJ 1054: 猴子吃桃

    2024-01-24 01:28:01       62 阅读
  5. 动态规划学习——赢得最大数

    2024-01-24 01:28:01       66 阅读
  6. 对重要数据进行MD5保护

    2024-01-24 01:28:01       54 阅读
  7. 第四章 施工质量管理

    2024-01-24 01:28:01       50 阅读
  8. FS4055B电池管理单节磷酸锂电池充电芯片芯片

    2024-01-24 01:28:01       50 阅读