RIS 系列 Mask Grounding for Referring Image Segmentation 论文阅读笔记


写在前面

  一篇 Arxiv 上面的新文章,看看清华大佬们的研究。

  • 论文地址:Mask Grounding for Referring Image Segmentation
  • 代码地址:原论文说将会开源,静待佳音~
  • 预计提交于:CVPR 2024
  • Ps:2023 年每周一篇博文阅读笔记,主页 更多干货,欢迎关注呀,期待 6 千粉丝有你的参与呦~

一、Abstract

  Referring Image Segmentation (RIS) 的定义,目前的 SOTA 方法仍然存在像素和词水平上的语言-图像模态鸿沟。主要原因:通常依赖于句子级别的语言特征用于语言-图像对齐;缺乏对细粒度视觉定位的监督。另外,由于弱的视觉和语言特征间的关联,因此需要更有效的推理去理解那些包含多个目标的复杂场景。于是本文引入 Mask Grounding 辅助任务来提升视觉定位的性能,Mask Grounding 直接适用于之前的模型。此外,为全面解决模态鸿沟,设计了一种跨模态对齐损失和一种辅助对齐模块。在 MagNet(Mask-grounded Network) 上达到了 SOTA 的效果。

二、引言

  首先指出图像分割任务与 Referring Image Segmentation 的区别,RIS 的应用。然后就是 RIS 的挑战:如何减少语言和图像特征间的模态鸿沟?需要一个有效性的对齐方法。

在这里插入图片描述
  如上图所示,之前的方法主要关注于设计不同的损失函数或者引入网络结构/模块来促进对齐,然而缺陷有俩:往往依赖于句子级别的语言特征对齐;缺乏对细粒度视觉定位的显著训练监督。于是难以处理复杂的目标间关系或者包含很少或混乱上下文的子句。如下图所示:

在这里插入图片描述
  于是文本引入一种 Mask Grounding 的辅助任务用于显式地教会模型进行细粒度的对齐。具体来说,在训练过程中,模型随机 mask 掉一些文本词汇,并且让模型来预测这些词汇的实体信息。除了整合文本上下文信息外,还利用了视觉和分割的信息。

  除 Mask Grounding 外,还提出一种跨模态对齐损失和一个对齐模块来全面填补模态鸿沟。整合的模型 MagNet (Mask-grounded Network) 达到了 SOTA 的效果。主要贡献如下:

  • 突出了最近 SOTA 的 RIS 方法的缺陷,指出细粒度视觉定位的缺乏;
  • 引入 Mask Grounding 辅助任务,旨在增强细粒度的视觉定位算法;
  • Mask Grounding + 跨模态对齐损失 + 辅助对齐模块 = MagNet (Mask-grounded Network),实现了新的 SOTA。

三、相关工作

Architecture Design for RIS

  早期的方法遵循拼接-卷积的操作,后续的工作采用 RNN 或动态卷积的方法。还有一些方法设计出语言-图像融合模块。此外,还有一些工作利用已知的语言结构或目标关系来增强融合。随着注意力结构的成功,当前的工作通常采用无向或双向的交叉注意力模块来执行语言-图像融合。有一些工作使用元学习的方法用于 RIS。受到大语言模型的驱动,一些方法将 RIS 视为自回归向量生成问题。接下来是一些举例:VPD、ReLA、DMMI。这些方法的缺陷在于:总是期待语言-图像对齐发生在 mask 预测过程中。于是本文引入一种辅助任务来显式地对齐语言-图像特征。

Loss Design for RIS

  早期用于训练 RIS 的方法通常采用简单的 binary cross entropy 损失, 接下来是对比学习的损失。与之前使用全局池化的语言特征计算损失相比,本文关注在像素-词水平上学习细粒度的目标联系。

Masked Language Modeling

  Masked language modeling (MLM) 在自然语言处理中很普遍。首先引入的是 BERT,之后就成为预训练模型及视觉-语言模型的标配。最近的 MaskedVLM 采用 MLM 执行 mask vision and language 建模,在一个模态的辅助作用下重建另外一个模态。Mask Grounding 与其不同,通过使用外部 mask 信号直接匹配缺失的单词从而确保重建过程,学习到相应的细粒度视觉定位的信息。

四、方法

在这里插入图片描述

4.1 结构

  MagNet (Mask-grounded Network) 由三个模块组成,首先 Mask Grounding 旨在提升细粒度视觉定位性能,使用视觉线索、语言上下文、分割信息来教会模型预测 masked 文本 tokens。之后 Cross-modal Alignment Module (CAM) 用于微调语言和图像特征间的双向交互。最后 Crossmodal Alignment Loss (CAL) 监督像素-像素和像素-文本的对齐。

4.2 Mask Grounding

  如图 3 所示,给定输入图像及其对应的指代表达式和分割 mask。首先利用一些特定词汇来代替指代表达式中的一些 tokens,从而训练模型来预测这些 tokens(类似于 MLM)。具体来说,首先取得 mask 区域的中心坐标,然后通过一个 2 层的 MLP,将其 mask 编码到一个 mask embedding 中。同时采用一个线性层将语言 embedding 投影到与图像 embedding 相同的维度。然后应用提出的 Masked Token 预测器在注意力机制的作用下来处理所有拼接的 embedding 用于 masked token 预测。最后,使用一个 cross-entropy 损失 L grounding \mathcal L_{\text{grounding}} Lgrounding 来比较最终的预测分布与目标分布。数学公式描述如下:令 T , I , M \bold{T},\bold{I},\bold{M} T,I,M 分别表示语言编码器、图像编码器、mask 编码器的输入:
O = Language E n c o d e r ( M a s k ( T ) ) P = I m a g e Encoder ( I ) C = MaskEncoder ( M ) L g r o u n d i n g = L C E ( y g t , Predictor ( C o n c a t ( [ O , P , C ] ) \mathbf{O}=\text{Language}\mathrm{Encoder}(\mathrm{Mask}(\mathbf{T}))\\ \mathbf{P}=\mathrm{Image}\text{Encoder}(\mathbf{I})\\ \mathbf{C}=\text{MaskEncoder}(\mathbf{M})\\ \mathcal{L}_{\mathrm{grounding}}=\mathcal{L}_{\mathrm{CE}}(\mathbf{y}_{\mathrm{gt}},\text{Predictor}(\mathrm{Concat}([\mathbf{O},\mathbf{P},\mathbf{C}]) O=LanguageEncoder(Mask(T))P=ImageEncoder(I)C=MaskEncoder(M)Lgrounding=LCE(ygt,Predictor(Concat([O,P,C])其中预测器为类似 BERT 编码器的结构, M M M 为 GT masks 的中心坐标, y g t \mathbf{y}_{\mathrm{gt}} ygt 为 masked token 的标签, L C E \mathcal{L}_{\mathrm{CE}} LCE 为交叉熵损失。实验中设置 Swin-B 为图像编码器,BERT-base 为语言编码器,但方法不限于此。

讨论

在这里插入图片描述
  如上表所示,Mask Grounding 超越了标准的 masked language modeling (MLM) 和 masked-vision language modeling (MaskedVLM)。原因在于:模型整合:传统的 MLM 为单模态,缺乏了指代表达式及其匹配的视觉目标的联系,而 MaskedVLM 为多模态,Mask Grounding 能够超越的目的在于引入了额外的 masking 信号来对齐 masked words 和匹配的视觉目标。这一结果表明 词-目标联系和细粒度的视觉定位很重要;任务属性:MLM 和 MaksedVLM 作为一般的预训练任务,需要在下游任务上进行微调,而 Mask Grounding 设计于一个 RIS 的辅助任务,在训练阶段增强了细粒度的视觉定位性能,且不需要额外的微调;预测上下文:MLM 和 MaskedVLM 采用文本或文本-视觉上下文预测,而 Mask Grounding 整合了外部的分割信息,于是性能更好。

4.3 跨模态对齐模块

在这里插入图片描述
  如上图所示,提出的跨模态对齐、 cross-modal alignment module (CAM) 将全局上下文先验注入到图像特征中,再进行跨模态融合。CAM 首先采用不同窗口尺寸的池化操作生成 K K K 个不同尺度的特征图构成特征金字塔。然后,每个特征图将会通过一个 3 层的 MLP 用于提取全局特征。之后所有的输出特征将通过双线性插值上采样到原始特征图,然后沿着特征维度拼接。同样采用一个门控单元来调制最后的输出。最终,输出后的特征返回到输入特征上用于下一阶段图像或语言编码器的输入。将语言编码器划分为 4 个阶段,并在每个阶段的末尾添加 CAM 模块。

  用数学公式表示如下:令 T i \bold T_i Ti I i \bold I_i Ii 分别表示语言和图像编码器每个阶段的输入,于是每个阶段有:
O i = LanguageStage ( T i ) , P i = I m a g e Stage ( I i ) P i k = M L P k ( P o o l k ( P i ) ) , p 2 t K , P i , t 2 p k = X − M H A k ( O i , P n k ) O i , p 2 t = C o n c a t ( [ O i , p 2 t i , . . . , O i , p 2 t N ] 2 p = C o n c a t ( [ U p ( P i , t 2 p 1 , . . . , U p ( P i , t 2 p N ) ] O i + 1 = O i + tanh ⁡ ( M L P ( O i , p 2 t ) ) P i + 1 = P i + tanh ⁡ ( M L P ( P i , t 2 p ) ) \begin{gathered} \mathbf{O}_i=\text{LanguageStage}(\mathbf{T}_i),\mathbf{P}_i=\mathrm{Image}\text{Stage}(\mathbf{I}_i)\\ \mathbf{P}_i^k=\mathrm{MLP}_k(\mathrm{Pool}_k(\mathbf{P}_i))\\ {}_{,p2t}^{K},\mathbf{P}_{i,t2p}^{k}=\mathrm{X-MHA}_{k}(\mathbf{O}_{i},\mathbf{P}_{n}^{k}) \\ \mathbf{O}_{i,p2t}=\mathrm{Concat}([\mathbf{O}_{i,p2t}^i,...,\mathbf{O}_{i,p2t}^N] \\ _{2p}=\mathrm{Concat}([\mathrm{Up}(\mathbf{P}_{i,t2p}^{1},...,\mathrm{Up}(\mathbf{P}_{i,t2p}^{N})]\\ \mathbf{O}_{i+1}=\mathbf{O}_i+\tanh(\mathsf{MLP}(\mathbf{O}_{i,p2t}))\\ \mathbf{P}_{i+1}=\mathbf{P}_{i}+\operatorname{tanh}(\mathsf{MLP}(\mathbf{P}_{i,t2p})) \end{gathered} Oi=LanguageStage(Ti),Pi=ImageStage(Ii)Pik=MLPk(Poolk(Pi)),p2tK,Pi,t2pk=XMHAk(Oi,Pnk)Oi,p2t=Concat([Oi,p2ti,...,Oi,p2tN]2p=Concat([Up(Pi,t2p1,...,Up(Pi,t2pN)]Oi+1=Oi+tanh(MLP(Oi,p2t))Pi+1=Pi+tanh(MLP(Pi,t2p))其中 U p Up Up 表示上采样,X-MHA 表示双向跨模态多头注意力。

4.4 跨模态对齐损失

  采用跨模态对齐损失来对齐语言和图像特征,其中 cross-modal alignment loss (CAL) 全面总结了像素-像素级别的损失 L P2P \mathcal L_{\text{P2P}} LP2P 和像素-文本损失 L P2T \mathcal L_{\text{P2T}} LP2T。用数学公式表示如下:给定语言编码器产生的语言特征 T ∈ R M × D \bold T\in\mathbb{R}^{M\times D} TRM×D,包含 ∣ P ∣ |\mathcal P| P 个正样本像素特征的最终像素解码器 mask 特征 I ∈ R C l × H l × W l \bold I\in \mathbb{R}^{C_l\times H_l \times W_l} IRCl×Hl×Wl ∣ N ∣ |\mathcal N| N 个负样本像素特征。 I i + \bold I_i^+ Ii+ 表示正样本集合 P \mathcal P P 中的第 i t h i^{th} ith 个像素特征, I j − I_j^- Ij 表示负样本集合 N \mathbb N N 中的第 j t h j^{th} jth 个像素特征, T k \bold T_k Tk 表示第 k t h k^{th} kth 个语言特征,然后有:
L C A L = L P 2 P + L P 2 T L P 2 P = − 1 ∣ P ∣ ∑ i ∣ P ∣ e I i + ⋅ I a v g + / τ 1 e I i + ⋅ I a v g + / τ 1 + ∑ j ∣ N ∣ e I i + ⋅ I j − / τ 1 + − 1 ∣ N ∣ ∑ j ∣ N ∣ e I j − ⋅ I a v g − / τ 1 e I j − ⋅ I a v g − / τ 1 + ∑ i ∣ P ∣ e I j − ⋅ I i + / τ 1 L P 2 T = − 1 ∣ P ∣ ∑ i ∣ P ∣ e I i + ⋅ T a v g / τ 2 e I i + ⋅ T a v g / τ 2 + ∑ j ∣ N ∣ e I i + ⋅ I j − / τ 2 \begin{aligned} \mathcal{L}_{\mathrm{CAL}}=\mathcal{L}_{\mathrm{P2P}}+\mathcal{L}_{\mathrm{P2T}}\\ \mathcal{L}_{\mathrm{P2P}}=-\frac{1}{|\mathcal{P}|}\sum_i^{|\mathcal{P}|}\frac{e^{\mathbf{I}_i^+ \cdot \mathbf{I}_{\mathbf{avg}}^+/\tau_1}}{e^{\mathbf{I}_i^+\cdot\mathbf{I}_{\mathbf{avg}}^+/\tau_1}+\sum_j^{|\mathcal{N}|}e^{\mathbf{I}_i^+\cdot\mathbf{I}_j^-/\tau_1}} +-\frac{1}{|\mathcal{N}|}\sum_j^{|\mathcal{N}|}\frac{e^{\mathbf{I}_j^-\cdot\mathbf{I}_{\mathrm{avg}}^-/\tau_1}}{e^{\mathbf{I}_j^-\cdot\mathbf{I}_{\mathrm{avg}}^-/\tau_1}+\sum_i^{|\mathcal{P}|}e^{\mathbf{I}_j^-\cdot\mathbf{I}_i^+/\tau_1}}\\ \mathcal{L}_{\mathrm{P2T}} = - \frac 1 { | \mathcal{P}|}\sum_i^{|\mathcal{P}|}\frac{e^{\mathbf{I}_i^+\cdot\mathbf{T}_{\mathrm{avg}}/\tau_2}}{e^{\mathbf{I}_i^+\cdot\mathbf{T}_{\mathrm{avg}}/\tau_2}+\sum_j^{|\mathcal{N}|}e^{\mathbf{I}_i^+\cdot\mathbf{I}_j^-/\tau_2}} \end{aligned} LCAL=LP2P+LP2TLP2P=P1iPeIi+Iavg+/τ1+jNeIi+Ij/τ1eIi+Iavg+/τ1+N1jNeIjIavg/τ1+iPeIjIi+/τ1eIjIavg/τ1LP2T=P1iPeIi+Tavg/τ2+jNeIi+Ij/τ2eIi+Tavg/τ2其中 I a v g + = 1 ∣ P ∣ ∑ i ∣ P ∣ I i + \mathbf{I}_\mathrm{avg}^+=\frac1{|\mathcal{P}|}\sum_i^{|\mathcal{P}|}\mathbf{I}_i^+ Iavg+=P1iPIi+ I a v g − = 1 ∣ N ∣ ∑ j ∣ N ∣ I j − \mathbf{I}_\mathrm{avg}^-=\frac1{|\mathcal{N}|}\sum_j^{|\mathcal{N}|}\mathbf{I}_j^- Iavg=N1jNIj 分别表示平均池化后正样本像素特征和负样本像素特征。 T a v g = proj ⁡ ( 1 M ∑ m M T k ) \mathbf{T}_{\mathrm{avg}}=\operatorname*{proj}(\frac1M\sum_{m}^{M}\mathbf{T}_{k}) Tavg=proj(M1mMTk) 为平均池化和线性投影后的词特征, τ 1 \tau_1 τ1 τ 2 \tau_2 τ2 为超参数。需要注意的是所有的特征在进行点乘前均经过 L2 归一化处理,但并未在上式中体现。

4.5 损失函数

  损失函数为下列 4 个不同损失的加权求和:
L = λ B C E L B C E + λ D i c e L D i c e + λ C A L L C A L + λ g r o u n d i n g L g r o u n d i n g , \begin{gathered} \mathcal{L}=\lambda_{\mathrm{BCE}}\mathcal{L}_{\mathrm{BCE}}+\lambda_{\mathrm{Dice}}\mathcal{L}_{\mathrm{Dice}}+ \lambda_\mathrm{CAL}\mathcal{L}_\mathrm{CAL}+\lambda_\mathrm{grounding}\mathcal{L}_\mathrm{grounding}, \end{gathered} L=λBCELBCE+λDiceLDice+λCALLCAL+λgroundingLgrounding,实验中 λ B C E = 2.0 \lambda_{\mathrm{BCE}}=2.0 λBCE=2.0 λ D i c e = 2.0 \lambda_{\mathrm{Dice}}=2.0 λDice=2.0 λ g r o u n d i n g = 1.0 \lambda_\mathrm{grounding}=1.0 λgrounding=1.0

五、实验

5.1 数据集及评估指标

  • 数据集:RefCOCO、RefCOCO+、GRef
  • 评估指标:overall intersection-over-union (oIoU)、mean intersection-overunion (mIoU)

5.2 主要结果

在这里插入图片描述
在这里插入图片描述

5.3 可视化

在这里插入图片描述

5.4 消融研究

  训练 10 个 epoch,输入图像尺寸 224 × 224 224\times 224 224×224。所有的消融实验执行在 RefCOCO 和 RefCOCO+ 数据集上。

  • RIS 性能的影响;
  • Mask 编码器的设计;
  • Mask Token 预测器的设计;
  • Mask Grounding 的统一效果。
  • CAM 的有效性
  • CAL 的有效性

在这里插入图片描述

语言-图像对齐的影响

在这里插入图片描述

MagNet 组件的兼容性

在这里插入图片描述

六、结论

  本文提出 Mask Grounding,基于周围的文本、视觉和分割信息,通过教导模型预测随机 mask 掉的文本 tokens,实验效果很好。为全面解决模态鸿沟,设计了一种跨模态对齐损失和一种辅助对齐模块。当一齐作用时,提出的 MagNet 实现了 SOTA 的性能。

写在后面

  这篇论文咋说呢,感觉就是那种顶会的边缘,创新点属于可拒可不拒的那种。最大的败笔还是论文的写作确实不咋地,没有一种连贯之感。另外,实验缺少了实验细节的介绍,放在补充材料中吗?

相关推荐

  1. GPT系列 论文阅读笔记

    2023-12-25 06:12:04       17 阅读
  2. 论文阅读笔记】清单

    2023-12-25 06:12:04       51 阅读
  3. PointMixer论文阅读笔记

    2023-12-25 06:12:04       31 阅读
  4. BERT 论文阅读笔记

    2023-12-25 06:12:04       30 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-25 06:12:04       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-25 06:12:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-25 06:12:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-25 06:12:04       18 阅读

热门阅读

  1. Spring DefaultListableBeanFactory源码分析

    2023-12-25 06:12:04       41 阅读
  2. Python jupyter notebook 自定义魔术方法

    2023-12-25 06:12:04       29 阅读
  3. conda镜像源,Jupyter内核配置

    2023-12-25 06:12:04       37 阅读
  4. EtherCAT主站SOEM -- 11 -- EtherCAT从站 XML 文件解析

    2023-12-25 06:12:04       35 阅读
  5. 【PostgreSQL表增加/删除字段是否会重写表】

    2023-12-25 06:12:04       34 阅读
  6. C#编程简单应用程序批量修改文件名2.0

    2023-12-25 06:12:04       42 阅读
  7. Node.js教程-mysql模块

    2023-12-25 06:12:04       36 阅读
  8. SQL面试题挑战06:互相关注的人

    2023-12-25 06:12:04       32 阅读
  9. 客户需求分析常用的ChatGPT通用提示词模板

    2023-12-25 06:12:04       38 阅读