🚩🚩🚩Transformer实战-系列教程总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)
3、ConvertCocoPolysToMask类
位置:datasets/coco.py/ConvertCocoPolysToMask类
ConvertCocoPolysToMask类主要是进行数据预处理,主要在CocoDetection类中被调用
从的ConvertCocoPolysToMask
类的代码来看,主要涉及到以下几种计算机视觉任务的数据预处理步骤:
- 物体检测(Object Detection):
- 体现:通过处理
bbox
(边界框)信息。代码中提取和调整bbox
坐标来适应物体检测任务的需求。
- 体现:通过处理
- 实例分割(Instance Segmentation):
- 体现:如果
return_masks
为True,将COCO多边形标注(segmentation
)转换为掩码(mask
)。这对于实例分割任务来说是必要的,因为它需要精确地区分图像中各个对象的形状。
- 体现:如果
- 姿态估计(Pose Estimation):
- 体现:通过处理
keypoints
信息。当标注中包含关键点数据时,代码会提取这些数据,这些数据对于识别和估计图像中人物的姿态非常有用。
- 体现:通过处理
class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
boxes = [obj["bbox"] for obj in anno] # x y w h
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {
}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]
target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])
return image, target
- 定义ConvertCocoPolysToMask类,用于处理COCO数据集的转换
- 类的初始化方法,参数return_masks用于控制是否返回标注的掩码信息
- return_masks
- 可调用方法,接收两个参数:image(PIL图像对象)和target(包含图像标注信息的字典)
- 图像w、h, 427 ∗ 640 427*640 427∗640,每张图片读进来的长宽都可能不一样
- 获取图像id,image id: 538686]
- 将id转化为Tensor,image id: tensor([538686])
- 获取标签的标注信息,包含面积、bbox框的长宽xy四个值、类别id、图像id、分割的标注信息
- 过滤标注信息,过滤掉有重叠框的,只保留对单个物体的框,包含重叠物体的不要,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
- 获取所有框,[[62.37, 135.48, 184.94, 364.52],…, [107.99, 46.17, 101.51, 157.66]]
- 框的数据转化为Tensor
- 将x、y、w、h
- 转化为
- x1、y1、x2、y2,tensor([[ 62.3700, 135.4800, 247.3100, 500.0000],…, [107.9900, 46.1700, 209.5000, 203.8300]])
- 获取当前图像的所有类别标签(可能对应有多个类别),[19, 21, 21, 1]
- 转化为Tensor,tensor([19, 21, 21, 1])
- 是否进行掩码转换
- 提取分割信息
- 调用函数将分割信息转化为掩码
- 初始化 keypoints (姿态估计任务使用)变量
- 判断标注信息中是否包含 keypoints 信息
- 提取所有标注的 keypoints 信息
- 将 keypoints 列表转换为PyTorch张量
- 获取 keypoints 的数量
- 判断是否存在 keypoints
- 重塑 keypoints 张量
- 过滤掉不合逻辑的边界框(即右下角坐标不大于左上角坐标的边界框),因为在标注数据的时候,外包人员如果没有按照标注说明去标,拉框不是从上面往下框住物体,而是从下往上,这会影响两个点的顺序判断
- 使用keep数组过滤边界框,保留有效的边界框
- 同样使用keep数组过滤类别ID,保留与有效边界框对应的类别ID
- 判断是否有掩码
- 使用keep数组过滤掩码,保留与有效边界框对应的掩码
- 如果 keypoints 信息存在
- 使用keep数组过滤 keypoints 信息
- 初始化一个新的字典target,用于存储处理后的标注信息
- 将过滤后的边界框信息添加到target字典
- 将过滤后的类别ID添加到target字典
- 判断是否有掩码
- 将过滤后的掩码添加到target字典
- 将图像ID添加到target字典
- 如果存在关键点信息
- 将过滤后的关键点信息添加到target字典
- 提取所有标注的面积信息,并转换为PyTorch张量
- 将过滤后的iscrowd信息添加到target字典,如果iscrowd为1,表示这个标注包含的是一个对象群,而不是单个对象
- 分别将原始图像的高度和宽度作为orig_size和size添加到target字典。这两个字段通常用于后续的处理或数据恢复步骤
- 返回处理后的图像和更新后的target字典
DETR 算法解读
DETR 源码解读1(项目配置/CocoDetection类)
DETR 源码解读2(ConvertCocoPolysToMask类)
DETR 源码解读3(DETR类)
DETR 源码解读4(Joiner类/PositionEmbeddingSine类/位置编码/backbone)