Transformers 文档

HIGGS

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

HIGGS

HIGGS 是一种零样本量化算法,它将哈达玛预处理与 MSE-最优量化网格相结合,以实现更低的量化误差和最先进的性能。

HIGGS 的运行时支持通过 FLUTE 库实现。目前仅支持 Llama 3 和 Llama 3.0 的 70B 和 405B 变体,以及 Gemma 2 的 8B 和 27B 变体。HIGGS 目前通常也不支持量化训练和反向传播。

运行以下命令安装 FLUTE。

CUDA 12.1
CUDA 11.8
pip install flute-kernel

使用要将模型量化到的位数创建 HiggsConfig

from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    quantization_config=HiggsConfig(bits=4),
    device_map="auto",
)

在官方 ISTA-DASLab collection 中查找已使用 HIGGS 预量化的模型。

torch.compile

HIGGS 完全兼容 torch.compile

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    quantization_config=HiggsConfig(bits=4),
    device_map="auto",
)

model = torch.compile(model)

请参考下表,了解在 RTX4090 上 Llama-3.1-8B-Instruct 的每秒前向传播基准。

批量大小 BF16(带 torch.compile HIGGS 4位(不带 torch.compile HIGGS 4位(带 torch.compile
1 59 41 124
4 57 42 123
16 56 41 120
< > 在 GitHub 上更新