Gumbel Softmax

Argmax是不可求导的,Gumbel Softmax允许模型能从网络层的离散分布(比如类别分布categorical distribution)中稀疏采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数。

算法流程

  1. 对于某个网络层输出的 n \mathrm{n} n 维向量 v = [ v 1 , v 2 , … , v n ] v=\left[v_1, v_2, \ldots, v_n\right] v=[v1,v2,,vn],生成 n \mathrm{n} n 个服从均匀分布 U ( 0 , 1 ) \mathrm{U}(0,1) U(0,1) 的独立样本 ϵ 1 , … , ϵ n \epsilon_1, \ldots, \epsilon_n ϵ1,,ϵn
  2. 通过 G i = − log ⁡ ( − log ⁡ ( ϵ i ) ) G_i=-\log \left(-\log \left(\epsilon_i\right)\right) Gi=log(log(ϵi)) 计算得到 G i G_i Gi
  3. 对应相加得到新的值向量 v ′ = [ v 1 + G 1 , v 2 + G 2 , … , v n + G n ] v^{\prime}=\left[v_1+G_1, v_2+G_2, \ldots, v_n+G_n\right] v=[v1+G1,v2+G2,,vn+Gn]
  4. 通过softmax函数计算各个类别的概率大小,其中 τ \tau τ 是温度参数:
    p τ ( v i ′ ) = e v i ′ / r ∑ j = 1 n e v j ′ / τ p_\tau\left(v_i^{\prime}\right)=\frac{e^{v_i^{\prime} / r}}{\sum_{j=1}^n e^{v_j^{\prime} / \tau}} pτ(vi)=j=1nevj/τevi/r

Gumbel-Max Trick

Gumbel分布是专门用来建模从其他分布(比如高斯分布)采样出来的极值形成的分布,而我们这里“使用argmax挑出概率最大的那个类别索引”就属于取极值的操作,所以它属于Gumbel分布。

注意,极值的分布也是有规律的。

Gumbel-Max Trick的采样思想:先用均匀分布采样出一个随机值,然后把这个值带入到gumbel分布的CDF函数的逆函数得到采样值,即我们最终想要的类别索引。公示如下:
z = argmax ⁡ i ( log ⁡ ( p i ) + g i ) g i = − log ⁡ ( − log ⁡ ( u i ) ) , u i ∼ U ( 0 , 1 ) z=\operatorname{argmax}_i\left(\log \left(p_i\right)+g_i\right) \\ g_i=-\log \left(-\log \left(u_i\right)\right), u_i \sim U(0,1) z=argmaxi(log(pi)+gi)gi=log(log(ui)),uiU(0,1)
上式使用了重参数技巧把采样过程分成了确定性的部分和随机性的部分,我们会计算所有类别的log分布概率(确定性的部分),然后加上一些噪音(随机性的部分),这里噪音是标准gumbel分布。在我们把采样过程的确定性部分和随机性部分结合起来之后,我们在此基础上再用一个argmax来找到具有最大概率的类别。

Softmax

使用softmax替换不可导的argmax,用温度系数 τ \tau τ 来近似argmax:
p i ′ = exp ⁡ ( g i + log ⁡ p i τ ) ∑ j exp ⁡ ( g j + log ⁡ p j τ ) p_i^{\prime}=\frac{\exp \left(\frac{g_i+\log p_i}{\tau}\right)}{\sum_j \exp \left(\frac{g_j+\log p_j}{\tau}\right)} pi=jexp(τgj+logpj)exp(τgi+logpi)
τ \tau τ 越大,越接近argmax。


参考

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-08 07:54:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-08 07:54:01       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-08 07:54:01       82 阅读
  4. Python语言-面向对象

    2024-04-08 07:54:01       91 阅读

热门阅读

  1. 【恩智浦FRDM-MCX947开箱实践指南1】

    2024-04-08 07:54:01       36 阅读
  2. 迁移学习和微调

    2024-04-08 07:54:01       33 阅读
  3. Python初级笔记4 排序

    2024-04-08 07:54:01       36 阅读
  4. go | chan 并发传输或者设置chan缓存|死锁

    2024-04-08 07:54:01       41 阅读
  5. Elasticsearch 如何实现 master 选举

    2024-04-08 07:54:01       38 阅读
  6. 柒拾贰- tushare 模拟策略交易 (三)

    2024-04-08 07:54:01       33 阅读
  7. STM32F103系列五个特殊引脚作为GPIO时的配置

    2024-04-08 07:54:01       34 阅读