【机器学习】对抗生成网络

e499c31491d0805f8758c82b482dae48.png

6bb0ee8b15aa9b9c459d1c3eb0b37bc0.png

一、随机数据生成

365a097e15910e7485c4c465b4a1f536.png

随机数据生成算法

c0ca391377032bc7a9d00b033b6d07bf.png

a0277f1e9e03b4ea2e673df49dabff58.png

随机数据生成的显示建模和隐式建模

1dda46d769c6425afb1012ab6d99a0b8.png

二、生成对抗网络结构

b9f8d7347953150df5aa19c520028d61.png

生成对抗网络(GAN)中,生成模型(Generator)和判别模型(Discriminator)的任务和训练目标分别是:

  • 生成模型的任务是从随机噪声中生成尽可能真实的数据,例如图像、文本、音频等。生成模型的训练目标是最小化生成数据被判别模型识别为假的概率,也就是最大化生成数据的真实性。

  • 判别模型的任务是区分输入的数据是真实的还是生成模型生成的。判别模型的训练目标是最大化真实数据被识别为真的概率和生成数据被识别为假的概率,也就是最小化判别模型的误判率。

生成模型和判别模型的训练是一个对抗的过程,它们互相竞争,不断提高自己的能力,最终达到一个平衡点,使得生成模型生成的数据无法被判别模型区分。这样,生成模型就可以生成高质量的数据,判别模型就可以提高数据的鉴别能力。

生成对抗网络的原理

861e1614346edf1f047018b99e83e85d.png

27e910394da50069af50a12ca433159a.png

cdedf4d639b2553eadfd31d41347411e.png

113fb7aff3158702c168f3acd2621434.png

bb7888077e2670cc1f313c0122352770.png

三、模型的训练

76bfb3630b1ed4803bff65922470bd06.png

GAN的训练过程是怎样的?

4bd817ae387b407914ad81832e819607.png

# 初始化生成器G和判别器D
G = Generator()
D = Discriminator()


# 设置优化器和超参数
optimizer_G = Optimizer(G.parameters(), ...)
optimizer_D = Optimizer(D.parameters(), ...)
epochs = ...
batch_size = ...
latent_dim = ...


# 循环训练epochs次
for epoch in range(epochs):
  # 循环训练每个批次的数据
  for x in data_loader(batch_size):
    # 训练判别器D
    optimizer_D.zero_grad()
    z = random_noise(latent_dim)
    fake_x = G(z)
    real_pred = D(x)
    fake_pred = D(fake_x.detach())
    loss_D = binary_cross_entropy(real_pred, 1) + binary_cross_entropy(fake_pred, 0)
    loss_D.backward()
    optimizer_D.step()


    # 训练生成器G
    optimizer_G.zero_grad()
    fake_pred = D(fake_x)
    loss_G = binary_cross_entropy(fake_pred, 1)
    loss_G.backward()
    optimizer_G.step()


    # 打印训练信息
    print(f"Epoch {epoch}, Loss_D: {loss_D}, Loss_G: {loss_G}")

四、应用和改进

GAN的变体

3faba4b99a9155c98c046bd8f994165a.png

这些变体的原理和GAN有什么不同?

68a772bc1ab622500382a07b1d386a9b.png

如何选择适合自己任务的GAN变体?

f90a1e484ccb5dc86e33cb970e9cb8a2.png

4.1 改进方案

CGAN

90e739b5c7275ab2ce1382e71ced8b3e.png

CGAN和GAN的训练过程有什么不同?

4d0bf741bc818d450dd16d01d68246b7.png

DCGAN

243d741f88742da818d187fc7d0599bc.png

拉普拉斯金字塔GAN

a07f29140032bb955ea7aeaaeda82927.png

dedd5a398a9583d5e7f613d4e25485fb.png

025593c3429e64047c8359d8dee12e8b.png

1d6c0c664cacd194e20c6853cb377c5a.png

GRAN

循环神经网络 (Recurrent Neural Networks, RNNs) 是一种深度学习的方法,它可以处理序列数据,如文本,语音,音乐等。RNNs 的特点是它们有一个内部状态,可以记住之前的信息,从而捕捉序列数据的长期依赖和结构。RNNs 可以用于生成对抗网络 (Generative Adversarial Networks, GANs) 的框架中,作为生成器或判别器,来生成或评估序列数据。这种结合了 RNNs 和 GANs 的方法称为生成循环对抗网络 (Generative Recurrent Adversarial Networks, GRANs)。GRANs 可以利用 RNNs 的能力来生成逼真和多样的序列数据,如文本,语音,音乐等。

328ef8d662e57217e5399f3589446210.png

InfoGAN

20c06fdc75b8a1932a5eec4d45024cc9.png

1b54d3b76cd79a0e49732d1bdf209f2d.png

4.2 典型应用

57074afdddabcd01d3155da19aa42303.png

09384c5eee64ee4e4b8e0e896357b287.png

9f343f76a7e98489d3ae40cea31f4cd9.png

1df35d16e8506d4bc9c718636fbe89af.png

Real-ESRGAN 超分辨率图像示例:

206456f51decb798fd984d870e038c27.png

从 https://github.com/ai-forever/Real-ESRGAN 下载源码,从

https://huggingface.co/ai-forever/Real-ESRGAN/tree/main 手动下载模型,放在weights文件夹中。 

主程序

# 导入os模块,用于操作系统相关的功能,如文件和目录的管理
import os 
# 导入torch模块,用于深度学习的计算和模型的构建
import torch
# 导入PIL模块,用于图像的处理和显示
from PIL import Image
# 导入numpy模块,用于科学计算和数组的操作
import numpy as np
# 导入RealESRGAN模块,这是一个基于生成对抗网络的超分辨率模型,可以将低分辨率的图像转换为高分辨率的图像
from RealESRGAN import RealESRGAN




# 定义一个主函数,返回值类型为整数
def main() -> int:
    # 判断当前设备是否支持CUDA,如果支持,就使用CUDA作为设备类型,否则使用CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 创建一个RealESRGAN的实例,指定设备类型和放大倍数为4
    model = RealESRGAN(device, scale=4)
    # 加载预训练的模型权重,从本地目录中读取文件,不需要下载
    model.load_weights('Real-ESRGAN/weights/RealESRGAN_x4.pth', download=False)
    # 遍历输入目录中的所有图像文件,使用enumerate函数给每个文件编号
    for i, image in enumerate(os.listdir("Real-ESRGAN/inputs")):
        # 打开图像文件,并转换为RGB模式
        image = Image.open(f"Real-ESRGAN/inputs/{image}").convert('RGB')
        # 使用模型对图像进行预测,得到超分辨率的图像
        sr_image = model.predict(image)
        # 将超分辨率的图像保存到结果目录中,文件名为编号.png
        sr_image.save(f'Real-ESRGAN/results/{i}.png')
    # 返回0表示程序正常结束
    return 0




# 如果当前文件是作为主程序运行,而不是被其他文件导入,就执行主函数
if __name__ == '__main__':
    main()

效果:

d957d408e6283dbff6705758892084fa.jpeg

原始图 650x650

8e4a0664ac25839cdbf50e62cd59d5ac.png

超分辨率图 2600x2600

参考网址:

https://en.wikipedia.org/wiki/Generative_adversarial_network

https://github.com/pytorch/examples/tree/main

https://arxiv.org/abs/1406.2661

https://zhuanlan.zhihu.com/p/53473337 [GAN学习系列2] GAN的起源

https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650730721&idx=2&sn=95b97b80188f507c409f4c72bd0a2767&chksm=871b349fb06cbd891771f72d77563f77986afc9b144f42c8232db44c7c56c1d2bc019458c4e4&scene=21#wechat_redirect

https://www.mindspore.cn/tutorials/application/zh-CN/r1.7/cv/dcgan.html 生成式对抗网络

https://pytorch.org/examples/

https://github.com/leftthomas/SRGAN

https://pytorch.org/hub/

https://github.com/xiong-jie-y/ml-examples/tree/master

https://huggingface.co/ai-forever/Real-ESRGAN/tree/main    模型下载Real-ESRGAN

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-01-10 09:50:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-10 09:50:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-10 09:50:03       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-10 09:50:03       20 阅读

热门阅读

  1. 关于c++中vector的使用(声明、清空、追加)

    2024-01-10 09:50:03       43 阅读
  2. 基本工具配置

    2024-01-10 09:50:03       33 阅读
  3. OCR识别PDF扫描件

    2024-01-10 09:50:03       36 阅读
  4. c++中getline的用法理解

    2024-01-10 09:50:03       32 阅读
  5. 第28关 k8s监控实战之Prometheus(六)

    2024-01-10 09:50:03       39 阅读