从“RL for LLM”视角重新理解KL近似:关于“近似KL散度”的笔记

社区文章 发布于2025年8月11日

PPO和GRPO中使用的KL散度估计方法有什么区别?

John Schulman的博客文章“近似KL散度”讨论了如何通过采样(蒙特卡罗)近似KL散度,并介绍了三种估计器(\(k_1\)、k2k_2k3k_3)及其偏差-方差行为。但原始文章是在一般概率分布的背景下提出的,并未涉及大型语言模型(LLM)的强化学习训练设置。本文记录了我在阅读时遇到的问题、将内容映射到RL for LLM后形成的思考,以及一些我认为原始解释可以进一步阐述的地方。

“近似KL散度”说了什么(用我自己的话)

在本节中,我假定读者尚未阅读原始文章,因此我们快速浏览最重要的部分。简单来说,这篇文章是关于当我们无法直接计算KL散度时,如何构建合理的蒙特卡罗式估计器。

KL(q,p)=xq(x)logq(x)p(x)=Exq ⁣[logq(x)p(x)]. \mathrm{KL}(q, p) = \sum_x q(x)\,\log\frac{q(x)}{p(x)} = \mathbb{E}_{x\sim q}\!\left[\log\frac{q(x)}{p(x)}\right].

如公式所示:当估计两个(复杂)分布之间的KL散度时,人们常用一种编码技巧:即通过从qq中抽样,用log ⁣(q(x)p(x))\log\!\big(\frac{q(x)}{p(x)}\big)的样本均值来近似KL(而不是试图精确评估完整的期望)。文章接着指出另一种方法:使用12(logr)2\tfrac{1}{2}(\log r)^2的样本均值来替代更“标准”的logr\log r形式,其中r=q(x)p(x)r=\frac{q(x)}{p(x)}。本文解释了为什么这个表达式可以成为KL的良好(尽管有偏)估计器,以及如何在保持低方差的同时使其无偏。

我们计算KL的方式取决于我们如何访问ppqq。这里我们假设我们可以评估任何xxp(x)p(x)q(x)q(x)(概率或密度),但我们**无法**对xx进行解析求和/积分。我们为什么不能进行解析求和/积分呢?可能是因为精确计算在计算或内存方面过于昂贵,可能没有闭合形式,或者我们为了简化代码,只存储对数概率而不是完整的分布,尤其是在KL仅用于诊断时(强化学习中常出现这种情况)。近似求和或积分最常见的策略是**蒙特卡罗**。给定从qq中抽取的样本x1,x2,,xnqx_1, x_2, \dots, x_n \sim q,我们如何构建一个好的估计器?

一个好的估计器应该**无偏**(平均值正确)且**方差低**。我们知道一个无偏估计器

k1=logq(x)p(x). k_1 = \log\frac{q(x)}{p(x)}.

但它的方差很高:根据定义,KL是一个非负量,然而对于上述估计器,大约“一半”的样本值可能是负的(如果我们不对ppqq做任何先验假设),这使得平均值波动很大,因此方差很高。为了符号方便,设r=q(x)p(x)r = \frac{q(x)}{p(x)}。那么原始的KL可以写成

KL[q,p]  =  Exq[logr]. \mathrm{KL}[q, p] \;=\; \mathbb{E}_{x\sim q}\,[\log r].

为了减少方差,我们可以设计一个替代估计器:k2=12(logr)2. k_2 = \frac{1}{2}(\log r)^2. 它的方差较低,但有偏。直观上,k2k_2感觉更好,因为每个样本都给出了ppqq之间的非负“距离”,因此它保持正值。经验上,k2k_2的方差确实比k1k_1低得多,而且偏差可以很小。至于为什么k2k_2相比k1k_1能够大幅降低方差,原始文章使用了f-散度视图给出了分析解释,这里我不再赘述。

现在,我们能否得到一个既**无偏**又**低方差**的估计器呢?一个通用的技巧是使用**控制变量**:从无偏的k1k_1开始,并添加一个期望值为零且与它负相关的量以降低方差。这里一个非常方便的零均值量是r1r-1。因此,对于任意λ\lambdak  =  logr+λ(r1) k \;=\; -\log r + \lambda\,(r-1) 仍然是一个无偏的KL估计器。理论上,我们可以在λ\lambda上最小化方差,但其闭合形式取决于ppqq,不容易得到。然而请注意,由于log(x)\log(x)是凹函数,log(x)    x1, \log(x) \;\le\; x-1, 所以如果我们选择λ=1\lambda=1,该表达式保证非负。在这里,r1r-1logr\log rr=1r=1处的切线。因此,当λ=1\lambda=1时,我们实际上测量的是log(x)\log(x)与其切线之间的垂直距离。这导致了估计器k3  =  (r1)    logr, k_3 \;=\; (r - 1) \;-\; \log r, 它总是非负的。而k3k_3正是实际中GRPO与PPO在KL估计方式上有所不同的地方(PPO使用k1k_1)。

从“RL for LLM”的角度讨论KL估计

在强化学习(例如PPO、GRPO等)中,我们通常会在损失函数中加入一个KL散度项,以防止新策略偏离旧策略太远。这里,qq是旧策略分布(πold\pi_{\text{old}}),pp是新策略分布(πnew\pi_{\text{new}}),而xx是一个完整的动作样本(在LLM中,这表示一个token或一个token序列)。我们通常用ss表示状态(在LLM中,这是提示或上下文),xx是在该上下文中生成的特定token。当我们计算KL时,我们实际上是在**给定状态下的动作分布**上计算KL,然后对状态进行平均:

KL[p,q]=Es[xp(xs)logp(xs)]. \mathrm{KL}[p, q] = \mathbb{E}_{s} \left[ \sum_x p(x|s) \log \frac{p(x|s)}{q(x|s)} \right].

在采样时,我们通常会固定一个提示(状态),然后为该提示估计此KL散度。

那么**为什么我们不能直接精确计算KL散度,而非要估计它呢?**原因与原始博客文章中列出的完全相同;在LLM的强化学习中,主要症结在于**原因1**:*动作空间(token空间)太大,无法对所有可能的xx进行求和/积分*。例如,如果一个分词器有50,000个词汇条目,即使计算单个token的KL散度也意味着对50,000个动作求和;而在强化学习中,我们通常进行多步(序列)生成,因此空间呈指数级增长,这完全不切实际。还有一个实用原因:在训练过程中,我们通常不存储完整的分布(所有token的概率);我们只保留沿轨迹实际生成的token的对数概率,以节省GPU内存和I/O。因此,我们必须使用**蒙特卡罗采样**:从某个分布(通常是qq,即旧策略)中抽取xx,并使用这些样本来近似KL散度。这就把我们直接带入了博客文章所讨论的领域。

在该文章中,我们一直谈论的**估计器**实际上只是样本的一个函数:它接收某个采样xxp(x)p(x)q(x)q(x)(或它们的比率r=q(x)p(x)r = \frac{q(x)}{p(x)}),并输出一个数字。然后,我们对这些数字在样本上求平均,以近似KL散度。例如:

  • k1(x)=logrk_1(x) = -\log r
  • k2(x)=(logr)2k_2(x) = \frac12 (\log r)^2
  • k3(x)=(r1)logrk_3(x) = (r - 1) - \log r

这些kik_i只是不同的KL估计器公式。它们都通过**对样本求平均**来近似KL散度,但在偏差和方差上有所不同。一旦我们选择了一个估计器,我们实际上就承诺使用一个特定的公式来近似KL散度。这个过程看起来像这样:

  1. 采样
    从旧策略qq中采样一批token(或序列)x1,x2,,xNx_1, x_2, \dots, x_N
  2. 计算对数概率
    对于每个样本,计算新旧策略下的对数概率

logp(xi), logq(xi) \log p(x_i),\ \log q(x_i)

并得到ri=q(xi)p(xi)r_i = \frac{q(x_i)}{p(x_i)}logri\log r_i。3. **代入估计器公式**
例如,如果我们选择k3k_3

k3(xi)=(ri1)logri k_3(x_i) = (r_i - 1) - \log r_i

  1. 平均分

KL^1Ni=1Nk3(xi) \widehat{\mathrm{KL}} \approx \frac1N \sum_{i=1}^N k_3(x_i)

这是近似的 KL 值,代表了真实的 KL。

如果我们将这与离散概率分布(LLM 单令牌步长)的真实 KL 计算(无估计)进行比较:我们需要遍历每个可能的令牌 xxKL(pq)=xp(x)logp(x)q(x) \mathrm{KL}(p\|q) = \sum_x p(x) \log \frac{p(x)}{q(x)} 您可以立即看到,使用估算器,计算量比进行完整求和小得多,尤其是在高维动作空间中。

谈论不同 KL 估计器的方差

重要提示:我们这里讨论的“方差”是估计器在样本上输出值的方差: Varxq[k(x)] \mathrm{Var}_{x \sim q}[k(x)] 也就是说, k(x)k(x) 在样本空间中的波动程度。一个**无偏**估计器意味着在无限多的样本下,其均值等于真实 KL。但高方差估计器意味着即使均值正确(无偏),在少量样本下,平均值也可能偏差很大。在 LLM 的强化学习中,KL 项通常是损失中的正则化因子(例如, βKL\beta \cdot \mathrm{KL})。如果 KL 估计器的方差很大,会使损失变得嘈杂,进而使梯度嘈杂并导致训练不稳定。

在原帖中,为了让读者直观理解为什么 k1k_1 不是低方差的,作者写道:

然而,它(k1k_1)具有高方差,因为它对一半的样本是负的,而 KL 始终是正的。

作者指出,尽管 k1k_1 是无偏的,但如果没有对 ppqq 的先验约束,一半的样本会一个比另一个大,所以一半的 k1k_1 值是正的,一半是负的。到目前为止,我都同意。但随后作者说:因为 KL 总是大于 0(一个基本不等式),所以 k1k_1 因此必须具有高方差。而在这里,我认为因果关系并不成立:你不能用期望的符号来决定单个样本的符号。一个简单的反例:在计算期望时, p(x)logp(x)q(x)p(x) \log \frac{p(x)}{q(x)} 也时而为正,时而为负;这个事实本身并不能说明方差。实际上,单样本的**对数比率**(无论是 logq(x)p(x)\log \frac{q(x)}{p(x)}logp(x)q(x)\log \frac{p(x)}{q(x)})都可以是正的或负的,就像 k1k_1 一样,所以**单独的符号翻转并不是高方差的唯一原因**。

根据 KL 定义: KL(qp)=Exq[logq(x)p(x)] \mathrm{KL}(q \| p) = \mathbb{E}_{x\sim q}\left[ \log \frac{q(x)}{p(x)} \right] 期望值**保证非负**,但被积函数 logq(x)p(x)\log\frac{q(x)}{p(x)} 可以对单个样本是正的或负的。而 k1k_1 正是这个被积函数: k1(x)=logq(x)p(x) k_1(x) = \log \frac{q(x)}{p(x)} 所以每个样本值确实可以是正的或负的,与 KL 定义中的被积函数相同。

那么为什么 k1k_1 会有高方差?

这不仅仅是“符号翻转”。真正的原因是 k1k_1 的值分布通常很宽(重尾)。例如,如果 p(x)p(x) 对于某些样本来说很小,那么 logqp\log\frac{q}{p} 可能会非常大(正或负)。这些极端值主导有限样本平均值,推高了方差。换句话说,它是**极端值 + 正负抵消**的组合:抵消意味着你需要更多的样本才能收敛到真实平均值,而极端值会使样本方差本身更大。因此,博客中“一半为负”的评论更多的是一种直觉提示,而不是完整的解释。

从这个角度来看,如果我们看其他估计器 k2k_2k3k_3,我们发现: k2=12(logr)2k_2 = \frac12 (\log r)^2 总是正的,所以没有抵消,但这引入了偏差;平方也平滑了幅度,降低了方差。k3k_3 使用控制变量来消除部分波动源,在保持无偏性的同时降低方差(详细信息见下文)。

在 PPO/GRPO 中,如果您使用 k1k_1 并且批次很小或分布相距很远,KL 估计值将跳来跳去(因为少数极端样本会使平均值剧烈波动)。这使得 KL 惩罚系数不稳定:它可能突然变得过强或过弱。切换到低方差估计器( k2k_2k3k_3 )使每个样本的 KL 贡献更稳定,更不容易被少数极端样本主导。

为什么 k3k_3 既能无偏又能低方差?

乍一看, k3k_3 总是正的,所以你可能会认为它的平均值必须大于 k1k_1 的平均值。
但请记住: k3k_3 是通过**控制变量**从 k1k_1 导出的。博客的推理如下: k~(x)=k1(x)+λh(x) \tilde{k}(x) = k_1(x) + \lambda \cdot h(x) 其中 h(x)=r1h(x) = r - 1,并且在 xqx\sim q 下,其期望值为: Exq[h(x)]=Eq[p(x)q(x)1]=xp(x)1=11=0. \mathbb{E}_{x\sim q}[h(x)] = \mathbb{E}_q\left[\frac{p(x)}{q(x)} - 1\right] = \sum_x p(x) - 1 = 1 - 1 = 0. 因此,添加任何 h(x)h(x) 的倍数都不会改变期望值。当 λ=1\lambda = 1 时: k~(x)=logr+(r1)=(r1)logr=k3(x). \tilde{k}(x) = -\log r + (r - 1) = (r - 1) - \log r = k_3(x). 这解释了为什么 k3k_3 的期望值等于 k1k_1 的期望值,并等于 KL,使其成为一个无偏估计器。

k3k_3k1k_1 具有更低方差的原因是: k1k_1 只有 logr-\log r,其值可能剧烈波动(既有正有负,偶尔出现巨大值)。但是 r1r - 1logr-\log r 在数值上高度相关(一个增长,另一个也增长/收缩),并且这种相关性是**负的**。添加 (r1)(r - 1) 就像引入一个**负相关项来抵消波动**。抵消后, k3k_3 中剩下的值范围更紧密,始终为正,因此样本方差更低。

社区

注册登录 发表评论