Transformers 文档

TorchAO

Hugging Face's logo
加入 Hugging Face 社区

并获取增强型文档体验

入门

TorchAO

TorchAO 是一个用于 PyTorch 的架构优化库,它提供高性能数据类型、优化技术和内核,用于推理和训练,并与原生 PyTorch 功能(如 `torch.compile`、FSDP 等)具有可组合性。一些基准数据可以在 此处 找到。

在开始之前,请确保以下库已安装其最新版本。

pip install --upgrade torch torchao
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
# We support int4_weight_only, int8_weight_only and int8_dynamic_activation_int8_weight
# More examples and documentations for arguments can be found in https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# compile the quantized model to get speedup
import torchao
torchao.quantization.utils.recommended_inductor_config_setter()
quantized_model = torch.compile(quantized_model, mode="max-autotune")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

TorchAO 量化是使用张量子类实现的,目前它不适用于 Hugging Face 序列化,包括 safetensor 选项和 非 safetensor 选项,当它可工作时,我们将在此处更新说明。

< > 在 GitHub 上更新