DiT结构原理代码详解

1. 背景知识

1.1 Classifier Guidance 和 Classifier-free Guidance

2. 算法详解

1. 模块化

2. DiT模块

2.1 上下文条件(In-context conditioning)

2.2 交叉注意力块(Cross-Attention)

2.3 自适应层归一化块(Adaptive Layer Normalization,AdaLN)

2.4 adaZero-Block

3. 总结

前言

在Sora[1]的技术报告中,作者指出Sora是一个Diffusion Transformer。这个Diffusion Transformer便是我们这里将要介绍的DiT[2]。相较于我们之前介绍的LDM[3],DiTs也是作用在潜空间,它最大的改进是将U-Net的CNN替换为了Transformer。同时DiT是一个可扩展的架构,而且样本质量和网络复杂度存在这强烈的相关性。

与DiTs最密切的算法是LDM,LDM最大的特点是将DDPM[4]的计算空间从图像空间改到了潜空间。而图像空间和潜空间的互相转换则通过VQ-VAE[5]的编码器和解码器。LDM采用了一个CNN和交叉注意力的混合结构,其中CNN用于对图像进行编码,交叉注意力用于将条件特征融入到模型中。而DiT则是将LDM的CNN完全替换为了Transformer。

1. 背景知识

1.1 Classifier Guidance 和 Classifier-free Guidance

Classifier Guidance是OpenAI在《Diffusion models beat gans on image synthesis》[6]中提出的思想,它使得扩散模型可以按照指定的类生成图像。Classifier Guidance可以通过Score function来解释,我们可以使用贝叶斯定理对条件生成概率进行分解,如式(1)。从中可以看出Classifier Guidance的条件生成只需要添加一个额外的Classifier梯度即可。(1)∇��log⁡�(��∣�)=∇��log⁡(�(��)�(�∣��)�(�))=∇��log⁡�(��)+∇��log⁡�(�∣��)−∇��log⁡�(�)=∇��log⁡�(��)⏟unconditional score +∇��log⁡�(�∣��)⏟classifier gradient 我们可以添加一个权重项 � 来调整来灵活的控制unconditional score和classifier gradient的占比,如式(2)。

(2)∇��log⁡�(��∣�)=∇��log⁡�(�∣��)+�∇��log⁡�(��)

从式(1)中我们也可以看出Classifer Guidance的几个问题,首先因为需要训练Classifier梯度项,这相当于要额外训练一个根据噪声得到类别标签的分类器,显然是一个非常困难的任务。此外这个分类器的结果反映到了生成梯度上,无疑会对生成效果产生一定程度的影响。

为了解决这个问题,Google提出了Classifier-free Guidance方案[7]。Classifier-free guidance的核心是通过一个隐式分类器来代替显式分类器,使得生成过程不再依赖这个显式的分类器,从而解决了Classifier Guidance的这几个问题。具体来讲,我们对式(1)进行移项,可得:(3)∇��log⁡�(�∣��)=∇��log⁡�(��∣�)−∇��log⁡�(��)将式(3)代入到式(2)中,我们有(4)∇��log⁡�(��∣�)=∇��log⁡�(��)+�(∇��log⁡�(��∣�)−∇log⁡�(��))=∇��log⁡�(��)+�∇log⁡�(��∣�)−�∇��log⁡�(��)=�∇��log⁡�(��∣�)⏟conditional score +(1−�)∇��log⁡�(��)⏟unconditional score 根据式(4),我们的分类器由conditional score和unconditional score两部分组成。在训练时,我们可以通过一个对标签的Dropout来将标签以一定概率置空,从而实现了两个score在同一个模型中的训练。

2. 算法详解

和LDM一样,DiT也是一个作用在潜空间上的模型,因此它也采用了一个VQ-VAE将图像编码到潜空间。这里我们主要介绍DiT在潜空间上的扩散过程做的改进,它的结构如图1所示,跟论文一样,我们也是按照DiTs的前向顺序介绍这个图。DiT的具体实现见FAIR的开源代码[8],下面我们结合代码来具体介绍它们。

图1:DiTs的网络结构

首先,我们看一下DiT的forward函数的实现。

class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
    ):
      super().__init__()
      self.learn_sigma = learn_sigma
      self.in_channels = in_channels
      self.out_channels = in_channels * 2 if learn_sigma else in_channels
      self.patch_size = patch_size
      self.num_heads = num_heads

      self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
      self.t_embedder = TimestepEmbedder(hidden_size)
      self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
      num_patches = self.x_embedder.num_patches
      # Will use fixed sin-cos embedding:
      self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

      self.blocks = nn.ModuleList([
          DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
      ])
      self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
      self.initialize_weights()
      
def forward(self, x, t, y):
  """
  Forward pass of DiT.
  x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
  t: (N,) tensor of diffusion timesteps
  y: (N,) tensor of class labels
  """
  x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
  t = self.t_embedder(t)                   # (N, D)
  y = self.y_embedder(y, self.training)    # (N, D)
  c = t + y                                # (N, D)
  for block in self.blocks:
      x = block(x, c)                      # (N, T, D)
  x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
  x = self.unpatchify(x)                   # (N, out_channels, H, W)
  return x
  • 第45行的x_embedder是import自timm.models.vision_transformer,DiT中将其叫做模块化(Patchify);
  • 第46行是对扩散模型的时间片t进行编码,使用的是Transformer中介绍的不可学习的绝对位置编码;
  • 第47行是计算标签特征,使用了Classifier-free guidance的思想,具体细节我会在后面进行介绍;
  • 第48行是对条件特征和时间片特征进行合并;
  • 第49-51行是对特征进行加工,使用了DiTBlock类和FinalLayer类,接下来我也回详细介绍。
  • 最后第52行的unpatchify是将一维序列还原为二维潜空间。

1. 模块化

模块化(Patchify)的作用是将VAE编码的二维特征转化为一维序列。这里有两个细节:

  • 因为DiT去掉了CNN,因此需要添加位置编码,DiT采用的是ViT中使用的同样是不可学习的绝对位置编码(sin/cos);
  • p是一个可调的超参数,表示每个patch的大小,通过调整 � 我们可以控制序列 � 的长度。

图2:DiT的模块化部分

模块化的实现继承自timm.models.vision_transformerPatchEmbed函数。

 from timm.models.vision_transformer import PatchEmbed

2. DiT模块

DiT模块有两个作用,一个是对特征进行加工,另一个是融合图像的特征和不同模态的条件特征。DiT中探索了四个不同的模块:

2.1 上下文条件(In-context conditioning)

如图1.(d)所示,基于上下文条件的DiT直接将条件特征附加到输入序列中,这个操作类似于在输入序列中添加了一个[CLS] token。DiT的条件编码是通过LabelEmbedder类实现的,具体实现见下面代码片段。

class LabelEmbedder(nn.Module):
  """
  Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
  """
  def __init__(self, num_classes, hidden_size, dropout_prob):
      super().__init__()
      use_cfg_embedding = dropout_prob > 0
      self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
      self.num_classes = num_classes
      self.dropout_prob = dropout_prob

  def token_drop(self, labels, force_drop_ids=None):
      """
      Drops labels to enable classifier-free guidance.
      """
      if force_drop_ids is None:
          drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
      else:
          drop_ids = force_drop_ids == 1
      labels = torch.where(drop_ids, self.num_classes, labels)
      return labels

  def forward(self, labels, train, force_drop_ids=None):
      use_dropout = self.dropout_prob > 0
      if (train and use_dropout) or (force_drop_ids is not None):
          labels = self.token_drop(labels, force_drop_ids)
      embeddings = self.embedding_table(labels)
      return embeddings

从上面的代码中我们可以看出,LabelEmbdder的核心计算是通过一个embedding层对类别标签进行编码。注意DiT对标签进行了dropoout。如第1.1节介绍的,label dropout的作用是为了classifier-free guidance。

2.2 交叉注意力块(Cross-Attention)

如图1.(c)所示,我们将时间片特征t和条件特征c拼成一个长度为2的序列(图1.(a))。然后将这个序列输入到一个多头交叉注意力模块中和图像特征进行融合。关于DiT交叉注意力的具体实现参照我的LDM一文。

2.3 自适应层归一化块(Adaptive Layer Normalization,AdaLN)

DiT在模型中尝试了AdaLN[9],AdaLN的核心思想是使用模型中的一些信息学习 � 和 � 两个归一化参数。DiT是使用时间片特征 � 和条件特征 � 相加后的结果计算这两个参数(也就是第一个代码片段中的变量c)。此外,DiT在每个残差连接之后还接了一个回归缩放参数 � ,它同样是由变量c计算得到。接下来我们根据下面的代码片段详细介绍DiT的具体结构。

from timm.models.vision_transformer import Attention, Mlp
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

首先我们观察forword函数的第1行,它使用adaLN_modulation计算了6个变量shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp,这6个变量分别对应了多头自注意力的LN的归一化参数与缩放参数(图1.(b)的 �1 , �1 , �1 )以及MLP的LN的归一化参数与缩放参数(图1.(b)的 �2 , �2 , �2 )。

forward函数的第二行是计算多头自注意力以及它的LN,它首先计算的是modulate函数,实现方式如下面代码片段,即相当于使用学习好的\beta和\gamma对LN进行归一化。接下来再计算的注意力模块,计算方式和Transformer相同。最后在通过乘以gate_msa对注意力计算的结果进行缩放。

 def modulate(x, shift, scale):
     return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

接下来forward是计算MLP部分,除了把attn函数换为mlp函数外,它和第二行基本相同,这里不再赘述。

当对特征加工完之后,我们需要使用FinalLayer模块来将特征还原为与输入相同的尺寸。它是由一个AdaLN和一个线性层组成,具体实现见下面代码片段。

class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

2.4 adaZero-Block

之前有研究表明使用0初始化网络中的某些参数可以加速模型的训练。例如我们可以将残差网络的残差部分初始化为0,这样初始化后的残差块相当于一个单位映射,可以直接将上一层的特征透传给下一层。我们也可以将BN的归一化因子 � 初始化为0来加速模型的训练[10]。DiT对模型参数的初始化都是在initialize_weights函数中实现的,它的作用是对DiT中的变量进行初始化,我们具体看一下这个函数。

def initialize_weights(self):
    # Initialize transformer layers:
    def _basic_init(module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
    self.apply(_basic_init)

    # Initialize (and freeze) pos_embed by sin-cos embedding:
    pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
    self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

    # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
    w = self.x_embedder.proj.weight.data
    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
    nn.init.constant_(self.x_embedder.proj.bias, 0)

    # Initialize label embedding table:
    nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

    # Initialize timestep embedding MLP:
    nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
    nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

    # Zero-out adaLN modulation layers in DiT blocks:
    for block in self.blocks:
        nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

    # Zero-out output layers:
    nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
    nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
    nn.init.constant_(self.final_layer.linear.weight, 0)
    nn.init.constant_(self.final_layer.linear.bias, 0)
  • 首先对于图像的位置编码,它使用了ViT等模型中使用的二维相对位置编码;
  • 对于DiTBlock涉及的adaLN中计算归一化参数和缩放参数,均使用了0初始化;
  • 对于FinaLayer的adaLN和线性层,也是使用0初始化;
  • 剩余的其它参数,则是使用常见的正态分部初始化或者xavier初始化。

作者对上面四种模块进行了对照实验,并使用了FID(Fréchet inception distance)指标对四个模块进行了效果评估。FID是计算真实图像和相似图像之间距离的的一种度量方式。他根据Inception v3分类模型计算得到的。分数越低则代表两组图像越相似,FID在最佳的情况下值是0,表示两组图完全相同。从实验结果我们可以看出adaLN-Zero还是有比较显著的优势的。

图3:DiT在四个模块上的对照实验

3. 总结

DiT最大的创新点是将Transformer引入到了扩散模型中,并完全抛弃了CNN。但是DiT并不是第一个引入Transformer的,例如之前的U-ViT[11],UniDiffuser[12]等都尝试了将Transformer引入到扩散模型中。至于对效果提升同样非常有帮助的adaLN,zero-初始化,classifier-free guidance等则是已有的工作了。DiT引入条件信息还是仅仅局限在样本类别,接下来我们有必要学习一些引入文本序列作为条件的生成模型了。

参考

  1. ^https://openai.com/sora Sora
  2. ^Peebles, William, and Saining Xie. "Scalable diffusion models with transformers." *Proceedings of the IEEE/CVF International Conference on Computer Vision*. 2023.
  3. ^Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." *Proceedings of the IEEE/CVF conference on computer vision and pattern recognition*. 2022.
  4. ^Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Advances in Neural Information Processing Systems* 33 (2020): 6840-6851.
  5. ^Van Den Oord A, Vinyals O. Neural discrete representation learning[J]. Advances in neural information processing systems, 2017, 30.
  6. ^Dhariwal, Prafulla, and Alexander Nichol. "Diffusion models beat gans on image synthesis." *Advances in neural information processing systems* 34 (2021): 8780-8794.
  7. ^Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." *arXiv preprint arXiv:2207.12598* (2022).
  8. ^https://github.com/facebookresearch/DiT/blob/main GitHub - facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
  9. ^Ethan Perez, Florian Strub, Harm De Vries, Vincent Du- moulin, and Aaron Courville. Film: Visual reasoning with a general conditioning layer. In AAAI, 2018. 2, 5
  10. ^Goyal, Priya, et al. "Accurate, large minibatch sgd: Training imagenet in 1 hour." *arXiv preprint arXiv:1706.02677* (2017).
  11. ^Bao, Fan, et al. "All are worth words: A vit backbone for diffusion models." *Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition*. 2023.
  12. ^Bao, Fan, et al. "One transformer fits all distributions in multi-modal diffusion at scale." *arXiv preprint arXiv:2303.06555* (2023).

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-03-10 16:58:01       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-10 16:58:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-10 16:58:01       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-10 16:58:01       20 阅读

热门阅读

  1. C. Messenger in MAC - 堆优化枚举

    2024-03-10 16:58:01       23 阅读
  2. vue-treeselect的下拉列表中的字体样式修改

    2024-03-10 16:58:01       23 阅读
  3. Spring-Cloud-Gateway Filter详细配置说明

    2024-03-10 16:58:01       23 阅读
  4. 53. 最大子数组和(力扣LeetCode)

    2024-03-10 16:58:01       25 阅读
  5. 阿里巴巴商家爬虫工具 1688采集软件使用教程

    2024-03-10 16:58:01       21 阅读
  6. hadoop 总结

    2024-03-10 16:58:01       28 阅读
  7. 解决:Glide 在回调中再次加载图片报错

    2024-03-10 16:58:01       21 阅读
  8. sql返回数据怎么添加索引

    2024-03-10 16:58:01       17 阅读
  9. 速盾网络:cdn加速技术和云计算的区别

    2024-03-10 16:58:01       22 阅读
  10. adb shell pm 查询设备应用

    2024-03-10 16:58:01       24 阅读