把组合损失中的权重设置为可学习参数

目前的需求是:有一个模型,准备使用组合损失,其中有2个或者多个损失函数。准备对其进行加权并线性叠加。但想让这些权重进行自我学习,更新迭代成最优加权组合。

目录

1、构建组合损失类

2、调用组合损失类

3、为其构建优化器

4、梯度归零

5、跟新优化器参数

6、结果展示


1、构建组合损失类

每项损失函数可以定义在init里面,这样的话就只需要模型的输出和训练目标。我这里没有这样设置,选择把每项损失值传过来进行线性加权叠加。

# 定义组合损失函数---------------------------------------START
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        # 定义损失函数权重作为可训练参数
        self.w_adv = nn.Parameter(torch.ones(1, requires_grad=True))  # 对抗损失的权重,初始值为0.2 
        self.w_con = nn.Parameter(torch.ones(1, requires_grad=True))  # 内容感知损失的权重,初始值为0.2
        self.w_mse = nn.Parameter(torch.ones(1, requires_grad=True))  # 均方误差损失的权重,初始值为0.2
        self.w_s3im = nn.Parameter(torch.ones(1, requires_grad=True))  # 随机结构相似性损失的权重,初始值为0.2
        self.w_gui = nn.Parameter(torch.ones(1, requires_grad=True))  # 边缘引导损失的权重,初始值为0.2


    def forward(self, loss_adv, loss_con, loss_mse, loss_s3im, loss_gui):
        return self.w_adv*loss_adv + self.w_con*loss_con + self.w_mse*loss_mse + self.w_s3im*loss_s3im + self.w_gui*loss_gui

2、调用组合损失类

在计算组合损失之前,需要初始化类对象。

combinedloss = Loss.CombinedLoss()

unet_loss = self.combinedloss(
                            loss_adv = unet_gan_loss, 
                            loss_con = gen_content_loss, 
                            loss_mse = unet_criterion, 
                            loss_s3im = s3im_loss, 
                            loss_gui = guid_loss)

3、为其构建优化器

最好单独构建优化器,这样我们可以设置与总损失不用的学习率。避免学习率过大导致梯度消失。

self.lr_weight_optimizer = optim.Adam(
            self.combinedloss.parameters(),
            lr = 1e-4,
            betas=(0.9, 0.999)
        )

4、梯度归零

在每次计算总损失之前,需要把每个优化器的梯度归零

self.lr_weight_optimizer.zero_grad()

5、跟新优化器参数

在总损失反向传播之后,需要对优化器的参数进行更新

self.lr_weight_optimizer.step()

6、结果展示

每个权重都会自动更新。 

相关推荐

  1. 神经网络初始化

    2024-03-29 22:58:03       16 阅读
  2. PyTorch定义学习参数

    2024-03-29 22:58:03       35 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-29 22:58:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-29 22:58:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-29 22:58:03       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-29 22:58:03       20 阅读

热门阅读

  1. 将一个nextjs项目部署到vercel

    2024-03-29 22:58:03       21 阅读
  2. Android-AR眼镜屏幕显示

    2024-03-29 22:58:03       23 阅读
  3. [c++]类和对象常见题目详解

    2024-03-29 22:58:03       22 阅读
  4. MySQL新建用户并授权、删除用户、修改用户名

    2024-03-29 22:58:03       18 阅读
  5. 大数据导论-大数据可视化——沐雨先生

    2024-03-29 22:58:03       19 阅读
  6. python编程入门

    2024-03-29 22:58:03       17 阅读
  7. 面试——深度分页问题的优化

    2024-03-29 22:58:03       18 阅读
  8. python报错unable to rollback pymysql

    2024-03-29 22:58:03       17 阅读
  9. mysql一些常用查询语句

    2024-03-29 22:58:03       19 阅读
  10. mysql03-内外连接

    2024-03-29 22:58:03       16 阅读
  11. docker 共享内存不足问题

    2024-03-29 22:58:03       13 阅读
  12. Python100个库分享第1个—Chardet

    2024-03-29 22:58:03       18 阅读
  13. Ubuntu 的cuda更新

    2024-03-29 22:58:03       17 阅读