【利用GroundingDINO裁剪分类任务的数据集】及文本提示检测图像任意目标(Grounding DINO) 的使用

背景

  • 在处理公开数据集ImageNet-21k的时候发现里面有很多的数据有问题,比如,数据目标有很多背景,且部分类别有其他种类的图片。
  • 针对数据目标有很多背景,公开数据集ImageNet-21k的21k种类别进行裁剪。
  • 文本提示检测图像任意目标(Grounding DINO),这更模型可以很好的应用在这个场景。

1.Grounding DINO安装

github地址

  1. 从 GitHub 克隆 GroundingDINO 存储库。
git clone https://github.com/IDEA-Research/GroundingDINO.git
  1. 将当前目录更改为 GroundingDINO 文件夹。
cd GroundingDINO/
  1. 在当前目录中安装所需的依赖项。
pip install -e .
  1. 下载预训练模型权重。
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
  1. 下载bert-base-uncased到text_encoder_type(自己创建一个文件夹)

需要下载下面的三个文件,放进text_encoder_type里面就好。
在这里插入图片描述

  1. 修改地址

修改/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py文件中text_encoder_type的路径。

在这里插入图片描述

  1. 如果您有 CUDA 环境,请确保设置了环境变量 CUDA_HOME 。如果没有可用的 CUDA,它将在仅 CPU 模式下编译。

  2. 可能遇到的bug

 Segmentation fault (core dumped)

是因为timm版本和cuda,pytorch等版本不匹配重新安装可以解决这个bug。

pip uninstall timm
pip install timm

2.裁剪指定目标的脚本

  1. 如下是测试的demo
import cv2

print("456")
from groundingdino.util.inference import load_model, load_image, predict, annotate

print("123")
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth", "cpu")
IMAGE_PATH = r"images/th.jpg"
TEXT_PROMPT = "dolphins"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
print("456")
image_source, image = load_image(IMAGE_PATH)

print("789")
boxes, logits, phrases = predict(
    model=model,
    image=image,
    caption=TEXT_PROMPT,
    box_threshold=BOX_TRESHOLD,
    text_threshold=TEXT_TRESHOLD
)

print("10")
print(boxes)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite("annotated_image.jpg", annotated_frame)

在这里插入图片描述

在这里插入图片描述

  1. 裁剪指定目标的脚本

该脚本指定目录后,会对该目录下子文件夹的不同目标类别,进行裁剪并将裁剪结果放在与原路径对应的相对路径种。

脚本全部代码:

import os
import time
from groundingdino.util.inference import load_model, load_image, predict
import cv2
import torch
from torchvision.ops import box_convert

def save_cropped_images(image, boxes, image_name, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    h, w, _ = image.shape
    boxes = boxes * torch.tensor([w, h, w, h])
    xyxy_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()

    for i, box in enumerate(xyxy_boxes):
        x_min, y_min, x_max, y_max = map(int, box)
        cropped_image = image[y_min:y_max, x_min:x_max]
        # Ensure the color channels are in BGR order for OpenCV
        cropped_image_bgr = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(f"{output_folder}/{image_name}_cropped_{i}.jpg", cropped_image_bgr)

def process_image(image_path, model, output_folder, box_threshold=0.35, text_threshold=0.25):
    image_source, image = load_image(image_path)
    try:
      boxes, logits, phrases = predict(
          model=model,
          image=image,
          caption=TEXT_PROMPT,
          box_threshold=box_threshold,
          text_threshold=text_threshold
      )
    except RuntimeError as e:
      print(f"RuntimeError: {e}")

    # Get the image name without extension
    image_name = os.path.splitext(os.path.basename(image_path))[0]

    # Save cropped images with image name included
    save_cropped_images(image_source, boxes, image_name, output_folder)

def process_images_in_folder(folder_path, model, box_threshold=0.35, text_threshold=0.25):
    folder_name = os.path.basename(folder_path.rstrip('/'))
    output_folder = os.path.join("/animals_classify/Cropped_Dataset/QuanKe", folder_name)
    print(f"{folder_name}, cropping.")
    # Start timer for processing this folder
    start_time = time.time()
    
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".JPEG"):
            image_path = os.path.join(folder_path, filename)
            process_image(image_path, model, output_folder, box_threshold, text_threshold)
    
    # End timer for processing this folder
    folder_processing_time = time.time() - start_time
    process_images_in_folder.total_time += folder_processing_time
    
    print(f"{folder_name}, cropped. Time taken: {folder_processing_time:.2f} seconds")
    print(f"Total time taken so far: {process_images_in_folder.total_time:.2f} seconds")

# Initialize the total time taken to 0
process_images_in_folder.total_time = 0.0

# Configuration and model loading
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth")
TEXT_PROMPT = "canine"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25

FOLDERS_PATH = "/animals_classify/Raw_Dataset/QuanKe"
for FOLDER_Name in os.listdir(FOLDERS_PATH):
	FOLDER_PATH = os.path.join(FOLDERS_PATH, FOLDER_Name)
	# Process all images in the folder
	process_images_in_folder(FOLDER_PATH, model, BOX_THRESHOLD, TEXT_THRESHOLD)

裁剪示例:
原图:
在这里插入图片描述

结果:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 07:16:03       101 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 07:16:03       108 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 07:16:03       91 阅读
  4. Python语言-面向对象

    2024-07-11 07:16:03       98 阅读

热门阅读

  1. Python编程实例-处理Linux/UNIX系统中的信号

    2024-07-11 07:16:03       31 阅读
  2. 构造函数语意学(The Semantics of Constructors)

    2024-07-11 07:16:03       28 阅读
  3. PostgreSQL关闭数据库服务的三种模式

    2024-07-11 07:16:03       27 阅读
  4. 聚类方法K-means和DBSCAN,附matlab代码

    2024-07-11 07:16:03       25 阅读
  5. mysql默认开启索引下推,减少回表的数据

    2024-07-11 07:16:03       25 阅读
  6. Spring Boot项目Jar包加密详解

    2024-07-11 07:16:03       30 阅读
  7. 云端足迹:在iCloud中同步您的地图标记和路线

    2024-07-11 07:16:03       28 阅读
  8. Spring Boot(八十):Tesseract实现图片文字自动识别

    2024-07-11 07:16:03       25 阅读
  9. 5-2.模型层

    2024-07-11 07:16:03       20 阅读
  10. 一键安装ros及出现问题的解决方案

    2024-07-11 07:16:03       27 阅读
  11. [PaddlePaddle飞桨] PaddleOCR图像小模型部署

    2024-07-11 07:16:03       28 阅读
  12. 一起来了解深度学习中的“梯度”

    2024-07-11 07:16:03       28 阅读
  13. linux之内存泄漏分析

    2024-07-11 07:16:03       23 阅读
  14. Kotlin Class

    2024-07-11 07:16:03       28 阅读
  15. uniapp vue3微信小程序如何获取dom元素

    2024-07-11 07:16:03       26 阅读
  16. ROI 接口便捷修改

    2024-07-11 07:16:03       22 阅读