pytorch-分类-检测-分割的dataset和dataloader创建

1.前言

        在PyTorch中,DatasetDataLoader是两个重要的工具,用于构建输入数据的管道。

(1)Dataset是一个抽象类,表示数据集,需要实现__len____getitem__方法。

(2)DataLoader是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。

2.分类任务创建DatasetDataLoader

        (1)对于分类任务,Dataset需要返回图像和对应的标签

from torch.utils.data import Dataset  
from PIL import Image  
import os  
import torch  
  
class ClassificationDataset(Dataset):  
    def __init__(self, root_dir, transform=None):  
        self.transform = transform  
        self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]  
        self.labels = [...]  # 这里应该是与图像对应的标签列表  
  
    def __len__(self):  
        return len(self.images)  
  
    def __getitem__(self, idx):  
        img_path = self.images[idx]  
        image = Image.open(img_path).convert('RGB')  
        label = self.labels[idx]  
          
        if self.transform:  
            image = self.transform(image)  
          
        return image, label

        (2)DataLoader加载数据

from torch.utils.data import DataLoader  
  
transform = ...  # 这里定义你的数据预处理流程  
dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)  
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3.检测任务创建DatasetDataLoader

        (1)Dataset需要返回图像和对应的边界框信息

class DetectionDataset(Dataset):  
    def __init__(self, root_dir, transform=None):  
        self.transform = transform  
        self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]  
        self.annotations = [...]  # 这里应该是与图像对应的边界框信息列表  
  
    def __len__(self):  
        return len(self.images)  
  
    def __getitem__(self, idx):  
        img_path = self.images[idx]  
        image = Image.open(img_path).convert('RGB')  
        boxes = self.annotations[idx]  # 这些是边界框信息  
  
        if self.transform:  
            image, boxes = self.transform(image, boxes)  
          
        return image, boxes

 (2)DataLoader加载数据

dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)

4.分割任务创建DatasetDataLoader

(1)Dataset需要返回图像和对应的分割掩码

class SegmentationDataset(Dataset):  
    def __init__(self, root_dir, transform=None):  
        self.transform = transform  
        self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]  
        self.masks = [...]  # 这里应该是与图像对应的分割掩码列表  
  
    def __len__(self):  
        return len(self.images)  
  
    def __getitem__(self, idx):  
        img_path = self.images[idx]  
        mask_path = self.masks[idx]  
        image = Image.open(img_path).convert('RGB')  
        mask = Image.open(mask_path).convert('L')  # 假设掩码是灰度图  
  
        if self.transform:  
            image, mask = self.transform(image, mask)  
          
        return image, mask

(2)DataLoader加载数据

dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)

在PyTorch的DatasetDataLoader框架中,idx(或称为索引)是通过迭代DataLoader时自动生成的。当你创建一个DataLoader实例,并在训练循环中迭代它时,DataLoader会内部调用Dataset__getitem__方法,并自动为你提供索引idx

相关推荐

  1. pytorch-分类-检测-分割datasetdataloader创建

    2024-04-01 06:40:01       45 阅读
  2. PyTorch DatasetDataLoader enumerate()

    2024-04-01 06:40:01       56 阅读
  3. pytorchdatasetdataloader

    2024-04-01 06:40:01       42 阅读
  4. PytorchDatasetDataLoader注意事项

    2024-04-01 06:40:01       43 阅读
  5. PytorchDatasetDataLoader

    2024-04-01 06:40:01       35 阅读
  6. PyTorch DatasetDataLoader长度

    2024-04-01 06:40:01       43 阅读
  7. 深度学习-4-PyTorch数据加载器DatasetDataLoader

    2024-04-01 06:40:01       20 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-01 06:40:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-01 06:40:01       101 阅读
  3. 在Django里面运行非项目文件

    2024-04-01 06:40:01       82 阅读
  4. Python语言-面向对象

    2024-04-01 06:40:01       91 阅读

热门阅读

  1. JVM堆栈详解

    2024-04-01 06:40:01       36 阅读
  2. 20240323-2-决策树面试题DecisionTree

    2024-04-01 06:40:01       27 阅读
  3. ubuntu22.04忘记用户密码

    2024-04-01 06:40:01       33 阅读
  4. 【内网离线环境】搭建本地YUM源

    2024-04-01 06:40:01       42 阅读
  5. Excel中文显示问号

    2024-04-01 06:40:01       44 阅读
  6. 大型语言模型可以“在两年内彻底改变金融业”

    2024-04-01 06:40:01       35 阅读
  7. 【布式事务】分布式事务和分布式指导理论简介

    2024-04-01 06:40:01       36 阅读
  8. 182. 查找重复的电子邮箱

    2024-04-01 06:40:01       38 阅读
  9. 设计模式(10):享元模式

    2024-04-01 06:40:01       30 阅读
  10. 关于rabbitmq的prefetch机制

    2024-04-01 06:40:01       40 阅读