Pytorch中Dataset和dadaloader的理解

不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。

dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的mask组成。

图1 典型数据集的结构

为了获取数据集,典型的代码如下

from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms

# 定义数据集
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'

class MyData(Dataset):
    def __init__(self, imgdir, maskdir,transform):
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.transform = transform
        self.img_list = os.listdir(self.imgdir)
        self.mask_list= os.listdir(self.maskdir)
        self.img_list.sort()
        self.mask_list.sort()

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        mask_name =self.mask_list[idx]
        img_item_path = os.path.join(self.imgdir, img_name)
        mask_item_path =os.path.join(self.maskdir,mask_name)

        img =Image.open(img_item_path)
        mask =Image.open(mask_item_path)

        img = self.transform(img)
        mask = self.transform(mask)

        return img, mask

    def __len__(self):
        assert len(self.img_list) == len(self.mask_list)
        return len(self.img_list)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_data_dir = 'dataset/train'
    train_GT_dir = 'dataset/train_GT'
    dataset = MyData(train_data_dir, train_GT_dir ,transform)
    dataloader = DataLoader(dataset, batch_size=4, num_workers=0)
    for step, (img,mask) in enumerate(dataloader):
        print(step)
        print(img.shape)
        print(mask.shape)
        if step>0:
            break

程序运行的结果如下:

返回了一个batch的img 和mask 的尺寸,说明数据集抽取成功了.

在建立数据集的过程中需用重写__getitem()__和__len()__方法即可。

相关推荐

  1. PyTorch DatasetDataLoader enumerate()

    2024-01-28 11:28:02       56 阅读
  2. pytorchdatasetdataloader

    2024-01-28 11:28:02       42 阅读
  3. 深度学习-4-PyTorch数据加载器DatasetDataLoader

    2024-01-28 11:28:02       20 阅读
  4. PytorchDatasetDataLoader注意事项

    2024-01-28 11:28:02       43 阅读
  5. pytorch-分类-检测-分割datasetdataloader创建

    2024-01-28 11:28:02       44 阅读
  6. PytorchDatasetDataLoader

    2024-01-28 11:28:02       35 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-28 11:28:02       91 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-28 11:28:02       97 阅读
  3. 在Django里面运行非项目文件

    2024-01-28 11:28:02       78 阅读
  4. Python语言-面向对象

    2024-01-28 11:28:02       88 阅读

热门阅读

  1. 【算法题】77. 组合

    2024-01-28 11:28:02       49 阅读
  2. org.springframework.util.StringUtils 下StringUtils工具类

    2024-01-28 11:28:02       40 阅读
  3. uniapp-app使用富文本编辑器editor

    2024-01-28 11:28:02       48 阅读
  4. RUST笔记: 动态链接库的创建和使用

    2024-01-28 11:28:02       69 阅读
  5. Springboot多数据源连接

    2024-01-28 11:28:02       70 阅读