从DETR到Mask2Former(1):DETR-segmentation结构全解析

网上关于DETR做的detection的解析很多,但是DETR做Segmentation的几乎没有,本文结合DETR的论文与代码,对DETR做一个详细的拆解。理解DETR是理解Mask2Former的基础。

首先得把DETR-segmentation给run起来。Github上DETR的repository,下载了也只能run起来detection,run不起来segmentation功能,但还是下载下来,后面留着有用。我们用torch的hub里集成的DETR segmentation模型,运行下面的代码

import torch
models_list = torch.hub.list('facebookresearch/detr', force_reload=True)
print(models_list)

你可以看到torch.hub中所有关于detr的模型。我们选择  detr_resnet50_panoptic

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)

再新建一个py文件,把以下代码放进去:

import math
from PIL import Image
import requests
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)

import matplotlib.pyplot as plt


# COCO classes
CLASSES = [
	'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
	'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
	'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
	'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
	'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
	'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
	'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
	'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
	'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
	'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
	'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
	'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
	'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
	'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
		  [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

transform = T.Compose([
	T.Resize(800),
	T.ToTensor(),
	T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
	x_c, y_c, w, h = x.unbind(1)
	b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
		 (x_c + 0.5 * w), (y_c + 0.5 * h)]
	return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
	img_w, img_h = size
	b = box_cxcywh_to_xyxy(out_bbox)
	b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
	return b

def plot_results(pil_img, prob, boxes):
	plt.figure(figsize=(16,10))
	plt.imshow(pil_img)
	ax = plt.gca()
	colors = COLORS * 100
	for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
		ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
								   fill=False, color=c, linewidth=3))
		# cl = p.argmax()
		# text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
		# ax.text(xmin, ymin, text, fontsize=15,
		#         bbox=dict(facecolor='yellow', alpha=0.5))
	plt.axis('off')
	plt.show()


model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)
model.eval()

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)

# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)

# propagate through the model
outputs = model(img)

# show result of detection
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)


# show segmentation
# compute the scores, excluding the "no-object" class (the last one)
scores = outputs['pred_logits']
scores = scores.softmax(dim=-1)
scores = scores[..., :-1]
scores_onehot = scores.max(-1)
scores_onehot = scores_onehot[0]
# threshold the confidence
keep = scores_onehot > 0.85

scores_selected = scores[keep]
labels = torch.argmax(scores_selected, dim=-1)

masks = outputs['pred_masks'][keep].detach().cpu().numpy()
mask_i = masks[0, :, :]

# plt.imshow(mask_i, cmap="viridis")
# plt.show()

# ## Plot all the remaining masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:
	for a in line:
		a.axis('off')
for i, mask in enumerate(outputs['pred_masks'][keep].detach().cpu().numpy()):
	ax = axs[i // ncols, i % ncols]
	ax.imshow(mask, cmap="cividis")
	ax.text(0, 0, labels[i].cpu().numpy().item(0))
	ax.axis('off')
fig.tight_layout()

plt.show()

mask_pred = outputs['pred_masks'].sigmoid()
mask_pred = F.interpolate(mask_pred, size=(480, 640), mode='bilinear', align_corners=False)

semseg = torch.einsum("bqc,bqhw->bchw", scores, mask_pred)
result = torch.argmax(semseg, dim=1)


plt.figure(figsize=(12, 8))
# 第一个子图
plt.subplot(1, 1, 1)
# image1_rgb = cv2.cvtColor(img_color, cv2.COLOR_BGR2RGB)
plt.imshow(result[0, :, :].to("cpu").numpy())
plt.title('Image 1')
plt.axis('off')

plt.show()

解释一下上面的代码,

plot_results(im, probas[keep], bboxes_scaled)

是把detection的结果可视化出来,本文略过的detection部分的内容。  plot_results()  下面的代码就是segmentation内容结果的可视化。模型的输出包括三个内容

pred_logits(1, 100, 251)中三个数字的含义是:1是batchsize,100是query的个数,251是分类数,去掉最后一个no-object不要,实际是250个类。可以这样理解,每一个query都会进入网络,但不是每一个query都能从图片找到东西,而找到东西的query,找到的instance所属于的类就是这个query所在的那一行中,最大的数所在列的index。如果一个图片里有两只猫,那结果就是有两个query,分别各自找到一只猫(也就是一个instance),这样也就实现了全景分割的功能。所以说DETR的结构是Mask2Former实现全景分割功能的基础。

pred_boxes是query所找到的instance对应的box。

pred_masks是query所找到的instance所在的像素。

这段代码:

# ## Plot all the remaining masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:
	for a in line:
		a.axis('off')
for i, mask in enumerate(outputs['pred_masks'][keep].detach().cpu().numpy()):
	ax = axs[i // ncols, i % ncols]
	ax.imshow(mask, cmap="cividis")
	ax.text(0, 0, labels[i].cpu().numpy().item(0))
	ax.axis('off')
fig.tight_layout()

plt.show()

会画出这样的图像:

有94个query没有找到instance,有6个query找到了instance,2个query找到了猫,也就是250列的17列,两个query找到了遥控器(250列的74列)。

好了,现在DETR-seg的输出我们弄清楚了,接下来进到DETR内部去看看,这个模型封装在了torch.hub中,进入内部的正确方法,说实话笔者也不知道,这里笔者耍了一个不正经的小trick:

import torch

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', pretrained=True)

from models.segmentation import DETRsegm

model = DETRsegm(model)

上面的代码中,models.segmentation文件是DETR的github repository下载下来,里面有的内容,运行上面的代码,会报错

但是通过这个报错,我们可以找到torch.hub中DETR的源代码。这样就可以在源代码里打断点,看DETR 内部了。进入源代码,我们能看到这样的内容,注释的代码是笔者加进去的,不是torch自带的。

self.detr.backbone是resnet18。 这里,我们打开DETR的github代码——detr/models/backbone.py 找到  class BackboneBase 这个类。

这里的 def forward 中的out,就是 DETRsegm 中的 features

features, pos = self.detr.backbone(samples)

pos是正弦函数的 position embeding,就是输入transformer的encoder的位置编码。

在获取features时,DETR detection和Segmentation的区别在与,detection只拿resnet第四层的输出,而segmentation将每层的输出都拿出来。通过这段代码实现(缩进乱了,忽略)。

if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)

这里扯远一点,复习一下resnet18  resnet18在torch中也有集成,通过下面的代码得到,

import torch
import torchvision.models as models
from torchvision.models._utils import IntermediateLayerGetter
import cv2
import numpy as np


# 加载一个预训练的 ResNet 模型
model = models.resnet18(pretrained=True)
model.to("cuda")

# 定义要获取的中间层
layers = {'layer1': 'layer1', 'layer2': 'layer2', 'layer3': 'layer3', 'layer4': 'layer4'}

# 创建 IntermediateLayerGetter
intermediate_layers = IntermediateLayerGetter(model, layers)

image0 = cv2.imread("/home/robotics/dino/img/kitaku/002.jpg")
height0, width0, channels = image0.shape
image = cv2.resize(image0, (640, 480), interpolation=cv2.INTER_AREA)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mean = np.array([0.485, 0.456, 0.406]) * 255
std = np.array([0.229, 0.224, 0.225]) * 255
image = image.astype(float)
for i in range(3):
	image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]

input_data = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(dim=0).to(torch.float32).to("cuda")

# 获取中间层的输出
outputs = intermediate_layers(input_data)

# 打印输出
for key, value in outputs.items():
	print(f"{key}: {value.shape}")


from torchview import draw_graph
model_graph = draw_graph(model, input_size=(1, 3, 480, 640))
model_graph.resize_graph(scale=5.0)
model_graph.visual_graph.render(format='svg')

通过draw_graph得到一个网络结构可视化图,格式是svg,如下图所示

也可以自己写代码,构建一个resnet18

import torch
import torch.nn as nn
from torch.nn import functional as F


class RestNetBasicBlock(nn.Module):
	def __init__(self, in_channels, out_channels, stride):
		super(RestNetBasicBlock, self).__init__()
		self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
		self.bn1 = nn.BatchNorm2d(out_channels)
		self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
		self.bn2 = nn.BatchNorm2d(out_channels)

	def forward(self, x):
		output = self.conv1(x)
		output = F.relu(self.bn1(output))
		output = self.conv2(output)
		output = self.bn2(output)
		return F.relu(x + output)


class RestNetDownBlock(nn.Module):
	def __init__(self, in_channels, out_channels, stride):
		super(RestNetDownBlock, self).__init__()
		self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
		self.bn1 = nn.BatchNorm2d(out_channels)
		self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
		self.bn2 = nn.BatchNorm2d(out_channels)

	def forward(self, x):
		output = self.conv1(x)
		out = F.relu(self.bn1(output))

		out = self.conv2(out)
		out = self.bn2(out)
		return F.relu(x + out)


class RestNet18(nn.Module):
	def __init__(self):
		super(RestNet18, self).__init__()
		self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
		self.bn1 = nn.BatchNorm2d(64)
		self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

		self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
									RestNetBasicBlock(64, 64, 1))

		self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
									RestNetBasicBlock(128, 128, 1))

		self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
									RestNetBasicBlock(256, 256, 1))

		self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2, 1]),
									RestNetBasicBlock(512, 512, 1))

		self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

		self.fc = nn.Linear(512, 10)

	def forward(self, x):
		out = self.conv1(x)
		out = self.layer1(out)
		out = self.layer2(out)
		out = self.layer3(out)
		out = self.layer4(out)
		out = self.avgpool(out)
		out = out.reshape(x.shape[0], -1)
		out = self.fc(out)
		return out


if __name__ == "__main__":
	resnet18 = RestNet18()
	
	from torchview import draw_graph
	model_graph = draw_graph(resnet18, input_size=(1, 3, 480, 640))
	model_graph.resize_graph(scale=5.0)
	model_graph.visual_graph.render(format='svg')

好,回到DETR,还是看torch.hub中的这段detr代码:

 上面的代码中:

hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])

self.detr.transformer就是下图经典的transformer结构,memory是encoder的输出  hs是decoder的输出。mask对于理解detr不重要,是masked attention中用到了。

pos和query_embed.weight分别对应上图的Positional Encoding和decoder中,output上面的OutputEmbeding。

DETR中transformer的理解,可以参看下面的代码,这段代码是笔者自己写的,对于DETR本身是没有用处的,只是为了方便理解,有错误的地方(feature那里只拿出resnet18最后一层的输出,是错的,不影响大局,懒得改了),自己理解时候做一个参考就好。

import torch
from torch import nn
from torchvision.models import resnet50
from models.transformer import TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder
from models.segmentation import MHAttentionMap, MaskHeadSmallConv
import time


class DETR(nn.Module):
    def __init__(self, num_classes, d_model, nheads, dim_feedforward=2048,
                 num_encoder_layers=6, num_decoder_layers=6, dropout=0.9, activation='relu', normalize_before=True):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(dim_feedforward, d_model, 1)
        
        encoder_layer = TransformerEncoderLayer(d_model, nheads, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nheads, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=True)

        self.linear_class = nn.Linear(d_model, num_classes + 1)
        self.linear_bbox = nn.Linear(d_model, 4)
        
        num_queries = 100
        self.query_embed = nn.Embedding(num_queries, d_model)
        
        #position embedding
        self.row_embed = nn.Parameter(torch.rand(50, d_model // 2))
        self.col_embed = nn.Parameter(torch.rand(50, d_model // 2))
        
        self.bbox_attention = MHAttentionMap(d_model, d_model, nheads, dropout=0.0)
        self.mask_head = MaskHeadSmallConv(d_model + nheads, [1024, 512, 256], d_model)
 
    def forward(self, inputs):
        #inputs是[1,3,800,1200]
        features = self.backbone(inputs)
        #x是[1,2048,25,38]
        hh = self.conv(features)
        #hh是[1,256,25,38]
        H, W = hh.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        
        bs, c, h, w = hh.shape
        
        query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
        tgt = torch.zeros_like(query_embed)
        src = hh.flatten(2).permute(2, 0, 1)
        memory = self.encoder(src=src, pos=pos)
        hs = self.decoder(tgt, memory, pos=pos, query_pos=query_embed)
        
        hs = hs.transpose(1, 2)
        memory = memory.permute(1, 2, 0).view(bs, c, h, w)
        bbox_mask = self.bbox_attention(hs[-1], memory)
        
        seg_masks = self.mask_head(hh, bbox_mask, [features[0], features[0], features[0]])
        outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
        
        #self.query_pos是[100,256]
        #src是encoder输入,tgt是decoder输入
        

        #h是[100,1,256]
        return outputs_seg_masks
 
 
#coco是91个类, hidden dimension是256, 多头注意力是8, encoder,decoder layer都是6
device = torch.device("cuda")
detr = DETR(num_classes=91, d_model=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval().cuda()


inputs = torch.randn(1, 3, 800, 1200).cuda()

outputs_seg_masks = detr(inputs)
# print(logits, bboxes)
#logits是[100,1,92]
#bboxes是[100,1,4]


还是回到torch.hub那个detr代码,下面结合DETR论文中的那张图来理解一下这段代码

图中multi head attention部分,对应代码中

bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)

hs[-1]是左侧四个不同颜色的小框,memory是Encoded image,mask可以忽略,bbox_mask的尺寸是:(batch_size,query的个数,注意力头的个数,attention map的高和宽)

        temp = bbox_mask[0][20].squeeze().cpu().numpy()
        import matplotlib.pyplot as plt
        for i in range(8):
            plt.imshow(temp[i, :, :], cmap="viridis")
            plt.show()

上面那段代码就是拿出第1个batchsize的第21个query的所有注意力头的attention map,也就是可以画出这几张图:

 

seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])

mask_head对应图中这个结构:

输入是多头注意力map,以及resnet18这个backbone的4层的特征。

本文先到这,下一篇写DETR的损失函数。

未完待续

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-01-13 15:02:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-13 15:02:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-13 15:02:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-13 15:02:02       18 阅读

热门阅读

  1. Mybatis 37_使用隐式参数名处理多个参数

    2024-01-13 15:02:02       39 阅读
  2. 问题解决记录-pypcd

    2024-01-13 15:02:02       34 阅读
  3. What is `response.isCommitted()` does?

    2024-01-13 15:02:02       36 阅读
  4. 【基础数据结构】栈和队列

    2024-01-13 15:02:02       34 阅读
  5. uniapp搜索附近蓝牙信标(iBeacon)

    2024-01-13 15:02:02       37 阅读
  6. SQL常用时间处理函数总结

    2024-01-13 15:02:02       37 阅读
  7. RNN和LSTM的区别是什么?

    2024-01-13 15:02:02       33 阅读
  8. 利用PowerShell和Mkvtoolnix批量去除视频封面

    2024-01-13 15:02:02       41 阅读
  9. 深入理解虚拟DOM:原理、优势与实践

    2024-01-13 15:02:02       28 阅读
  10. 一文读懂Qt信号与槽的机制

    2024-01-13 15:02:02       31 阅读