SwiftBrush算法与代码解读

个人理解:一步去噪是利用教师模型强力去噪能力去校正学生模型(Input:高斯白噪声+prompt Output:去噪后的样本)的参数,随机采样 t t t可以覆盖各个阶段,从而使学生模型学到教师模型在不同噪声级别上的去噪能力。最终实现一步去噪。


Algorithm 1 SwiftBrush Distillation


1: Require: a pretrained text-to-image teacher ϵ ψ \epsilon_\psi ϵψ, a
LoRA teacher ϵ ϕ \epsilon_\phi ϵϕ, a student model f θ f_\theta fθ, two learning rates
η 1 \eta_1 η1 and η 2 \eta_2 η2, a weighting function ω \omega ω, a prompts dataset
Y Y Y, the maximum number of time steps T T T and the noise
schedule { ( α t , σ t ) } t = 1 T \{(\alpha_t, \sigma_t)\}_{t=1}^T {(αt,σt)}t=1T of the teacher model
2: Initialize: ϕ ← ψ \phi \leftarrow \psi ϕψ, θ ← ψ \theta \leftarrow \psi θψ
3: while not converged do
4: \qquad Sample input noise z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) zN(0,I)
5: \qquad Sample text caption input y ∼ Y y \sim Y yY
6: \qquad Compute student output x ^ 0 = f θ ( z , y ) \hat{x}_0 = f_\theta(z, y) x^0=fθ(z,y)
7: \qquad Sample timestep t ∼ U ( 0.02 T , 0.98 T ) t \sim \mathcal{U}(0.02T, 0.98T) tU(0.02T,0.98T)
8: \qquad Sample added noise ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)
9: \qquad Compute noisy sample x ^ t = α t x ^ 0 + σ t ϵ \hat{x}_t = \alpha_t \hat{x}_0 + \sigma_t \epsilon x^t=αtx^0+σtϵ
10: \qquad θ ← θ − η 1 [ ω ( t ) ( ϵ ψ ( x ^ t , t , y ) − ϵ ϕ ( x ^ t , t , y ) ) ∂ x ^ 0 ∂ θ ] \theta \leftarrow \theta - \eta_1 \left[ \omega(t) \left( \epsilon_\psi(\hat{x}_t, t, y) - \epsilon_\phi(\hat{x}_t, t, y) \right) \frac{\partial \hat{x}_0}{\partial \theta} \right] θθη1[ω(t)(ϵψ(x^t,t,y)ϵϕ(x^t,t,y))θx^0]
11: \qquad Sample timestep t ′ ∼ U ( 0 , T ) t' \sim \mathcal{U}(0, T) tU(0,T)
12: \qquad Sample added noise ϵ ′ ∼ N ( 0 , I ) \epsilon' \sim \mathcal{N}(0, I) ϵN(0,I)
13: \qquad Compute noisy sample x ^ t ′ = α t ′ x ^ 0 + σ t ′ ϵ ′ \hat{x}_{t'} = \alpha_{t'} \hat{x}_0 + \sigma_{t'} \epsilon' x^t=αtx^0+σtϵ
14: \qquad ϕ ← ϕ − η 2 ∇ ϕ ∥ ϵ ϕ ( x ^ t ′ , t ′ , y ) − ϵ ′ ∥ 2 \phi \leftarrow \phi - \eta_2 \nabla_\phi \|\epsilon_\phi(\hat{x}_{t'}, t', y) - \epsilon'\|^2 ϕϕη2ϕϵϕ(x^t,t,y)ϵ2
15: end while
16: return trained student model f θ f_\theta fθ


teacher_lora ( ϵ ϕ \epsilon_\phi ϵϕ) 和 unet ( ϵ θ \epsilon_\theta ϵθ) 都加载的Unet预训练权重 ( ψ \psi ψ),对应于初始化

# Get alphas cummulative product
    alphas_cumprod = noise_scheduler.alphas_cumprod
    alphas_cumprod = alphas_cumprod.to(accelerator.device, dtype=weight_dtype)

    for epoch in range(first_epoch, args.num_train_epochs):
        train_loss_vsd = train_loss_lora = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet, teacher_lora):
                bsz = batch["prompt_embeds"].shape[0] #获取batch size

                # Sample input noise
                input_shape = (bsz, 4, args.resolution // 8, args.resolution // 8)
                input_noise = torch.randn(*input_shape, dtype=weight_dtype, device=accelerator.device)

                # Predict the noise residual
                # 将批次中的提示嵌入向量转换到使用的设备上(例如GPU),并确保数据类型与模型权重的数据类型一致。
                prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) 
                # 创建一个与批次大小相同的提示嵌入向量的副本,这些副本将用于生成无条件的噪声样本。
                # `null_dict` 包含空文本的嵌入向量,通过重复它来模拟没有文本提示时的条件。
                prompt_null_embeds = (
                    null_dict["prompt_embeds"].repeat(bsz, 1, 1).to(accelerator.device, dtype=weight_dtype)
                )

                # Get predicted original sampls from unet
                # x_0
                pred_original_samples = predict_original(unet, noise_scheduler, input_noise, prompt_embeds).to(
                    dtype=weight_dtype
                )

                # VSD loss

                # Sample noise that we'll add to the predicted original samples
                # Sample added noise
                noise = torch.randn_like(pred_original_samples)

                # Sample a random timestep for each image
                timesteps_range = torch.tensor([0.02, 0.981]) * noise_scheduler.config.num_train_timesteps
                timesteps = torch.randint(*timesteps_range.long(), (bsz,), device=accelerator.device).long()

                # Add noise to the predicted original samples according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                # Compute noisy sample
                # x_t
                noisy_samples = noise_scheduler.add_noise(pred_original_samples, noise, timesteps)

                # Prepare outputs from the teacher
                # 不计算梯度的情况下,获取教师模型(teacher model)的有条件和无条件预测。
                with torch.no_grad():
                	# 禁用LoRA adapters以获取基础模型的预测。
                    accelerator.unwrap_model(teacher_lora).disable_adapters()
                    teacher_pred_cond = teacher_lora(noisy_samples, timesteps, prompt_embeds).sample
                    teacher_pred_uncond = teacher_lora(noisy_samples, timesteps, prompt_null_embeds).sample
					# 启用LoRA适配器以获取使用LoRA方法调整后的模型预测。
                    accelerator.unwrap_model(teacher_lora).enable_adapters()
                    lora_pred_cond = teacher_lora(noisy_samples, timesteps, prompt_embeds).sample
                    lora_pred_uncond = teacher_lora(noisy_samples, timesteps, prompt_null_embeds).sample

                    # Apply classifier-free guidance to the teacher prediction
                    teacher_pred = teacher_pred_uncond + args.guidance_scale * (
                        teacher_pred_cond - teacher_pred_uncond
                    )
                    # 通过计算 lora_pred_cond - lora_pred_uncond,我们得到了一个噪声残差,这个残差表示在给定文本提示的条件下,预测的图像与没有文本提示时的预测图像之间的差异。
                    # 这种差异随后被用于引导模型生成更符合文本提示的图像。
                    lora_pred = lora_pred_uncond + args.guidance_scale * (lora_pred_cond - lora_pred_uncond)

                # Compute the score gradient for updating the model
                sigma_t = ((1 - alphas_cumprod[timesteps]) ** 0.5).view(-1, 1, 1, 1)
                score_gradient = torch.nan_to_num(sigma_t**2 * (teacher_pred - lora_pred))

                # Compute the VSD loss for the model
                target = (pred_original_samples - score_gradient).detach()
                loss_vsd = 0.5 * F.mse_loss(pred_original_samples.float(), target.float(), reduction="mean")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss_vsd = accelerator.gather(loss_vsd.repeat(args.train_batch_size)).mean()
                train_loss_vsd += avg_loss_vsd.item() / args.gradient_accumulation_steps

                # Backpropagate for the unet
                accelerator.backward(loss_vsd)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=args.set_grads_to_none)

                # LoRA loss

                # Sample noise that we'll add to the predicted original samples
                # Sample added noise
                noise = torch.randn_like(pred_original_samples.detach())

                # Sample a random timestep for each image
                # Sample timestep
                timesteps_range = torch.tensor([0, 1]) * noise_scheduler.config.num_train_timesteps
                timesteps = torch.randint(*timesteps_range.long(), (bsz,), device=accelerator.device).long()

                # Add noise to the predicted original samples according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                # x_t'
                noisy_samples = noise_scheduler.add_noise(pred_original_samples.detach(), noise, timesteps)

                # Compute output for updating the LoRA teacher
                # 为了增加模型的鲁棒性,使其在训练时能够处理没有文本提示的情况
                encoder_hidden_states = prompt_null_embeds if random.random() < 0.1 else prompt_embeds
                # \epsilon_\phi
                lora_pred = teacher_lora(noisy_samples, timesteps, encoder_hidden_states).sample

                alpha_t = (alphas_cumprod[timesteps] ** 0.5).view(-1, 1, 1, 1)
                lora_pred = alpha_t * lora_pred
                target = alpha_t * noise # 通过乘以 alpha_t,我们可以将噪声和预测值调整到同一尺度,使得损失计算在不同时间步长上是可比的。

                # Compute the loss for LoRA teacher
                loss_lora = F.mse_loss(lora_pred.float(), target.float(), reduction="mean")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss_lora = accelerator.gather(loss_lora.repeat(args.train_batch_size)).mean()
                train_loss_lora += avg_loss_lora.item() / args.gradient_accumulation_steps

                # Backpropagate for the LoRA teacher
                accelerator.backward(loss_lora)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(teacher_lora.parameters(), args.max_grad_norm)
                optimizer_lora.step()
                lr_scheduler_lora.step()
                optimizer_lora.zero_grad(set_to_none=args.set_grads_to_none)

分数梯度(score gradient)的计算在VSD损失中起到了至关重要的作用。为了更好地理解分数梯度的计算,我们需要回顾扩散模型的核心思想以及分数匹配的概念。

背景:扩散模型与分数匹配

在扩散模型中,数据逐步加入噪声,最终变成纯噪声。反向过程则试图从纯噪声中逐步去噪,恢复原始数据。分数匹配(score matching)是扩散模型中的一个关键技术,它试图学习数据的分数函数(score function),即数据概率分布的对数梯度。

分数函数与分数梯度

分数函数 ∇ x log ⁡ p ( x ) \nabla_x \log p(x) xlogp(x) 表示数据分布的对数概率密度的梯度,反映了在每个点上,数据分布如何变化。分数梯度的学习目标是逼近这个分数函数。

分数梯度的计算过程

在代码中,分数梯度通过以下公式计算:
score_gradient = σ t 2 ⋅ ( ϵ teacher − ϵ lora ) \text{score\_gradient} = \sigma_t^2 \cdot (\epsilon_{\text{teacher}} - \epsilon_{\text{lora}}) score_gradient=σt2(ϵteacherϵlora)
这里的每一部分都有其特定的含义。

1. σ t 2 \sigma_t^2 σt2 的作用

σ t \sigma_t σt 表示在时间步 t t t 时刻的噪声方差。 σ t 2 \sigma_t^2 σt2 σ t \sigma_t σt 的平方,它调整了噪声的幅度。噪声方差的平方乘以噪声预测残差,确保了在不同时间步上,噪声对样本的影响是适当的。

2. ϵ teacher − ϵ lora \epsilon_{\text{teacher}} - \epsilon_{\text{lora}} ϵteacherϵlora

这里, ϵ teacher \epsilon_{\text{teacher}} ϵteacher ϵ lora \epsilon_{\text{lora}} ϵlora 是教师模型和LoRA教师模型对噪声的预测。这两个模型分别给出了带噪样本的预测值,它们的差异表示了在给定条件下的噪声残差。

通过计算 ϵ teacher \epsilon_{\text{teacher}} ϵteacher ϵ lora \epsilon_{\text{lora}} ϵlora 的差异,我们得到了一个估计的分数梯度。这个分数梯度反映了在当前噪声水平下,如何调整样本来更接近真实数据分布。

3. score_gradient \text{score\_gradient} score_gradient 的作用

分数梯度用于计算目标样本,它通过调整预测的样本,使得生成的样本更接近无噪声的原始数据。

数学公式解释

给定一个带噪样本 x ^ t \hat{x}_t x^t,分数梯度的计算公式如下:
∇ x ^ t log ⁡ p ( x ^ t ) ≈ σ t 2 ⋅ ( ϵ teacher − ϵ lora ) \nabla_{\hat{x}_t} \log p(\hat{x}_t) \approx \sigma_t^2 \cdot (\epsilon_{\text{teacher}} - \epsilon_{\text{lora}}) x^tlogp(x^t)σt2(ϵteacherϵlora)
这里, σ t 2 \sigma_t^2 σt2 调整了噪声预测的幅度,而 ϵ teacher − ϵ lora \epsilon_{\text{teacher}} - \epsilon_{\text{lora}} ϵteacherϵlora 提供了噪声预测的残差。


这两行代码是用于计算VSD(Variance-Scaled Distillation)损失的核心部分。为了理解它们,我们需要拆解并分析每个步骤背后的数学原理和动机。

1. target = (pred_original_samples - score_gradient).detach()

背景

在VSD损失中,我们希望引导模型学习一个更精确的去噪过程。为此,我们需要生成一个目标值(target),使得模型的预测能够更接近真实的无噪声样本。

分步解释
  • pred_original_samples:这是通过模型(例如,UNet)预测的无噪声样本,即去噪后的样本。
  • score_gradient:这是从教师模型(teacher model)和LoRA模型(LoRA model)的预测中计算出的分数梯度(score gradient)。这个梯度表示在当前噪声水平下,如何调整样本来更接近真实数据分布。
  • pred_original_samples - score_gradient:通过从模型的预测样本中减去分数梯度,我们得到了一个调整后的样本,这个样本更接近于我们期望的无噪声目标。这一步的动机是使用教师模型的知识来校正学生模型的预测。
  • .detach():这个操作会从计算图中分离出目标值,防止在计算损失时传播梯度。这意味着在后续的反向传播过程中,不会计算与目标值相关的梯度,确保目标值是一个静态的参考点。

数学上,我们希望目标值 x ^ 0 \mathbf{\hat{x}_0} x^0 能够更好地反映无噪声的样本:
x ^ 0 = x 0 − ∇ x ^ t log ⁡ p ( x ^ t ) \mathbf{\hat{x}_0} = \mathbf{x_0} - \nabla_{\mathbf{\hat{x}_t}} \log p(\mathbf{\hat{x}_t}) x^0=x0x^tlogp(x^t)
其中, ∇ x ^ t log ⁡ p ( x ^ t ) \nabla_{\mathbf{\hat{x}_t}} \log p(\mathbf{\hat{x}_t}) x^tlogp(x^t) 是通过分数梯度近似的对数概率梯度。

2. loss_vsd = 0.5 * F.mse_loss(pred_original_samples.float(), target.float(), reduction="mean")

背景

我们使用均方误差损失(MSE loss)来衡量模型预测的样本与目标值之间的差异。通过最小化这个差异,我们希望模型能够生成更精确的去噪样本。

分步解释
  • F.mse_loss(pred_original_samples.float(), target.float(), reduction="mean"):计算模型预测样本和目标样本之间的均方误差(MSE)。float() 确保数据类型一致,reduction="mean" 表示对所有样本的损失取平均值。
  • 0.5 *:缩放损失值。这里的0.5是一个常数因子,用于平衡损失的尺度。这个因子可以根据具体的应用场景进行调整。

数学上,MSE损失计算公式为:
MSE = 1 N ∑ i = 1 N ( x ^ 0 ( i ) − x 0 ( i ) ) 2 \text{MSE} = \frac{1}{N} \sum_{i=1}^N (\mathbf{\hat{x}_0^{(i)}} - \mathbf{x_0^{(i)}})^2 MSE=N1i=1N(x^0(i)x0(i))2
其中, x ^ 0 ( i ) \mathbf{\hat{x}_0^{(i)}} x^0(i) 是第 i i i 个样本的预测值, x 0 ( i ) \mathbf{x_0^{(i)}} x0(i) 是第 i i i 个样本的目标值。

通过最小化这个损失,我们希望模型能够学习到更精确的去噪过程,使得生成的样本更接近于无噪声的原始样本。


为何能实现一步去噪:
这个算法实现一步去噪的核心在于通过两个教师模型(预训练的教师模型 ϵ ψ \epsilon_\psi ϵψ 和 LoRA 教师模型 ϵ ϕ \epsilon_\phi ϵϕ)对学生模型 f θ f_\theta fθ 进行指导。算法的目标是通过学习在不同时间步长下的噪声预测,从而在一个前向传播步骤中生成去噪后的图像。

算法步骤解析

  1. 初始化

    • 将学生模型 f θ f_\theta fθ 和 LoRA 教师模型 ϵ ϕ \epsilon_\phi ϵϕ 都初始化为预训练的教师模型 ϵ ψ \epsilon_\psi ϵψ
  2. 训练过程

    • 噪声和文本提示采样
      • 在每次迭代中,从标准正态分布中采样输入噪声 z z z 和添加噪声 ϵ \epsilon ϵ,从数据集 Y Y Y 中采样文本提示 y y y
    • 计算学生模型的输出
      • 使用学生模型 f θ f_\theta fθ 生成初始去噪图像 x ^ 0 = f θ ( z , y ) \hat{x}_0 = f_\theta(z, y) x^0=fθ(z,y)
    • 生成带噪图像样本
      • 根据噪声计划 { ( α t , σ t ) } \{(\alpha_t, \sigma_t)\} {(αt,σt)} 在一个随机时间步 t t t 生成带噪图像 x ^ t = α t x ^ 0 + σ t ϵ \hat{x}_t = \alpha_t \hat{x}_0 + \sigma_t \epsilon x^t=αtx^0+σtϵ
    • 更新学生模型
      • 计算两个教师模型在带噪图像 x ^ t \hat{x}_t x^t 下的预测噪声,并根据其差异指导学生模型的更新:
        θ ← θ − η 1 [ ω ( t ) ( ϵ ψ ( x ^ t , t , y ) − ϵ ϕ ( x ^ t , t , y ) ) ∂ x ^ 0 ∂ θ ] \theta \leftarrow \theta - \eta_1 \left[ \omega(t) \left( \epsilon_\psi(\hat{x}_t, t, y) - \epsilon_\phi(\hat{x}_t, t, y) \right) \frac{\partial \hat{x}_0}{\partial \theta} \right] θθη1[ω(t)(ϵψ(x^t,t,y)ϵϕ(x^t,t,y))θx^0]
    • 更新 LoRA 教师模型
      • 使用不同的随机时间步 t ′ t' t 和噪声 ϵ ′ \epsilon' ϵ 生成另一个带噪图像样本 x ^ t ′ \hat{x}_{t'} x^t,并最小化其预测噪声与真实噪声的均方误差:
        ϕ ← ϕ − η 2 ∇ ϕ ∥ ϵ ϕ ( x ^ t ′ , t ′ , y ) − ϵ ′ ∥ 2 \phi \leftarrow \phi - \eta_2 \nabla_\phi \|\epsilon_\phi(\hat{x}_{t'}, t', y) - \epsilon'\|^2 ϕϕη2ϕϵϕ(x^t,t,y)ϵ2
  3. 训练终止

    • 训练过程在模型收敛后终止,返回训练后的学生模型 f θ f_\theta fθ

实现一步去噪的关键点

  • 指导信号的利用:在训练过程中,学生模型 f θ f_\theta fθ 通过从两个教师模型 ϵ ψ \epsilon_\psi ϵψ ϵ ϕ \epsilon_\phi ϵϕ 之间的预测差异中学习。这种指导信号使得学生模型能够在不同时间步长下学习到更准确的去噪能力。
  • 噪声计划的使用:通过在每次迭代中随机采样时间步长 t t t t ′ t' t,算法能够覆盖扩散过程中的各个阶段,使学生模型在不同的噪声水平上都能有效去噪。
  • 分数函数的优化学生模型通过优化一个分数函数,即在噪声样本上预测噪声残差,从而学习到如何一步生成去噪图像。

数学推导

  1. 学生模型的更新

    • 通过最小化噪声残差的均方误差,学生模型学习在每个时间步长下的去噪能力:
      L θ = ∥ ϵ ψ ( x ^ t , t , y ) − ϵ ϕ ( x ^ t , t , y ) ∥ 2 \mathcal{L}_\theta = \|\epsilon_\psi(\hat{x}_t, t, y) - \epsilon_\phi(\hat{x}_t, t, y)\|^2 Lθ=ϵψ(x^t,t,y)ϵϕ(x^t,t,y)2
      更新公式为:
      θ ← θ − η 1 ∇ θ L θ \theta \leftarrow \theta - \eta_1 \nabla_\theta \mathcal{L}_\theta θθη1θLθ
  2. LoRA 教师模型的更新

    • LoRA 教师模型通过优化其在不同时间步长下的预测噪声,使其逐渐逼近真实噪声:
      L ϕ = ∥ ϵ ϕ ( x ^ t ′ , t ′ , y ) − ϵ ′ ∥ 2 \mathcal{L}_\phi = \|\epsilon_\phi(\hat{x}_{t'}, t', y) - \epsilon'\|^2 Lϕ=ϵϕ(x^t,t,y)ϵ2
      更新公式为:
      ϕ ← ϕ − η 2 ∇ ϕ L ϕ \phi \leftarrow \phi - \eta_2 \nabla_\phi \mathcal{L}_\phi ϕϕη2ϕLϕ

通过上述训练过程,学生模型 f θ f_\theta fθ 学会了在一个前向传播步骤中有效去噪,从而实现了一步去噪的目标。这种方法不仅提高了生成效率,还保持了高质量的生成图像。

相关推荐

  1. SwiftBrush算法代码解读

    2024-07-13 13:56:04       20 阅读
  2. 微软面试高频算法解析代码实现(C++)

    2024-07-13 13:56:04       33 阅读
  3. 麻雀搜索算法解释代码

    2024-07-13 13:56:04       50 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 13:56:04       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 13:56:04       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 13:56:04       57 阅读
  4. Python语言-面向对象

    2024-07-13 13:56:04       68 阅读

热门阅读

  1. 005-基于Sklearn的机器学习入门:逻辑回归

    2024-07-13 13:56:04       28 阅读
  2. opencv—常用函数学习_“干货“_总

    2024-07-13 13:56:04       21 阅读
  3. Web组成架构

    2024-07-13 13:56:04       22 阅读
  4. Artificial intelligence machine learning DATA4800

    2024-07-13 13:56:04       23 阅读
  5. 自用的C++20协程学习资料

    2024-07-13 13:56:04       20 阅读
  6. 如何在uniapp中使用websocket?

    2024-07-13 13:56:04       15 阅读
  7. 【linux】预防rm误删文件的3种方法

    2024-07-13 13:56:04       22 阅读
  8. 掌控版本脉动:Gradle依赖更新策略全解析

    2024-07-13 13:56:04       20 阅读
  9. 解释器模式(大话设计模式)C/C++版本

    2024-07-13 13:56:04       19 阅读
  10. 资源搜索网址

    2024-07-13 13:56:04       18 阅读
  11. 关于最近项目数字前端FLOW的一些总结

    2024-07-13 13:56:04       19 阅读
  12. 【AI应用探讨】—迁移学习(TL)应用场景

    2024-07-13 13:56:04       29 阅读