经典目标检测YOLOV1模型的训练及验证

1、前期准备

准备好目录结构、数据集和关于YOLOv1的基础认知

1.1  创建目录结构

        自己创建项目目录结构,结构目录如下:

network                    CNN Backbone 存放位置
weights                    权重存放的位置
test_images             测试用的图片
utils                          辅助功能的代码存放位置 

models                    保存模型位置

data                         训练的数据集

1.2  数据集介绍与下载

1.2.1 数据集介绍

       首先了解数据集,对数据集了解后方便对数据进行相应处理。数据集详细介绍直通车:https://blog.csdn.net/qq_41946216/article/details/137683750?spm=1001.2014.3001.5501

1.2.1 数据集下载

       本次采用数据集: VOC2012数据集。

       数据集下载方式一http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

       数据集下载方式二:

     下载并构建VOC2012数据集,从:https://gitee.com/ppov-nuc/pascal-vocdataset_-for_-yolo.git, 下载get_data文件generate_csv.py文件到本地,放到创建的目录结构中,修改get_data中下载的内容和相应路径,然后运行批处理文件get_data,在get_dat中会自动执行generate_csv.py,如下图所示。


2. 数据集处理

       在utils目录下创建工具类 generate_txt_file.py,主要用于数据集的划分和解析 Annotations/xxxxx.xml 文件中的类别bbox信息,并将信息存入voctrain.txt和voctest.txt文件,如下图所示:

具体代码:

# author: baiCai
# 1. 导包
from xml.etree import ElementTree as ET
import os
import random

# 2. 定义一些基本的参数
# 定义所有的类名
VOC_CLASSES = (
    'aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor')

'''
读取所有 xml 文件,存入列表
'''
# 要读取的xml文件路径,记得自己修改路径
Annotations = '../data/VOC2012/Annotations/'
# 列出所有的xml文件
xml_files = os.listdir(Annotations)
# 打乱数据集
random.shuffle(xml_files)
'''
定义训练集和测试比例
划分Annotations中的训练集和测试集文件列表
'''
# 训练集数量
train_num = int(len(xml_files) * 0.7)
# 训练列表
train_file_list = xml_files[:train_num]
# 测测试列表
test_file_list = xml_files[train_num:]

'''
定义 xml 解析后的信息存储路径和写对象
'''
# 训练集和测试集文件名字
train_set_path = './voctrain.txt'
test_set_path = './voctest.txt'

# 3. 定义解析xml文件的函数
'''
主要解析 xml 获取 类别名字和bbox,如 
{'name': 'person','bbox': [174, 101, 349, 351]}
'''
def parse_rec(filename):
    # 参数:输入xml文件名
    # 创建xml对象
    tree = ET.parse(filename)
    objects = []
    # 迭代读取xml文件中的object节点,即物体信息
    for obj in tree.findall('object'):
        obj_struct = {}
        # difficult属性,即这里不需要那些难判断的对象
        difficult = int(obj.find('difficult').text)
        if difficult == 1:  # 若为1则跳过本次循环
            continue
        # 开始收集信息
        obj_struct['name'] = obj.find('name').text
        bbox = obj.find('bndbox')
        obj_struct['bbox'] =\
            [int(float(bbox.find('xmin').text)),
            int(float(bbox.find('ymin').text)),
            int(float(bbox.find('xmax').text)),
            int(float(bbox.find('ymax').text))]
        objects.append(obj_struct)

    return objects

# 4. 把信息保存入文件中
def write_txt(file_list,set_path):
    # # 生成训练集txt
    count = 0
    with  open(set_path, 'w') as wt:
        for xml_file in file_list:
            count += 1
            # 获取图片名字
            image_name = xml_file.split('.')[0] + '.jpg'  # 图片文件名
            # 对xml_file进行解析
            results = parse_rec(Annotations + xml_file)
            # 如果返回的对象为空,表示张图片难以检测,因此直接跳过
            if len(results) == 0:
                print(xml_file)
                continue
            # 否则,则写入文件中
            # 先写入图片名字
            wt.write(image_name)
            # 接着指定下面写入的格式
            for result in results:
                class_name = result['name']
                bbox = result['bbox']
                class_name = VOC_CLASSES.index(class_name)  # 名字在类别中是下标位置
                wt.write(' ' + str(bbox[0]) +
                                ' ' + str(bbox[1]) +
                                ' ' + str(bbox[2]) +
                                ' ' + str(bbox[3]) +
                                ' ' + str(class_name))
            wt.write('\n')
        wt.close()

# 5. 运行
if __name__ == '__main__':
    write_txt(train_file_list,train_set_path)
    write_txt(test_file_list,test_set_path)

3. 构建数据加载器 

      在utils目录下创建工具类 yolo_dataset.py,主要用于数据集的构建,包含初始化、图片增强及归一化等功能。

3.1定义初始化方法

       读取xxxx.xml解析后的文件
       对每行数据(每个图片信息)的所有中心点信息以【x,y,w,h】和标签分别存入box列表和label列表。
       当前图片的边界框和标签信息即box列表和label列表,转换为LongTensor格式添加到对应的boxex列表和labels列表。

3.2 定义增强图片方法

增加方法名称 定义的函数
随机翻转图片和边界框 random_flip(img, boxes)
随机缩放图片和边界框 randomScale(img, boxes)
随机模糊图片 randomBlur(img)
随机调整图片亮度 RandomBrightness(img)
随机调整图片色调 RandomHue(img)
随机调整图片饱和度 RandomSaturation(img)
随机移动图片和边界框 randomShift(img, boxes, labels)        
随机裁剪图片和边界框 randomCrop(img, boxes, labels)
用于从图像中减去均值 subMean(self, bgr, mean)
将BGR图像转换为RGB图像 BGR2RGB(self, img)
将BGR图像转换为HSV图像 BGR2HSV(self, img)
将HSV图像转换为BGR图像 HSV2BGR(self, img)

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-21 15:38:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-21 15:38:05       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-21 15:38:05       82 阅读
  4. Python语言-面向对象

    2024-04-21 15:38:05       91 阅读

热门阅读

  1. Python机器学习项目开发实战:如何进行语音识别

    2024-04-21 15:38:05       36 阅读
  2. c++ 去掉小数位后面的零

    2024-04-21 15:38:05       33 阅读
  3. 交换排序:冒泡排序和快速排序

    2024-04-21 15:38:05       34 阅读
  4. k8s调度场景

    2024-04-21 15:38:05       26 阅读
  5. 状态模式(状态和行为分离)

    2024-04-21 15:38:05       39 阅读