加速LLM推理:使用Gumbel-Max技巧进行快速采样

社区文章 发布于2024年10月24日

引言

大型语言模型(LLM)的推理速度受token采样过程的严重影响。在每个生成步骤中,我们需要从整个词汇表(通常为32K到100K个token)的概率分布中采样下一个token。使用torch.multinomial的标准方法已成为推理管道中的一个显著瓶颈。

传统LLM采样的问题

LLM推理中的传统采样过程如下:

  1. 从模型获取logits
  2. 应用softmax将logits转换为概率
  3. 使用torch.multinomial从概率分布中采样

这种方法有两个主要瓶颈:

  • 在大型词汇表上计算softmax代价高昂
  • 多项式采样操作本身相对较慢

核心洞察:Gumbel-Max采样

我们方法的核心创新来自于对Gumbel-Max技巧的两个关键观察:

  1. 使用Gumbel-Max采样在数学上等同于类别采样

    # Instead of:
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    
    # We can do:
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    next_token = torch.argmax(logits + gumbel_noise, dim=-1)
    
  2. 关键优化:Gumbel噪声可以预先计算

    • 噪声张量独立于logits
    • 我们可以在收到模型输出之前准备好它
    • 这将其从token生成的关键路径中移除
    • 我们完全避免了softmax的计算

A100上的性能结果

我们在A100 80GB上的基准测试显示,在不同规模下都有显著的加速。完整的基准测试代码和实现可以在以下网址找到:https://github.com/NonvolatileMemory/fast_llm_sampling/tree/main

小规模(批次大小=32,词汇大小=32000)

  • 传统:0.600毫秒 ± 0.058毫秒
  • Gumbel-Max:0.214毫秒 ± 0.004毫秒
  • 提速2.8倍

中等规模(批次大小=128,词汇大小=50000)

  • 传统:4.549毫秒 ± 2.609毫秒
  • Gumbel-Max:1.294毫秒 ± 0.009毫秒
  • 提速3.5倍

大规模(批次大小=512,词汇大小=100000)

  • 传统:64.386毫秒 ± 2.748毫秒
  • Gumbel-Max:30.544毫秒 ± 1.725毫秒
  • 提速2.1倍

实现细节

高效实现的关键在于正确的噪声预计算

class GumbelSampler:
    def __init__(self, batch_size, vocab_size, device):
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        # Pre-compute noise
        self.noise = self._prepare_gumbel_noise(device)
    
    def _prepare_gumbel_noise(self, device):
        # Generate noise tensor once
        uniform_noise = torch.rand(self.batch_size, self.vocab_size, device=device)
        return -torch.log(-torch.log(uniform_noise))
    
    def sample(self, logits):
        # Direct sampling without softmax
        return torch.argmax(logits + self.noise, dim=-1)

社区

注册登录 发表评论