问题描述
torchvision.datasets.ImageFolder 假定:子文件名=子文件夹的图像的标签
但在KDEF文件夹中,子文件夹下有所有的类,不宜用ImageFolder读取path来得到dataset
My 实现:
# how to build dataset?
from torch.utils import data
from torchvision import transforms, utils
import os
from PIL import Image
to_exp={
'SU':0, 'AF':1, 'DI':2, 'HA':3, 'SA':4, 'AN':5, 'NE':6
}
mytransform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
class PoseDataset(data.Dataset):
def __init__(self, path):
super(PoseDataset,self).__init__()
root_dir = os.path.join(os.getcwd(),path)
dir_list = os.listdir(root_dir)
self.data = []
for dir_name in dir_list:
tmp_dir = os.path.join(os.path.join(root_dir,dir_name))
img_list = os.listdir(tmp_dir)
for img_name in img_list:
t_img = mytransform(Image.open(os.path.join(tmp_dir,img_name)).convert("RGB"))
print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')
# import pdb; pdb.set_trace()
id_label1 = to_exp[img_name[4:6]]
id_label2 = int(img_name[2:4])
self.data.append((t_img,id_label1,id_label2))
def __getitem__(self,index):
return self.data[index]
def __len__(self):
return len(self.data)
mydataset = PoseDataset('data\\KDEF\\AF')
- 其中
print('\r{},{},{}'.format(img_name,type(t_img),t_img.shape),end='')
在一行连续更新, - 用PIL.Image.open()打开image file,并convert为RGB格式
之后就可用torch.util.data.DataLoader成批读取img tensor和label用于训练了
train_loader = data.DataLoader(mydataset,batch_size=16,drop_last=True,shuffle=True)
for image,label1,label2 in train_loader:
print(image.shape)
print(label1)
print(label2)
- torch.util.data.DataLoader自动转label list为label tensor~🙂😊
划分dataset with data.random_split
trainset,validset=torch.utils.data.random_split(dataset1,[train_size,valid_size])