MLA:通过低秩投影和按需解压重塑 KV-Cache

社区文章 发布于 2025 年 2 月 4 日

引言

随着大型语言模型 (LLM) 的蓬勃发展,硬件资源仍然是一个令人生畏的“天花板”——尤其是在 GPU 内存 (VRAM) 有限的情况下。如何在受限资源下实现更长的上下文长度和更快的推理速度,长期以来一直是工程和研究界关注的重点。除了常见的量化和剪枝技术外,人们越来越重视**“在推理时减少 KV-Cache 占用”**。

本文首先回顾了 MHA (Multi-Head Attention)、MQA (Multi-Query Attention) 和 GQA (Grouped-Query Attention) 如何处理或减少 K/V 存储,然后重点介绍 DeepSeek 提出的 MLA (Multi-Head Latent Attention) 方法。与早期主要在“K/V 共享”层面工作的方法不同,MLA 采用了**低秩投影**和**按需解压**的组合,使我们能够绕过直接存储多头 K/V。MLA 使用**“潜在向量”**,并进一步利用矩阵合并技巧,使得在推理过程中,注意力机制只需极少的 VRAM 即可运行。

值得注意的是,MLA 在实际系统中部署时,通常需要适应 RoPE (Rotary Position Embedding)。为了清晰起见,我们将**首先解释 MLA 的核心思想(低秩投影)**,然后再讨论如何集成 RoPE。我们希望这种结构化的方法能深入了解 MLA 设计背后的推理和细微之处。

特别鸣谢:本文的部分灵感来源于苏剑林老师的博客文章缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA。我们向他的工作致敬。


1. 为何要减少 KV-Cache?

1.1 长上下文推理中的“隐形瓶颈”

在 Transformer 的自回归生成过程中,每个新生成的 token 都会引用所有先前 token 的历史 Key/Value ((K, V)) 向量。这些存储的 Key/Value 向量在推理过程中构成了 KV-Cache。如果序列长度为 (L),注意力头数为 (h),每个头的维度为 (d_k) 或 (d_v),则总 KV-Cache 大致按 L×h×(dk+dv), L \times h \times (d_k + d_v), LL 成线性增长。一旦上下文达到数千甚至数万个 token,KV-Cache 就会主导 VRAM 使用,甚至超过模型自身的激活存储。

1.2 VRAM 和带宽的限制

当序列非常长时,将所有 KV-Cache 放在单个 GPU 上变得不可行。将缓存拆分到多个 GPU 或机器上会导致带宽瓶颈——设备间通信通常比设备内内存访问慢得多。简而言之,如果我们能在更少的 GPU 上处理更长的上下文,就能最大限度地减少通信开销并提高吞吐量。这正是 MQA、GQA 和 MLA 出现并不断发展的原因。


2. MHA → MQA → GQA:多头注意力中 K/V 的简化

在介绍 MLA 之前,让我们简要概述一下传统的多头注意力 (MHA) 以及基于共享的 MQA 和 GQA 方法,它们旨在减少 K/V 存储。此上下文为将 MLA 与先前工作进行比较奠定了基础。

2.1 多头注意力 (MHA) 基础

2.1.1 经典注意力公式

在 Transformer 中,token 序列 x1,,xl\mathbf{x}_1,\dots,\mathbf{x}_l 被投影到多组 (Q,K,V)(Q, K, V) 中进行注意力计算。对于第 (s) 个头,假设隐藏维度为 ddqi(s)=xiWq(s),ki(s)=xiWk(s),vi(s)=xiWv(s). \mathbf{q}_i^{(s)} = \mathbf{x}_i \mathbf{W}_q^{(s)},\quad \mathbf{k}_i^{(s)} = \mathbf{x}_i \mathbf{W}_k^{(s)},\quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \mathbf{W}_v^{(s)}. 在自回归解码中,第 (t) 步的注意力分数通常写为 αt,i(s)=qt(s)ki(s),for it. \alpha_{t,i}^{(s)} = \mathbf{q}_t^{(s)} \,\mathbf{k}_i^{(s)\top}, \quad\text{for } i \le t. 为了加速推理,我们将计算出的 ki(s)\mathbf{k}_i^{(s)}vi(s)\mathbf{v}_i^{(s)} 缓存到 VRAM 中以供后续 token 使用,这种存储被称为 KV-Cache

2.1.2 VRAM 压力

由于 MHA 通常为每个头保留不同的 K/V,如果 (h) 很大,您最终会存储 (h) 组 Key/Value,这会迅速耗尽 VRAM。因此,研究人员想知道:我们能否让多头注意力共享或压缩这些 K/V 表示?

2.2 MQA:极致的 K/V 共享

MQA (Multi-Query Attention) 专注于让每个头共享一对 K/V: ki=xiWk,vi=xiWv, \mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k,\quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v, 而每个头仍保留自己的 (\mathbf{q}_i^{(s)})。这样,KV-Cache 就只有 1 组 K/V,而不是 (h) 组。VRAM 使用量可以减少 (1/h) 倍。PaLM 或 StarCoder 等实现都采用了 MQA。然而,由于所有头共享相同的 K/V,某些任务的性能可能会下降,除非采用额外的训练策略。

2.3 GQA:分组头

如果 MQA 过于激进,GQA (Grouped-Query Attention) 提供了一个折衷方案:将 (h) 个头分成 (g) 个组,每个组共享一组 K/V。因此,KV-Cache 缩小到 (g) 组(而不是 (h) 组)。示例包括 LLaMA2-70B 和 ChatGLM2/3。GQA 保留了比 MQA 更大的多样性,但仍比标准 MHA 节省 VRAM。

2.4 MHA / MQA / GQA 对比

方法 KV-Cache 存储 K/V 共享? VRAM 节省 头部多样性
MHA 存储 (h) 份 每个头独立 低(基线)
MQA 存储 1 份 K/V 完全共享 更低
GQA 存储 (g) 份 按组共享 中等 相当高

无论是 MQA 还是 GQA,它们都围绕着“K/V 在不同头之间共享多少”的问题。相比之下,MLA 重新思考了“我们在推理时实际存储了什么”:它将大部分 K/V 内容转移到**潜在向量**中,仅在需要时按需重建。让我们探讨 MLA 如何使用低秩投影和按需解压,暂时不考虑 RoPE。


3. MLA 核心:低秩投影和按需重建(不含 RoPE)

3.1 核心思想:从“存储多头 K/V”到“存储低维潜在向量”

MLA (Multi-Head Latent Attention) 中,我们仍然在训练时将每个输入投影到 Key 和 Value 中,但我们引入了一个低秩潜在向量 (\mathbf{c}_i)。在推理过程中,我们不再缓存高维多头 K/V,而是只存储这个紧凑的 (\mathbf{c}_i),然后在需要时**合并矩阵**以重建多头 K/V。具体来说,暂时忽略 RoPE,我们可以这样表示 MLA 的训练步骤:

ci=xiWc(a low-rank projection, WcRd×dc), \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c \quad (\text{a low-rank projection, } \mathbf{W}_c \in \mathbb{R}^{d \times d_c}),

并为每个头(s) 定义投影矩阵(\mathbf{W}_{kc}^{(s)}, \mathbf{W}_v^{(s)}),使得 ki(s)=ciWkc(s),vi(s)=ciWv(s). \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}. 因此,无论我们有多少个头,只有在训练时才明确生成多头 K/V。在**推理**时,我们只需缓存潜在向量 (\mathbf{c}_i),并使用矩阵组合即时重建 K/V。这就是 MLA 将 KV-Cache 成本与头数大大解耦的方式。

3.2 按需解压:VRAM 如何节省

在推理过程中,每当我们生成 token (t) 并需要计算先前 token (i < t) 的点积 (\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}) 时,传统方法会从 VRAM 中读取 (\mathbf{k}_i^{(s)})。然而,MLA 将它们合并:

ki(s)=ciWkc(s)qt(s)ki(s)=(xtWq(s))(ciWkc(s)). \mathbf{k}_i^{(s)} = \mathbf{c}_i\,\mathbf{W}_{kc}^{(s)}\quad\Longrightarrow\quad\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top.

由于矩阵乘法的性质,我们可以进一步进行一些合并,例如

(xtWq(s))(ciWkc(s))=xt(Wq(s)Wkc(s))(ci)=xtWmerged(s)(ci) (\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top = \mathbf{x}_t (\mathbf{W}^{(s)}_q \mathbf{W}^{(s)\top}_{kc}) (\mathbf{c}_i)^\top = \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} (\mathbf{c}_i)^\top

因此,我们只在 VRAM 中保留 (如果 (

3.3 低秩投影如何大幅减少存储

有人可能会问,压缩率能有多大?假设 (d=4096),(h=32),以及单头维度 (d_k=128)。在标准 MHA 中,每个令牌的 Key 为 (32 \times 128 = 4096) 个元素(Value 也类似)。MLA 可以将潜在向量 (\mathbf{c}_i) 设置为 512 个元素,从而将 VRAM 使用量从 4096 减少到 512——提高 8 倍。在更极端的情况下,你可能会看到数十甚至数百倍的压缩因子。

当然,这是没有位置编码的理想情况。在实践中,Transformer 通常使用 RoPE(旋转位置嵌入),它修改了 Q 和 K 的投影方式。因此,我们将首先澄清 MLA 的基本低秩方法,然后探讨 RoPE 如何融入其中。接下来,我们将用“智能相册”类比来阐述 MLA 的工作流程,然后再回到 RoPE。


4. 将 MLA 理解为智能相册中的“低秩缩略图”

即使你理解了 MLA 的公式,可能仍然觉得它们有些抽象。让我们使用一个更直观的比喻:将每个“令牌”视为一张“照片”,“多头注意力”视为“应用于照片的滤镜”,而“KV 缓存”视为“相册存储”。这个类比展示了 MLA 如何实现压缩存储按需解压

4.1 照片存储:低秩缩略图

想象一下,每次你拍一张照片(处理一个令牌 (\mathbf{x}_i))时,你只保存一个“小而信息丰富的缩略图”——潜在向量 (\mathbf{c}_i),而不是保存“全分辨率图像加上所有滤镜”。例如,如果原始图像分辨率是 4096(如 MLA 中的 (d=4096)),缩略图大小可能是 512,实现了大约 1/8 的原始大小。
在数学上,ci=xiWc,WcR4096×512. \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c, \quad \mathbf{W}_c \in \mathbb{R}^{4096 \times 512}. 这类似于“在捕获时对照片进行下采样”,大大减少了存储开销。

4.2 查看照片:实时解压

当你用某个滤镜“查看”照片时——对应于生成 Key/Value 的注意力头——MLA 会执行:ki(s)=ciWkc(s),vi(s)=ciWv(s). \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}. 因此,无论有多少滤镜(头),你都只保留缩略图 ((\mathbf{c}_i)),而不是同一图像的多个版本。在推理时,每个滤镜的参数都可以从该缩略图重建 Key 或 Value,从而实现大规模存储缩减

4.3 按需重建:计算时的合并

在实际推理中,点积计算时会立即执行“合并”步骤

qt(s)ki(s)=(xtWq(s))(ciWkc(s))xtWmerged(s)ci. \mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top} = (\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top \approx \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} \mathbf{c}_i^\top.

因此,滤镜参数 (\mathbf{W}_{kc}^{(s)}) 不必为每个图片(头)存储在“相册”中。只保留紧凑的缩略图 ((\mathbf{c}_i))。这种方法省略了标准多头注意力所需的重复 Key/Value 开销。


5. RoPE 的挑战:为什么要添加“位置贴纸”?

既然我们已经介绍了 MLA 的低秩方法和按需重建,我们必须解决一个实际问题。Transformer 通常依赖 RoPE(旋转位置嵌入)将位置信息整合到 Key/Query 中。RoPE 使“仅潜在向量”的直接方法复杂化,因为每个令牌的位置 (i) 都引入了一个旋转矩阵 (\mathbf{\mathcal{R}}_i),它对点积的影响不同。

5.1 RoPE:时间戳和 GPS 坐标

回到我们的相册类比,RoPE 变成了每张照片上的“时间戳或 GPS 位置”。除了核心视觉内容 ((\mathbf{c}_i)),我们还保留了一个小贴纸,用于编码照片拍摄的时间和地点。如果我们尝试将时间数据直接嵌入到缩略图中,相对距离(时间差)可能会丢失。因此,在 MLA 中,Key/Query 维度的一部分仍明确乘以 (\mathbf{\mathcal{R}}_i),即使在我们的低秩方案下也能保持相对位置。

5.2 分割策略:(\mathbf{c}_i) + 小 RoPE 维度

正式地,为了保持旋转特性 (\mathbf{\mathcal{R}}_m \mathbf{\mathcal{R}}n^\top = \mathbf{\mathcal{R}}{m-n}),MLA 将每个 Key(以及类似的 Query)分成两部分:

ki(s)=(ciWkc(s))compressed portion    (xiWkrRi)positional portion, \mathbf{k}_i^{(s)}=\underbrace{\bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)}_{\text{compressed portion}}\;\oplus\;\underbrace{\bigl(\mathbf{x}_i \mathbf{W}_{kr}\,\mathbf{\mathcal{R}}_i\bigr)}_{\text{positional portion}},

因此 KV-缓存只通过一个适度的“位置维度”增长,而主要存储仍然是潜在向量 ((\mathbf{c}_i))。这种设计巧妙地将低秩投影旋转嵌入融合在一起,用于相对位置。


6. MLA 的综合优势:存储创新、灵活检索和时空保真度

在将 MLA 分解为核心步骤后,我们可以看到三个主要优势:

  • 存储:传统多头注意力必须为每个头存储 K/V。MLA 则使用潜在向量 ((\mathbf{c}_i))(低秩)以及一个用于 RoPE(如果需要)的小维度。VRAM 可以缩小数倍或更多。
  • 计算:通过按需解压,Key/Value 仅在必要时才重建——这可以通过将它们合并到 Query 中进一步优化。对于非常长的序列,内存带宽是真正的瓶颈,因此减少重复的 K/V 可以显著加快推理速度。
  • 位置:当模型需要相对位置(RoPE)时,MLA 可以保留一个单独的“位置贴纸”维度。这可以在不强制整个 K 空间单独存储的情况下保留时间/空间信息。

7. 从工程角度:关键考虑因素

7.1 平衡 VRAM 和速度

如果你的应用程序涉及数千或数万个令牌的序列,MLA 的潜在压缩有助于大幅减少 KV-Cache,并允许单个 GPU 处理更多令牌或更大的批量大小,从而提高吞吐量。

7.2 调整 RoPE 维度

当将 K 分割为低秩区域 ((\mathbf{c}_i)) 和 RoPE 区域时,如果 RoPE 维度太小,极长上下文可能无法获得足够的定位信号。相反,如果它太大,MLA 的压缩优势就会减弱。最佳权衡通常通过经验实验得出。

7.3 数值稳定性和精度

由于 MLA 在推理时合并权重矩阵,使用 BF16/FP16 可能会因乘法顺序的改变而引入小的累积误差。通常,这是可以接受的。如果你的应用程序对精度极其敏感,请考虑使用更高精度的累加器或部分 float32 回退。


8. 总结和未来方向

MLA (多头潜在注意力) 不仅仅是“K/V 的低秩分解”。它通过在推理仅缓存潜在向量 ((\mathbf{c}_i)),并通过按需解压矩阵合并重建多头 K/V,从而大幅减少 KV-Cache 的使用。然后,通过对 RoPE 采用分割策略,MLA 在不强制整个 K/V 保持显式的情况下保留了相对位置信息。

从工程角度来看,MLA 在长上下文 LLM 推理中的 VRAM 效率是一个巨大的优势,可能会增加单个显卡或小型集群处理的令牌数量。然而,如何精确地在潜在向量和 RoPE 之间划分维度,取决于你的任务和模型规模。这个概念也可以扩展到其他位置编码(ALiBi、NTK Scaling 等)或专业领域。

无论如何,MLA 显然为减少 KV-Cache 提供了一条全新且强大的途径。我们可能会看到更多结合其他注意力优化的 MLA 变体,帮助大型模型在严格的硬件限制下实现更高的性能。


附录:关键公式及其“相册”类比

下面是 MLA 中的关键公式,以及它们与智能相册比喻的对应关系:

  1. 低秩投影

ci=xiWc \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c

\quad\updownarrow\quad

“将缩略图而非完整照片存储,从而大幅缩小存储空间。” \text{“Store a thumbnail instead of a full photo,” drastically shrinking storage.}

  1. 动态解压

ki(s)=ciWkc(s),vi(s)=ciWv(s) \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}

\quad\updownarrow\quad

“在运行时从单个缩略图生成过滤后的视图(Key/Value)。” \text{“Generate filtered views (Key/Value) from the single thumbnail at runtime.”}

  1. 按需重建(合并矩阵)

qt(s)ki(s)=(xtWq(s))(ciWkc(s)) \mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top

\quad\updownarrow\quad

“将滤镜计算与查看步骤相结合,进一步减少存储。” \text{“Combine the filter math with the viewing step, further reducing storage.”}

  1. RoPE 分割

ki(s)=(ciWkc(s))    (xiWkrRi) \mathbf{k}_i^{(s)} = (\mathbf{c}_i \mathbf{W}_{kc}^{(s)}) \;\oplus\; (\mathbf{x}_i \mathbf{W}_{kr}\mathbf{\mathcal{R}}_i)

\quad\updownarrow\quad

“为时间戳或 GPS 保留一个单独的小标签,即相对位置。” \text{“Keep a separate small label for timestamps or GPS, i.e., relative position.”}

通读这些步骤,你可以看到 MLA 如何无缝地从低秩投影发展到按需 Key/Value 恢复,最终与 RoPE 共同用于位置编码。正如其名称所示,MLA 既保留了多头注意力的强大表达能力,又将主要的 K/V 存储负担委托给较小的潜在表示——从而在有限的 GPU 内存下实现更长的上下文。


9. MLA 的一个最小工作示例

为了更直观地感受 MLA 的核心概念,这里是一个基于 MLA 的 MoE Transformer 的最小可运行代码示例,展示了“潜在变量”、“按需重建”和“RoPE 集成”如何在实际代码结构中体现。在 `ModelArgs` 数据类中,字段 `attn_impl: Literal["naive", "absorb"] = "absorb"` 控制我们是使用 `naive` 风格存储经典 K、V 还是依赖 MLA 的潜在缓存 (`absorb`)。要了解完整功能,请参阅 DeepSeek 官方仓库

以下是整合后的示例

import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal

import torch
import torch.nn.functional as F
from torch import nn

@dataclass
class ModelArgs:
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 2
    n_dense_layers: int = 1
    n_heads: int = 16
    # moe
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    # mla
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    # rope
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.
    # kv-cache
    attn_impl: Literal["naive", "absorb"] = "absorb"

@dataclass
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return F.linear(x, weight, bias)


class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)


def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    beta_fast = args.beta_fast
    beta_slow = args.beta_slow
    base = args.rope_theta
    factor = args.rope_factor

    def find_correction_dim(num_rotations, dim, base, max_seq_len):
        return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
        low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
        high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
        return max(low, 0), min(high, dim - 1)

    def linear_ramp_factor(min_val, max_val, dim):
        if min_val == max_val:
            max_val += 0.001
        linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    if seqlen > args.original_seq_len:
        low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
        freqs = freqs / factor * (1 - smooth) + freqs * smooth

    t = torch.arange(seqlen)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    dtype = x.dtype
    x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
    y = torch.view_as_real(x * freqs_cis).flatten(3)
    return y.to(dtype)


class MLA(nn.Module):
    """
    Multi-Headed Attention Layer (MLA).
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = self.n_heads
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim
        self.attn_impl = args.attn_impl

        # Q
        if self.q_lora_rank == 0:
            self.wq = Linear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)

        # K,V
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = Linear(
            self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
        )

        # Output
        self.wo = Linear(self.n_heads * self.v_head_dim, self.dim)

        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale *= mscale * mscale

        # register different buffer based on the choice of "attn_impl"
        if self.attn_impl == "naive":
            self.register_buffer(
                "k_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim),
                persistent=False
            )
            self.register_buffer(
                "v_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim),
                persistent=False
            )
        else:
            self.register_buffer(
                "kv_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
                persistent=False
            )
            self.register_buffer(
                "pe_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
                persistent=False
            )


    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen

        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)

        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        if self.attn_impl == "naive":
            # naive mode
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))  # (bsz, seqlen, n_heads*(qk_nope_head_dim + v_head_dim))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)

            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)

            # write cache
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v

            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
            out = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])

        else:
            # the absorb mode proposed in MLA
            wkv_b = self.wkv_b.weight 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)

            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

            q_nope = torch.einsum(
                "bshd,hdc->bshc",
                q_nope,
                wkv_b[:, :self.qk_nope_head_dim] 
            )

            scores = (
                torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
            ) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

            out = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            out = torch.einsum(
                "bshc,hdc->bshd",
                out,
                wkv_b[:, -self.v_head_dim:]
            )

        out = self.wo(out.flatten(2))
        return out


class MLP(nn.Module):

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.
    """

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        nn.init.xavier_normal_(self.weight)
        self.bias = nn.Parameter(torch.zeros(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = linear(x, self.weight, self.bias)

        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()

        original_scores = scores
        if self.n_groups > 1:
            scores = scores.view(x.size(0), self.n_groups, -1)
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            indices_groups = group_scores.topk(self.topk_groups, dim=-1)[1]
            mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices_groups, True)
            scores = (scores * mask.unsqueeze(-1)).flatten(1)

        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale
        return weights.type_as(x), indices


class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_routed_experts = args.n_routed_experts
        self.n_activated_experts = args.n_activated_experts
        self.gate = Gate(args)

        self.experts = nn.ModuleList([
            Expert(args.dim, args.moe_inter_dim) for _ in range(self.n_routed_experts)
        ])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)

        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.n_routed_experts):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]

        z = self.shared_experts(x)
        return (y + z).view(shape)


class Block(nn.Module):
    """
    Transformer block combining attention and feed-forward layers.
    """

    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
                mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.max_seq_len = args.max_seq_len
        self.embed = torch.nn.Embedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = Linear(args.dim, args.vocab_size)

        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        seqlen = tokens.size(1)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)

        h = self.norm(h)[:, -1]
        logits = self.head(h)
        return logits


if __name__ == "__main__":
    torch.manual_seed(0)
    args = ModelArgs()
    x = torch.randint(0, args.vocab_size, (2, 128))
    model = Transformer(args)
    logits = model(x)
    print(logits.shape)  # (batch_size, vocab_size)

社区

很棒的文章!如果能修复文本中数学符号的格式问题就更好了(我不知道是不是浏览器问题,但在我的 Firefox 浏览器上不起作用)

注册登录 发表评论