最近读到一篇非常好的博客:
lips.cs.princeton.edu: The Gumbel-Max Trick for Discrete Distributions
强烈推荐直接阅读原文。这篇文章非常简洁地解释了 Gumbel-Max trick。很多论文会引用这个技巧,但真正讲清楚的资料并不多。这里简单记录几个关键点,并补充一点理解。
一句话核心思想
从 softmax 分布采样,可以通过给每个 logit 加一个 Gumbel 噪声,然后取最大值来实现。
其中
重复很多次后,第 i 个类别被选中的概率正好是
也就是 softmax。
原文中的变量含义
原文定义:
于是采样规则可以写成:
这里:
- x_i:logit / score / log-probability
- g_i:Gumbel 噪声
- z_i:加入随机扰动后的 score
因此整篇文章实际上是在计算:
这和计算
是同一件事,因为原文只是把
换了个记号。
为什么 z 仍然是 Gumbel
原文有一个容易忽略的点:
如果
那么
仍然服从 Gumbel 分布,只是 location 参数变成了 x:
这就是为什么原文后面可以直接把 z 当成一个 Gumbel 随机变量来处理。
一句话说,这里用到的是:
Gumbel 分布在“加一个常数”之后,还是 Gumbel 分布。
一个直觉理解
可以把 Gumbel-Max 理解为:
每个类别先有一个固定分数 x_i,再额外加上一点随机扰动 g_i,最后看谁最大。
也就是说,我们比较的是:
如果重复很多次,原本分数更高的类别更容易赢,但又不是每次都赢。最终它被选中的概率恰好就是 softmax:
从 Gumbel-Max 到 Gumbel-Softmax
Gumbel-Max 的问题在于:
不可导,因此不能直接用于神经网络训练。
一个自然的想法是:用 softmax 去近似 argmax。于是得到
当温度 tau 趋近于 0 时,这个分布会越来越接近 one-hot,也就越来越像 argmax 的结果。
这就是 Gumbel-Softmax 的来历:
- Gumbel-Max:加 Gumbel 噪声后做 argmax
- Gumbel-Softmax:加 Gumbel 噪声后做 softmax 近似
它的意义是:既保留了离散采样的味道,又让整个过程变得可导。