论文阅读:Forget-Me-Not: Learning to Forget in Text-to-Image Diffusion Models

Forget-Me-Not: Learning to Forget in Text-to-Image Diffusion Models

论文链接
代码链接
这篇文章提出了Forget-Me-Not (FMN),用来消除文生图扩散模型中的特定内容。FMN的流程图如下:
framework
可以看到,FMN的损失函数是最小化要消除的概念对应的attention map的 L 2 L_2 L2范数。这里需要补充一些关于diffusion model的知识。
首先,以Stable Diffusion为代表的模型使用U-Net对图片的低维嵌入进行建模。文本条件在被CLIP的text encoder编码为文本嵌入后,通过U-Net中的cross-attention layers输入到U-Net中。cross-attention层的具体映射过程是一个QKV (Query-Key-
Value)结构,如上图的中间所示。其中,Q代表图片的视觉信息,K和V都是文本嵌入经过线性层后计算得到的( k i = W k c i   a n d   v i = W v c i k_i = W_kc_i~and~v_i = W_vc_i ki=Wkci and vi=Wvci)。而FMN损失函数中的attention map的计算过程如下:
attention map
然而,attention map还不是cross attention层的输出,其输出通过以下公式计算:
cross-attention output
上面两个公式,也就是图3中间方框中的内容,可以用下面的公式概括,
cross-attention
从FMN的源码中可以看到对应的部分如下:

class AttnController:
        def __init__(self) -> None:
            self.attn_probs = []
            self.logs = []
        def __call__(self, attn_prob, m_name) -> Any:
            bs, _ = self.concept_positions.shape
            head_num = attn_prob.shape[0] // bs
            target_attns = attn_prob.masked_select(self.concept_positions[:,None,:].repeat(head_num, 1, 1)).reshape(-1, self.concept_positions[0].sum())
            self.attn_probs.append(target_attns)
            self.logs.append(m_name)
        def set_concept_positions(self, concept_positions):
            self.concept_positions = concept_positions
        def loss(self):
            return torch.cat(self.attn_probs).norm()
        def zero_attn_probs(self):
            self.attn_probs = []
            self.logs = []
            self.concept_positions = None

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-25 03:44:02       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-25 03:44:02       101 阅读
  3. 在Django里面运行非项目文件

    2024-03-25 03:44:02       82 阅读
  4. Python语言-面向对象

    2024-03-25 03:44:02       91 阅读

热门阅读

  1. 力扣hot100题解(python版74-80题)

    2024-03-25 03:44:02       44 阅读
  2. vue的history路由实现形式

    2024-03-25 03:44:02       40 阅读
  3. 关于Al大规模学习

    2024-03-25 03:44:02       40 阅读
  4. MYSQL远程登录权限设置

    2024-03-25 03:44:02       40 阅读
  5. 【CSP试题回顾】202209-1-如此编码(优化)

    2024-03-25 03:44:02       43 阅读
  6. python笔记基础--类(6)

    2024-03-25 03:44:02       36 阅读
  7. Git 的 cherry-pick

    2024-03-25 03:44:02       30 阅读
  8. LeetCode热题Hot100-两数相加

    2024-03-25 03:44:02       42 阅读
  9. LeetCode第二天(628. 三个数的最大乘积)

    2024-03-25 03:44:02       44 阅读
  10. 设计模式之观察者模式

    2024-03-25 03:44:02       47 阅读
  11. C++异常处理

    2024-03-25 03:44:02       44 阅读
  12. c++统计字符出现次数

    2024-03-25 03:44:02       39 阅读
  13. 字母在字符串中的百分比

    2024-03-25 03:44:02       40 阅读