guided-diffusion 相比于improved-diffusion的sample增加的cond_fn()

1、cond_fn()函数代码

def cond_fn(x, t, y=None):
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

cond_fn 的函数接受三个参数:x、t 和一个可选的 y。这个函数的主要目的是计算一个关于输入 x 的梯度,这个梯度是基于通过某个分类器 classifier 对 x 和 t 进行分类时,针对特定标签 y 的对数概率的梯度。

参数检查: assert y is not None 确保 y 不为 None。这是必要的,因为后续的操作依赖于 y 来选择对数概率。
启用梯度计算: with torch.enable_grad(): 确保在这个代码块内,所有需要梯度的操作都会被记录,以便后续可以计算梯度。不过,在 PyTorch 中,更常见的做法是直接设置张量的 .requires_grad 属性,因为 torch.enable_grad() 主要用于全局控制梯度记录,而在这个函数中,我们只需要对 x_in 进行这样的设置。
准备输入: x_in = x.detach().requires_grad_(True) 通过 detach() 创建一个 x 的新副本,并从计算图中分离出来,然后通过 requires_grad_(True) 允许 PyTorch 对这个副本的操作进行梯度追踪。
前向传播: 通过 classifier(x_in, t) 获取分类器的输出(logits),然后使用 F.log_softmax(logits, dim=-1) 计算对数概率。
选择特定标签的对数概率: selected = log_probs[range(len(logits)), y.view(-1)] 这行代码通过索引选择每个样本对应标签 y 的对数概率。y.view(-1) 确保 y 的形状与 logits 的最后一维相匹配。log_probs[range(len(logits)), y.view(-1)]:这行代码使用高级索引(advanced indexing)来从log_probs中选择元素,range(len(logits))值是行索引, y.view(-1)是列索引。具体来说,它首先通过range(len(logits))生成一个与样本数量相等的索引序列,然后使用y.view(-1)来提供每个样本对应真实类别的索引。因此,这行代码实际上是在选择每个样本对应其真实类别的对数概率值。
计算梯度: th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 计算 selected.sum()(即所有选中对数概率的和)关于 x_in 的梯度,并将这个梯度乘以一个缩放因子 args.classifier_scale。th.autograd.grad 返回的是一个元组,其中包含所有需要梯度的张量的梯度,这里我们只关心 x_in 的梯度,所以通过 [0] 索引获取。
总的来说,这个函数计算了分类器对于输入 x 和条件 t,在给定标签 y 下的对数概率梯度,并对这个梯度进行了缩放。这样的梯度可以用于各种优化或学习算法中,特别是在需要基于条件梯度的场景下。

2、softmax与log_softmax函数

当Softmax的输入比较大的时候,可能会产生上溢出,超出float的能表示范围;同理,当输入为负值且绝对值比较大的时候,分子分母会极小,接近0,从而导致下溢出。log_Softmax能够很好的解决溢出问题,且可以加快运算速度,提升数据稳定性。
softmax
在这里插入图片描述

log_softmax
在这里插入图片描述

相关推荐

  1. Stable Diffusion数学原理

    2024-07-11 14:44:05       39 阅读
  2. Stable DiffusionEmbeddings

    2024-07-11 14:44:05       38 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 14:44:05       8 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 14:44:05       8 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 14:44:05       7 阅读
  4. Python语言-面向对象

    2024-07-11 14:44:05       10 阅读

热门阅读

  1. C# 委托和事件

    2024-07-11 14:44:05       10 阅读
  2. MySQL常见的几种索引类型及对应的应用场景

    2024-07-11 14:44:05       10 阅读
  3. 带内管理与带外管理

    2024-07-11 14:44:05       7 阅读
  4. linux 内核 红黑树接口说明

    2024-07-11 14:44:05       9 阅读
  5. 使用Python绘制堆积面积图

    2024-07-11 14:44:05       9 阅读
  6. React@16.x(53)Redux@4.x(2)- action

    2024-07-11 14:44:05       9 阅读
  7. TS-类型别名和接口的区别

    2024-07-11 14:44:05       9 阅读
  8. 索引

    2024-07-11 14:44:05       8 阅读