从不平凡Image文件夹,自定义读取img和label,构造dataset

问题描述

torchvision.datasets.ImageFolder 假定:子文件名=子文件夹的图像的标签
但在KDEF文件夹中,子文件夹下有所有的类,不宜用ImageFolder读取path来得到dataset

在这里插入图片描述

My 实现:

# how to build dataset?
from torch.utils import data
from torchvision import transforms, utils
import os
from PIL import Image

to_exp={
    'SU':0, 'AF':1, 'DI':2, 'HA':3, 'SA':4, 'AN':5, 'NE':6
}

mytransform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

class PoseDataset(data.Dataset):
    def __init__(self, path):
        super(PoseDataset,self).__init__()
        root_dir = os.path.join(os.getcwd(),path)
        dir_list = os.listdir(root_dir)
        self.data = []
        for dir_name in dir_list:
            tmp_dir = os.path.join(os.path.join(root_dir,dir_name))
            img_list = os.listdir(tmp_dir)
            for img_name in img_list:
                t_img = mytransform(Image.open(os.path.join(tmp_dir,img_name)).convert("RGB"))
                print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')
                # import pdb; pdb.set_trace()
                id_label1 = to_exp[img_name[4:6]]
                id_label2 = int(img_name[2:4])
                self.data.append((t_img,id_label1,id_label2))
    
    def __getitem__(self,index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

mydataset = PoseDataset('data\\KDEF\\AF')

在这里插入图片描述

  • 其中print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')在一行连续更新,
  • 用PIL.Image.open()打开image file,并convert为RGB格式

之后就可用torch.util.data.DataLoader成批读取img tensor和label用于训练了

train_loader = data.DataLoader(mydataset,batch_size=16,drop_last=True,shuffle=True)
for image,label1,label2 in train_loader:
    print(image.shape)
    print(label1)
    print(label2)

在这里插入图片描述

  • torch.util.data.DataLoader自动转label list为label tensor~🙂😊

划分dataset with data.random_split

trainset,validset=torch.utils.data.random_split(dataset1,[train_size,valid_size])

相关推荐

  1. CSS中的imgbackground-image

    2024-05-04 22:48:03       35 阅读
  2. 定义shell工具函数之pull_image()

    2024-05-04 22:48:03       30 阅读
  3. 机器学习复习(9)——定义dataset

    2024-05-04 22:48:03       22 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-04 22:48:03       17 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-04 22:48:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-04 22:48:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-04 22:48:03       18 阅读

热门阅读

  1. 高等代数复习:特征值

    2024-05-04 22:48:03       12 阅读
  2. 深入理解Linux 内核 内存管理(上)

    2024-05-04 22:48:03       10 阅读
  3. Linux中快速清空文件而不是删除

    2024-05-04 22:48:03       11 阅读
  4. 深入理解 ICMP 协议

    2024-05-04 22:48:03       12 阅读
  5. Rust 动态数组Vector

    2024-05-04 22:48:03       11 阅读
  6. Ruby递归目录文件的又一种方法

    2024-05-04 22:48:03       10 阅读
  7. 【leetcode】滑动窗口题目总结

    2024-05-04 22:48:03       13 阅读
  8. 初始MySQL

    2024-05-04 22:48:03       9 阅读
  9. Django框架之模板层

    2024-05-04 22:48:03       10 阅读