AIGC笔记--VAE模型的搭建

目录

1--VAE模型

2--代码实例


1--VAE模型

简单介绍:

        通过一个 encoder 将图片映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,通过 decoder 重构图片;

        计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)

2--代码实例

简单的VAE模型搭建:

        Encoder 返回映射标准分布的均值和方差,从标准分布中随机采样,利用Decoder重构图片;

class VAE(nn.Module):
    def __init__(self, input_dim = 784, h_dim = 400, z_dim = 20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, h_dim) # 28*28 → 784
        self.fc21 = nn.Linear(h_dim, z_dim) # 均值
        self.fc22 = nn.Linear(h_dim, z_dim) # 标准差
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim) # 784 → 28*28
        self.input_dim = input_dim

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1) # 均值、标准差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        # z = mu + eps*std
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        # sigmoid: 0-1 之间,后边会用到 BCE loss 计算重构 loss(reconstruction loss)
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

简单的损失计算:

x_reconst, mu, log_var = self.model(x)
# Compute reconstruction loss and kl divergence
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Backprop and optimize
loss = reconst_loss + kl_div

完整可运行代码参考:VAE代码实例

相关推荐

  1. AIGC笔记--VAE模型

    2024-01-20 03:38:02       42 阅读
  2. vue项目---1.基础框架

    2024-01-20 03:38:02       31 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-01-20 03:38:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-20 03:38:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-20 03:38:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-20 03:38:02       18 阅读

热门阅读

  1. Mellanox Cumulus 10GB交换机 - 网卡接口配置成网桥

    2024-01-20 03:38:02       31 阅读
  2. list上

    list上

    2024-01-20 03:38:02      31 阅读
  3. IDEA 常用快捷键(持续更新)

    2024-01-20 03:38:02       31 阅读
  4. Elasticsearch 字段更新机制

    2024-01-20 03:38:02       36 阅读
  5. ASOP的电池设置

    2024-01-20 03:38:02       31 阅读
  6. MacBook将大文件分割成很多个小文件split命

    2024-01-20 03:38:02       30 阅读
  7. 网络的各类型攻击方式

    2024-01-20 03:38:02       30 阅读
  8. mysql 主从通过mysqldump方式搭建

    2024-01-20 03:38:02       31 阅读
  9. 设计模式——访问者模式

    2024-01-20 03:38:02       32 阅读