深度学习——yolov5的txt和xml互转

在学习工作的过程中,有时会需要自己新建数据集,向训练数据中添加新的数据,存在已有模型对新数据进行检测,得到yolov5对应的txt文件,之后转成xml,使用标注工具对数据进行校正。后续将xml转成yolov5训练使用的txt格式。
以下是使用多进程和多线程两种方式实现的txt和xml互转的代码。

xml->txt

# -*- coding: UTF-8 -*-
"""
@Project :yolov5_relu_fire_smoke_v1.3 
@IDE     :PyCharm 
@Author  :mufeng
@Date    :2024/2/22 15:58

将xml转为yolo训练使用的txt格式
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo训练使用的是[xn,yn,wn,hn]坐标格式
data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import xml.etree.ElementTree as ET
import cv2
import numpy as np


def decodeVocAnnotation(voc_xml_path, class_index_dict):
    """
    voc数据集格式的文件解析,将一个文件解析成一个list,
    使用空格间隔不同对象
    注意:返回的类别不是整型,而是字符串的类别名称
    注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图
    :param voc_xml_path: xml路径
    :param class_index_dict: 类别字典
    :return:
    """
    assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml"

    xml_file = open(voc_xml_path, 'r', encoding='utf-8')
    # 打开xml文件,并返回根节点
    root = ET.ElementTree().parse(xml_file)

    # 定义一个列表,专门保存目标
    information = []

    # 查找root节点下所有目标信息
    for obj in root.iter('object'):
        # 目标的名称
        name = obj.find('name').text
        # 目标的bbox坐标,一般voc是保存的corner格式的bbox
        box = obj.find('bndbox')
        xmin = box.find('xmin').text
        ymin = box.find('ymin').text
        xmax = box.find('xmax').text
        ymax = box.find('ymax').text

        # 添加一个目标的信息
        # NOTE:返回值的list
        information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax)))

    xml_file.close()
    return information


def xyxy2xywh(matrix):
    """

    :param matrix: np矩阵, x1, y1, x2, y2
    :return:
    """
    # 确保输入矩阵的形状为 (n, 4),其中 n 是矩阵中矩形的数量
    if matrix.ndim < 2 or matrix.shape[1] != 4:
        raise ValueError("Input matrix must have shape (n, 4)")

    # 计算中心点坐标
    center_x = (matrix[:, 0] + matrix[:, 2]) / 2
    center_y = (matrix[:, 1] + matrix[:, 3]) / 2

    # 计算宽度和高度
    width = np.abs(matrix[:, 2] - matrix[:, 0])
    height = np.abs(matrix[:, 3] - matrix[:, 1])

    # 返回结果,组合为 (center_x, center_y, width, height) 形式
    return np.column_stack((center_x, center_y, width, height))


def run_thread(root, file, class_index_dict):
    """

    :param root:
    :param file: 图片路径
    :param class_index_dict:
    :return:
    """
    image_name, suffix = os.path.splitext(file)
    image_path = os.path.join(root, file)
    xml_path = image_path.replace("images", "annotations").replace(suffix, ".xml")
    txt_path = image_path.replace("images", "labels").replace(suffix, ".txt")

    if os.path.exists(xml_path):
        # cls_index, x1, y1, x2, y2
        bbox = decodeVocAnnotation(xml_path, class_index_dict)
        bbox = np.array(bbox, dtype=np.float32)
    else:
        bbox = np.zeros(shape=(0, 5), dtype=np.float32)
        return

    if len(bbox) == 0:
        bbox = np.zeros(shape=(0, 5), dtype=np.float32)
        return

    image = cv2.imread(image_path)
    if image is None:
        print(f"\n\033[31m{image_path} is None\033[0m")
        return
    else:
        print(f"\r\033[32m{image_path}\033[0m", end='')

    imh, imw = image.shape[:2]

    # # 画框,视为了检查框是否正确
    # for cls_id, xmin, ymin, xmax, ymax in np.array(bbox.copy(), dtype=np.int32):
    #     cv2.putText(image, text=f"{cls_id}", org=(xmin, ymin),
    #                 fontScale=2, fontFace=1, color=(0, 255, 0), thickness=1)
    #     cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=(0, 255, 0), thickness=2)
    # cv2.imwrite(os.path.join(data_root, "temp", image_file), image)

    # 坐标转换 xyxy -> xywh
    bbox[:, 1:] = xyxy2xywh(bbox[:, 1:])
    # 归一化
    bbox[..., [1, 3]] /= imw
    bbox[..., [2, 4]] /= imh

    os.makedirs(os.path.dirname(txt_path), exist_ok=True)
    # 保存结果
    with open(txt_path, "w", encoding="utf-8") as wFile:
        for cls_id, x, y, w, h in bbox:
            wFile.write(f"{int(cls_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")


def run_process(root_file, class_index_dict):
    """

    :param root_file: [(root, file), ...] 因为进程的创建花费时间长,所以一个进程处理多个图片
    :param class_index_dict:
    :return:
    """
    for root, file in root_file:
        image_name, suffix = os.path.splitext(file)

        image_path = os.path.join(root, file)
        xml_path = image_path.replace("images", "annotations").replace(suffix, ".xml")
        txt_path = image_path.replace("images", "labels").replace(suffix, ".txt")

        if os.path.exists(xml_path):
            # cls_index, x1, y1, x2, y2
            bbox = decodeVocAnnotation(xml_path, class_index_dict)
            bbox = np.array(bbox, dtype=np.float32)
        else:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            continue

        if len(bbox) == 0:
            bbox = np.zeros(shape=(0, 5), dtype=np.float32)
            continue

        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31m{image_path} is None\033[0m")
            continue
        else:
            print(f"\r\033[32m{image_path}\033[0m", end='')

        imh, imw = image.shape[:2]

        # # 画框,视为了检查框是否正确
        # for cls_id, xmin, ymin, xmax, ymax in np.array(bbox.copy(), dtype=np.int32):
        #     cv2.putText(image, text=f"{cls_id}", org=(xmin, ymin),
        #                 fontScale=2, fontFace=1, color=(0, 255, 0), thickness=1)
        #     cv2.rectangle(image, pt1=(xmin, ymin), pt2=(xmax, ymax), color=(0, 255, 0), thickness=2)
        # cv2.imwrite(os.path.join(data_root, "temp", image_file), image)

        # 坐标转换 xyxy -> xywh
        bbox[:, 1:] = xyxy2xywh(bbox[:, 1:])
        # 归一化
        bbox[..., [1, 3]] /= imw
        bbox[..., [2, 4]] /= imh

        os.makedirs(os.path.dirname(txt_path), exist_ok=True)
        # 保存结果
        with open(txt_path, "w", encoding="utf-8") as wFile:
            for cls_id, x, y, w, h in bbox:
                wFile.write(f"{int(cls_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")


if __name__ == '__main__':
    # 需要已知目标的名称和类别索引
    class_index_dict = {
        "fire": 0,
        "smoke": 1,
    }

    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMFireSmoke"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\TSMCandle"
    # data_root = r"Z:\Datasets\FireSmoke_v4"
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    # data_root = r"Z:\Datasets\Detection\FireSmoke\candle-test"
    data_root = os.path.abspath(data_root)

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数
    # 使用的类型
    # run_type = "thread"  # 多线程
    run_type = "process"  # 多进程

    print(f"running use run_type={run_type}, max_workers:{max_workers}")

    if run_type == "thread":
        # 使用线程池控制程序执行
        with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            for root, _, files in os.walk(os.path.join(data_root, "images")):
                # 需要排除的目录
                if any(exclude_dir in root for exclude_dir in exclude_dirs):
                    continue

                for file in files:
                    # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
                    executor.submit(run_thread, *(root, file, class_index_dict))

        print("\nFinish ...")

    elif run_type == "process":
        # 一个进程处理多少图片
        max_file_num = 1000
        # 保存root和file的list
        root_file_list: List[Tuple] = list()

        # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
        pool = multiprocessing.Pool(processes=max_workers)

        # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
        for root, _, files in os.walk(os.path.join(data_root, "images")):
            # 需要排除的目录
            if any(exclude_dir in root for exclude_dir in exclude_dirs):
                continue

            for file in files:
                root_file_list.append((root, file))

                if len(root_file_list) > max_file_num:
                    # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                    pool.apply_async(run_process, (deepcopy(root_file_list), class_index_dict))
                    # 清除列表中的存储
                    root_file_list.clear()

        else:
            # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
            # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
            # 启动一个进程,开始处理当前list中的信息
            pool.apply_async(run_process, (deepcopy(root_file_list), class_index_dict))
            # 清除列表中的存储
            root_file_list.clear()

        # 关闭进程池
        pool.close()
        # 等待所有子进程执行结束
        pool.join()

        print("\nFinish ...")

    else:
        print("run_type should be thread or process.")

txt->xml

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :TestCode 
@IDE     :PyCharm 
@Author  :mufeng
@Date    :2023/7/21 17:15

yolov5检测出来的目标结果,转成xml
xml保存时使用的是[x1,y1,x2,y2]坐标格式
yolo检测结果保存使用的是[xn,yn,wn,hn]坐标格式
如果保存txt保存了置信度则txt每一行是:[class_index, xn, yn, wn, hn, conf]
使用线程池实现

data_root
    |----annotations
    |----images
    |----labels
"""
import os
import multiprocessing
from concurrent import futures
from typing import List, Tuple
from copy import deepcopy

import cv2
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom


def create_voc_xml(image_folder, image_filename, width: int, height: int, labels,
                   save_path, class_name_dict, conf_thresh_dict=None):
    """

    :param image_folder: 图片的相对路径
    :param image_filename: 000001.jpg
    :param width: 图片宽
    :param height: 图片高
    :param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...]
    :param save_path: 保存xml的根目录
    :param class_name_dict: cls_index:cls_name,根据index获取正确的类别name
    :param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度
    :return:
    """
    # 创建 XML 文件的根元素
    root = ET.Element("annotation")

    # 添加图片信息
    folder = ET.SubElement(root, "folder")
    folder.text = str(image_folder)
    # 图片名字
    filename = ET.SubElement(root, "filename")
    filename.text = os.path.join(image_filename)
    # 图片大小
    size = ET.SubElement(root, "size")
    width_element = ET.SubElement(size, "width")
    width_element.text = str(width)
    height_element = ET.SubElement(size, "height")
    height_element.text = str(height)
    depth = ET.SubElement(size, "depth")  # 通道数
    depth.text = "3"

    # 添加目标框信息
    for label in labels:
        # 如果该参数设置为None,表示保存的txt没有None
        if conf_thresh_dict is None:
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32)
        else:
            class_index, x1, y1, x2, y2, conf = label
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32)

            # 根据置信度过滤是否保存项
            if conf < conf_thresh_dict[class_index]:
                continue

        obj = ET.SubElement(root, "object")

        name = ET.SubElement(obj, "name")
        name.text = class_name_dict[int(class_index)]

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(x1)

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(y1)

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(x2)

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(y2)

    # 创建 XML 文件并保存
    xml_str = ET.tostring(root, encoding="utf-8")
    xml_str = minidom.parseString(xml_str)
    # 设置缩进为4个空格,xml可读性提高
    pretty_xml = xml_str.toprettyxml(indent=" " * 4)

    save_path = os.path.join(save_path, f"{os.path.splitext(image_filename)[0]}.xml")
    os.makedirs((os.path.dirname(save_path)), exist_ok=True)
    with open(save_path, "w") as xmlFile:
        xmlFile.write(pretty_xml)


def run_thread(root, image_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root: ..\images\train
    @param image_file: 0000000.jpg
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    # 获取图片的名称和后缀
    image_name, suffix = os.path.splitext(image_file)
    # 图片路径
    image_path = os.path.join(root, image_file)
    # 设置捕捉异常,防止因为异常导致的代码停止运行
    try:
        # 读图
        image = cv2.imread(image_path)
        if image is None:
            print(f"\n\033[31mError {image_path}\033[0m")
            return

        # 图片的宽高
        imh, imw = image.shape[:2]

        # txt路径
        txt_file = image_path.replace(image_root, txt_root).replace(suffix, ".txt")
        if not os.path.exists(txt_file):
            return

        # class_index xn yn wn hn conf
        labels = np.loadtxt(txt_file, dtype=np.float32)
        # 空txt跳过
        if len(labels) == 0:
            return

        # 确包所有矩阵维度都是2维,方便后续处理
        if labels.ndim == 1:
            labels = np.array([labels])

        # xywhn -> xywh
        labels[:, [1, 3]] = labels[:, [1, 3]] * imw
        labels[:, [2, 4]] = labels[:, [2, 4]] * imh
        center = labels[:, 1:5].copy()

        # xywh - > xyxy
        corner = np.zeros_like(center)
        corner[:, 0] = center[:, 0] - center[:, 2] / 2  # xmin = x - w / 2
        corner[:, 1] = center[:, 1] - center[:, 3] / 2  # ymin = y - h / 2
        corner[:, 2] = center[:, 0] + center[:, 2] / 2  # xmax = x + w / 2
        corner[:, 3] = center[:, 1] + center[:, 3] / 2  # ymax = y + h / 2
        # np.float32
        labels[:, 1:5] = corner[:, :]

        # 创建xml
        create_voc_xml(root.replace(image_root + os.sep, ""),  # Z:\FireData\images\train -> train
                       image_filename=image_file,
                       width=imw,
                       height=imh,
                       labels=labels,
                       save_path=root.replace(image_root, save_root),
                       class_name_dict=class_name_dict,
                       conf_thresh_dict=conf_thresh_dict)
        # 处理完成后打印信息,要不不知道执行到哪里了
        print(f"\r{image_path}", end='')

    except Exception as e:
        print(f"{image_path} \n{e}\n\n")


def run_process(root_file, save_root, image_root, txt_root, class_name_dict, conf_thresh_dict=None):
    """

    @param root_file: [(..\images\train, 0000000.jpg), ...]
    @param image_file:
    @param save_root: ..\annotations\train
    @param image_root: ..\images\train
    @param txt_root: ...\txt\train
    @param class_name_dict:
    @param conf_thresh_dict: 使用yolov5模型跑detect.py没有保存置信度conf,该参数可以不输入
    @return:
    """
    for root, image_file in root_file:
        # 获取图片的名称和后缀
        image_name, suffix = os.path.splitext(image_file)
        # 图片路径
        image_path = os.path.join(root, image_file)
        # 设置捕捉异常,防止因为异常导致的代码停止运行
        try:
            # 读图
            image = cv2.imread(image_path)
            if image is None:
                print(f"\n\033[31mError {image_path}\033[0m")
                continue

            # 图片的宽高
            imh, imw = image.shape[:2]

            # txt路径
            txt_file = image_path.replace(image_root, txt_root).replace(suffix, ".txt")
            if not os.path.exists(txt_file):
                continue

            # class_index xn yn wn hn conf
            labels = np.loadtxt(txt_file, dtype=np.float32)
            # 空txt跳过
            if len(labels) == 0:
                continue

            # 确包所有矩阵维度都是2维,方便后续处理
            if labels.ndim == 1:
                labels = np.array([labels])

            # xywhn -> xywh
            labels[:, [1, 3]] = labels[:, [1, 3]] * imw
            labels[:, [2, 4]] = labels[:, [2, 4]] * imh
            center = labels[:, 1:5].copy()

            # xywh - > xyxy
            corner = np.zeros_like(center)
            corner[:, 0] = center[:, 0] - center[:, 2] / 2  # xmin = x - w / 2
            corner[:, 1] = center[:, 1] - center[:, 3] / 2  # ymin = y - h / 2
            corner[:, 2] = center[:, 0] + center[:, 2] / 2  # xmax = x + w / 2
            corner[:, 3] = center[:, 1] + center[:, 3] / 2  # ymax = y + h / 2
            # np.float32
            labels[:, 1:5] = corner[:, :]

            # 创建xml
            create_voc_xml(root.replace(image_root + os.sep, ""),  # Z:\FireData\images\train -> train
                           image_filename=image_file,
                           width=imw,
                           height=imh,
                           labels=labels,
                           save_path=root.replace(image_root, save_root),
                           class_name_dict=class_name_dict,
                           conf_thresh_dict=conf_thresh_dict)
            # 处理完成后打印信息,要不不知道执行到哪里了
            print(f"\r{image_path}", end='')

        except Exception as e:
            print(f"{image_path} \n{e}\n\n")


if __name__ == '__main__':
    # 类别字典
    class_name_dict = {
        0: "fire",
        1: "smoke"
    }
    # 置信度阈值,不同类别设置不同的阈值,
    CONF_THRESH_DICT = None  # 如果该参数设置为None表示txt没有保存conf这一项
    # CONF_THRESH_DICT = {
    #     0: 0.2,
    #     1: 0.2
    # }
    if CONF_THRESH_DICT != None:
        assert class_name_dict.keys() == CONF_THRESH_DICT.keys(), "class_name_dict.keys() != CONF_THRESH_DICT.keys()."

    # 数据集根目录
    data_root = r"E:\CodeFiles\pycharm\YOLO\yolov5\my_test\data"
    data_root = os.path.abspath(data_root)

    # 指定的子目录
    sub_dir = r""
    # sub_dir = r"\train\fire_smoke"

    # 要保证这三个的目录结构是一致的
    # 保存xml的根路径 save_root\annotations\...
    if sub_dir == '':
        save_root = os.path.join(data_root, "annotations")
        # txt路径
        txt_root = os.path.join(data_root, "labels")  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images")

    else:
        save_root = os.path.join(data_root, "annotations", sub_dir)
        # txt路径
        txt_root = os.path.join(data_root, "labels", sub_dir)  # txt和images不在一个目录下,目录结构应该和images一样
        # 图片路径
        image_root = os.path.join(data_root, "images", sub_dir)

    # 需要跳过的目录
    exclude_dirs = [
        r"background",
    ]

    # NOTE:多线程/多进程 程序不好调试,将线程池/进程池 中的数量改为1,可以调试程序
    max_workers = 6  # 线程/进程 数
    # 使用的类型
    run_type = "thread"  # 多线程
    # run_type = "process"  # 多进程

    print(f"running use run_type={run_type}, max_workers:{max_workers}")

    if run_type == "thread":
        # 使用线程池控制程序执行
        with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            for root, _, files in os.walk(image_root):
                # 需要排除的目录
                if any(exclude_dir in root for exclude_dir in exclude_dirs):
                    continue

                for file in files:
                    # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
                    executor.submit(run_thread,
                                    *(root, file, save_root, image_root, txt_root, class_name_dict, CONF_THRESH_DICT))

        print("\nFinish ...")

    elif run_type == "process":
        # 一个进程处理多少图片
        max_file_num = 1000
        # 保存root和file的list
        root_file_list: List[Tuple] = list()

        # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
        pool = multiprocessing.Pool(processes=max_workers)

        # for image_file in os.listdir(os.path.join(data_root, "images", sub_dir)):
        for root, _, files in os.walk(image_root):
            # 需要排除的目录
            if any(exclude_dir in root for exclude_dir in exclude_dirs):
                continue

            for file in files:
                root_file_list.append((root, file))

                if len(root_file_list) > max_file_num:
                    # 启动一个进程,开始处理当前list中的信息,使用deepcopy是为了防止下面清除list后导致进程崩溃
                    pool.apply_async(run_process,
                                     (deepcopy(root_file_list), save_root, image_root, txt_root,
                                      class_name_dict, CONF_THRESH_DICT))
                    # 清除列表中的存储
                    root_file_list.clear()

        else:
            # for循环正常结束的话,如果剩下的文件数量不足max_file_num,上面不会启动新的进程,
            # 所以为了防止丢掉信息,在for循环正常结束之后,丢掉信息,就将root_file_list中的信息处理掉
            # 启动一个进程,开始处理当前list中的信息
            pool.apply_async(run_process, (deepcopy(root_file_list), save_root, image_root, txt_root,
                                           class_name_dict, CONF_THRESH_DICT))
            # 清除列表中的存储
            root_file_list.clear()

        # 关闭进程池
        pool.close()
        # 等待所有子进程执行结束
        pool.join()

        print("\nFinish ...")

    else:
        print("run_type should be thread or process.")


相关推荐

  1. 深度学习——yolov5txtxml

    2024-03-18 06:16:05       36 阅读
  2. 深度学习毕设思路--yolov5使用方法

    2024-03-18 06:16:05       51 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-18 06:16:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-18 06:16:05       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-18 06:16:05       82 阅读
  4. Python语言-面向对象

    2024-03-18 06:16:05       91 阅读

热门阅读

  1. Python基础学习(12)环境搭建/ATM1

    2024-03-18 06:16:05       44 阅读
  2. 控制工程学 en 2

    2024-03-18 06:16:05       35 阅读
  3. ssh命令——安全远程连接服务

    2024-03-18 06:16:05       44 阅读
  4. 《C缺陷和陷阱》-笔记(5)

    2024-03-18 06:16:05       42 阅读