【DiT 】推理代码

在这里插入图片描述

!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

根据尺寸自动下载模型

image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

输出

Downloading https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt to pretrained_models/DiT-XL-2-256x256.pt
100%|██████████| 2700611775/2700611775 [00:11<00:00, 225927973.23it/s]
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 

pip install accelerate

.
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
config.json: 100%
 547/547 [00:00<00:00, 30.0kB/s]
diffusion_pytorch_model.safetensors: 100%
 335M/335M [00:01<00:00, 184MB/s]

使用预训练模型采样

# %%time
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 1000 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 10 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row), 
           normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

相关推荐

  1. onnx推理python代码

    2024-03-28 11:42:05       58 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-28 11:42:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-28 11:42:05       101 阅读
  3. 在Django里面运行非项目文件

    2024-03-28 11:42:05       82 阅读
  4. Python语言-面向对象

    2024-03-28 11:42:05       91 阅读

热门阅读

  1. 网站建设服务器怎么选

    2024-03-28 11:42:05       47 阅读
  2. 函数 GetMemoryType 的理解

    2024-03-28 11:42:05       41 阅读
  3. linux进程切换

    2024-03-28 11:42:05       44 阅读
  4. 【C语言】RC4 测试代码

    2024-03-28 11:42:05       45 阅读
  5. el-upload上传文件前端自己读取excel

    2024-03-28 11:42:05       38 阅读
  6. uniapp H5 开发,公众号时请求跨域了,要用proxy

    2024-03-28 11:42:05       43 阅读
  7. Nginx服务

    2024-03-28 11:42:05       42 阅读
  8. Docker Compose 中的网络配置和优先级管理

    2024-03-28 11:42:05       44 阅读
  9. 无感刷新token

    2024-03-28 11:42:05       45 阅读
  10. CSS选择器 个人练习笔记

    2024-03-28 11:42:05       42 阅读