GradNorm理解

主要参考这一篇,GradNorm:Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks,梯度归一化_grad norm-CSDN博客

14:20-15:30 

提前需要理解的概念

损失函数,衡量ypred与ytruth的差距。

Grad Loss定义为:各个任务【实际的梯度范数】与【理想的梯度范数】的【差的绝对值和】;

先把范数简单理解成长度,目前把范数想象成了神经网络找最优参数时向某个方向走的距离;

参数理解

其中,alpha 是设定恢复力强度的超参数,即将任务的训练速度调节到平均水准的强度。如果任务的复杂程度很不一样,导致任务之间的学习速率大不相同,就应该使用较高的 alpha 来进行较强的训练速率平衡;反之,对于多个相似的任务,应该使用较小的 alpha。

Training with GradNorm

用自己的想法总结了下,

gradnorm在单个batch step的流程总结如下:
整体任务是指L=w_aL_a+w_bL_b, W是神经网络的参数值, gradnorm主要在动态学w_a, w_b;
1、前向传播计算总损失L=w_a*L_a+w_b*L_b(假设我现在有2个任务);
2、计算第i个任务对整体任务的梯度范数,计算任务i的相对反向训练速度,计算所有任务对整体任务的梯度范数的平均;
3、计算GradLoss;
4、计算GradLoss对wi的导数(wi是指w_a,w_b);
5、利用第1步计算的Loss反向传播更新神经网络参数;
6、利用第4步的导数更新wi(更新后在下一个bacth step生效);
7、对wi进行renormalize(下一个bacth step使用的是renormalize之后的wi,意思就是下一个batch训练的时候wa,wb已经换了);
 

代码(aaa)

Multi-Task Learning:GradNorm - 知乎(参数含义的查漏补缺)

GitHub - QunBB/DeepLearning: All about DeepLearning: 推荐系统、自然语言处理、Tensorflow、Pytorch等

- 附加第二个代码的相关实验数据:多任务学习MTL模型:多目标Loss优化策略 - 知乎 

    - 多目标loss优化在开源数据实验一(uncertainty weight、GradNorm) - 知乎

GitHub - brianlan/pytorch-grad-norm: Pytorch implementation of the GradNorm. GradNorm addresses the problem of balancing multiple losses for multi-task learning by learning adjustable weight coefficients.

为什么loss量级大的task1的权重更大呢?

相关推荐

  1. JVM<span style='color:red;'>理解</span>

    JVM理解

    2023-12-14 09:34:04      30 阅读
  2. <span style='color:red;'>理解</span>CAS

    理解CAS

    2023-12-14 09:34:04      17 阅读
  3. keepalive 理解

    2023-12-14 09:34:04       12 阅读
  4. rpc理解

    2023-12-14 09:34:04       9 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-14 09:34:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-14 09:34:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-14 09:34:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-14 09:34:04       20 阅读

热门阅读

  1. mysql 当前时间加3个工作日

    2023-12-14 09:34:04       41 阅读
  2. Hive的几种排序方式、区别,使用场景

    2023-12-14 09:34:04       30 阅读
  3. 【Python基础】迭代器

    2023-12-14 09:34:04       31 阅读
  4. 哪些数据适合放入缓存?

    2023-12-14 09:34:04       33 阅读
  5. 子组件调用父组件的方法

    2023-12-14 09:34:04       42 阅读
  6. ElasticSearch之cat templates API

    2023-12-14 09:34:04       42 阅读
  7. prim算法求最小生成树

    2023-12-14 09:34:04       33 阅读
  8. QEMU源码全解析 —— virtio(6)

    2023-12-14 09:34:04       47 阅读
  9. Android WebView 响应缓存 笔记

    2023-12-14 09:34:04       44 阅读
  10. 【工具】VUE 前端列表拖拽功能代码

    2023-12-14 09:34:04       44 阅读