理解 Gemma 3n:MatFormer 如何让你在一个模型中拥有多个模型
当我们谈论部署大型语言模型时,话题几乎总是落在一个熟悉的权衡上:你可以拥有一个更大、更智能的模型,或者一个更小、更快、能适应你硬件的模型。这似乎是常识,对吧?你在性能与资源的曲线上选择一个点,然后就固定下来。
但如果你不必如此呢?如果你可以训练一个大模型,并免费获得一整套小而高性能的模型呢?
这就是谷歌 Gemma 3n 背后的核心思想,它建立在一个名为 Matryoshka Transformer (套娃 Transformer),或称 MatFormer 的迷人架构之上。这是一项巧妙的工程设计,改变了我们对模型效率的看法。
让我们像上次在 https://huggingface.co/blog/rishiraj/kld-guided-quantization 中一样,一起来分解这个问题。我们将从核心架构思想开始,逐步了解它如何在推理时为我们提供如此大的灵活性。
套娃原则:一个模型,多种尺寸
你知道下图中的俄罗斯套娃吗?打开一个,会发现里面有一个更小的、一模一样的娃娃,再打开那个,里面还有一个。这正是 MatFormer 的完美心智模型。
在标准的 Transformer 模块中,前馈网络 (FFN) 有一个固定的中间层大小。例如,它可能接收一个 4096 维的输入,将其扩展到一个 16384 维的中间层 (W_in
),然后再将其投影回 4096 维 (W_out
)。这些维度是固定的。
MatFormer 改变了这一点。在每个 Transformer 层内部,它不只有一个 FFN,而是一系列 *嵌套* 的 FFN。这不仅仅是概念上的嵌套,而是字面上的。较小 FFN 的权重矩阵是较大 FFN 的子矩阵。
让我们具体点。如果最大的 FFN (我们称之为尺寸 S
) 的权重矩阵是 W_in
(4096x16384) 和 W_out
(16384x4096),那么下一个较小的 FFN (S/2
) 将只使用这些矩阵的左上部分——比如说,W_in
的前 8192 列和 W_out
的前 8192 行。S/4
FFN 将使用前 4096 列/行,依此类推。它们物理上嵌入在同一个参数块内。
那么,如何训练这样的东西而不会让较小的网络落后呢?
诀窍在于训练过程,这是一种随机深度或随机路径训练的形式。在每个训练步骤中,对于每一层,模型会随机选择一个“容量因子”——S
、S/2
、S/4
等。该层的输入随后只通过那个特定的子网络进行前向传播。有一次,一个输入可能在第 1 层通过 S/2
FFN,在第 2 层通过 S/8
FFN。下一次,它可能在两层都使用完整的 S
FFN。
通过让每个子模块都有平等的机会看到数据、计算梯度和更新其权重,训练确保了 *所有* 子模块都变得有能力。较小的网络不仅仅是弱的近似;它们是经过明确和稳健训练的。结果是,你不仅仅是在训练一个大模型。你同时在训练指数级数量的、更小的、有效的、并且都嵌套在同一组权重内的子模型。
回报:推理时“选择你的战士”
现在看看下面的架构,因为这是架构优雅在实践中得到回报的地方。因为每个子模型都是一个经过充分训练、可行的网络,所以在运行模型时你会获得令人难以置信的灵活性。
1. 简单的缩小: 假设你训练了一个大模型,但需要将其部署在只有四分之一内存的设备上。使用 MatFormer,你可以简单地决定在 *每一层* 都使用 S/4
大小的 FFN 子模块。你会立即得到一个大小约为原始模型 1/4 的模型。至关重要的是,由于这个配置是经过明确训练的,它的性能明显优于在那个较小尺寸下从头开始训练的独立模型。它得益于与更大、更强大的路径共同训练所带来的“知识转移”。
2. “混搭”杰作: 这才是真正有趣的地方。在 Transformer 中,并非所有层对每个任务的贡献都相同。早期层可能处理语法和局部模式,而深层则管理更抽象的语义推理。
使用 MatFormer,你可以在不同层之间“混搭”子模块,以创建定制的架构。你可以对模型进行性能分析,找到对你的任务最关键的层,并为它们分配更大的 FFN (如 S
或 S/2
),同时通过使用较小的 FFN (如 S/8
) 来节省不那么关键的层的容量。
例如,如果你确定第 5 层对于处理翻译任务中的语法细微差别至关重要,你可以为其分配完整的 S
FFN。但如果第 20 层的影响较小,你可以将其缩小到 S/8
,从而在对该特定任务的性能损失最小的情况下,节省大量的计算和内存。这使你能够构建一个定制的模型,以最佳方式平衡性能和资源使用。
内存魔法:50 亿参数如何装入 20 亿参数的内存占用空间
所以,我们有了 MatFormer 这种灵活的计算结构。但 Gemma 3n 还有另一个锦囊妙计,而且完全关乎内存。你可能已经看到 Gemma 3n 2B 模型 (E2B) 实际上有大约 50 亿个真实参数,但它占用的 GPU 内存却与典型的 2B 模型相当。这怎么可能?
答案是 逐层嵌入 (Per-Layer Embeddings, PLE)。
在标准的语言模型中,词元嵌入表是一个单一、庞大的内存块。它是一个大小为 词汇表大小 x 隐藏层维度
的巨型查找表,必须驻留在你的 GPU 显存 (VRAM) 中。让我们用数字来说明。对于一个拥有 256,000 个词元的词汇表和 2048 维隐藏层的模型,使用 bfloat16 (每个参数 2 字节),仅嵌入表就需要 256,000 * 2048 * 2 字节 ≈ 1.05 GB
。在你处理单个词元之前,这是一个巨大的、静态的成本。
PLE 巧妙地避开了这个问题,它将嵌入权重从高速但稀缺的 GPU VRAM 卸载到容量大得多但速度较慢的 CPU RAM。当模型需要处理一个输入序列时,它不会加载整个表。相反,它只通过 PCIe 总线将该序列中词元特定的嵌入向量从 CPU 拉到 GPU。
这是一个经典的工程权衡。你接受了从 CPU 到 GPU 数据传输带来的微小延迟,但作为回报,你释放了大量的 VRAM。这使得一个拥有更大 *真实* 参数数量的模型能够在受限的内存预算内运行。
这正是 Gemma 3n 家族的构建方式。4B 模型 (E4B,实际上是 5.44B 参数) 是完整的模型。2B 模型 (E2B) 是其内部的一个子网络,通过结合两样东西创建:
- MatFormer: 选择较小的 FFN 子模块以减少计算量和活动参数数量。
- 逐层嵌入 (Per-Layer Embeddings): 使用内存卸载来管理完整的 5B 参数集的内存占用。
最后一块拼图:利用 KV 缓存共享加速长上下文
对于涉及长序列的任务,如总结文档或处理长音频剪辑,键值 (KV) 缓存通常是主要瓶颈。在自回归生成中,模型会存储所有先前词元的计算出的键 (Key) 和值 (Value),这样就不必为每个新词元重新计算它们。
这个缓存的大小与序列长度成线性增长,并且可能变得非常巨大:序列长度 * 层数 * 注意力头数 * 注意力头维度 * 2
。对于非常长的上下文,这个缓存很容易超过可用的 VRAM。
Gemma 3n 使用 KV 缓存共享 来缓解这个问题,尤其是在多模态输入中。这项技术允许模型的不同部分或不同模态 (例如,音频和文本) 重用或共享此缓存的部分。通过避免冗余存储,它显著减少了内存压力并加速了“预填充”阶段——即对整个输入提示的初始、昂贵的处理。不过,我目前对这部分的技术理解还不够深入,希望以后能了解更多。
融会贯通
Gemma 3n 不仅仅是模型排行榜上的又一个点。它展示了智能、高效的架构设计。通过结合:
- MatFormer: 实现灵活、嵌套的计算结构,在一个模型中为你提供指数级数量的模型。
- 逐层嵌入 (Per-Layer Embeddings): 实现巧妙的内存管理,让更大的模型能适应更小的空间。
- KV 缓存共享: 加速长上下文、多模态任务。
...你得到了一个天生具有适应性的系统。它让我们摆脱了僵化的“一刀切”方法,赋予开发者权力,可以为他们的特定应用、硬件,甚至是特定输入选择正确的权衡。这是一个强有力的提醒:最激动人心的创新不总是关于规模的扩大,也关乎 *更聪明* 地扩展。