Timm ❤️ Transformers:在 Transformers 中使用任何 timm 模型
在友好的 🤗
transformers
生态系统中,为**任何** timm
模型实现闪电般的推理速度、快速量化、torch.compile
加速和轻松微调。
隆重介绍 TimmWrapper
——一个简单而强大的工具,释放了这一潜力。
在这篇文章中,我们将涵盖
- timm 集成的工作原理及其为何能改变游戏规则。
- 如何将
timm
模型与 🤗transformers
集成。 - 实践示例:pipeline、量化、微调等。
要跟随本博客文章,请运行以下命令安装最新版本的
transformers
和timm
:pip install -Uq transformers timm
查看包含所有代码示例和 notebook 的完整代码库:🔗 TimmWrapper 示例
什么是 timm?
PyTorch Image Models (timm
) 库提供了丰富的最先进的计算机视觉模型,以及有用的层、工具、优化器和数据增强。截至本文撰写时,它在 GitHub 上拥有超过 32K 颗星,每日下载量超过 200K,是图像分类和特征提取(用于目标检测、分割、图像搜索等下游任务)的首选资源。
timm
拥有涵盖各种架构的预训练模型,简化了计算机视觉从业者的工作流程。
为何使用 timm 集成?
虽然 🤗 transformers
支持多种视觉模型,但 timm
提供了更广泛的集合,包括许多在 transformers 中不可用的移动端友好和高效的模型。
timm
集成弥补了这一差距,带来了两全其美的优势
- ✅ Pipeline API 支持:轻松将任何
timm
模型插入到高级transformers
pipeline 中,以实现流线型推理。 - 🧩 与 Auto 类兼容:虽然
timm
模型本身与transformers
不兼容,但此集成使其能够与Auto
类 API 无缝协作。 - ⚡ 快速量化:只需约 5 行代码,您就可以量化**任何**
timm
模型以进行高效推理。 - 🎯 使用 Trainer API 进行微调:使用
Trainer
API 微调timm
模型,甚至可以与低秩自适应 (LoRA) 等适配器集成。 - 🔁 返回 timm:在
timm
中再次使用微调后的模型。 - 🚀 Torch Compile 加速:利用
torch.compile
优化推理时间。
Pipeline API:使用 timm 模型进行图像分类
timm
集成的一个突出特点是它允许您利用 🤗 pipeline
API。pipeline
API 抽象了许多复杂性,使得加载预训练模型、执行推理和查看结果变得非常简单,只需几行代码即可完成。
让我们看看如何将 transformers pipeline 与 MobileNetV4 一起使用。该架构没有原生的 transformers
实现,但可以轻松地从 timm
中使用
from transformers import pipeline
import requests
# Load the image classification pipeline with a timm model
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")
# URL of the image to classify
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
# Perform inference
outputs = image_classifier(url)
# Print the results
for output in outputs:
print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")
输出:
Device set to use cpu
Label: tabby, tabby cat Score: 0.69
Label: tiger cat Score: 0.21
Label: Egyptian cat Score: 0.02
Label: bee Score: 0.00
Label: marmoset Score: 0.00
Gradio 集成:构建食物分类器演示 🍣
想要快速创建一个用于图像分类的交互式 Web 应用吗?Gradio 使您能够用最少的代码构建一个用户友好的界面。让我们将 Gradio 与 pipeline
API 结合起来,使用一个微调过的 timm ViT 模型来分类食物图像(我们将在后面的章节中介绍微调)。
以下是如何使用 timm
模型快速设置一个演示
import gradio as gr
from transformers import pipeline
# Load the image classification pipeline using a timm model
pipe = pipeline(
"image-classification",
model="ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
)
def classify(image):
return pipe(image)[0]["label"]
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs="text",
examples=[["./sushi.png", "sushi"]]
)
demo.launch()
这是一个托管在 Hugging Face Spaces 上的实时示例。您可以直接在浏览器中测试!
Auto 类:简化模型加载
🤗 transformers
库提供了 Auto 类 来抽象化加载模型和处理器的复杂性。通过 TimmWrapper
,您可以使用 AutoModelForImageClassification
和 AutoImageProcessor
轻松加载任何 timm
模型。
这是一个快速示例
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
)
from transformers.image_utils import load_image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)
# Use Auto classes to load a timm model
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
# Check the types
print(type(image_processor)) # TimmWrapperImageProcessor
print(type(model)) # TimmWrapperForImageClassification
运行量化的 timm 模型
量化是一种强大的技术,可以减小模型大小并加速推理,尤其适用于资源受限的设备。通过 timm
集成,您可以使用 bitsandbytes
中的 BitsAndBytesConfig
,只需几行代码即可即时量化任何 timm
模型。
以下是量化一个 timm
模型是多么简单
from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
checkpoint,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()
print(f"Original model size: {original_footprint / 1e6:.2f} MB")
print(f"Quantized model size: {quantized_footprint / 1e6:.2f} MB")
print(f"Reduction: {(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%")
输出
Original model size: 346.27 MB
Quantized model size: 88.20 MB
Reduction: 74.53%
量化模型在推理时的性能与全精度模型几乎完全相同
模型 | 标签 | 准确率 |
---|---|---|
原始模型 | 遥控器,遥控 | 0.35% |
量化模型 | 遥控器,遥控 | 0.33% |
timm
模型的监督式微调
使用 🤗 transformers
的 Trainer
API 微调 timm
模型是直接且高度灵活的。您可以使用 Trainer
类在自定义数据集上微调您的模型,该类处理训练循环、日志记录和评估。此外,您可以使用 LoRA (低秩自适应) 进行微调,以更少的参数高效地进行训练。
本节对标准微调和 LoRA 微调进行了简要概述,并提供了完整代码的链接。
使用 Trainer
API 进行标准微调
Trainer
API 使得用最少的代码设置训练变得容易。以下是微调设置的概要
from transformers import TrainingArguments, Trainer
# Define training arguments
training_args = TrainingArguments(
output_dir="my_model_output",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
load_best_model_at_end=True,
push_to_hub=True,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Start training
trainer.train()
这种方法的显著之处在于,它反映了用于原生 transformers
模型的完全相同的工作流程,从而在不同模型类型之间保持了一致性。
这意味着您可以使用熟悉的 Trainer
API 不仅微调 Transformers 模型,还可以微调**任何 timm
模型**——将 timm
库中强大的模型引入 Hugging Face 生态系统,只需进行最少的调整。这极大地拓宽了您可以使用相同可信赖的工具和工作流程进行微调的模型范围。
模型示例
在 Food-101 上微调的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101
LoRA 微调以实现高效训练
LoRA (低秩自适应) 允许您通过仅训练少量额外参数而不是完整的模型权重来高效地微调大型模型。这使得微调更快,并允许使用消费级硬件。您可以使用 PEFT 库通过 LoRA 微调一个 timm
模型。
以下是您可以如何设置它
from peft import LoraConfig, get_peft_model
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["qkv"],
lora_dropout=0.1,
bias="none",
modules_to_save=["head"],
)
# Wrap the model with PEFT
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()
使用 LoRA 的可训练参数
trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.77%
模型示例
在 Food-101 上进行 LoRA 微调的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101
LoRA 只是您可以应用于 timm
模型的众多高效适配器微调方法中的一个例子。timm
与 🤗 生态系统的集成为您开启了各种**参数高效微调 (PEFT)** 技术的大门,让您可以选择最适合您应用场景的方法。
使用 LoRA 微调模型进行推理
一旦模型经过 LoRA 微调,我们仅将适配器权重推送到 Hugging Face Hub。本节将帮助您下载适配器权重,将适配器权重与基础模型合并,然后进行推理。
from peft import PeftModel, PeftConfig
repo_name = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101"
config = PeftConfig.from_pretrained(repo_name)
model = AutoModelForImageClassification.from_pretrained(
config.base_model_name_or_path,
label2id=label2id,
num_labels=num_labels,
id2label=id2label,
ignore_mismatched_sizes=True,
)
inference_model = PeftModel.from_pretrained(model, repo_name)
# Make prediction with the model
双向集成
Ross (timm
的创建者) 最喜欢的一个功能是,这种集成保持了完整的“双向”兼容性。也就是说,使用包装器,人们可以使用 transformer
的 Trainer
在新数据集上微调 timm 模型,将结果模型发布到 Hugging Face hub,然后再次使用 timm.create_model('hf-hub:my_org/my_fine_tuned_model', pretrained=True)
在 timm
中加载微调后的模型。
让我们看看如何用 timm
加载我们微调过的模型 ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101
checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
config = AutoConfig.from_pretrained(checkpoint)
model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timm
model = model.eval()
image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(image).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")
输出
Label: sushi Score: 0.98%
Label: spring_rolls Score: 0.01%
Label: sashimi Score: 0.00%
Label: club_sandwich Score: 0.00%
Label: cannoli Score: 0.00%
Torch Compile:即时加速
在 PyTorch 2.0 中使用 torch.compile
,您只需一行代码即可编译模型,从而实现更快的推理速度。timm
集成完全兼容 torch.compile
。以下是一个快速基准测试,用于比较使用 TimmWrapper
时有无 torch.compile
的推理时间。
# Load the model and input
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to(device)
processed_input = image_processor(image, return_tensors="pt").to(device)
# Benchmark function
def run_benchmark(model, input_data, warmup_runs=5, benchmark_runs=300):
# Warm-up phase
model.eval()
with torch.no_grad():
for _ in range(warmup_runs):
_ = model(**input_data)
# Benchmark phase
times = []
with torch.no_grad():
for _ in range(benchmark_runs):
start_time = time.perf_counter()
_ = model(**input_data)
if device.type == "cuda":
torch.cuda.synchronize(device=device) # Ensure synchronization for CUDA
times.append(time.perf_counter() - start_time)
avg_time = sum(times) / benchmark_runs
return avg_time
# Run benchmarks
time_no_compile = run_benchmark(model, processed_input)
compiled_model = torch.compile(model).to(device)
time_compile = run_benchmark(compiled_model, processed_input)
# Results
print(f"Without torch.compile: {time_no_compile:.4f} s")
print(f"With torch.compile: {time_compile:.4f} s")
总结
timm
与 transformers 的集成为利用最先进的视觉模型开辟了新的大门,且只需最少的努力。无论您是想进行微调、量化,还是仅仅运行推理,这种集成都提供了一个统一的 API 来简化您的工作流程。
立即开始探索,解锁计算机视觉的新可能!
致谢
我们要向在 Transformers PR #34564 中促成此次集成的各位同仁表示衷心的感谢。排名不分先后,衷心感谢 Pavel Iakubovskii、Ross Wightman、Lysandre Debut、Pablo Montalvo、Arthur Zucker 和 Amy Roberts 所做的杰出工作。你们的共同努力使这个想法变成了现实,让每个人都能享受到这个功能!