使用pytorch构建GAN模型的评估

本文为此系列的第六篇对GAN的评估,上一篇为Controllable GAN。文中使用训练好的分类模型的部分网络提取特征将真实分布与生成分布进行对比来评估模型的好坏,若有不懂的无监督知识点可以看本系列第一篇。

原理

1.评估模型的指标
一般来说,我们评估模型的好坏可以通过对测试集的错误率来体现:比如图像分类我们可以统计几张分错几张分对来量化错误率、目标检测我们可以通过比对每个框得到mAP从而量化错误率…但是我们怎么通过生成的图像来评估GAN的好坏呢?
在这里插入图片描述
我们总不能说,生成的某一个像素要更绿色一点比较好,或者某个像素要更黄色一点比较好吧?
先进行概括一下,全文主要围绕着生成质量(保真度fidelity)、多样性(diversity)进行讲解。
在这里插入图片描述
2. 图像对比有两种方法,pixel distance、feature distance。
第一种像素对比,直接做相减运算。这样做的缺点是尽管两张图片可能非常相似,但是每个像素的像素值会有一些细微的差异,即使我们肉眼看不出来,最终的差值也会非常大,太过于关注细节。
在这里插入图片描述
第二种则是特征对比,通俗的说是成片的像素区域进行对比是否相似,这样的对比更符合我们人眼观察标准。
在这里插入图片描述
那么,接下来的问题就是如何进行特征提取。
3. 特征提取的方法
我们训练好的分类器是一个很好的特征提取器,比如我们训练了一个识别猫狗的分类器,那它必然是学习到了猫狗的特征才会对他们进行分类。
在这里插入图片描述
直接将分类部分的最后一层分类层去掉,其余的都是对我们有价值的。我们一般选择的是连接最后一个全连接层的池化层作为输出特征的层,我们成为特征层,输出的特征我们称为embedding。
选择这个位置并不固定,只是选择的位置越后面,每个单元的感受野越大,所包含的信息就越多,更符合我们的要求。很前面的层获取到的特征可能只是一横或者一竖或者一个弧度等。

  • 我们使用Inception v3作为我们的特征提取器,Inception使用超1400万张图片、2万多类别的ImageNet数据库作为训练集。提取详细流程如图:
    在这里插入图片描述

对总的概括可以概括为一下流程:
在这里插入图片描述
最终我们就是对真实数据提取的特征于生成数据提取的特征进行对比。
4. Frechet Inception Distance(FID)
我们使用FID来量化真假特征的差异。
通俗来说Frechet Distance是用来衡量两条曲线之间的的最小距离,比如人狗同时走所需的最短牵引绳的长度。
在这里插入图片描述
严格来说,Frechet Distance是衡量两个分布之间的差异。
在这里插入图片描述
①我们可以使用以下公式来表示两个单维正态分布的Frechet Distance:
在这里插入图片描述
分别从真实数据和生成数据里面提取大量的特征,分别作为真实特征分布于生成特征分布,计算出各自的均值和标准差即可计算出真假之间的差值。
②两个多变量正态分布的Frechet Distance
我们可以为每个维度提供一个单变量的正态分布,假设是两个变量的(便于举例),如图:
在这里插入图片描述

协方差矩阵:
比如(x1,x2)代表第一变量的正态分布的随机变量与第二正态分布的随机变量之间的协方差。非对角线元素代表不同变量之间的协方差,即不同变量之间的相关性。若两个变量变化趋势一致则协方差为正值,反之负值,若没有线性关系则为0。上图就代表两个变量之间相互不影响相互独立,下图代表两变量之间负相关;
比如(x1,x1)代表第一变量的正态分布的方差。对角线元素代表每个变量分布的方差,即每个变量本身的变化程度。
在这里插入图片描述
由此可以计算我们的多变量正态分布之间的Frechet Distance,可以将单维正态分布之间的Frechet Distance公式展开进行对比发现他们之间其实是相似的:
在这里插入图片描述
Tr运算为矩阵的对角线元素之和,例如上面那个负相关的协方差矩阵的Tr运算结果为2+2=4。
将多变量正态分布之间的Frechet Distance应用于真假特征的分布就是FID了:
在这里插入图片描述
FID越小,就代表着真假分布就越接近,那么GAN就越好。

代码

import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

z_dim = 64
image_size = 299
device = 'cuda'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = CelebA(".", download=True, transform=transform)

gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(f"pretrained_celeba.pth", map_location=torch.device(device))["gen"])
gen = gen.eval()

from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=False)
inception_model.load_state_dict(torch.load("inception_v3_google-1a9a5a14.pth"))
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation mode

inception_model.fc = torch.nn.Identity()

from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

import scipy
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)
    
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

import numpy as np
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

fake_features_list = []
real_features_list = []

n_samples = 512 # The total number of samples
batch_size = 4 # Samples per iteration

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

cur_samples = 0
with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memory
    try:
        for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
            real_samples = real_example
            real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
            real_features_list.append(real_features)

            fake_samples = get_noise(len(real_example), z_dim).to(device)
            fake_samples = preprocess(gen(fake_samples))
            fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
            fake_features_list.append(fake_features)
            cur_samples += len(real_samples)
            if cur_samples >= n_samples:
                break
    except:
        print("Error in loop")

fake_features_all = torch.cat(fake_features_list)
real_features_all = torch.cat(real_features_list)

mu_fake = fake_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_fake = get_covariance(fake_features_all)
sigma_real = get_covariance(real_features_all)

indices = [2, 4, 5]
fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))
real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))

import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

代码中使用的生成器模型可以从上一篇当中下载,inception_v3_google-1a9a5a14.pth模型可以从这里下载。

代码解析

  • 去掉分类层
inception_model.fc = torch.nn.Identity()

将最后一层的全连接层替换为恒等函数,它将输入的数据不做任何操作、原封不动地输出。
通常Inception模型的全连接层用于图像分类任务,它将提取的特征映射到类别预测上。然而我们不需要进行图像分类,而是想要利用Inception模型的前面部分来提取图像的特征。
这样就将Inception模型从原始的分类任务模型转变为一个特征提取器,从而不再执行图像分类任务,而是将图像转换为特征向量。

  • 可视化多变量正态分布
from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

首先定义均值和协方差矩阵(原理中举的两个例子),然后使用MultivariateNormal构建一个多变量正态分布对象covariant_dist。然后从这个分布中抽取了10000个样本,每个样本是一个shape为(samples, 2)的二维向量。最后将生成的样本可视化为二维核密度估计图(Kernel Density Estimate,KDE)。
在这里插入图片描述
在这里插入图片描述

  • 计算矩阵的平方根
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

首先将输入矩阵转移到CPU上并将其转换为NumPy数组。这是因为scipy.linalg.sqrtm函数只能接受NumPy数组作为输入,不能接受PyTorch张量,且在CPU上计算更高效。
然后使用scipy.linalg.sqrtm函数计算平方根且返回一个复数矩阵,所以需要取其实部(real)部分,然后再转换为PyTorch张量。同时,函数还会确保新的张量与输入矩阵在相同的设备(device)上。

  • 计算FID
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

给定两个分布的均值和协方差矩阵,利用原理中的公式进行计算。

  • 对生成图像进行处理
def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

将输入的图像进行插值操作,插值方法使用双线性插值,参数align_corners=False指示在进行插值操作时不对齐图像的角点,这在图像处理中常用于避免不必要的插值偏差。
在这里插入图片描述

  • 计算协方差矩阵
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

使用NumPy的np.cov()函数计算特征向量集合的协方差矩阵,rowvar=False参数表示传递的数据中每一列代表一个特征向量的观测值,而不是每一行代表一个观测样本。

  • 提取特征
for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
    real_samples = real_example
    real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
    real_features_list.append(real_features)

    fake_samples = get_noise(len(real_example), z_dim).to(device)
    fake_samples = preprocess(gen(fake_samples))
    fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
    fake_features_list.append(fake_features)
    cur_samples += len(real_samples)
    if cur_samples >= n_samples:
        break

使用预训练的Inception模型提取真实图像和生成图像的特征,并将这些特征存储在列表中,以备后续计算Fréchet Distance。
在这里需要对生成的图像进行preprocess()处理为299的宽高是因为真实数据的宽高为299,而生成数据的宽高为64。
我们可以将生成数据和preprocess处理后的数据显示出来看效果:

import matplotlib.pyplot as plt

# 选择其中一个样本进行显示
sample_index = 0

# 显示生成图像
fake_image = fake[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

# 显示经过处理的图像
fake_image = fake_samples[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

在这里插入图片描述
在这里插入图片描述
可以看到插值操作后平滑很多。

  • 可视化真实数据分布与生成数据分布,并计算FID
indices = [2, 4, 5]
import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

在这里插入图片描述
在这里插入图片描述

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-22 11:22:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-22 11:22:01       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-22 11:22:01       82 阅读
  4. Python语言-面向对象

    2024-04-22 11:22:01       91 阅读

热门阅读

  1. 几道练习题八

    2024-04-22 11:22:01       37 阅读
  2. 数据结构中顺序表的应用

    2024-04-22 11:22:01       29 阅读
  3. 使用go_concurrent_map 管理 并发更新缓存

    2024-04-22 11:22:01       36 阅读
  4. html-docx-js网页转为word格式框架

    2024-04-22 11:22:01       33 阅读
  5. Es6Proxy基础用法

    2024-04-22 11:22:01       33 阅读
  6. 笔记:Python 选择结构练习题

    2024-04-22 11:22:01       41 阅读
  7. tcp inflight 守恒算法(tcp_ccr)

    2024-04-22 11:22:01       34 阅读
  8. 将数据库中的数据接入Echarts图表

    2024-04-22 11:22:01       30 阅读
  9. PostCSS概述

    2024-04-22 11:22:01       35 阅读
  10. 环境感知——自动驾驶模型训练(菜鸟版本)

    2024-04-22 11:22:01       30 阅读
  11. 考研依据数学思维导图,整理出的章节知识大纲

    2024-04-22 11:22:01       36 阅读
  12. ZooKeeper的分布式锁

    2024-04-22 11:22:01       41 阅读
  13. 程序员如何修炼线路

    2024-04-22 11:22:01       193 阅读