Transformers 文档

torchao

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

torchao

Open In Colab: Torchao Demo

torchao 是一个 PyTorch 架构优化库,支持自定义高性能数据类型、量化和稀疏性。它可以与 torch.compile 等原生 PyTorch 功能组合使用,以实现更快的推理和训练。

有关其他 torchao 功能,请参阅下表。

特性 描述
量化感知训练 (QAT) 使用最小的精度损失训练量化模型(参见 QAT README
Float8 训练 使用 float8 格式进行高吞吐量训练(参见 torchtitanAccelerate 文档)
稀疏性支持 用于更快推理的半结构化 (2:4) 稀疏性(参见 使用半结构化 (2:4) 稀疏性加速神经网络训练 博文)
优化器量化 使用 Adam 的 4 位和 8 位变体减少优化器状态内存
KV 缓存量化 通过更少的内存实现长上下文推理(参见 KV 缓存量化
自定义内核支持 使用您自己的 torch.compile 兼容操作
FSDP2 可与 FSDP2 组合进行训练

有关该库的更多详细信息,请参阅 torchao README.md

torchao 支持以下 量化技术

  • A16W8 Float8 动态量化
  • A16W8 Float8 仅权重量化
  • A8W8 Int8 动态量化
  • A16W8 Int8 仅权重量化
  • A16W4 Int4 仅权重量化
  • A16W4 Int4 仅权重量化 + 2:4 稀疏性
  • 自动量化

torchao 还支持按模块配置,方法是指定一个字典,键为模块的完全限定名称,值为其相应的量化配置。这允许跳过某些层的量化,并为不同模块使用不同的量化配置。

查看下表以了解您的硬件是否兼容。

组件 兼容性
CUDA 版本 ✅ cu118, cu126, cu128
XPU 版本 ✅ pytorch2.8
CPU ✅ 更改 device_map="cpu"(参见下面的示例)

使用以下命令从 PyPi 或 PyTorch 索引安装 torchao。

PyPi
PyTorch 索引
# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
# Stable release from Pypi which will default to CUDA 12.6
pip install --upgrade torchao transformers

如果您的 torchao 版本低于 0.10.0,您需要升级它,请参阅 弃用通知 以获取更多详细信息。

量化示例

TorchAO 提供各种量化配置。每种配置都可以使用 group_sizeschemelayout 等参数进行进一步自定义,以针对特定硬件和模型架构进行优化。

有关可用配置的完整列表,请参阅 量化 API 文档

您可以手动选择量化类型和设置,也可以自动选择量化类型。

创建一个 TorchAoConfig 并指定要量化的权重类型和 group_size(适用于 int8 仅权重和 int4 仅权重)。将 cache_implementation 设置为 "static" 以自动 torch.compile 前向方法。

我们将根据硬件(例如 A100 GPU、H100 GPU、CPU)展示推荐的量化方法的示例。

如果我们设置 cache_implementation="static",torchao 会在第一次推理时自动编译模型。每次修改批次大小或 max_new_tokens 时,模型都会重新编译。在 generate() 中传递 disable_compile=True 可以在不编译的情况下进行量化。

H100 GPU

float8-dynamic-and-weight-only
int4-weight-only
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig

quant_config = Float8DynamicActivationFloat8WeightConfig()
# or float8 weight only quantization
# quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> <hfoption id="int4-weight-only-24sparse">
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout

quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Sparse-Llama-3.1-8B-2of4",
    dtype=torch.float16,
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> </hfoptions>

A100 GPU

int8-dynamic-and-weight-only
int4-weight-only
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> <hfoption id="int4-weight-only-24sparse">
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout

quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Sparse-Llama-3.1-8B-2of4",
    dtype=torch.float16,
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption> </hfoptions>

Intel XPU

int8-dynamic-and-weight-only
int4-weight-only
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))

CPU

int8-dynamic-and-weight-only
int4-weight-only
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))

按模块量化

1. 为某些层跳过量化

使用 FqnToConfig,我们可以为所有层指定默认配置,同时跳过某些层的量化。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "meta-llama/Llama-3.1-8B-Instruct"

from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig
config = Int4WeightOnlyConfig(group_size=128)

# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
quant_config = FqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to(quantized_model.device, quantized_model.dtype)
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

2. 使用不同量化配置量化不同层(不使用正则表达式)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "facebook/opt-125m"

from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType

weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
embedding_config = IntxWeightOnlyConfig(
    weight_dtype=weight_dtype,
    granularity=granularity,
    mapping_type=mapping_type,
)
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
quant_config = FqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
# set `include_embedding` to True in order to include embedding in quantization
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", dtype=torch.bfloat16, quantization_config=quantization_config)
print("quantized model:", quantized_model)
# make sure embedding is quantized
print("embed_tokens weight:", quantized_model.model.decoder.embed_tokens.weight)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu", quantized_model.dtype)
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, cache_implementation="static")
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

3. 使用不同量化配置量化不同层(使用正则表达式)

我们还可以使用正则表达式来指定与匹配正则表达式的 module_fqn 的所有模块的配置。所有正则表达式都应以 re: 开头,例如 re:layers\..*\.gate_proj 将匹配所有类似于 layers.0.gate_proj 的层。有关文档,请参阅 此处

import logging

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

# Configure logging to see warnings and debug information
logging.basicConfig(
    level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)

# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)

model_id = "facebook/opt-125m"

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Int4WeightOnlyConfig,
    IntxWeightOnlyConfig,
    PerRow,
    PerAxis,
    FqnToConfig,
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)

float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))

qconfig_dict = {
    # highest priority
    "model.decoder.layers.3.self_attn.q_proj": int4wo,
    "model.decoder.layers.3.self_attn.k_proj": int4wo,
    "model.decoder.layers.3.self_attn.v_proj": int4wo,
    # vllm
    "model.decoder.layers.3.self_attn.qkv_proj": int4wo,

    "re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
    # this should not take effect and we'll fallback to _default
    # since no full mach (missing `j` in the end)
    "re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
    # vllm
    "re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,

    "_default": intxwo,
}
quant_config = FqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)

# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
).to(quantized_model.device, quantized_model.dtype)
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)

correct_output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])


# Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
    save_to,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
    # quantization_config=quantization_config,
)

generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])

assert(correct_output_text == output_text)

序列化

使用 save_pretrained(以 safetensors 格式)保存量化模型仅支持 torchao >= v0.15。对于任何低于此版本的版本,只能使用 torch.save 手动保存为不安全的 .bin 检查点。

本地保存
推送到 Hugging Face Hub
# torchao >= 0.15
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained("llama3-8b-int4wo-128")

加载量化模型

加载量化模型取决于量化方案。对于 int8 和 float8 等量化方案,您可以在任何设备上对模型进行量化,也可以在任何设备上加载它。下面的示例演示了在 CPU 上量化模型,然后在 CUDA 或 XPU 上加载它。

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig

quant_config = Int8WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)
# save the quantized model
output_dir = "llama-3.1-8b-torchao-int8"
quantized_model.save_pretrained(output_dir)

# reload the quantized model
reloaded_model = AutoModelForCausalLM.from_pretrained(
    output_dir,
    device_map="auto",
    dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)

output = reloaded_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

对于 int4,模型只能在量化时所在的同一设备上加载,因为布局是特定于设备的。下面的示例演示了在 CPU 上量化和加载模型。

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import Int4CPULayout

quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)
# save the quantized model
output_dir = "llama-3.1-8b-torchao-int4-cpu"
quantized_model.save_pretrained(output_dir)

# reload the quantized model
reloaded_model = AutoModelForCausalLM.from_pretrained(
    output_dir,
    device_map="cpu",
    dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)

output = reloaded_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

⚠️ 弃用通知

从 0.10.0 版本开始,用于量化配置的基于字符串的 API(例如 TorchAoConfig("int4_weight_only", group_size=128))已被弃用,并将在未来版本中删除。

请改用新的 AOBaseConfig 基础方法

# Old way (deprecated)
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)

# New way (recommended)
from torchao.quantization import Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)

新 API 提供了更大的灵活性、更好的类型安全以及对 torchao 中所有可用功能的访问。

迁移指南

以下是如何从常见的字符串标识符迁移到其 AOBaseConfig 等效项

旧的字符串 API 新的 AOBaseConfig API
"int4_weight_only" Int4WeightOnlyConfig()
"int8_weight_only" Int8WeightOnlyConfig()
"int8_dynamic_activation_int8_weight" Int8DynamicActivationInt8WeightConfig()

所有配置对象都接受参数以进行自定义(例如 group_sizeschemelayout)。

资源

为了更好地了解预期的性能,请查看针对各种模型和 CUDA 及 XPU 后端的 基准测试。您也可以运行下面的代码来为您自己对模型进行基准测试。

from torch._inductor.utils import do_bench_using_profiling
from typing import Callable

def benchmark_fn(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

为获得最佳性能,您可以使用推荐的设置,方法是调用 torchao.quantization.utils.recommended_inductor_config_setter()

有关更多示例和文档,请参阅 其他可用量化技术

问题

如果在使用 Transformers 集成时遇到任何问题,请在 Transformers 仓库中打开一个 issue。对于直接与 torchao 相关的问题,请在 torchao 仓库中打开一个 issue。

在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.