AIGC笔记--Classifer Guidance的代码理解

1--完整可debug代码

Simple_Classifer_Guide

2--核心代码

import torch as th
import torch.nn.functional as F

from tools.unet import EncoderUNetModel

# 核心函数
def cond_fn(x, t, y = None, classifier = None): 
    # x.shape: [batch_size, 3, 64, 64]
    # t.shape: [batch_size] 
    # y.shape: [batch_size]
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t) # [batch_size, 1000] # classifier本质上是EncoderUNetModel
        log_probs = F.log_softmax(logits, dim = -1) # [batch_size, 1000]
        selected = log_probs[range(len(logits)), y.view(-1)] # [batch_size]
        classifier_scale = 1.0 # set by ljf
        return th.autograd.grad(selected.sum(), x_in)[0] * classifier_scale # th.autograd.grad(y, x)计算y对x的导数
        # return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale # [batch_size, 3, 64, 64]

def condition_mean(cond_fn, p_mean_var, x, t, model_kwargs = None, classifier = None):
    # gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) # origin version
    gradient = cond_fn(x, t, y = model_kwargs, classifier = classifier) # [batch_size, 3, 64, 64] set by ljf
    new_mean = (
        p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() # [batch_size, 3, 64, 64]
    )
    return new_mean

def p_sample(x, t, cond_fn = None, model_kwargs = None, classifier = None):
        # origin version
        # out = self.p_mean_variance(
        #     model,
        #     x,
        #     t,
        #     clip_denoised=clip_denoised,
        #     denoised_fn=denoised_fn,
        #     model_kwargs=model_kwargs,
        # )

        # 这里使用random的方法模拟p_mean_variance的过程
        out = {}
        out["mean"] = th.rand(x.shape)
        out["variance"] = th.rand(x.shape)
        out["log_variance"] = th.rand(x.shape)
        out["pred_xstart"] = th.rand(x.shape)

        noise = th.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        if cond_fn is not None:
            out["mean"] = condition_mean(
                cond_fn, out, x, t, model_kwargs = model_kwargs, classifier = classifier
            )
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise # [batch_size, 3, 64, 64]
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}


if __name__ == "__main__":
    batch_size = 2
    channel = 3
    height = 64
    width = 64
    x = th.rand(batch_size, channel, height, width) # 随机生成输入
    t = th.tensor([999]).repeat(batch_size) # 随机模拟time steps
    y = th.randint(low = 0, high = 999, size = (2,)) # 模拟类别
    classifier = EncoderUNetModel(
        image_size = 64,
        in_channels = 3,
        model_channels = 128,
        out_channels = 1000, # 1000个类别
        num_res_blocks = 3,
        attention_resolutions = [32,16,8])
    
    output = p_sample(x = x, t = t, model_kwargs = y, cond_fn = cond_fn, classifier = classifier) # 传入cond_fn函数和分类器classifier
    assert list(output['sample'].shape) == [batch_size, channel, height, width]
    print("output['sample'].shape: ", output['sample'].shape)
    print("All Done!")

相关推荐

  1. AIGC笔记--Classifer Guidance代码理解

    2024-07-19 01:18:03       25 阅读
  2. AIGC笔记--VAE模型搭建

    2024-07-19 01:18:03       61 阅读
  3. AIGC笔记--Diffuser基本使用

    2024-07-19 01:18:03       30 阅读
  4. AIGC笔记--Diffuser训练pipeline

    2024-07-19 01:18:03       33 阅读
  5. 代码,我理解

    2024-07-19 01:18:03       35 阅读
  6. 代码开发业务在AIGC时代应用

    2024-07-19 01:18:03       56 阅读
  7. AIGC示例代码

    2024-07-19 01:18:03       36 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-19 01:18:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-19 01:18:03       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-19 01:18:03       58 阅读
  4. Python语言-面向对象

    2024-07-19 01:18:03       69 阅读

热门阅读

  1. rust 构建自己的库和模块

    2024-07-19 01:18:03       20 阅读
  2. 大语言模型系列-Transformer

    2024-07-19 01:18:03       24 阅读
  3. Git入门

    2024-07-19 01:18:03       25 阅读
  4. JVM高频面试题

    2024-07-19 01:18:03       23 阅读
  5. 割点(Articulation Point)

    2024-07-19 01:18:03       24 阅读
  6. [C/C++入门][变量和运算]4、带余除法

    2024-07-19 01:18:03       20 阅读
  7. 理解 Nginx 中的 sites-enabled 目录

    2024-07-19 01:18:03       24 阅读