Bamba:推理高效的混合 Mamba2 模型 🐍

摘要
我们介绍 Bamba-9B,这是一个由 IBM、普林斯顿大学、卡内基梅隆大学和伊利诺伊大学厄巴纳-香槟分校在完全开放的数据上训练的推理高效混合 Mamba2 模型。在推理时,与 vLLM 中的标准 Transformer 相比,该模型的吞吐量提升了 2.5 倍,延迟降低了 2 倍。为促进社区实验,该模型可立即在 transformers
、vLLM
、TRL
和 llama.cpp
中使用。我们还发布了带有状态化数据加载器的微调、训练和扩展预训练方案,并邀请社区进一步改进该模型。让我们一起克服 KV 缓存瓶颈!
产出物 📦
动机 🌟
Transformer 模型在实际应用中越来越广泛,但在推理过程中面临内存带宽瓶颈,尤其是在长上下文长度模型中进行逐个 Token 解码时。低精度、层剪枝和压缩等技术可以缓解此问题,但并未解决根本原因,即随着上下文长度的增加,KV 缓存所需的内存量不断增长。新兴架构如 Mamba、Griffin 和 DeltaNet 通过使 KV 缓存大小恒定来消除这一瓶颈。Mamba 架构最近在社区中获得了极大的关注。例如,Jamba 和 Samba 将 Mamba 层与 Transformer 层交错,探索由此产生的混合 Mamba 模型。Codestral Mamba,一个纯 Mamba2 模型,在编码任务上展示了最先进(SOTA)的结果,而 NVIDIA 的混合 Mamba2 模型在长上下文和传统 LLM 基准测试中取得了有竞争力的性能。近期的创新,如 Falcon Mamba 和 Falcon 3 Mamba 在发布时在 Hugging Face 排行榜上取得了 SOTA 排名。
我们介绍了 Bamba-9B,这是一个在 2.2T Token 上训练的混合 Mamba2 模型,进一步验证了这些新兴架构。这项由 IBM、普林斯顿大学、卡内基梅隆大学和伊利诺伊大学厄巴纳-香槟分校合作的项目提供了完整的训练沿袭、模型检查点和预训练代码,以支持可复现性和实验。发布的检查点的训练数据集不包含任何基准对齐的指令数据(FLAN 除外),以保留扩展预训练和微调的灵活性。我们的目标是通过在中低规模模型(7B-10B)上展示强大的性能,来展示混合 Mamba2 架构的潜力,并为社区提供完全可复现且使用开放数据集训练的检查点。
为了促进社区实验,我们还发布了一个分布式无状态洗牌数据加载器,并在开源库如 transformers
、TRL
、vLLM
和 llama.cpp
中启用了混合 Mamba2 架构。我们希望这些努力能推动 Mamba 架构的采用,缓解 KV 缓存瓶颈,并缩小与 SOTA 开源模型的差距。
在 transformers 中的使用 🤗
要将 Bamba 与 transformers 一起使用,您可以使用熟悉的 AutoModel
类和 generate
API。更多详情,请遵循 Bamba GitHub 中概述的说明。
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")
message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
评估 📊
我们将评估分为三个部分
- 与当前最先进的 Transformer 模型的比较
- 与具有相似 Token 预算的 Transformer 模型的比较
- 与其他 Mamba 变体的比较。
评估设置 ⚙️ 🖥️: 我们遵循此处的设置和脚本重新运行了所有基准测试,但 NVIDIA Mamba2 混合模型除外。我们无法对 NVIDIA Mamba2 混合模型进行基准测试,因为其模型权重不兼容 Hugging Face transformers 格式。因此,我们报告了原始论文中的数据。对于 v2 排行榜结果,我们执行了归一化并报告了归一化结果。在所有评估中,除非另有说明,否则越高越好。
评估摘要
Bamba-9B 展示了混合 Mamba 模型相较于 Transformer 模型的竞争力。尽管在数学基准和 MMLU 分数(MMLU、GSM8K、MMLU-PRO、MATH Lvl 5)上存在差距,但排除这些基准后,其平均性能几乎与 Meta Llama 3.1 8B(Llama 为 44.68,Bamba 为 45.53)相当,而后者是在 7 倍多数据上训练的模型。这些差距可以通过以下方式解决:(a) 用更多 Token 进行扩展预训练(MMLU 分数在训练期间稳步提高),以及 (b) 在预训练/退火阶段加入高质量的数学数据。未来的计划包括使用更新的数据集,如 Olmo2 mix,并使用基准对齐的混合数据集(如 Dolmino mix)进行退火。
Bamba-9B 的结果也缓解了人们对 NVIDIA 混合 Mamba2 模型在排行榜基准测试中得分相对较低的担忧。NVIDIA 研究的目标是在相同条件下比较不同架构。与其发现一致,Bamba-9B 再次确认了混合 Mamba2 架构在提供与 Transformer 模型相当性能的同时,可提供高达 5 倍的推理效率。
与当前最先进的 Transformer 模型的比较
我们将 Bamba-9B 与类似规模的 SOTA Transformer 模型(Meta Llama 3.1 8B、IBM Granite v3 8B、Olmo2 7B 和 Gemma 2 9B)进行比较。我们观察到,虽然存在明显的基准差距,但尚不清楚这些差距是否指向基于 Mamba/Mamba2 模型的缺陷。实际上,仔细分析表明,差距主要是由于训练模型所用的数据量以及在退火阶段是否包含基准对齐的指令数据集。例如,我们进行了一次小规模实验,添加了 metamath
数据集,结果我们的 GSM8k
分数从 36.77
提高到了 60.0
。我们将在即将发布的论文中公布详细的分析和发现。
HF OpenLLM v1 排行榜
HF LLM-V1 + OpenbookQA 和 PIQA
模型 | 平均分 | MMLU | ARC-C | GSM8K | Hellaswag | OpenbookQA | Piqa | TruthfulQA | Winogrande |
---|---|---|---|---|---|---|---|---|---|
Bamba 9B | 62.31 | 60.77 | 63.23 | 36.77 | 81.8 | 47.6 | 82.26 | 49.21 | 76.87 |
Meta Llama 3.1 8B | 63.51 | 66.26 | 57.85 | 49.96 | 81.98 | 46.8 | 82.54 | 45.16 | 77.51 |
Olmo2 7B | 66.17 | 63.96 | 64.51 | 68.01 | 81.93 | 49.2 | 81.39 | 43.32 | 77.03 |
IBM Granite v3 8B | 67.47 | 65.45 | 63.74 | 62.55 | 83.29 | 47.6 | 83.41 | 52.89 | 80.82 |
Gemma2 9B | 68.38 | 72.29 | 68.26 | 67.4 | 82.56 | 47.8 | 83.24 | 45.39 | 80.11 |
Qwen2.5 7B | 70.58 | 75.41 | 63.82 | 83.24 | 80.23 | 48.40 | 81.28 | 56.34 | 75.93 |
HF LLM-V2** :
模型 | 平均分 | MMLU-PRO | BBH | GPQA | IFEval | MATH Lvl 5 | MuSR |
---|---|---|---|---|---|---|---|
Bamba 9B | 10.91 | 17.53 | 17.4 | 4.14 | 15.16 | 1.66 | 9.59 |
Meta Llama 3.1 8B | 14.27 | 25.46 | 25.16 | 8.61 | 12.55 | 5.14 | 8.72 |
Olmo2 7B | 13.36 | 22.79 | 21.69 | 4.92 | 16.35 | 4.38 | 10.02 |
IBM Granite v3 8B | 21.14 | 25.83 | 28.02 | 9.06 | 44.79 | 9.82 | 9.32 |
Gemma2 9B | 21.79 | 34.84 | 34.81 | 11.07 | 21.28 | 13.44 | 15.3 |
Qwen2.5 7B | 25.13 | 37.62 | 35.62 | 9.96 | 34.77 | 18.35 | 14.6 |
安全评估
安全基准对于确保 AI 模型生成的内容符合道德、包容且无害至关重要。我们在著名的安全基准上评估了我们的模型,例如 Toxigen(5-shot, logits)(专注于检测有毒语言)、BBQ(5-shot, generation)、PopQA(5-shot, generation)以及 CrowS-Pairs(5-shot, logits)(衡量偏见和公平性)。我们打算通过全面的 SFT 和 DPO 方法来解决这些安全方面的差距。
模型 | PopQA | Toxigen | BBQ | Crow-SPairs* |
---|---|---|---|---|
Bamba 9B | 20.5 | 57.4 | 44.2 | 70.8 |
Meta Llama 3.1 8B | 28.77 | 67.02 | 59.97 | 70.84 |
IBM Granite v3 8B | 27.5 | 79.9 | 82.1 | 75 |
Olmo2 7B | 25.7 | 63.1 | 58.4 | 72 |
Olmo1.5 7B | 20.4 | 56.7 | 53.3 | 72.2 |
Gemma2 9B | 27.3 | 69.6 | 59.9 | 71.7 |
Qwen2.5 7B | 18.2 | 64.1 | 78.1 | 70 |
*越低越好
与具有相似 Token 预算的 Transformer 模型的比较
我们挑选了几个著名的模型:在相同数据上训练的 Olmo 7B (2024),Meta Llama 2 7B (2023) 和 IBM Granite 7B (2023),这些模型的训练 Token 量约为 2T。在这些 Transformer 模型中,Olmo 7B 在 8 个关键基准上的平均得分最高。Bamba-9B 的性能优于在相同数量的 Token 和数据集上训练的 Olmo 7B。由于 Bamba-9B 模型有 9B 参数,直接比较仍然困难,但主要结论是,混合 Mamba2 模型与具有相似 Token 预算的 Transformer 模型相比具有竞争力。
模型 | 平均分 | MMLU | ARC-C | GSM8K | Hellaswag | OpenbookQA | Piqa | TruthfulQA | Winogrande |
---|---|---|---|---|---|---|---|---|---|
Bamba 9B (2.2T) | 62.31 | 60.77 | 63.23 | 36.77 | 81.8 | 47.6 | 82.26 | 49.21 | 76.87 |
Olmo1.5 7B (2T) | 55.8 | 53.38 | 50.51 | 27.67 | 79.13 | 45.2 | 81.56 | 35.92 | 73.09 |
Bamba 9B (2T) | 59.11 | 59.05 | 57.25 | 24.03 | 83.66 | 47.6 | 83.62 | 38.26 | 79.4 |
Meta Llama2 7B (2T) | 53.78 | 46.64 | 52.65 | 13.57 | 78.95 | 45.2 | 80.03 | 38.96 | 74.27 |
IBM Granite 7B (2T) | 52.07 | 49.02 | 49.91 | 10.84 | 77.0 | 40.8 | 80.14 | 38.7 | 70.17 |
Mamba/Mamba2 比较
与基于 Mamba/Mamba2 架构的语言模型的比较
在过去 6 个月里,多个基于 Mamba/Mamba2 架构的模型开始出现(例如,NVIDIA 混合 Mamba2、Codestral Mamba、Falcon Mamba 和 Zamba 7B v1),进一步提升了这些架构的性能,展示了它们优越的推理性能,并缩小了与 Transformer 模型在基准测试结果上的差距。我们比较了 Bamba-9B、NVIDIA 混合 Mamba2、Zamba 和 Falcon Mamba 在 8 个关键基准上的表现。
Falcon Mamba 是一个纯 Mamba 模型,Zamba 每 6 个 Mamba 层共享一个注意力层,而 Bamba-9B 和 NVIDIA 都是混合模型,其中穿插着完整的注意力层和 Mamba2 层。Falcon Mamba 经过 5.5T Token 的训练,整体表现最佳,但在长上下文任务上的表现仍有待观察,而基于 Mamba 的架构在这些任务的推理性能上真正大放异彩。Zamba 训练的 Token 数量较少(1T),但采用了不同的混合架构,并使用了基准对齐的指令数据集,包括那些由更强大的语言模型生成的数据集。Bamba-9B 和 NVIDIA 混合 Mamba2 非常相似(差异细节在模型架构部分总结),但 Bamba-9B 训练了 2.2T Token,而 NVIDIA 混合 Mamba 训练了 3.5T Token。
注意:在撰写此博客时,Falcon3 Mamba 7B 已经发布,其结果甚至优于 Falcon Mamba。我们计划借鉴 Falcon3 Mamba 的任何经验,并在我们下一个 Bamba 版本中进行改进。
模型 | 平均分 | MMLU | ARC-C | GSM8K | Hellaswag | OpenbookQA | Piqa | TruthfulQA | Winogrande |
---|---|---|---|---|---|---|---|---|---|
Bamba 9B | 62.31 | 60.77 | 63.23 | 36.77 | 81.8 | 47.6 | 82.26 | 49.21 | 76.87 |
NVIDIA Mamba2 混合 8B* | 58.78 | 53.6 | 47.7 | 77.69 | -- | 42.8 | 79.65 | 38.72 | 71.27 |
Zamba 7B | 64.36 | 57.85 | 55.38 | 61.33 | 82.27 | 46.8 | 82.21 | 49.69 | 79.32 |
Falcon Mamba 7B | 65.31 | 63.19 | 63.4 | 52.08 | 80.82 | 47.8 | 83.62 | 53.46 | 78.14 |
* 结果取自 NVIDIA 论文。
💡 注意: 训练数据集和训练过程中见过的 Token 数量的差异使得直接比较这些模型变得困难。从这个表中可以得出的关键结论是,混合 Mamba2 架构可以提供有竞争力的结果,同时训练效率几乎与 Transformer 模型一样高。此外,尽管穿插了完整的注意力层和 Mamba2 层,它们仍可以在推理效率上实现显著提升(理论上高达 5 倍)。我们正在继续使用最新的数据集对 Bamba-9B 模型进行预训练,并计划在模型改进时发布未来的检查点。
推理效率 ⚡🏎️
KV 缓存瓶颈是大型语言模型面临的主要挑战,这促使了量化、剪枝以及 Mamba2、线性 Transformer 和 RetNets 等新颖架构的解决方案。即使是标准 Transformer,要实现规模化的推理效率,也通常需要自定义内核。Bamba-9B 建立在社区内核可用性的势头之上,通过与 vLLM 模型服务框架的集成进一步改进。
我们在 vLLM 集成方面的进展通过 此 PR 进行跟踪,将 Bamba-9B 与 Meta Llama 3.1 8B 在 NVIDIA H100 80GB GPU 上进行基准测试。我们使用 1K Token 的输入大小和 2K 到 64K 的输出大小,在不同的批处理大小下,测量了吞吐量(Token/秒)和延迟。结果显示,随着批处理大小和序列长度的增加,Bamba-9B 的吞吐量和延迟比 Transformer 模型提高了 2-2.5 倍。这些收益增强了实时应用和 GPU 利用率,更高的吞吐量比率(>1)和更低的延迟比率(<1)是有益的。
我们的分析表明,在 H100 NVIDIA GPU 上,当推理转向内存瓶颈时(这通常发生在生产环境中),我们预计会有 5 倍的加速——请参阅附录中的计算强度部分。然而,由于以下三个主要原因,我们尚未在 vLLM 中实现这种加速
- 分块预填充(Chunked pre-fill)尚不支持 Bamba 和任何基于 Mamba2 的架构
- 内存分配假设为标准 Transformer KV 缓存
- Mamba2 内核未针对 H100 GPU 进行优化
这些问题正在这里进行跟踪。
模型架构
我们的模型架构基于 NVIDIA 混合 Mamba2,但有以下改动。
参数 | Bamba 9B | NVIDIA 混合 Mamba2 8B |
---|---|---|
层数 | 32 | 29 |
注意力层数 | 3 | 4 |
Mamba2 层数 | 29 | 25 |
MLP 扩展因子 | 3.5 | 4 |
词汇表大小 | 128k | 256k |
非嵌入参数 | 8.8B | 8.6B |
RoPE | 是的 | 否 |
门控线性单元 | 是的 | 否 |
我们总共有 8B 参数在 Mamba2 层中,800M 在全注意力层中,1B 在嵌入层中。隐藏状态大小为 4K,全注意力的 GQA 有 8 个 KV 头和 32 个头,Mamba2 层的头维度为 64,卷积滤波器大小为 4。两个模型之间最显著的变化是将全注意力层从 NVIDIA 混合 Mamba2 模型中的 4 层减少到 Bamba-9B 中的 3 层,并引入了 RoPE 嵌入。
数据
自 The Pile 数据集问世以来,开源数据已经取得了长足的进步。当我们开始训练这个模型时,最好的开源数据是 Dolma v1.7,通过 Olmo 模型和 Hugging Face 数据团队的消融实验证明其性能非常出色。此后,又发布了几个更高质量的开源数据集,例如 DCLM、FineWeb-2 和 Olmo2 mix。
我们在第一阶段训练中使用 Dolma v1.7,选择的数据混合如下所示。在第二阶段训练中,我们使用了 Fineweb-edu 和 Cosmopedia。这些数据集以其原始形式下载,我们使用在内部大规模 Red Hat Open Shift 集群上运行的 Ray 框架对它们进行分词。我们计划尽快发布分词和格式化的 parquet 数据,以实现可复现性。
预训练第一阶段的数据混合
预训练
Bamba 的预训练分阶段进行,我们进行了几次 1.8B 模型大小和 100B Token 的消融实验以确定正确的学习率。基于这项研究的有希望的结果,我们使用 Dolma mix 训练了一个更大规模的模型——3B 到 2T Token。我们还使用相同的数据混合训练了一个遵循 Meta Llama 架构的 3B Transformer 模型,并观察到 Bamba 模型的性能相似或更好,这与 NVIDIA 同时进行的研究得出的结论一致。最后,我们设计了一个 9B 模型架构,并使用相同的混合数据重新训练。PyTorch FSDP 用于训练我们所有的模型。
训练细节:我们使用了余弦学习率调度,峰值学习率为
3e−4
,在 2000 步内进行二次预热,衰减因子为 0.033,在 2T Token 上的结束学习率为1e−5
。我们使用了 AdamW 优化器,β1
为 0.9,β2
为 0.95。我们使用了 0.1 的权重衰减,4096 的序列长度,以及 1.5M Token/批次的全局批处理大小。我们使用了来自 IBM Cloud Vela 生产集群的 192 个 A100 GPU,在 2 个月的时间内训练了这个模型。该集群由 Red Hat OpenShift 管理。我们经历了 3 次作业中断,原因是作业部署不正确和硬件故障。硬件相关的作业故障是使用 autopilot 自动检测的。我们还使用来自 Hugging Face 的高质量数据 FineWeb-edu 和 Cosmopedia 进行了第二阶段的训练,额外训练了 200B Token。我们使用了 2e-5 的学习率和一个余弦调度来退火模型,这有助于提高我们的分数。我们目前正在试验额外的高质量数据,并将作为我们对开源承诺的一部分发布任何未来的检查点。
数据加载器
训练高质量语言模型有几个方面,数据加载器是其中重要的一环。在过去的 18 个月里,我们一直致力于开发一个能满足大规模分布式训练需求的数据加载器。我们开源了这个数据加载器,以便其他人可以将其与他们选择的框架结合使用。我们在 Bamba 模型训练中使用了它,并将其与 Torch Titan 集成。到目前为止,我们相信这是唯一一个提供如此丰富功能的开源数据加载器。
该数据加载器提供以下关键功能
- 有状态且可检查点,以确保在周期中无缝恢复
- 自动扩展以适应变化的工作负载和 GPU 分配
- 数据流式传输,零开销进行数据混洗
- 异步分布式操作,无点对点通信
- 允许动态数据混合和即时分词
- PyTorch 原生、模块化和可扩展
我们已经在数百个训练作业中对这个数据加载器进行了实战测试,并在数月的持续运行中对其进行了优化。主要代码库位于我们的仓库 这里,我们还与 Torch Titan 团队合作,使其在这里可用。我们正在与 Meta PyTorch 团队合作,将这个数据加载器贡献到 PyTorch 核心中。
量化
我们最近开源了一个用于模型量化的框架。通过这个框架,我们利用 llm-compressor 将 Bamba 检查点量化为 fp8
。我们观察到,在 OpenLLM 排行榜的所有基准测试中,准确度损失极小。具体来说,对于 Bamba 9B,V1 的平均得分差异可以忽略不计,为 0.1
(从 62.31
降至 61.5
),而 V2 的平均得分下降了 0.9
(从 10.91
降至 10.04
)。这些量化后的检查点也与 bf16
对应版本一起发布。这也验证了 Bamba 模型与 SOTA Transformer 模型一样,同样适用于量化。
我们正在 vLLM 中为该模型启用 fp8
推理,这将需要更新内核。线性层和全注意力层将很容易处理,但 Mamba2 层将需要更新 Triton/CUDA 内核以处理 fp8
。
上下文长度扩展
我们目前正在探索各种长上下文长度扩展的方法,首先是应用 LongRope 到全注意力层。我们使用 PhoneBook 检索作为任务的初步发现表明,LongRoPE 可以应用于该模型。我们将 Bamba-9B 的上下文长度扩展了 4 倍和 8 倍,并将上下文扩展后的 Bamba-9B 与 Meta Llama 的三个变体——LLama2、Llama3、LLama3.1 进行比较,它们的训练上下文长度分别为 4K、8K 和 128K。结果绘制如下。
我们观察到,上下文扩展后的 Bamba-9B 模型在未经任何调整的情况下,在高达 16K 的上下文长度下表现非常出色,大幅超越了原始的 Bamba-9B 模型、Llama2-7B 和 Llama3-8B,并获得了与 Llama3.1-8B 相当的性能。在序列长度为 32K 时,LLama3.1 取得了最佳性能结果。我们计划在准备就绪后发布长上下文长度扩展模型。
总结 🎯
Bamba-9B 是由 IBM、普林斯顿大学、卡内基梅隆大学和伊利诺伊大学厄巴纳-香槟分校合作开发的一款性能强大的混合 Mamba2 模型。该模型完全在开放数据集上训练,我们正在发布中间和最终检查点。为了促进社区实验,该模型可立即在 transformers
、vLLM
、TRL
和 llama.cpp
中使用。我们还发布了带有状态化数据加载器的微调、训练和扩展预训练方案,并邀请社区进一步改进该模型。
关键要点
推理效率:Bamba-9B 在吞吐量和延迟方面实现了显著提升,增强了实时应用性能。使用 vLLM 对比 Llama 3.1 8B 的基准测试显示,吞吐量提升了 2.5 倍,延迟降低了 2 倍,并且未来还会有更多改进!
有竞争力的基准:Bamba-9B 的性能与 Meta Llama 3.1 8B 等最先进的 (SoTA) Transformer 模型相比具有竞争力。在排除数学和 MMLU 任务后,它的平均基准性能与它们相当,并且有机会通过扩展训练和专注于数学的数据集来缩小这些差距。
开放合作:模型的开发利用了开放数据,促进了人工智能社区内的透明度和可复现性。
有关更多详细信息以及访问模型和相关资源,请访问 Bamba GitHub 仓库。
未来工作
我们打算探索几个方向,并进一步研究推理高效的 mamba2 混合架构
- 通过在额外数据上持续预训练来不断改进模型;我们欢迎社区的任何反馈,以便我们能够共同创建一个出色的 Mamba2 混合模型。
- 使用 SFT 数据集(如 Tuluv3、agent instruct 和 Anteater)对基础模型进行 SFT,并将结果模型与其他最先进的指令微调模型进行比较。
- 与社区合作,在 vLLM 中启用该模型。分块预填充和管理该架构的内存分配问题将是关键。
- 启用
fp8
内核以使推理更快。 - 训练时间改进和应用
torch.compile
以及fp8
训练,我们团队已在与 Meta 合作的 Transformer 架构上展示了这两项技术。 - 长上下文长度扩展至 1M+
贡献者
- 数据收集和整理:我们感谢并感谢 AllenAI 团队提供了高质量的开源数据集 Dolma,以及 Hugging Face 数据团队提供了 FineWeb-edu 和 Cosmopedia。这些都是巨大的贡献,使我们能够创建这个模型。
- 数据预处理:我们感谢 IBM 内部的数据预处理团队,特别是 Tuan Hoang Trong、Syed Zawad、Jay Gala 和 Ryan Gordon,他们帮助我们大规模地对数据进行分词。分词代码可在此处获取:这里。
- 模型架构:模型架构设计由普林斯顿大学、卡内基梅隆大学、IBM 和伊利诺伊大学厄巴纳-香槟分校共同完成,参与人员包括:Tri Dao (普林斯顿大学)、Albert Gu (卡内基梅隆大学)、Linsong Chu (IBM)、Davis Wertheimer (IBM)、Minjia Zhang (伊利诺伊大学厄巴纳-香槟分校)、Mudhakar Srivatsa (IBM) 和 Raghu Ganti (IBM)。
- 模型训练:模型训练主要由 IBM 团队使用 Tri Dao 和 Albert Gu 的 Mamba2 内核和层实现来完成。IBM 的以下人员主要参与其中:Linsong Chu、Divya Kumari、Davis Wertheimer、Raghu Ganti 和 Dakshi Agrawal。
- 模型微调:模型的微调由 IBM 团队在 TRL 中启用和验证,参与人员包括 Sukriti Sharma 和 Anh Uong。
- 模型推理:在
transformers
、vLLM
和llama.cpp
中的模型推理建立在普林斯顿大学和卡内基梅隆大学编写的内核之上。IBM 团队正在与社区合作,以便在各种生态系统中启用它。该团队包括 Fabian Lim、Antoni viros i Martin、Adnan Hoque、Jamie Yang、Nelson Nimura Gonzalez、Joshua Rosenkranz、Nick Hill 和 Gabe Goodhart。 - 量化:量化由 IBM 团队领导 - Naigang Wang 和 Charlie Liu。
- 评估:评估由 IBM 的一个团队领导,长上下文评估由伊利诺伊大学厄巴纳-香槟分校执行,参与人员包括:Yotam Perlitz、Ofir Arviv、Michal Shmueli-Scheuer (IBM)、Haoechen Shen 和 Minjia Zhang (伊利诺伊大学厄巴纳-香槟分校)。
最后,我们要感谢我们的领导层对这项工作的支持——Priya Nagpurkar、David Cox、Sriram Raghavan、Aya Soffer、Ruchir Puri 和 Mukesh Khare。
我们还要感谢社区,特别是来自 Hugging Face 的 Pablo Montalvo-Leroux、Aritra Roy Gosthipaty 和 Vaibhav Srivastav,以及来自 Contextual AI 的 Stas Bekman,他们为这篇博客和向 transformers 提交的 PR 提供了宝贵的反馈。此外,我们还要感谢来自 Neural Magic 的 Tyler Michael Smith,他正在指导与 vLLM 的集成。
特别感谢 Meta PyTorch、AllenAI 和 Hugging Face 团队对开放计划的贡献,PyTorch FSDP 让我们能够顺利地训练这个模型,而来自 Dolma 和 Fineweb/Cosmopedia 的数据使这个模型得以诞生!
附录:计算强度
使用以下符号
$b$:批处理大小
$s$:序列长度
$h$:隐藏状态大小 (4096)
$d$:头维度 (128)
$l$:总层数 (32)
$l_{attn}$:注意力层数 (3)
$l_{ssd}$:SSD 层数 (29)
注意力模型和 Bamba 模型都配置了 4:1 的 GQA(在注意力层中),MLP 扩展比为 3.5,并在 MLP 块中使用 GLU。Bamba 中的 SSD 层配置的状态维度为 $d$,头维度为 $d/2$,头数为 $4h/d$。不包括嵌入层的模型大小为
模型类型 | 模型大小 |
---|---|
注意力 | $13h^2l$ |
Bamba | $15.5h^2l$ |
在预填充阶段,模型施加的计算和内存(读+写)要求是
模型类型 | 计算预填充 | 内存预填充 |
---|---|---|
注意力 | $26bsh^2l + 4bs^2hl$ | $13h^2l + 0.5bshl$ |
Bamba | $31bsh^2l + 4bs^2hl_{attn} + 4bsdhl_{ssd}$ | $15.5h^2l + 0.5bshl_{attn} + 4bdhl_{ssd}$ |
在解码阶段,模型施加的计算和内存(读+写)要求是
模型类型 | 计算解码 | 内存解码 |
---|---|---|
注意力 | $26bh^2l + 4bshl$ | $13h^2l + 0.5bshl$ |
Bamba | $31bh^2l + 4bshl_{attn} + 4bdhl_{ssd}$ | $15.5h^2l + 0.5bshl_{attn} + 4bdhl_{ssd}$ |
下文显示了 Bamba 和 LLaMa 模型在预填充阶段的计算浮点运算和解码阶段的内存(读+写)大小的比较。请注意,小于 1 的比率是有益的。由于推理吞吐量主要受解码阶段的瓶颈限制,对于长序列(> 16K),Bamba(相对于 LLaMa)的潜在加速可达 5 倍。目前的测量结果(在 vLLM 上)徘徊在 2.5 倍左右,我们预计在不久的将来会有所改善。