第G1周:生成对抗网络(GAN)入门

前期工作

定义超参数:

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
 
## 创建文件夹
os.makedirs("D:\GAN-Data\images", exist_ok=True)         ## 记录训练过程的图片效果
os.makedirs("D:\GAN-Data\save", exist_ok=True)           ## 训练完成时模型保存的位置
os.makedirs("D:\GAN-Data\datasets\mnist", exist_ok=True)      ## 下载数据集存放的位置
 
## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
 
## 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)
 
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

下载数据集训练模型(以下代码二选一):

方法一:如果 GPU 驱动程序是最新的,并且与安装的 CUDA 版本兼容支持则使用 CUDA 的 PyTorch 下运行模型

## mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity                      # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## 模型中间块儿
        def block(in_feat, out_feat, normalize=True):        # block(in, out )
            layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
            layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数
            return layers
        ## prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
            *block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
            *block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
            *block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
            nn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784
            nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间
        )
    ## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
    def forward(self, z):                           # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                        # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
 
## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()
 
## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
 
## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
    generator     = generator.cuda()
    discriminator = discriminator.cuda()
    criterion     = criterion.cuda()
for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
 
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0
 
 
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数
 
 
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数
 
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if ( i + 1 ) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

 方法二:如果系统没有可用的 CUDA 支持或者您不想使用 GPU 进行计算,可以将模型切换到 CPU 运行。

## mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
## 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity                      # 鉴别器返回的是一个[0, 1]间的概率

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## 模型中间块儿
        def block(in_feat, out_feat, normalize=True):        # block(in, out )
            layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
            layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数
            return layers
        ## prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
            *block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
            *block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
            *block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
            nn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784
            nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间
        )
    ## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
    def forward(self, z):                           # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                        # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()

## 将模型切换到CPU
generator = generator.cpu()
discriminator = discriminator.cpu()

## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()

## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)

        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs)             # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1))      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1))     ## 定义假的图片的label为0


        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim))      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数


        z = Variable(torch.randn(imgs.size(0), latent_dim))      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数

        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)

torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

[Epoch 0/50] [Batch 99/118] [D loss: 1.348898] [G loss: 0.768190] [D real: 0.680194] [D fake: 0.608467]
[Epoch 1/50] [Batch 99/118] [D loss: 1.088056] [G loss: 0.825508] [D real: 0.743123] [D fake: 0.538369]
[Epoch 2/50] [Batch 99/118] [D loss: 1.114249] [G loss: 1.370899] [D real: 0.772192] [D fake: 0.571547]
[Epoch 3/50] [Batch 99/118] [D loss: 1.074770] [G loss: 1.163773] [D real: 0.687619] [D fake: 0.492250]
[Epoch 4/50] [Batch 99/118] [D loss: 1.126985] [G loss: 0.981320] [D real: 0.586216] [D fake: 0.420572]
[Epoch 5/50] [Batch 99/118] [D loss: 1.402424] [G loss: 0.648313] [D real: 0.352514] [D fake: 0.158908]
[Epoch 6/50] [Batch 99/118] [D loss: 1.128472] [G loss: 1.799203] [D real: 0.806444] [D fake: 0.586730]
[Epoch 7/50] [Batch 99/118] [D loss: 1.066108] [G loss: 1.737643] [D real: 0.764517] [D fake: 0.531414]
[Epoch 8/50] [Batch 99/118] [D loss: 1.162140] [G loss: 1.896749] [D real: 0.797096] [D fake: 0.600104]
[Epoch 9/50] [Batch 99/118] [D loss: 0.931134] [G loss: 1.207216] [D real: 0.600548] [D fake: 0.271762]
[Epoch 10/50] [Batch 99/118] [D loss: 0.906784] [G loss: 1.635424] [D real: 0.649781] [D fake: 0.306829]
[Epoch 11/50] [Batch 99/118] [D loss: 1.128253] [G loss: 0.814038] [D real: 0.455113] [D fake: 0.181365]
[Epoch 12/50] [Batch 99/118] [D loss: 0.656877] [G loss: 2.148012] [D real: 0.799558] [D fake: 0.318620]
[Epoch 13/50] [Batch 99/118] [D loss: 0.758273] [G loss: 1.781502] [D real: 0.849074] [D fake: 0.437607]
[Epoch 14/50] [Batch 99/118] [D loss: 0.982824] [G loss: 2.315076] [D real: 0.795012] [D fake: 0.504426]
[Epoch 15/50] [Batch 99/118] [D loss: 0.846314] [G loss: 1.125144] [D real: 0.594142] [D fake: 0.150700]
[Epoch 16/50] [Batch 99/118] [D loss: 0.788453] [G loss: 1.134926] [D real: 0.598429] [D fake: 0.113131]
[Epoch 17/50] [Batch 99/118] [D loss: 0.860472] [G loss: 1.416159] [D real: 0.554753] [D fake: 0.070943]
[Epoch 18/50] [Batch 99/118] [D loss: 0.729715] [G loss: 2.033889] [D real: 0.813916] [D fake: 0.380846]
[Epoch 19/50] [Batch 99/118] [D loss: 0.699210] [G loss: 2.655535] [D real: 0.845672] [D fake: 0.384237]
[Epoch 20/50] [Batch 99/118] [D loss: 0.608509] [G loss: 1.670838] [D real: 0.758573] [D fake: 0.230607]
[Epoch 21/50] [Batch 99/118] [D loss: 0.669346] [G loss: 2.555538] [D real: 0.817196] [D fake: 0.330501]
[Epoch 22/50] [Batch 99/118] [D loss: 0.811412] [G loss: 3.608017] [D real: 0.880692] [D fake: 0.466917]
[Epoch 23/50] [Batch 99/118] [D loss: 0.879888] [G loss: 1.472922] [D real: 0.610781] [D fake: 0.124722]
[Epoch 24/50] [Batch 99/118] [D loss: 0.767168] [G loss: 3.407905] [D real: 0.906761] [D fake: 0.470930]
[Epoch 25/50] [Batch 99/118] [D loss: 0.534345] [G loss: 2.263444] [D real: 0.890575] [D fake: 0.311338]
[Epoch 26/50] [Batch 99/118] [D loss: 0.473837] [G loss: 1.867095] [D real: 0.807679] [D fake: 0.173667]
[Epoch 27/50] [Batch 99/118] [D loss: 0.672992] [G loss: 2.960940] [D real: 0.846083] [D fake: 0.356172]
[Epoch 28/50] [Batch 99/118] [D loss: 0.726250] [G loss: 2.020569] [D real: 0.658650] [D fake: 0.034624]
[Epoch 29/50] [Batch 99/118] [D loss: 0.503680] [G loss: 2.267217] [D real: 0.826359] [D fake: 0.216285]
[Epoch 30/50] [Batch 99/118] [D loss: 0.987975] [G loss: 1.588039] [D real: 0.544412] [D fake: 0.043705]
[Epoch 31/50] [Batch 99/118] [D loss: 1.162546] [G loss: 2.823585] [D real: 0.729459] [D fake: 0.494907]
[Epoch 32/50] [Batch 99/118] [D loss: 0.924303] [G loss: 1.293745] [D real: 0.582127] [D fake: 0.117892]
[Epoch 33/50] [Batch 99/118] [D loss: 0.747387] [G loss: 2.206166] [D real: 0.797877] [D fake: 0.343705]
[Epoch 34/50] [Batch 99/118] [D loss: 0.623693] [G loss: 3.111738] [D real: 0.898811] [D fake: 0.381497]
[Epoch 35/50] [Batch 99/118] [D loss: 0.567340] [G loss: 2.021876] [D real: 0.757147] [D fake: 0.179998]
[Epoch 36/50] [Batch 99/118] [D loss: 0.727314] [G loss: 1.915004] [D real: 0.755489] [D fake: 0.287818]
[Epoch 37/50] [Batch 99/118] [D loss: 0.826854] [G loss: 1.472841] [D real: 0.674383] [D fake: 0.238948]
[Epoch 38/50] [Batch 99/118] [D loss: 1.143365] [G loss: 0.757286] [D real: 0.489352] [D fake: 0.069917]
[Epoch 39/50] [Batch 99/118] [D loss: 0.818748] [G loss: 1.114080] [D real: 0.601727] [D fake: 0.127117]
[Epoch 40/50] [Batch 99/118] [D loss: 0.918430] [G loss: 1.276388] [D real: 0.629529] [D fake: 0.249412]
[Epoch 41/50] [Batch 99/118] [D loss: 0.727234] [G loss: 1.541735] [D real: 0.718813] [D fake: 0.211931]
[Epoch 42/50] [Batch 99/118] [D loss: 0.979106] [G loss: 0.957361] [D real: 0.568877] [D fake: 0.127781]
[Epoch 43/50] [Batch 99/118] [D loss: 0.683977] [G loss: 1.902616] [D real: 0.765684] [D fake: 0.275655]
[Epoch 44/50] [Batch 99/118] [D loss: 0.681833] [G loss: 2.164286] [D real: 0.775665] [D fake: 0.293220]
[Epoch 45/50] [Batch 99/118] [D loss: 0.762346] [G loss: 1.543463] [D real: 0.613166] [D fake: 0.084738]
[Epoch 46/50] [Batch 99/118] [D loss: 0.780659] [G loss: 1.477143] [D real: 0.697691] [D fake: 0.234303]
[Epoch 47/50] [Batch 99/118] [D loss: 0.709177] [G loss: 1.837770] [D real: 0.750658] [D fake: 0.254356]
[Epoch 48/50] [Batch 99/118] [D loss: 0.884956] [G loss: 2.488509] [D real: 0.832655] [D fake: 0.457649]
[Epoch 49/50] [Batch 99/118] [D loss: 0.990627] [G loss: 4.116466] [D real: 0.913515] [D fake: 0.563187]

相关推荐

  1. 生成对抗网络GAN入门

    2023-12-21 23:52:05       10 阅读
  2. 生成对抗网络GAN

    2023-12-21 23:52:05       14 阅读

最近更新

  1. 图形渲染基础-GPU驱动的渲染管线

    2023-12-21 23:52:05       0 阅读
  2. 数据库的基本概念

    2023-12-21 23:52:05       0 阅读
  3. 图形渲染基础-Unity渲染管线介绍

    2023-12-21 23:52:05       0 阅读
  4. spring xml实现bean对象(仅供自己参考)

    2023-12-21 23:52:05       0 阅读
  5. Tomcat异常处理【Spring源码学习】

    2023-12-21 23:52:05       0 阅读
  6. Leetcode101 判断二叉树是否对称

    2023-12-21 23:52:05       1 阅读
  7. 【深入剖析】Kylin架构全景及其组件详解

    2023-12-21 23:52:05       1 阅读

热门阅读

  1. hive(2)

    2023-12-21 23:52:05       29 阅读
  2. 自定义ORM(mybatis)源码(二)-解析mapper.xml

    2023-12-21 23:52:05       38 阅读
  3. Linux多线程

    2023-12-21 23:52:05       54 阅读