【扩散模型(五)】IP-Adapter 源码详解3-推理代码

系列文章目录



前言

这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb 中最基础的以图生图(Image Variations)为例:

SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:

  1. 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
  2. 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
  3. Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。

其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
在这里插入图片描述

*图中的 bs 代表 batch size

一、输入处理

IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py 文件的 IPAdapter 类的 generate() 函数中。

在这里插入图片描述

  1. 输入1: image prompt
    • 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
    • 以及训练好的 image_proj_model(ImageProjModel)
  2. 输入1对应的输出1有:
    • image_prompt_embeds
    • uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
    self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
  1. 输入2: text prompt、negative_prompt(默认的 ['monochrome, lowres, bad anatomy, worst quality, low quality']

    • text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
      • encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
      • 检查是否超过 clip 的长度
      • 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
    • negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
  2. 输入2 对应的输出2有:

    • prompt_embeds_
    • negative_prompt_embeds_
  3. 输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上 torch.cat 后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。

with torch.inference_mode():
    prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
        prompt,
        device=self.device,
        num_images_per_prompt=num_samples,
        do_classifier_free_guidance=True,
        negative_prompt=negative_prompt,
    )
    prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
    negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

二、过 Unet

  1. 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
    • 得到进一步的 prompt_embeds 和 negative_prompt_embeds
  2. prompt_embeds 和 negative_prompt_embeds 做 torch.cat 是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
 # For classifier free guidance, we need to do two forward passes.
 # Here we concatenate the unconditional and text embeddings into a single batch
 # to avoid doing two forward passes
 if self.do_classifier_free_guidance:
     prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1. 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
    在这里插入图片描述

三、Unet 中被替换的 CA

该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。

在这里插入图片描述

相关推荐

  1. ChatGLM3 解析(

    2024-07-17 22:56:01       30 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-17 22:56:01       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-17 22:56:01       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-17 22:56:01       58 阅读
  4. Python语言-面向对象

    2024-07-17 22:56:01       69 阅读

热门阅读

  1. Linux下Supervisor的安装与配置

    2024-07-17 22:56:01       18 阅读
  2. 布儒斯特定律

    2024-07-17 22:56:01       16 阅读
  3. A const member function

    2024-07-17 22:56:01       21 阅读
  4. 代码随想录算法训练营第14天/优先掌握递归

    2024-07-17 22:56:01       23 阅读
  5. ES6函数部分和数组部分的小练习

    2024-07-17 22:56:01       20 阅读
  6. 学习笔记(数据库)1

    2024-07-17 22:56:01       17 阅读
  7. 后端实现图片上传本地,可采用url查看图片

    2024-07-17 22:56:01       21 阅读
  8. 总览

    总览

    2024-07-17 22:56:01      18 阅读
  9. C4D S26新功能完整列表

    2024-07-17 22:56:01       24 阅读
  10. 大模型日报 2024-07-17

    2024-07-17 22:56:01       25 阅读