【Pytorch使用自制数据集,Dataloader】

数据集结构
在这里插入图片描述

话不多说,直接上核心代码

myDataset.py

from collections import Counter
from torch.utils.data import Dataset
import os
from PIL import Image

class MyDataset(Dataset):
    """
    读取自制的数据集
    args:
        - image_dir: 图片的地址
        - label_dir: 标签的地址
        - name: 数据集的名称
        - transform: 数据集的预处理
    """
    def __init__(self, image_dir:str,  label_dir:str, name:str, transform=None):
        self.img_dir = os.path.join(image_dir, name)
        self.label_dir = os.path.join(label_dir, name)
        self.name = name
        self.image_path = os.listdir(self.img_dir)
        self.label_path = os.listdir(self.label_dir)
        self.transform = transform
    """
    读取数据集
    args:
        - index: 数据集的索引
    return:
        - image: 图片
        - label: 图片的标签
    """
    def __getitem__(self, index:int)->tuple:
        # 获取图片的地址
        image = self.image_path[index]
        image = os.path.join(self.img_dir, image)
        # 获取图像
        image = Image.open(image)
        # 如果不是彩色图像,将下面的注释解开可以转换成彩色图像,不过图片的模样改变很大
        # if image.mode!= 'RGB':
        #     image = image.convert('RGB')
        # 获取label的地址
        index_path = self.label_path[index]
        index_path = os.path.join(self.label_dir, index_path)
        label = self.parseTxt(index_path)
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    """
    将txt文件解析成数字
    description:
        > 这里每个txt文件下可能有多个label,选出现最多的,如果你的txt里面只有一个label的话,想办法读取出来返回就行
    args:
        - label: txt文件的地址
    return:
        - label: 图片的标签
    """
    def parseTxt(self, label:str)->int:
        first_column = []
        with open(label, 'r') as f:
            for line in f.readlines():
                first_column.append(int(line.split()[0]))

        counter = Counter(first_column)
        return counter.most_common(1)[0][0]
    """
    获取数据集的长度
    """
    def __len__(self)->int:
        return len(self.image_path)

demo

train.py

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
# 导入加载数据集的类
from dataset import MyDataset
import os

root = os.path.join(os.getcwd(),'courseHomework','datasets')
transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.5), (0.5,))
])

train_dataset = MyDataset(root + '/images', root +'/labels', 'train', transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=False)

for step, data in enumerate(train_loader):
    imgs, labels = data
    print(imgs[0].shape)
    transforms.ToPILImage()(imgs[0]).show()
    break

大家结构和我不一样可以自由发挥

相关推荐

  1. [pytorch] 定义自己dataloader

    2023-12-06 05:26:12       36 阅读
  2. pytorch学习(四):Dataloader使用

    2023-12-06 05:26:12       13 阅读
  3. DataLoader自定义数据制作

    2023-12-06 05:26:12       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-06 05:26:12       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-06 05:26:12       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-06 05:26:12       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-06 05:26:12       20 阅读

热门阅读

  1. 【水】pytorch:torch.reshape和torch.Tensor.view的区别

    2023-12-06 05:26:12       43 阅读
  2. 45. 跳跃游戏 II

    2023-12-06 05:26:12       36 阅读
  3. jenkins清理僵尸任务和排队任务

    2023-12-06 05:26:12       35 阅读
  4. LINUX 下部署github仓库

    2023-12-06 05:26:12       32 阅读
  5. CSS基础概念之选择器类型

    2023-12-06 05:26:12       46 阅读
  6. html css 布局layout

    2023-12-06 05:26:12       33 阅读
  7. #Django事务#

    2023-12-06 05:26:12       38 阅读
  8. MyBatis-Plus使用步骤

    2023-12-06 05:26:12       40 阅读
  9. 记录 | linux下互换键盘的Ctrl和CapsLock键

    2023-12-06 05:26:12       35 阅读