0%

Gumbel-Max Trick 简记

最近读到一篇非常好的博客:

lips.cs.princeton.edu: The Gumbel-Max Trick for Discrete Distributions

强烈推荐直接阅读原文。这篇文章非常简洁地解释了 Gumbel-Max trick。很多论文会引用这个技巧,但真正讲清楚的资料并不多。这里简单记录几个关键点,并补充一点理解。


一句话核心思想

从 softmax 分布采样,可以通过给每个 logit 加一个 Gumbel 噪声,然后取最大值来实现。

其中

重复很多次后,第 i 个类别被选中的概率正好是

也就是 softmax。


原文中的变量含义

原文定义:

于是采样规则可以写成:

这里:

  • x_i:logit / score / log-probability
  • g_i:Gumbel 噪声
  • z_i:加入随机扰动后的 score

因此整篇文章实际上是在计算:

这和计算

是同一件事,因为原文只是把

换了个记号。


为什么 z 仍然是 Gumbel

原文有一个容易忽略的点:

如果

那么

仍然服从 Gumbel 分布,只是 location 参数变成了 x:

这就是为什么原文后面可以直接把 z 当成一个 Gumbel 随机变量来处理。

一句话说,这里用到的是:

Gumbel 分布在“加一个常数”之后,还是 Gumbel 分布。


一个直觉理解

可以把 Gumbel-Max 理解为:

每个类别先有一个固定分数 x_i,再额外加上一点随机扰动 g_i,最后看谁最大。

也就是说,我们比较的是:

如果重复很多次,原本分数更高的类别更容易赢,但又不是每次都赢。最终它被选中的概率恰好就是 softmax:


从 Gumbel-Max 到 Gumbel-Softmax

Gumbel-Max 的问题在于:

不可导,因此不能直接用于神经网络训练。

一个自然的想法是:用 softmax 去近似 argmax。于是得到

当温度 tau 趋近于 0 时,这个分布会越来越接近 one-hot,也就越来越像 argmax 的结果。

这就是 Gumbel-Softmax 的来历:

  • Gumbel-Max:加 Gumbel 噪声后做 argmax
  • Gumbel-Softmax:加 Gumbel 噪声后做 softmax 近似

它的意义是:既保留了离散采样的味道,又让整个过程变得可导。