快速入门Torch读取自定义图像数据集

学习新技术当然首先要看官网了

所有数据集都是torch.utils.data.Dataset的子类,即实现了__getitem__和__len__方法。因此,它们都可以传递给torch.utils.data. dataloader,它可以使用torch并行加载多个样本。多处理工人。例如:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

就这???官方提供了许多内置好的数据集,但是我需要自定义啊!!!

还好官方上面文字说需要继承Dataset这个抽象类,实现__getitem__和__len__方法就ok了。

class CatDogDataSet(Dataset):
	def __init__(self):
		pass
		
    def __getitem__(self, index):
    	pass

    def __len__(self):
    	pass

我是谁?我在哪?我在干什么?完全不知道如何实现好吧

我知道ImageNet是从网上拉下来zip包解压后处理图片读取图片的,不妨看看ImageNet是如何实现的class ImageNet(ImageFolder):ImageFolder!!!这个类让我有预感我很快就可以copy了。果然datasets.ImageFolder(root)传入数据根目录且符合下面的格式就可以读取自定义数据集。

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way by default: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/[...]/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/[...]/asd932_.png

完结撒花?我的数据集格式和ImageFolder需要的格式不一样

在这里插入图片描述

最简单的方法当然是写个脚本整理为官方需求的格式,但是我不忘初心,说自定义就是自定义,copy99%也要自定义,而且移动数据的成本高,改改代码读取逻辑就能完成当然要改代码了

源码中find_classes方法,根据目录名定义classes变量改为classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))就可以了

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {
     directory}.")

    class_to_idx = {
   cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

再看数据集部分

    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        """
        源码是判断目录是否与当前target一致,一直则读取这一目录
        target_dir = os.path.join(directory, target_class)
		if not os.path.isdir(target_dir):
			continue
        """
        for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
            for fname in sorted(fnames):
            	# TODO: 在此处添加判断,当前文件名是否包含target
                if target_class in fname:
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

献上完整自定义数据集代码

import os
from typing import Dict, Optional, Tuple, Callable, List, Union, cast

from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    classes = list(frozenset([i.split('.')[0] for i in os.listdir(directory)]))
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {
     directory}.")

    class_to_idx = {
   cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,
        extensions: Optional[Union[str, Tuple[str, ...]]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        # target_dir = os.path.join(directory, target_class)
        # if not os.path.isdir(target_dir):
        #     continue
        for root, _, fnames in sorted(os.walk(directory, followlinks=True)):
            for fname in sorted(fnames):
                if target_class in fname:
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {
     ', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {
     extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances


class CatDogLoader(ImageFolder):
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(root,
                         transform,
                         target_transform,
                         is_valid_file=is_valid_file)
        classes, class_to_idx = self.find_classes(self.root)
        self.samples = self.make_dataset(self.root, class_to_idx, IMG_EXTENSIONS if is_valid_file is None else None,
                                         is_valid_file)

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
        return find_classes(directory)

    def make_dataset(
            self,
            directory: str,
            class_to_idx: Dict[str, int],
            extensions: Optional[Tuple[str, ...]] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

相关推荐

  1. pytorch图像数据定义

    2024-01-16 10:30:03       17 阅读
  2. DataLoader定义数据制作

    2024-01-16 10:30:03       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-01-16 10:30:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-16 10:30:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-16 10:30:03       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-16 10:30:03       20 阅读

热门阅读

  1. linux centos7 django uwsgi 部署

    2024-01-16 10:30:03       32 阅读
  2. 15.单例模式

    2024-01-16 10:30:03       28 阅读
  3. 重磅!2024版一建新教材开始预售!(新大纲版)

    2024-01-16 10:30:03       30 阅读
  4. 2024年Top 10的人工智能岗位及如何准备

    2024-01-16 10:30:03       33 阅读
  5. Mysql

    2024-01-16 10:30:03       33 阅读
  6. leetcode热题100.两数之和

    2024-01-16 10:30:03       30 阅读
  7. show processlist 显示的MySQL语句不全的解决方法

    2024-01-16 10:30:03       34 阅读
  8. K8s面试题——基础篇1

    2024-01-16 10:30:03       26 阅读
  9. LeetCode——82. 删除排序链表中的重复元素II

    2024-01-16 10:30:03       32 阅读
  10. 【uniapp-小程序-给video添加水印】

    2024-01-16 10:30:03       27 阅读
  11. linux不同场景下修改文件名的五种方法

    2024-01-16 10:30:03       34 阅读