Yolov5封装detect.py面向对象

主要目标是适应摄像头rtsp流的检测

如果是普通文件夹或者图片,run中的while True去掉即可。

web_client是根据需求创建的客户端,将检测到的数据打包发送给服务器

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.

Usage:
    $ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""

import argparse
import json
import os
import sys
import time
import moment
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
    increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
    strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_sync

from mytools import read_yaml_all, base64_encode_img
from message_base import MessageBase
from websocket_client import WebClient


class Detect:
    def __init__(self, config: dict, client: WebClient):
        self.config = config
        self.weights = self.config.get("weights")  # weights path
        self.source = self.config.get("source")  # source 
        self.imgsz = self.config.get("imgsz")  # imgsz
        self.conf_thres = self.config.get("conf_thres")
        self.iou_thres = self.config.get("iou_thres")
        self.max_det = self.config.get("max_det")
        self.device = self.config.get("device")  # "cpu" or "0,1,2,3"
        self.view_img = self.config.get("view_img")  # show results
        self.save_txt = self.config.get("save_txt")  # save results to *.txt
        self.save_conf = self.config.get("save_conf")  # save confidences in --save-txt labels
        self.save_crop = self.config.get("save_crop")  # save cropped prediction boxes
        self.nosave = self.config.get("nosave")  # do not save images/videos
        self.classes = self.config.get("classes")  # filter by class: --class 0, or --class 0 2 3
        self.agnostic_nms = self.config.get("agnostic_nms")  # class-agnostic NMS
        self.augment = self.config.get("augment")  # augmented inference
        self.visualize = self.config.get("visualize")  # visualize features
        self.update = self.config.get("update")  # update all models
        self.save_path = self.config.get("save_path")  # save results to project/name
        self.line_thickness = self.config.get("line_thickness")  # bounding box thickness (pixels)
        self.hide_labels = self.config.get("hide_labels")  # hide labels
        self.hide_conf = self.config.get("hide_conf")  # hide confidences
        self.half = self.config.get("half")  # use FP16 half-precision inference
        self.dnn = self.config.get("dnn")  # use OpenCV DNN for ONNX inference
        self.func_device = self.config.get("func_device")  # 对应功能的设备名字
        self.save_img = not self.nosave and not self.source.endswith('.txt')  # save inference images
        self.webcam = self.source.isnumeric() or self.source.endswith('.txt') or self.source.lower().startswith(
            ('rtsp://', 'rtmp://', 'http://', 'https://'))
        set_logging()
        self.device = select_device(self.device)
        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA
        self.model = attempt_load(self.weights, map_location=self.device)
        self.imgsz = check_img_size(self.imgsz, s=int(self.model.stride.max()))
        self.stride = int(self.model.stride.max())
        self.names = self.model.module.names if hasattr(
            self.model, 'module') else self.model.names
        # 获取数据
        if self.webcam:
            self.view_img = check_imshow()
            cudnn.benchmark = True  # set True to speed up constant image size inference
            self.dataset = LoadStreams(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
            self.bs = len(self.dataset)  # batch_size
        else:
            self.dataset = LoadImages(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
            self.bs = 1  # batch_size

        self.client = client  # 客户端
        self.last_time = moment.now()
        self.check_time_step = 5  # 每隔多少时间检测一次
        os.mkdir(self.save_path) if not os.path.exists(self.save_path) else None

    def inference(self, img):
        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        pred = self.model(img, augment=self.augment)[0]
        # NMS
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
                                   self.classes, self.agnostic_nms, max_det=self.max_det)
        return pred

    def process(self, im0s, img, pred, path):
        for i, det in enumerate(pred):  # per image
            if self.webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), self.dataset.count
            else:
                p, s, im0, frame = path, '', im0s.copy(), getattr(self.dataset, 'frame', 0)

            p = Path(p)  # to Path
            txt_path = str(self.save_path + "/" + 'labels' + "/" + p.stem) + (
                '' if self.dataset.mode == 'image' else f'_{frame}')  # img.txt
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            imc = im0.copy() if self.save_crop else im0  # for save_crop
            annotator = Annotator(im0, line_width=self.line_thickness, example=str(self.names))
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    c = int(cls)
                    label = self.names[c]
                    # if label == "person":
                    if label:  # 根据对应标签做处理
                        # annotator.box_label(xyxy, label, color=colors(c, True)) # 画框
                        t = int(time.time())
                        img_path = f"{self.save_path}/{self.func_device}_{label}_{t}.jpg"
                        crop = save_one_box(xyxy, imc, img_path, BGR=True)
                        x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
                        data = {
                            "device": self.func_device,
                            "value": {
                                "label": label,
                                "time": t,
                                "locate": (x1, y1, x2, y2),
                                "crop": base64_encode_img(crop)
                            }
                        }
                        data = json.dumps(data)  # 打包数据
                        try:
                            self.client.send(data)  # 客户端发送数据
                            pass
                        except Exception as err:
                            print("发送失败:", err)
                            self.client.connect()
                            self.client.send(data)
                            print("重连成功!")
                        print(data)
                    # if self.save_txt:  # Write to file
                    #     xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(
                    #         -1).tolist()  # normalized xywh
                    #     line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh)  # label format
                    #     with open(txt_path + '.txt', 'a') as f:
                    #         f.write(('%g ' * len(line)).rstrip() % line + '\n')
                    # 画框
                    # if self.save_img or self.save_crop or self.view_img:  # Add bbox to image
                    #     c = int(cls)  # integer class
                    #     label = None if self.hide_labels else (self.names[c] if self.hide_conf else
                    #                                            f'{self.names[c]} {conf:.2f}')
                    #     annotator.box_label(xyxy, label, color=colors(c, True))

    def run(self):
        self.client.connect()
        while True:
            for path, img, im0s, vid_cap in self.dataset:
                if self.last_time.__lt__(moment.now()):
                    self.last_time = moment.now().add(seconds=self.check_time_step)
                    try:
                        pred = self.inference(img)
                        self.process(im0s, img, pred, path)              
                    except Exception as err:
                        print(err)

            if self.save_txt or self.save_img:
                s = f"\n{len(list(self.save_path.glob('labels/*.txt')))} labels saved to {self.save_path / 'labels'}" if self.save_txt else ''
                print(f"Results saved to {colorstr('bold', self.save_path)}{s}")
            if self.update:
                strip_optimizer(self.weights)  # update model (to fix SourceChangeWarning)


if __name__ == "__main__":
    message_base = MessageBase()
    wc = WebClient("192.168.6.28", 8000)
    configs = read_yaml_all("yolo_configs.yaml")
    config = read_yaml_all("configs.yaml")
    device_name = config.get("DEVICE_LIST")[0]
    device_source = config.get("RTSP_URLS").get(device_name)
    configs["source"] = device_source
    configs["func_device"] = device_name
    print(configs)
    detect = Detect(configs, wc)
    detect.run()

相关推荐

  1. Yolov5封装detect.py面向对象

    2024-04-03 09:42:02       12 阅读
  2. C#面向对象——封装封装案例示例

    2024-04-03 09:42:02       21 阅读
  3. python 面向对象(封装、继承、多态)

    2024-04-03 09:42:02       16 阅读
  4. 面向对象三大特征——封装,继承

    2024-04-03 09:42:02       12 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-03 09:42:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-03 09:42:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-03 09:42:02       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-03 09:42:02       20 阅读

热门阅读

  1. Rancher(v2.6.3)——Rancher部署Minio(单机版)

    2024-04-03 09:42:02       13 阅读
  2. STM32为什么不能跑Linux?

    2024-04-03 09:42:02       13 阅读
  3. 菜鸟笔记-Python函数-ones

    2024-04-03 09:42:02       16 阅读
  4. 14、Lua 模块与包

    2024-04-03 09:42:02       16 阅读
  5. 基于单片机的LED 灯调光系统的研究

    2024-04-03 09:42:02       13 阅读
  6. Git 多人协作开发

    2024-04-03 09:42:02       16 阅读
  7. Springboot自动配置原理

    2024-04-03 09:42:02       13 阅读
  8. JVM原理

    2024-04-03 09:42:02       11 阅读