主要目标是适应摄像头rtsp流的检测
如果是普通文件夹或者图片,run中的while True去掉即可。
web_client是根据需求创建的客户端,将检测到的数据打包发送给服务器
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.
Usage:
$ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""
import argparse
import json
import os
import sys
import time
import moment
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_sync
from mytools import read_yaml_all, base64_encode_img
from message_base import MessageBase
from websocket_client import WebClient
class Detect:
def __init__(self, config: dict, client: WebClient):
self.config = config
self.weights = self.config.get("weights") # weights path
self.source = self.config.get("source") # source
self.imgsz = self.config.get("imgsz") # imgsz
self.conf_thres = self.config.get("conf_thres")
self.iou_thres = self.config.get("iou_thres")
self.max_det = self.config.get("max_det")
self.device = self.config.get("device") # "cpu" or "0,1,2,3"
self.view_img = self.config.get("view_img") # show results
self.save_txt = self.config.get("save_txt") # save results to *.txt
self.save_conf = self.config.get("save_conf") # save confidences in --save-txt labels
self.save_crop = self.config.get("save_crop") # save cropped prediction boxes
self.nosave = self.config.get("nosave") # do not save images/videos
self.classes = self.config.get("classes") # filter by class: --class 0, or --class 0 2 3
self.agnostic_nms = self.config.get("agnostic_nms") # class-agnostic NMS
self.augment = self.config.get("augment") # augmented inference
self.visualize = self.config.get("visualize") # visualize features
self.update = self.config.get("update") # update all models
self.save_path = self.config.get("save_path") # save results to project/name
self.line_thickness = self.config.get("line_thickness") # bounding box thickness (pixels)
self.hide_labels = self.config.get("hide_labels") # hide labels
self.hide_conf = self.config.get("hide_conf") # hide confidences
self.half = self.config.get("half") # use FP16 half-precision inference
self.dnn = self.config.get("dnn") # use OpenCV DNN for ONNX inference
self.func_device = self.config.get("func_device") # 对应功能的设备名字
self.save_img = not self.nosave and not self.source.endswith('.txt') # save inference images
self.webcam = self.source.isnumeric() or self.source.endswith('.txt') or self.source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))
set_logging()
self.device = select_device(self.device)
self.half = self.device.type != 'cpu' # half precision only supported on CUDA
self.model = attempt_load(self.weights, map_location=self.device)
self.imgsz = check_img_size(self.imgsz, s=int(self.model.stride.max()))
self.stride = int(self.model.stride.max())
self.names = self.model.module.names if hasattr(
self.model, 'module') else self.model.names
# 获取数据
if self.webcam:
self.view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
self.dataset = LoadStreams(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
self.bs = len(self.dataset) # batch_size
else:
self.dataset = LoadImages(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
self.bs = 1 # batch_size
self.client = client # 客户端
self.last_time = moment.now()
self.check_time_step = 5 # 每隔多少时间检测一次
os.mkdir(self.save_path) if not os.path.exists(self.save_path) else None
def inference(self, img):
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = self.model(img, augment=self.augment)[0]
# NMS
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
self.classes, self.agnostic_nms, max_det=self.max_det)
return pred
def process(self, im0s, img, pred, path):
for i, det in enumerate(pred): # per image
if self.webcam: # batch_size >= 1
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), self.dataset.count
else:
p, s, im0, frame = path, '', im0s.copy(), getattr(self.dataset, 'frame', 0)
p = Path(p) # to Path
txt_path = str(self.save_path + "/" + 'labels' + "/" + p.stem) + (
'' if self.dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if self.save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=self.line_thickness, example=str(self.names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls)
label = self.names[c]
# if label == "person":
if label: # 根据对应标签做处理
# annotator.box_label(xyxy, label, color=colors(c, True)) # 画框
t = int(time.time())
img_path = f"{self.save_path}/{self.func_device}_{label}_{t}.jpg"
crop = save_one_box(xyxy, imc, img_path, BGR=True)
x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
data = {
"device": self.func_device,
"value": {
"label": label,
"time": t,
"locate": (x1, y1, x2, y2),
"crop": base64_encode_img(crop)
}
}
data = json.dumps(data) # 打包数据
try:
self.client.send(data) # 客户端发送数据
pass
except Exception as err:
print("发送失败:", err)
self.client.connect()
self.client.send(data)
print("重连成功!")
print(data)
# if self.save_txt: # Write to file
# xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(
# -1).tolist() # normalized xywh
# line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh) # label format
# with open(txt_path + '.txt', 'a') as f:
# f.write(('%g ' * len(line)).rstrip() % line + '\n')
# 画框
# if self.save_img or self.save_crop or self.view_img: # Add bbox to image
# c = int(cls) # integer class
# label = None if self.hide_labels else (self.names[c] if self.hide_conf else
# f'{self.names[c]} {conf:.2f}')
# annotator.box_label(xyxy, label, color=colors(c, True))
def run(self):
self.client.connect()
while True:
for path, img, im0s, vid_cap in self.dataset:
if self.last_time.__lt__(moment.now()):
self.last_time = moment.now().add(seconds=self.check_time_step)
try:
pred = self.inference(img)
self.process(im0s, img, pred, path)
except Exception as err:
print(err)
if self.save_txt or self.save_img:
s = f"\n{len(list(self.save_path.glob('labels/*.txt')))} labels saved to {self.save_path / 'labels'}" if self.save_txt else ''
print(f"Results saved to {colorstr('bold', self.save_path)}{s}")
if self.update:
strip_optimizer(self.weights) # update model (to fix SourceChangeWarning)
if __name__ == "__main__":
message_base = MessageBase()
wc = WebClient("192.168.6.28", 8000)
configs = read_yaml_all("yolo_configs.yaml")
config = read_yaml_all("configs.yaml")
device_name = config.get("DEVICE_LIST")[0]
device_source = config.get("RTSP_URLS").get(device_name)
configs["source"] = device_source
configs["func_device"] = device_name
print(configs)
detect = Detect(configs, wc)
detect.run()