将 LLM 微调至 1.58 位:轻松实现极限量化

发布于 2024 年 9 月 18 日
在 GitHub 上更新

随着大型语言模型 (LLM) 规模和复杂性的增长,寻找降低其计算和能源成本的方法已成为一项关键挑战。一种流行的解决方案是量化,即将参数的精度从标准的 16 位浮点 (FP16) 或 32 位浮点 (FP32) 降低到 8 位或 4 位等低位格式。虽然这种方法显著减少了内存使用并加快了计算速度,但通常会以牺牲精度为代价。过度降低精度可能导致模型丢失关键信息,从而导致性能下降。

BitNet 是一种特殊的 Transformer 架构,它仅用三个值来表示每个参数:(-1, 0, 1),从而实现了每参数 1.58 (log2(3) log_2(3) ) 位的极限量化。然而,它需要从头开始训练模型。虽然结果令人印象深刻,但并非所有人都拥有预训练 LLM 的预算。为了克服这一限制,我们探索了一些技巧,允许将现有模型微调到 1.58 位!继续阅读以了解如何实现!

目录

TL;DR

BitNet 是微软研究院推出的一种架构,它采用极限量化,每个参数仅用三个值表示:-1、0 和 1。这使得模型每个参数仅使用 1.58 位,显著降低了计算和内存需求。

与 LLaMA LLM 的 FP16 加法和乘法运算相比,该架构在执行矩阵乘法时使用 INT8 加法计算。

The new computation paradigm of BitNet b1.58
BitNet b1.58 的新计算范式(来源:BitNet 论文 https://arxiv.org/abs/2402.17764)

这导致理论上能耗降低,BitNet b1.58 在矩阵乘法方面比 Llama 基线节省了 71.4 倍的算术运算能耗。

Energy consumption of BitNet b1.58 compared to LLaMA
BitNet b1.58 与 Llama 的能耗比较(来源:BitNet 论文 https://arxiv.org/abs/2402.17764)

我们已成功使用 BitNet 架构微调了 Llama3 8B 模型,并在下游任务中取得了出色的表现。我们开发的 8B 模型在 HF1BitLLM 组织下发布。其中两个模型在 100 亿个 Token 上进行了不同训练设置的微调,而第三个模型在 1000 亿个 Token 上进行了微调。值得注意的是,我们的模型在 MMLU 基准测试中超越了 Llama 1 7B 模型。

如何与 Transformers 配合使用

为了将 BitNet 架构集成到 Transformers 中,我们引入了一种新的量化方法,名为“bitnet”(PR)。这种方法涉及将标准线性层替换为与 BitNet 架构兼容的专用 BitLinear 层,并进行适当的动态激活量化、权重解包和矩阵乘法。

在 Transformers 中加载和测试模型非常简单,API 没有任何变化

model = AutoModelForCausalLM.from_pretrained(
    "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
    device_map="cuda",
    torch_dtype=torch.bfloat16
)    
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_new_tokens=10)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

通过这段代码,一切都在后台无缝管理,因此无需担心额外的复杂性,您只需安装最新版本的 transformers 即可。

要快速测试模型,请查看此 Notebook

BitNet 深入解读

BitNet 用名为 BitLinear 的特殊层取代了多头注意力和前馈网络中的传统线性层,这些特殊层使用三元精度(或二进制,在初始版本中)。我们在这个项目中使用的 BitLinear 层使用三元精度(值分别为 -1、0 和 1)量化权重,并将激活量化为 8 位精度。我们用于训练的 BitLinear 实现与用于推理的不同,这将在下一节中看到。

三元精度训练的主要障碍是权重值是离散的(通过 round() 函数),因此不可微分。BitLinear 通过一个巧妙的技巧解决了这个问题:STE (Straight Through Estimator)。STE 允许梯度通过不可微分的舍入操作,通过将其梯度近似为 1(将 round() 视为等同于恒等函数)。另一种看法是,STE 不会在舍入步骤停止梯度,而是让梯度通过,就像舍入从未发生过一样,从而可以使用标准基于梯度的优化技术更新权重。

The architecture of BitNet with BitLinear layers
带 BitLinear 层的 BitNet 架构(来源:BitNet 论文 https://arxiv.org/pdf/2310.11453)

训练

我们以全精度进行训练,但在此过程中使用对称的每张量量化将权重量化为三元值。首先,我们计算权重矩阵的绝对值平均值并将其用作比例。然后,我们将权重除以比例,对值进行舍入,将其限制在 -1 和 1 之间,最后将其重新缩放以继续以全精度进行计算。

scalew=11nmijWij scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}

Wq=clamp[1,1](round(Wscale)) W_q = \text{clamp}_{[-1,1]}(\text{round}(W*scale))

Wdequantized=Wqscalew W_{dequantized} = W_q*scale_w

然后,激活量化为指定的位宽(在本例中为 8 位),使用 absmax 每 token 量化(有关量化方法的全面介绍,请查看此帖子)。这涉及将激活缩放到 8 位位宽的范围 [−128, 127]。量化公式为:

scalex=127Xmax,dim=1 scale_x = \frac{127}{|X|_{\text{max}, \, \text{dim}=-1}}

Xq=clamp[128,127](round(Xscale)) X_q = \text{clamp}_{[-128,127]}(\text{round}(X*scale))

Xdequantized=Xqscalex X_{dequantized} = X_q * scale_x

为了使公式更清晰,这里有使用 3x3 矩阵进行权重和激活量化的示例


示例 1:权重矩阵量化

设权重矩阵 ( W ) 为:

W=[0.80.51.21.50.40.91.30.70.2] W = \begin{bmatrix} 0.8 & -0.5 & 1.2 \\ -1.5 & 0.4 & -0.9 \\ 1.3 & -0.7 & 0.2 \end{bmatrix}

步骤 1:计算权重比例

使用公式

scalew=11nmijWij scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}

我们计算 ( W ) 的平均绝对值:

1nmijWij=19(0.8+0.5+1.2+1.5+0.4+0.9+1.3+0.7+0.2)=19(7.5)=0.8333 \frac{1}{nm} \sum_{ij} |W_{ij}| = \frac{1}{9}(0.8 + 0.5 + 1.2 + 1.5 + 0.4 + 0.9 + 1.3 + 0.7 + 0.2) = \frac{1}{9}(7.5) = 0.8333

现在,比例因子是

scalew=10.83331.2 scale_w = \frac{1}{0.8333} \approx 1.2

步骤 2:量化权重矩阵

使用公式

Wq=clamp[1,1](round(W×scalew)) W_q = \text{clamp}_{[-1, 1]}(\text{round}(W \times scale_w))

我们首先将权重按 scalew1.2 scale_w \approx 1.2 缩放。

W×scalew=[0.8×1.20.5×1.21.2×1.21.5×1.20.4×1.20.9×1.21.3×1.20.7×1.20.2×1.2]=[0.960.61.441.80.481.081.560.840.24] W \times scale_w = \begin{bmatrix} 0.8 \times 1.2 & -0.5 \times 1.2 & 1.2 \times 1.2 \\ -1.5 \times 1.2 & 0.4 \times 1.2 & -0.9 \times 1.2 \\ 1.3 \times 1.2 & -0.7 \times 1.2 & 0.2 \times 1.2 \end{bmatrix} = \begin{bmatrix} 0.96 & -0.6 & 1.44 \\ -1.8 & 0.48 & -1.08 \\ 1.56 & -0.84 & 0.24 \end{bmatrix}

接下来,我们对值进行舍入并将其限制在范围 [1,1] [-1, 1] 内。

Wq=[111101110] W_q = \begin{bmatrix} 1 & -1 & 1 \\ -1 & 0 & -1 \\ 1 & -1 & 0 \end{bmatrix}

步骤 3:反量化权重

最后,我们使用以下公式反量化权重:

Wdequantized=Wq×scalew W_{dequantized} = W_q \times scale_w

代入 scale_w,我们得到:

Wdequantized=[1×1.21×1.21×1.21×1.20×1.21×1.21×1.21×1.20×1.2]=[1.21.21.21.201.21.21.20] W_{dequantized} = \begin{bmatrix} 1 \times 1.2 & -1 \times 1.2 & 1 \times 1.2 \\ -1 \times 1.2 & 0 \times 1.2 & -1 \times 1.2 \\ 1 \times 1.2 & -1 \times 1.2 & 0 \times 1.2 \end{bmatrix} = \begin{bmatrix} 1.2 & -1.2 & 1.2 \\ -1.2 & 0 & -1.2 \\ 1.2 & -1.2 & 0 \end{bmatrix}

示例 2:激活矩阵量化

设激活矩阵 ( X ) 为:

X=[1.00.60.70.90.41.20.80.50.3] X = \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix}

步骤1:计算激活的尺度

对于每一行(或通道),计算最大绝对值

  • 第1行:最大绝对值 = 1.0
  • 第2行:最大绝对值 = 1.2
  • 第3行:最大绝对值 = 0.8

计算每一行的尺度因子

scale=[1271.01271.21270.8]=[127105.83158.75] \text{scale} = \begin{bmatrix} \frac{127}{1.0} \\ \frac{127}{1.2} \\ \frac{127}{0.8} \end{bmatrix} = \begin{bmatrix} 127 \\ 105.83 \\ 158.75 \end{bmatrix}

步骤2:量化激活矩阵

使用公式

Xq=clamp[128,127](round(X×scale)) X_q = \text{clamp}_{[-128,127]}(\text{round}(X \times \text{scale}))

缩放激活

X×scale=[1.0×1270.6×1270.7×1270.9×105.830.4×105.831.2×105.830.8×158.750.5×158.750.3×158.75]=[12776.288.995.242.312712779.447.6] X \times \text{scale} = \begin{bmatrix} 1.0 \times 127 & -0.6 \times 127 & 0.7 \times 127 \\ -0.9 \times 105.83 & 0.4 \times 105.83 & -1.2 \times 105.83 \\ 0.8 \times 158.75 & -0.5 \times 158.75 & 0.3 \times 158.75 \end{bmatrix} = \begin{bmatrix} 127 & -76.2 & 88.9 \\ -95.2 & 42.3 & -127 \\ 127 & -79.4 & 47.6 \end{bmatrix}

将值四舍五入并将其限制在 [128,127][-128, 127] 范围内

Xq=[127768995421271277948] X_q = \begin{bmatrix} 127 & -76 & 89 \\ -95 & 42 & -127 \\ 127 & -79 & 48 \end{bmatrix}

步骤3:反量化激活

最后,使用以下公式反量化激活:

Xdequantized=Xq×1scale X_{dequantized} = X_q \times \frac{1}{\text{scale}}

代入尺度

Xdequantized=[127×112776×112789×112795×1105.8342×1105.83127×1105.83127×1158.7579×1158.7548×1158.75]=[1.00.60.70.90.41.20.80.50.3] X_{dequantized} = \begin{bmatrix} 127 \times \frac{1}{127} & -76 \times \frac{1}{127} & 89 \times \frac{1}{127} \\ -95 \times \frac{1}{105.83} & 42 \times \frac{1}{105.83} & -127 \times \frac{1}{105.83} \\ 127 \times \frac{1}{158.75} & -79 \times \frac{1}{158.75} & 48 \times \frac{1}{158.75} \end{bmatrix} = \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix}


我们在量化激活之前应用层归一化(LN)以保持输出的方差。

LN(x)=xE(x)Var(x)+ϵ \text{LN}(x) = \frac{x - E(x)}{\sqrt{\text{Var}(x) + \epsilon}}

其中 $\epsilon$ 是一个很小的数字,以防止溢出。

如前所述,`round()` 函数是不可微分的。我们使用 `detach()` 作为一种技巧,在反向传播中实现可微分的直通估计器。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y
 
def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight
        x_norm = LN(x)
        
        # A trick for implementing Straight−Through−Estimator (STE) using detach()
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        
        # Perform quantized linear transformation
        y = F.linear(x_quant, w_quant)
        return y

推理

在推理过程中,我们只需将权重量化为三元值,无需重新缩放。我们对激活应用相同的8位精度方法,然后使用高效内核执行矩阵乘法,再除以权重和激活尺度。这应该能显著提高推理速度,尤其是在优化硬件上。你可以看到,训练期间的重新缩放过程有所不同,因为矩阵乘法保持在fp16/bf16/fp32以进行正确训练。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant_inference(x):
    x = LN(x)
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale
 
class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight # weights here are already quantized to (-1, 0, 1)    
        w_scale = self.w_scale  
        x_quant, x_scale = activation_quant_inference(x)
        y = efficient_kernel(x_quant, w) / w_scale / x_scale
        return y

1.58b 预训练结果

在尝试微调之前,我们首先尝试使用预训练重现 BitNet 论文的结果。我们从一个小型数据集 tinystories 和一个 Llama3 8B 模型开始。我们确认,像论文那样添加归一化函数可以提高性能。例如,经过2000步训练后,验证集上的困惑度在没有归一化的情况下为6.3,在有归一化的情况下为5.9。两种情况下的训练都是稳定的。

Pre-training plots without (blue) & with (green) layer normalisation
预训练图(不含层归一化为蓝色,含层归一化为橙色)

虽然这种预训练方法看起来非常有趣,但只有少数机构能够以必要的规模进行。然而,目前已经有各种强大的预训练模型,如果它们能在预训练后转换为1.58位,那将非常有益。其他小组报告称,微调结果不如预训练实现的结果强,因此我们着手研究如何使1.58微调起作用。

1.58位微调

当我们开始从预训练的Llama3 8B权重进行微调时,模型表现略好,但不如我们预期。

注意:我们所有的实验都是使用 Nanotron 进行的。如果您对尝试 1.58 位预训练或微调感兴趣,可以查看此 PR

Fine-tuning plot compared to pre-training plot
微调图与预训练图对比

为了理解原因,我们尝试检查随机初始化模型和预训练模型的权重分布,以找出潜在问题。

Random weights distribution (2 merged stds)
随机权重分布(2个合并的标准差)
Pre-trained Llama3 weights distribution
预训练的 Llama3 权重分布

两个分布的尺度值分别为

Random weights scales distribution
随机权重尺度分布
Pre-trained Llama3 weights distribution
预训练的 Llama3 权重分布

初始随机权重分布是两个正态分布的混合

  • 一个标准差 (std) 为 0.025 0.025
  • 另一个标准差为 0.0252num_hidden_layers=0.00325 \frac{0.025}{\sqrt{2 \cdot \text{num\_hidden\_layers}}} = 0.00325

这是由于在 `nanotron` 中,列线性权重和行线性权重使用不同的标准差。在量化版本中,所有矩阵只有两个权重尺度(50.25 和 402),它们是每个矩阵权重绝对值的平均值的倒数:`scale = 1.0 / w.abs().mean().clamp_(min=1e-5)`。

  • 对于 scale=50.25\text{scale} = 50.25 w.abs().mean()=0.0199 w.abs().mean() = 0.0199 ,这与我们的第一个标准差 std=0.025\text{std} = 0.025 相匹配。用于推导标准差的公式基于 w |w| 的半正态分布期望。
    E(w)=std(w)2π \mathbb{E}(|w|) = \text{std}(w) \cdot \sqrt{\frac{2}{\pi}}
  • 对于 scale=402 \text{scale} = 402 w.abs().mean()=0.0025 w.abs().mean() = 0.0025 ,导致 std=0.00325\text{std} = 0.00325

另一方面,预训练权重的分布看起来像一个标准差为 std=0.013 \text{std} = 0.013 的正态分布。

显然,预训练模型以更多信息(尺度)开始,而随机初始化模型几乎没有信息,并随着时间的推移增加信息。我们的结论是,使用随机权重开始会给模型提供最小的初始信息,从而实现渐进式学习过程,而在微调过程中,BitLinear层的引入会使模型不堪重负,从而丢失所有先前的的信息。

为了改善微调结果,我们尝试了不同的技术。例如,我们尝试了逐行和逐列量化,而不是逐张量量化,以保留Llama 3权重中的更多信息。我们还尝试改变尺度计算方式:不再仅仅将权重的平均绝对值作为尺度,而是将离群值的平均绝对值作为尺度(离群值是超出k*平均绝对值的值,其中k是我们尝试在实验中改变的常数),但我们没有发现显著改进。

def scale_outliers(tensor, threshold_factor=1):
    mean_absolute_value = torch.mean(torch.abs(tensor))
    threshold = threshold_factor * mean_absolute_value
    outliers = tensor[torch.abs(tensor) > threshold]
    mean_outlier_value = torch.mean(torch.abs(outliers))
    return mean_outlier_value

def weight_quant_scaling(w):
    scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
    quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
    return quantized_weights

我们观察到,随机权重和Llama 3权重都导致损失从大约相同的13值开始。这表明Llama 3模型在引入量化时丢失了其所有先验信息。为了进一步研究模型在此过程中丢失了多少信息,我们尝试了逐组量化。

作为一项健全性检查,我们首先将组大小设置为1,这实际上意味着没有量化。在这种情况下,损失从1.45开始,与我们在正常微调期间看到的情况相同。然而,当我们将组大小增加到2时,损失跃升至约11。这表明即使组大小最小为2,模型仍然几乎丢失了所有信息。

为了解决这个问题,我们考虑了逐步引入量化的可能性,而不是对每个张量的权重和激活突然应用量化。为此,我们实现了一个lambda值来控制该过程。

lambda_ = ?
x_quant = x + lambda_ * (activation_quant(x) - x).detach()
w_quant = w + lambda_ * (weight_quant(w) - w).detach()

当 `lambda` 设置为 0 时,基本上没有发生量化,而当 `lambda=1` 时,则应用完全量化。

我们最初测试了一些离散的 `lambda` 值,例如 0.25、0.5、0.75 和 1。然而,这种方法并没有导致结果的显著改善,主要是因为 `lambda=0.25` 已经足够高,导致损失从一开始就非常高。

Fine-tuning plot with lambda = 0.25->0.5->0.75->1
lambda = 0.25->0.5->0.75->1的微调图

因此,我们决定尝试根据训练步数动态调整的 `lambda` 值。

lambda_ = training_step / total_training_steps

使用这个动态的 `lambda` 值导致了更好的损失收敛,但当 `lambda` 设置为 1 时,推理期间的困惑度(ppl)结果仍然远不尽如人意。我们意识到这可能是因为模型在 `lambda=1` 下训练的时间不够长。为了解决这个问题,我们调整了 `lambda` 值以改进训练过程。

lambda_ = min(2 * training_step / total_training_steps, 1)

在此配置下,经过2000步训练后,我们得到

Fine-tuning plot with lambda = min(2*training_step/total_training_steps, 1)
lambda = min(2*training_step/total_training_steps, 1) 的微调图

我们的微调方法整体显示出更好的收敛性。您会注意到在1000步左右,损失曲线略有增加,这对应于我们开始接近`lambda=1`或完全量化的时候。然而,在此之后,损失立即开始再次收敛,导致困惑度提高到大约4。

尽管取得了这一进展,但当我们在WikiText数据集(而不是我们用于微调的tinystories数据集)上测试量化模型时,它显示出非常高的困惑度。这表明在特定数据集上以低位模式微调模型会导致它失去大部分通用知识。这个问题可能是因为我们用三元权重追求的最小表示在不同数据集之间可能存在显著差异。为了解决这个问题,我们扩大了训练过程,包括更大的FineWeb-edu数据集。我们保持了

lambda_ = min(training_step/1000, 1)

我们选择这个 `lambda` 值是因为它似乎是预热模型的一个良好起点。然后,我们使用 1e-4 的学习率在 FineWeb-edu 数据集上训练了模型 5,000 步。训练涉及 200 万的批处理大小 (BS),总计 100 亿个 token。

找到合适的学习率和衰减率极具挑战性;这似乎是模型性能的关键因素。

Fine-tuning plot with warmup quantization on Fineweb-edu
使用Fineweb-edu进行预热量化的微调图

在 Fineweb-Edu 上进行微调后,WikiText 数据集上的困惑度达到了 12.2,考虑到我们只使用了 100 亿个标记,这相当令人印象深刻。其他评估指标也显示出强劲的性能,考虑到有限的数据量(参见结果)。

我们还尝试在 lambda 接近 1 时平滑急剧增加。为此,我们考虑使用 lambda 调度器,它们首先呈指数增长,然后随着接近 1 而趋于平稳。

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

对于不同的 k 值,在总预热步数为 1 的情况下,我们有如下所示的图:

Exponential scheduler for different k values
不同k值的指数调度器

我们使用表现最佳的学习率1e-4进行了4次实验,测试了k在[4, 6, 8, 10]中的值。

Fine-tuning plots with exponential scheduler
指数调度器微调图

平滑处理效果良好,没有像线性调度器那样出现尖峰。然而,困惑度并不理想,保持在15左右,并且下游任务的性能也没有更好。

我们还注意到一开始的尖峰,模型很难从中恢复过来。当lambda=0时,基本上没有量化,所以损失从低点开始,大约在2左右。但紧接着第一步,就出现了一个尖峰,类似于线性调度器的情况(如上图蓝色曲线所示)。因此,我们尝试了另一种调度器——S形调度器——它开始缓慢,急剧上升到1,然后随着接近1而趋于平稳。

def sigmoid_scheduler(step, total_steps, k):
    # Sigmoid-like curve: slow start, fast middle, slow end
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))

对于不同的 k 值,我们有以下曲线

Sigmoid scheduler for different k values
不同k值的S形调度器

我们这次运行了 5 次实验,k 的取值范围是 [15, 20, 25, 40, 100]。

Finetuning plots with sigmoid scheduler
S形调度器微调图

lambda的急剧增加导致了第500步左右的不稳定性,并且没有解决第一个发散问题。然而,对于 k=100 k = 100 ,我们观察到下游任务有所改善(参见结果表),尽管困惑度仍保持在13.5左右。尽管如此,它并未显示出比线性调度器更明显的性能提升。

此外,我们还尝试从头开始,使用随机权重和不同的学习率来训练模型。这使我们能够比较微调方法与传统预训练方法的有效性。

Different Pre-training plots with different learning rates
不同学习率下的预训练曲线图

所有从随机权重训练的模型表现都不如我们微调后的模型。这些模型的最佳困惑度为26,远低于我们微调方法的结果。

扩展到 1000 亿个 Token!

我们将实验扩展到 1000 亿个 Token,以观察能否与 Llama 3 8B 的性能相匹配。我们进行了更长时间的训练,从线性调度器下较短运行中表现最佳的检查点开始,并继续微调 45,000 步。我们尝试了不同的学习率,虽然模型在某些指标上与 Llama 3 模型表现接近,但平均而言仍有所落后。

以下是我们在训练过程中不同检查点评估的指标示例:

Metrics evaluations during the training for different lrs
不同学习率下训练过程中的指标评估

平均得分如下:

Average evaluation during the training for different lrs
不同学习率下训练过程中的平均评估

小型模型实验

在我们对 SmolLM 等小型模型的初步实验中,我们观察到热身量化技术并未像在大型模型中那样带来显著改进。这表明热身量化的有效性可能与模型大小和复杂性更密切相关。

例如,这里是 SmolLM 135M 模型的损失曲线,比较了热身量化与从一开始就进行完全量化的效果。有趣的是,曲线非常吻合,并且最终的困惑度没有显著差异。

Smoll LLm fine-tuning experiment with & without warmup quantization
Smol LLm 微调实验(带/不带热身量化)

结果与比较

BitNet 在提供强大性能方面,尤其是低比特级别,与基线方法相比非常有效。根据论文所述,BitNet 取得了与 8 比特模型相当的分数,但推理成本显著降低。对于 4 比特模型,仅量化权重的方法优于同时量化权重和激活的方法,因为激活更难量化。然而,使用 1.58 比特权重的 BitNet 超越了仅权重和权重与激活量化这两种方法。

下表展示了 Llama3 8B 经过 100 亿次微调后的各项指标结果。这些结果与其他模型架构的结果进行了比较,以提供全面的性能概览(所有评估均使用 LightevalNanotron 格式模型上进行)

Metrics comparison with Llama models
与 Llama 模型的指标比较:Linear 表示线性 lambda 调度器,Sigmoid 表示 Sigmoid lambda 调度器(在我们的例子中 k = 100)

仅使用三元权重对模型进行 100 亿个 Token 的微调后,模型展现出令人印象深刻的性能,尤其是在与经过更广泛训练的其他模型进行比较时。例如,它优于 Bitnet 7B 模型,后者在规模显著更大的 1000 亿个 Token 数据集上进行了训练。此外,它也优于 FBI LLM (Fully Binarized LLM),一个在更庞大的 1.26 万亿个 Token 上进行蒸馏的模型。这凸显了该模型尽管微调规模相对较小,但其效率和有效性。

对于 1000 亿个 Token 的实验,我们表现最好的检查点如下:

Metrics comparison with Llama models for the model trained on 100B tokens
在 1000 亿个 Token 上训练的模型与 Llama 模型的指标比较

为了复现这些结果,您可以查看此 PR 以将模型转换为 nanotron 格式,解包权重(查看函数 unpack_weights),并使用 lighteval

请注意,尽管这些模型是从 Instruct-tuned 模型微调而来,但它们仍然需要使用 Instruct 数据集进行微调。这些可以被视为基础模型。

自定义内核与基准测试

为了利用 BitNet 的低精度权重,我们将它们打包成 int8 张量(这将参数数量从 8B 减少到 2.8B!)。在推理过程中,这些权重必须在进行矩阵乘法之前解包。我们用 Cuda 和 Triton 实现了自定义内核,以处理矩阵乘法过程中的即时解包。对于矩阵乘法本身,我们采用了缓存平铺矩阵乘法技术。为了充分理解这种方法,我们首先回顾一些 Cuda 编程基础知识。

基本 GPU 概念:线程、块和共享内存

在深入研究缓存平铺矩阵乘法之前,理解一些基本的 GPU 概念非常重要:

  • 线程和块:GPU 同时执行数千个线程。这些线程被分组到块中,每个块独立运行。网格由这些块组成,它代表整个问题空间。例如,在矩阵乘法中,每个线程可能负责计算输出矩阵的一个元素。
  • 共享内存:每个块都可以访问有限的共享内存,其速度远快于全局内存(GPU 上的主内存)。然而,共享内存的大小有限,并且在块内的所有线程之间共享。有效地使用共享内存是提高 GPU 程序性能的关键。

矩阵乘法中的挑战

在 GPU 上实现矩阵乘法的简单方法可能涉及每个线程通过直接从全局内存读取所需元素来计算结果矩阵的单个元素。然而,这种方法可能效率低下,原因如下:

  • 内存带宽:与 GPU 核心执行计算的速度相比,访问全局内存相对较慢。如果每个线程直接从全局内存读取矩阵元素,内存访问时间可能会成为瓶颈。
  • 冗余数据访问:在矩阵乘法中,输入矩阵的许多元素被多次使用。如果每个线程独立地从全局内存获取所需数据,则相同的数据可能会被多次加载到 GPU 中,从而导致效率低下。例如,如果每个线程用于计算输出矩阵中的单个元素,则负责计算位置 (i, j) 处元素的线程将需要从全局内存加载矩阵 A 的第 i 行和矩阵 B 的第 j 列。然而,其他线程,例如计算位置 (i+1, j) 处元素的线程,无法重用此数据,并且必须再次从全局内存中重新加载相同的第 j 列。

平铺的思想

平铺是一种用于解决这些挑战的技术,它主要用于 FlashAttention 以提高内核的效率。基本思想是将矩阵分成更小的子矩阵,称为瓦片(tiles),它们可以放入 GPU 的共享内存中。计算不再一次性完成整个输出矩阵,而是分解成更小的部分,逐瓦片处理。

在矩阵乘法中,这意味着将矩阵 A 和 B 分成块(瓦片),将这些瓦片加载到共享内存中,然后对这些较小的块进行乘法运算。这种方法允许线程重用存储在快速共享内存中的数据,从而减少重复访问全局内存的需求。

工作原理如下:

  • 将瓦片加载到共享内存:每个线程块协同地将矩阵 A 的一个瓦片和矩阵 B 的一个对应瓦片从全局内存加载到共享内存中。此操作每个瓦片只执行一次,然后该瓦片由块中的线程多次重用。
  • 计算部分积:一旦瓦片加载到共享内存中,每个线程计算一个部分积。由于块中的所有线程都在共享内存中的相同瓦片上工作,它们可以有效地重用数据而无需额外的全局内存访问。
  • 累积结果:计算完一个瓦片的部分积后,线程将矩阵 A 和 B 的下一个瓦片加载到共享内存中,并重复该过程。结果累积在寄存器(或本地内存)中,一旦所有瓦片都处理完毕,输出矩阵元素的最终值将写回全局内存。
Tiled Matrix multiplication illustration
平铺矩阵乘法示意图(来源:https://cnugteren.github.io/tutorial/pages/page4.html)

实际考量

在实现缓存平铺矩阵乘法时,需要考虑几个因素:

  • 瓦片大小:瓦片的大小应在可放入共享内存的数据量和全局内存访问次数之间取得平衡。
  • 内存合并:全局内存访问被合并,这意味着相邻线程访问相邻内存位置。
  • 占用率:每个块的线程数和网格中的块数应选择,以确保高占用率,这意味着在 GPU 上尽可能多地拥有活动的 warp(warp 是 32 个线程的集合),以隐藏内存延迟。

Triton 内核

这是我们进行基准测试的 Triton 内核:

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  
        GROUP_SIZE_M: tl.constexpr,
):

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

    for i in range(4) : 
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
            k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j 

            # BLOCK_SIZE_K must be a divisor of K / 4 
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
            b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
            mask = 3<<(2*i)
            b = ((b_uint8 & mask) >> (2*i))

            # We accumulate the tiles along the K dimension.
            tensor_full = tl.full((1,), 1, dtype=tl.int8)

            accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
    assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

代码分解

  1. 确定瓦片位置

内核首先确定每个线程块负责的输出矩阵瓦片(块)

  • pid 是每个线程块的唯一标识符,通过 tl.program_id(axis=0) 获取。
  • 网格被划分为线程块组(GROUP_SIZE_M)。每个组处理输出矩阵的一部分。
  • pid_mpid_n 分别是瓦片在 M 和 N 维度上的坐标。
  • 计算偏移量(offs_am, offs_bn, offs_k)以确定块中每个线程将处理矩阵 A 和 B 的哪些元素。
  1. 加载和计算瓦片

内核使用循环以 BLOCK_SIZE_K 的块大小迭代 K 维度。对于每个块:

  • 加载瓦片:从全局内存加载矩阵 A 和 B 的瓦片。
  • 解包矩阵 B:内核假定矩阵 B 用 int8 值打包,这意味着每个元素实际上代表四个打包成一个字节的较小值。解包在循环中进行:
    • b_uint8 作为打包的 int8 从全局内存加载。
    • 每个打包的值都被解包以获取用于计算的实际权重值。
  • 点积:内核计算从 A 和 B 加载的瓦片的点积,并将结果累积到 accumulator 中。accumulator 存储输出矩阵 C 的瓦片的部分结果。
  1. 存储结果

在沿 K 维度处理完所有瓦片后,存储在 accumulator 中的最终结果将转换为 float16 并写回全局内存中矩阵 C 的相应瓦片。写入过程通过掩码遵循内存边界,以确保只写入有效元素。

有关代码的更详细说明,请查看此 PR

基准测试

我们对自定义内核与使用 @torch.compile 解包权重后执行 BF16 精度的矩阵乘法进行了基准测试,发现两种方法均取得了大致相同的性能。为确保基准测试的准确性,我们对矩阵乘法操作进行了 2000 次迭代,并对最后 1000 次迭代所花费的时间取平均值,以消除任何与初始加载或编译相关的低效率。下图显示了基准测试结果。我们还测试了各种矩阵大小,其中 x 轴表示乘法次数(对数刻度),y 轴表示平均时间(毫秒)。

Triton kernel compared to torch.compile
Triton 内核与 torch.compile 的比较

我们还尝试了使用 BitBlas,这是一个旨在执行混合精度矩阵运算的软件库。它通过允许以 INT8、INT4 甚至 INT2 等低精度格式进行计算,而不是传统的 FP32 或 FP16 格式,从而帮助优化这些操作。

基准测试结果令人鼓舞,如折线图所示,BitBlas 在低精度方面优于我们的自定义内核和 Torch 的 matmul 函数。

Bitblas benchmark
Bitblas 基准测试

然而,在模型加载过程中,BitBlas 需要编译针对权重矩阵形状定制的内核并将其存储在本地数据库中,这可能会增加初始加载时间。

结论

总而言之,随着大型语言模型(LLM)的不断扩展,通过量化降低其计算需求至关重要。本文探讨了 1.58 位量化方法,该方法使用三元权重。虽然以 1.58 位预训练模型是资源密集型的,但我们已经证明,通过一些技巧,可以对现有模型进行此精度级别的微调,从而在不牺牲准确性的前提下实现高效性能。通过专用内核优化推理速度,BitNet 为使 LLM 更加实用和可扩展开辟了新的可能性。

致谢

我们衷心感谢 Leandro von Werra、Thomas Wolf 和 Marc Sun 在整个项目中提供的宝贵帮助和见解。我们还要感谢 Omar Sanseviero 和 Pedro Cuenca 在完善这篇博客文章方面的贡献,帮助我们清晰有效地向人工智能社区传达我们的发现。此外,我们还要感谢 GeneralAI 团队在 BitNet 项目上的开创性工作。他们的研究是我们努力的基础,我们特别感谢他们在论文中提供了清晰精确的图表。

额外资源

  1. H. Wang 等人,《BitNet: Scaling 1-bit Transformers for Large Language Models》。arxiv 论文
  2. S. Ma 等人,《1 比特 LLM 时代:所有大型语言模型均为 1.58 比特》。arxiv 论文
  3. S. Ma 等人,《1 比特 LLM 时代:训练技巧、代码和常见问题解答》。链接
  4. RJ. Honicky,《所有大型语言模型真的都是 1.58 位吗?》。博客文章
  5. L. Mao,《CUDA 矩阵乘法优化》。博客文章
  6. 《教程:针对 Kepler 的 OpenCL SGEMM 调优》。链接
  7. 《CUDAMODE》。githubyoutube
  8. Wen-mei W. Hwu, David B. Kirk, Izzat El Hajj, 《大规模并行处理器编程:实践方法》

社区

好文章,谢谢分享。

我想知道为什么微调是在指令模型而不是基础模型上进行的?

注册登录发表评论