加速LLM推理:使用Gumbel-Max技巧进行快速采样
社区文章 发布于2024年10月24日
引言
大型语言模型(LLM)的推理速度受token采样过程的严重影响。在每个生成步骤中,我们需要从整个词汇表(通常为32K到100K个token)的概率分布中采样下一个token。使用torch.multinomial
的标准方法已成为推理管道中的一个显著瓶颈。
传统LLM采样的问题
LLM推理中的传统采样过程如下:
- 从模型获取logits
- 应用softmax将logits转换为概率
- 使用
torch.multinomial
从概率分布中采样
这种方法有两个主要瓶颈:
- 在大型词汇表上计算softmax代价高昂
- 多项式采样操作本身相对较慢
核心洞察:Gumbel-Max采样
我们方法的核心创新来自于对Gumbel-Max技巧的两个关键观察:
使用Gumbel-Max采样在数学上等同于类别采样
# Instead of: probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # We can do: gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) next_token = torch.argmax(logits + gumbel_noise, dim=-1)
关键优化:Gumbel噪声可以预先计算
- 噪声张量独立于logits
- 我们可以在收到模型输出之前准备好它
- 这将其从token生成的关键路径中移除
- 我们完全避免了softmax的计算
A100上的性能结果
我们在A100 80GB上的基准测试显示,在不同规模下都有显著的加速。完整的基准测试代码和实现可以在以下网址找到:https://github.com/NonvolatileMemory/fast_llm_sampling/tree/main
小规模(批次大小=32,词汇大小=32000)
- 传统:0.600毫秒 ± 0.058毫秒
- Gumbel-Max:0.214毫秒 ± 0.004毫秒
- 提速2.8倍
中等规模(批次大小=128,词汇大小=50000)
- 传统:4.549毫秒 ± 2.609毫秒
- Gumbel-Max:1.294毫秒 ± 0.009毫秒
- 提速3.5倍
大规模(批次大小=512,词汇大小=100000)
- 传统:64.386毫秒 ± 2.748毫秒
- Gumbel-Max:30.544毫秒 ± 1.725毫秒
- 提速2.1倍
实现细节
高效实现的关键在于正确的噪声预计算
class GumbelSampler:
def __init__(self, batch_size, vocab_size, device):
self.batch_size = batch_size
self.vocab_size = vocab_size
# Pre-compute noise
self.noise = self._prepare_gumbel_noise(device)
def _prepare_gumbel_noise(self, device):
# Generate noise tensor once
uniform_noise = torch.rand(self.batch_size, self.vocab_size, device=device)
return -torch.log(-torch.log(uniform_noise))
def sample(self, logits):
# Direct sampling without softmax
return torch.argmax(logits + self.noise, dim=-1)