从“RL for LLM”视角重新理解KL近似:关于“近似KL散度”的笔记
PPO和GRPO中使用的KL散度估计方法有什么区别?
John Schulman的博客文章“近似KL散度”讨论了如何通过采样(蒙特卡罗)近似KL散度,并介绍了三种估计器(\(k_1\)、、)及其偏差-方差行为。但原始文章是在一般概率分布的背景下提出的,并未涉及大型语言模型(LLM)的强化学习训练设置。本文记录了我在阅读时遇到的问题、将内容映射到RL for LLM后形成的思考,以及一些我认为原始解释可以进一步阐述的地方。
“近似KL散度”说了什么(用我自己的话)
在本节中,我假定读者尚未阅读原始文章,因此我们快速浏览最重要的部分。简单来说,这篇文章是关于当我们无法直接计算KL散度时,如何构建合理的蒙特卡罗式估计器。
如公式所示:当估计两个(复杂)分布之间的KL散度时,人们常用一种编码技巧:即通过从中抽样,用的样本均值来近似KL(而不是试图精确评估完整的期望)。文章接着指出另一种方法:使用的样本均值来替代更“标准”的形式,其中。本文解释了为什么这个表达式可以成为KL的良好(尽管有偏)估计器,以及如何在保持低方差的同时使其无偏。
我们计算KL的方式取决于我们如何访问和。这里我们假设我们可以评估任何的和(概率或密度),但我们**无法**对进行解析求和/积分。我们为什么不能进行解析求和/积分呢?可能是因为精确计算在计算或内存方面过于昂贵,可能没有闭合形式,或者我们为了简化代码,只存储对数概率而不是完整的分布,尤其是在KL仅用于诊断时(强化学习中常出现这种情况)。近似求和或积分最常见的策略是**蒙特卡罗**。给定从中抽取的样本,我们如何构建一个好的估计器?
一个好的估计器应该**无偏**(平均值正确)且**方差低**。我们知道一个无偏估计器
但它的方差很高:根据定义,KL是一个非负量,然而对于上述估计器,大约“一半”的样本值可能是负的(如果我们不对和做任何先验假设),这使得平均值波动很大,因此方差很高。为了符号方便,设。那么原始的KL可以写成
为了减少方差,我们可以设计一个替代估计器:它的方差较低,但有偏。直观上,感觉更好,因为每个样本都给出了和之间的非负“距离”,因此它保持正值。经验上,的方差确实比低得多,而且偏差可以很小。至于为什么相比能够大幅降低方差,原始文章使用了f-散度视图给出了分析解释,这里我不再赘述。
现在,我们能否得到一个既**无偏**又**低方差**的估计器呢?一个通用的技巧是使用**控制变量**:从无偏的开始,并添加一个期望值为零且与它负相关的量以降低方差。这里一个非常方便的零均值量是。因此,对于任意,仍然是一个无偏的KL估计器。理论上,我们可以在上最小化方差,但其闭合形式取决于和,不容易得到。然而请注意,由于是凹函数,所以如果我们选择,该表达式保证非负。在这里,是在处的切线。因此,当时,我们实际上测量的是与其切线之间的垂直距离。这导致了估计器它总是非负的。而正是实际中GRPO与PPO在KL估计方式上有所不同的地方(PPO使用)。
从“RL for LLM”的角度讨论KL估计
在强化学习(例如PPO、GRPO等)中,我们通常会在损失函数中加入一个KL散度项,以防止新策略偏离旧策略太远。这里,是旧策略分布(),是新策略分布(),而是一个完整的动作样本(在LLM中,这表示一个token或一个token序列)。我们通常用表示状态(在LLM中,这是提示或上下文),是在该上下文中生成的特定token。当我们计算KL时,我们实际上是在**给定状态下的动作分布**上计算KL,然后对状态进行平均:
在采样时,我们通常会固定一个提示(状态),然后为该提示估计此KL散度。
那么**为什么我们不能直接精确计算KL散度,而非要估计它呢?**原因与原始博客文章中列出的完全相同;在LLM的强化学习中,主要症结在于**原因1**:*动作空间(token空间)太大,无法对所有可能的进行求和/积分*。例如,如果一个分词器有50,000个词汇条目,即使计算单个token的KL散度也意味着对50,000个动作求和;而在强化学习中,我们通常进行多步(序列)生成,因此空间呈指数级增长,这完全不切实际。还有一个实用原因:在训练过程中,我们通常不存储完整的分布(所有token的概率);我们只保留沿轨迹实际生成的token的对数概率,以节省GPU内存和I/O。因此,我们必须使用**蒙特卡罗采样**:从某个分布(通常是,即旧策略)中抽取,并使用这些样本来近似KL散度。这就把我们直接带入了博客文章所讨论的领域。
在该文章中,我们一直谈论的**估计器**实际上只是样本的一个函数:它接收某个采样的和(或它们的比率),并输出一个数字。然后,我们对这些数字在样本上求平均,以近似KL散度。例如:
这些只是不同的KL估计器公式。它们都通过**对样本求平均**来近似KL散度,但在偏差和方差上有所不同。一旦我们选择了一个估计器,我们实际上就承诺使用一个特定的公式来近似KL散度。这个过程看起来像这样:
- 采样
从旧策略中采样一批token(或序列)。 - 计算对数概率
对于每个样本,计算新旧策略下的对数概率
并得到或。3. **代入估计器公式**
例如,如果我们选择
- 平均分
这是近似的 KL 值,代表了真实的 KL。
如果我们将这与离散概率分布(LLM 单令牌步长)的真实 KL 计算(无估计)进行比较:我们需要遍历每个可能的令牌 : 您可以立即看到,使用估算器,计算量比进行完整求和小得多,尤其是在高维动作空间中。
谈论不同 KL 估计器的方差
重要提示:我们这里讨论的“方差”是估计器在样本上输出值的方差: 也就是说, 在样本空间中的波动程度。一个**无偏**估计器意味着在无限多的样本下,其均值等于真实 KL。但高方差估计器意味着即使均值正确(无偏),在少量样本下,平均值也可能偏差很大。在 LLM 的强化学习中,KL 项通常是损失中的正则化因子(例如, )。如果 KL 估计器的方差很大,会使损失变得嘈杂,进而使梯度嘈杂并导致训练不稳定。
在原帖中,为了让读者直观理解为什么 不是低方差的,作者写道:
然而,它()具有高方差,因为它对一半的样本是负的,而 KL 始终是正的。
作者指出,尽管 是无偏的,但如果没有对 和 的先验约束,一半的样本会一个比另一个大,所以一半的 值是正的,一半是负的。到目前为止,我都同意。但随后作者说:因为 KL 总是大于 0(一个基本不等式),所以 因此必须具有高方差。而在这里,我认为因果关系并不成立:你不能用期望的符号来决定单个样本的符号。一个简单的反例:在计算期望时, 也时而为正,时而为负;这个事实本身并不能说明方差。实际上,单样本的**对数比率**(无论是 或 )都可以是正的或负的,就像 一样,所以**单独的符号翻转并不是高方差的唯一原因**。
根据 KL 定义: 期望值**保证非负**,但被积函数 可以对单个样本是正的或负的。而 正是这个被积函数: 所以每个样本值确实可以是正的或负的,与 KL 定义中的被积函数相同。
那么为什么 会有高方差?
这不仅仅是“符号翻转”。真正的原因是 的值分布通常很宽(重尾)。例如,如果 对于某些样本来说很小,那么 可能会非常大(正或负)。这些极端值主导有限样本平均值,推高了方差。换句话说,它是**极端值 + 正负抵消**的组合:抵消意味着你需要更多的样本才能收敛到真实平均值,而极端值会使样本方差本身更大。因此,博客中“一半为负”的评论更多的是一种直觉提示,而不是完整的解释。
从这个角度来看,如果我们看其他估计器 和 ,我们发现: 总是正的,所以没有抵消,但这引入了偏差;平方也平滑了幅度,降低了方差。 使用控制变量来消除部分波动源,在保持无偏性的同时降低方差(详细信息见下文)。
在 PPO/GRPO 中,如果您使用 并且批次很小或分布相距很远,KL 估计值将跳来跳去(因为少数极端样本会使平均值剧烈波动)。这使得 KL 惩罚系数不稳定:它可能突然变得过强或过弱。切换到低方差估计器( 或 )使每个样本的 KL 贡献更稳定,更不容易被少数极端样本主导。
为什么 既能无偏又能低方差?
乍一看, 总是正的,所以你可能会认为它的平均值必须大于 的平均值。
但请记住: 是通过**控制变量**从 导出的。博客的推理如下: 其中 ,并且在 下,其期望值为: 因此,添加任何 的倍数都不会改变期望值。当 时: 这解释了为什么 的期望值等于 的期望值,并等于 KL,使其成为一个无偏估计器。
比 具有更低方差的原因是: 只有 ,其值可能剧烈波动(既有正有负,偶尔出现巨大值)。但是 和 在数值上高度相关(一个增长,另一个也增长/收缩),并且这种相关性是**负的**。添加 就像引入一个**负相关项来抵消波动**。抵消后, 中剩下的值范围更紧密,始终为正,因此样本方差更低。