Bitnet 1.5 的实验 (ngmi)

社区文章 发布于 2024 年 3 月 30 日

论文:arxiv.org/abs/2402.17764

image/png

(有没有人想过他们为什么不在论文中展示损失曲线)

更新:人们已在大规模上独立验证,前景光明。https://huggingface.co/1bitLLM/bitnet_b1_58-3B

有点作用,但别太兴奋

这更像是量化感知训练,而不是什么根本性的新事物。CUDA 核融合和优化还有很大的空间,因为假设存储的权重矩阵只包含 1、0、-1,这意味着我们只需要加、减操作,这比昂贵的乘法更好。虽然在推理时我们可以实现惊人的加速,但我不能确定在训练时也能实现。因为我们需要平滑的优化器梯度动量变化。

如有疑问或指正,请通过 twitter/x00shxf@gmail.com 联系我。我是在混乱中编写代码的,所以如果发现需要更正的地方,请告诉我。

如果您只对代码感兴趣,代码在这里。

实现细节

这是基于作者官方分享的 PDF 的 Bitnet 实现。image/png

权重量化

请注意,对于当前的实验,所有权重都以 fp32/16 存储并在正向传播过程中进行量化。

image/png

这里 γ\gamma = 缩放因子,W~\widetilde{W} = w_quant

from torch import Tensor

def weight_quant(w: Tensor) -> tuple[Tensor, Tensor]:
    scale: Tensor = 1.0 / w.abs().mean().clamp(min=1e-5)
    quant: Tensor = (w * scale).round().clamp(-1, 1) / scale
    w_quant: Tensor = w + (quant - w).detach()
    scale = abs(w_quant).max().detach()
    w_quant = w_quant / scale
    return w_quant, scale 

w_quant 是由 1、0、-1 组成的矩阵。

激活量化

这来自一篇前导论文 https://arxiv.org/pdf/2310.11453.pdf

image/png

from torch import Tensor

def activation_quant(x: Tensor) -> Tensor:
    scale: Tensor = 127.0 / x.abs().max(dim=1, keepdim=True).values.clamp(min=1e-5)
    y: Tensor = (x * scale).round().clamp(-128, 127) / scale
    return x + (y - x).detach()

BitLinear

这两个函数实现了 BitLinear;这一层是 nn.Linear 的直接替代。通过核融合、手动编写反向传播等方式,仍有很大的优化空间。

import torch
import torch.nn as nn

class BitLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super(BitLinear, self).__init__(*args, **kwargs)
        self.rms_norm = RMSNorm(self.in_features)

    def forward(self, x: Tensor) -> Tensor:
        w = self.weight
        x_norm = self.rms_norm(x)

        x_quant = activation_quant(x_norm)
        w_quant, scale = weight_quant(w)

        output = nn.functional.linear(x_quant, w_quant)
        return output * scale

训练代码

您可以在这里找到我的 PyTorch 实现:https://github.com/joey00072/ohara/tree/master/experiments/bitnet

我用 15.5M 的 llama 和 15.6M 的 bitnet llama(bitnet RMSNorm 额外 0.1M)训练了它。您可以查看损失曲线;bitnet 比 llama 差(我非常确定 2 比特量化感知训练会更好或相同)。

image/png

目前,BitNet 在现有硬件上进行推理没什么用;如果需要量化版本,最好以 bfloat16 训练模型并使用量化版本。为了实现显著的加速,

BitNet 1.5 的前提条件相当高。我们需要有人制造一种支持 2 比特混合精度的特殊芯片。所以,最大的假设是我们将使用一个 2 比特模型进行推理,这意味着需要有人花费大量(非常大量)资金来制造芯片、软件并训练量化感知 LLM。更不用说我们不知道 1.5 比特(即 2 比特)量化训练的扩展定律是否与普通训练的相同。除非有人准备投入大量资金。

所以,BITNET NGMI


ko-fi

社区

注册登录 发表评论