自定义数据集 - Dataset

1. PASCAL VOC格式 划分训练集和验证集

import os
import random

def main():
    random.seed(0)  # 设置随机种子,保证随机结果可复现

    files_path = "./VOCdevkit/VOC2012/Annotations"  # 指定annotations目录
    assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)

    val_rate = 0.5  # 定义划分比例

    files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
    files_num = len(files_name)
    val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        train_f = open("train.txt", "x")
        eval_f = open("val.txt", "x")
        train_f.write("\n".join(train_files))
        eval_f.write("\n".join(val_files))
    except FileExistsError as e:
        print(e)
        exit(1)

if __name__ == '__main__':
    main()

2. 自定义dataset

自定义自己的数据集,需要继承torch.utils.data.Dataset,并且实现__len____getitem__方法。

  • __len__:获取数据集的大小
  • __getitem__:返回数据信息

如果使用多GPU训练,还需要实现get_height_and_width方法,获取图像的高度和宽度,如果不实现这个方法,它就会载入所有的图片,去计算图片的高和宽,这样的话就比较耗时和占内存。

在这里插入图片描述

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree

# 继承torch.utils.data.Dataset
class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""
    """
    voc_root:数据集所在根目录
    year:数据集年份
    transforms:数据预处理方法
    text_name:读取txt文件名称 train.txt/val.txt
    """
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{
     year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{
     year}")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # 读取 train.txt 或 val.txt文件
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            # strip() 方法用于移除字符串头尾指定的字符
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []
        # 确认文件是否存在
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{
     xml_path}', skip this annotation file.")
                continue

            # 排除图像中没有目标的数据
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            if "object" not in data:
                print(f"INFO: no objects in {
     xml_path}, skip this annotation file.")
                continue

            self.xml_list.append(xml_path)

        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # 读取数据集类别信息
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

    # 获取数据集大小
    def __len__(self):
        return len(self.xml_list)

    # 根据传入的索引值获取数据信息
    def __getitem__(self, idx):
        # 读取xml文件
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # 将数据转为tensor类型
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {
   }
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    # 获取图像高度和宽度
    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    """
    将xml文件解析成字典形式
    """
    def parse_xml_to_dict(self, xml):
        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {
   xml.tag: xml.text}

        result = {
   }
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {
   xml.tag: result}

    def coco_index(self, idx):
        """
        该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
        由于不用去读取图片,可大幅缩减统计时间

        Args:
            idx: 输入需要获取图像的索引
        """
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        # img_path = os.path.join(self.img_root, data["filename"])
        # image = Image.open(img_path)
        # if image.format != "JPEG":
        #     raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {
   }
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        return (data_height, data_width), target

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))

注意:在进行数据预处理的时候,与进行图像分类数据预处理操作有些不同,如进行图像的翻转时,bbox的坐标信息也应该随之进行改变。

在这里插入图片描述

所以,我们自己定义了一个transform.py

import random
from torchvision.transforms import functional as F

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

相关推荐

  1. 机器学习复习(9)——定义dataset

    2024-01-23 10:26:03       40 阅读
  2. DataLoader定义数据制作

    2024-01-23 10:26:03       45 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-23 10:26:03       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-23 10:26:03       100 阅读
  3. 在Django里面运行非项目文件

    2024-01-23 10:26:03       82 阅读
  4. Python语言-面向对象

    2024-01-23 10:26:03       91 阅读

热门阅读

  1. Spring Boot 项目请求参数丢失问题排查与解决

    2024-01-23 10:26:03       52 阅读
  2. MySQL运维实战(4.6) SQL_MODE之NO_BACKSLASH_ESCAPES

    2024-01-23 10:26:03       64 阅读
  3. 【MySQL】索引

    2024-01-23 10:26:03       54 阅读
  4. springboot项目打包jar和war有什么区别

    2024-01-23 10:26:03       53 阅读
  5. 设计模式-命令模式

    2024-01-23 10:26:03       52 阅读
  6. 图论基本知识--->最短路练习--->最小生成树

    2024-01-23 10:26:03       51 阅读
  7. python面试题大全(二)

    2024-01-23 10:26:03       36 阅读
  8. Charles将证书安装到系统的方法(adb)

    2024-01-23 10:26:03       55 阅读
  9. C# 创建多线程的函数

    2024-01-23 10:26:03       58 阅读
  10. webpack从0到1构建Vue3

    2024-01-23 10:26:03       54 阅读