Timm ❤️ Transformers:在 Transformers 中使用任何 timm 模型

发布于 2025 年 1 月 16 日
在 GitHub 上更新

在友好的 🤗 transformers 生态系统中,为**任何** timm 模型实现闪电般的推理速度、快速量化、torch.compile 加速和轻松微调。

隆重介绍 TimmWrapper——一个简单而强大的工具,释放了这一潜力。

在这篇文章中,我们将涵盖

  • timm 集成的工作原理及其为何能改变游戏规则。
  • 如何将 timm 模型与 🤗 transformers 集成。
  • 实践示例:pipeline、量化、微调等。

要跟随本博客文章,请运行以下命令安装最新版本的 transformerstimm

pip install -Uq transformers timm

查看包含所有代码示例和 notebook 的完整代码库:🔗 TimmWrapper 示例

什么是 timm?

PyTorch Image Models (timm) 库提供了丰富的最先进的计算机视觉模型,以及有用的层、工具、优化器和数据增强。截至本文撰写时,它在 GitHub 上拥有超过 32K 颗星,每日下载量超过 200K,是图像分类和特征提取(用于目标检测、分割、图像搜索等下游任务)的首选资源。

timm 拥有涵盖各种架构的预训练模型,简化了计算机视觉从业者的工作流程。

为何使用 timm 集成?

虽然 🤗 transformers 支持多种视觉模型,但 timm 提供了更广泛的集合,包括许多在 transformers 中不可用的移动端友好和高效的模型。

timm 集成弥补了这一差距,带来了两全其美的优势

  • Pipeline API 支持:轻松将任何 timm 模型插入到高级 transformers pipeline 中,以实现流线型推理。
  • 🧩 与 Auto 类兼容:虽然 timm 模型本身与 transformers 不兼容,但此集成使其能够与 Auto 类 API 无缝协作。
  • 快速量化:只需约 5 行代码,您就可以量化**任何** timm 模型以进行高效推理。
  • 🎯 使用 Trainer API 进行微调:使用 Trainer API 微调 timm 模型,甚至可以与低秩自适应 (LoRA) 等适配器集成。
  • 🔁 返回 timm:在 timm 中再次使用微调后的模型。
  • 🚀 Torch Compile 加速:利用 torch.compile 优化推理时间。

Pipeline API:使用 timm 模型进行图像分类

timm 集成的一个突出特点是它允许您利用 🤗 pipeline APIpipeline API 抽象了许多复杂性,使得加载预训练模型、执行推理和查看结果变得非常简单,只需几行代码即可完成。

让我们看看如何将 transformers pipeline 与 MobileNetV4 一起使用。该架构没有原生的 transformers 实现,但可以轻松地从 timm 中使用

from transformers import pipeline
import requests

# Load the image classification pipeline with a timm model
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")

# URL of the image to classify
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"

# Perform inference
outputs = image_classifier(url)

# Print the results
for output in outputs:
    print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")

输出:

Device set to use cpu
Label: tabby, tabby cat     Score: 0.69
Label: tiger cat            Score: 0.21
Label: Egyptian cat         Score: 0.02
Label: bee                  Score: 0.00
Label: marmoset             Score: 0.00

Gradio 集成:构建食物分类器演示 🍣

想要快速创建一个用于图像分类的交互式 Web 应用吗?Gradio 使您能够用最少的代码构建一个用户友好的界面。让我们将 Gradiopipeline API 结合起来,使用一个微调过的 timm ViT 模型来分类食物图像(我们将在后面的章节中介绍微调)。

以下是如何使用 timm 模型快速设置一个演示

import gradio as gr
from transformers import pipeline

# Load the image classification pipeline using a timm model
pipe = pipeline(
    "image-classification",
    model="ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
)

def classify(image):
    return pipe(image)[0]["label"]

demo = gr.Interface(
    fn=classify,
    inputs=gr.Image(type="pil"),
    outputs="text",
    examples=[["./sushi.png", "sushi"]]
)

demo.launch()

这是一个托管在 Hugging Face Spaces 上的实时示例。您可以直接在浏览器中测试!

Auto 类:简化模型加载

🤗 transformers 库提供了 Auto 类 来抽象化加载模型和处理器的复杂性。通过 TimmWrapper,您可以使用 AutoModelForImageClassificationAutoImageProcessor 轻松加载任何 timm 模型。

这是一个快速示例

from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
)
from transformers.image_utils import load_image

image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)

# Use Auto classes to load a timm model
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()

# Check the types
print(type(image_processor))  # TimmWrapperImageProcessor
print(type(model))            # TimmWrapperForImageClassification

运行量化的 timm 模型

量化是一种强大的技术,可以减小模型大小并加速推理,尤其适用于资源受限的设备。通过 timm 集成,您可以使用 bitsandbytes 中的 BitsAndBytesConfig,只需几行代码即可即时量化任何 timm 模型。

以下是量化一个 timm 模型是多么简单

from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"

model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
    checkpoint,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
)
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()

print(f"Original model size: {original_footprint / 1e6:.2f} MB")
print(f"Quantized model size: {quantized_footprint / 1e6:.2f} MB")
print(f"Reduction: {(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%")

输出

Original model size: 346.27 MB  
Quantized model size: 88.20 MB  
Reduction: 74.53%  

量化模型在推理时的性能与全精度模型几乎完全相同

模型 标签 准确率
原始模型 遥控器,遥控 0.35%
量化模型 遥控器,遥控 0.33%

timm 模型的监督式微调

使用 🤗 transformersTrainer API 微调 timm 模型是直接且高度灵活的。您可以使用 Trainer 类在自定义数据集上微调您的模型,该类处理训练循环、日志记录和评估。此外,您可以使用 LoRA (低秩自适应) 进行微调,以更少的参数高效地进行训练。

本节对标准微调和 LoRA 微调进行了简要概述,并提供了完整代码的链接。

使用 Trainer API 进行标准微调

Trainer API 使得用最少的代码设置训练变得容易。以下是微调设置的概要

from transformers import TrainingArguments, Trainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="my_model_output",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    load_best_model_at_end=True,
    push_to_hub=True,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

这种方法的显著之处在于,它反映了用于原生 transformers 模型的完全相同的工作流程,从而在不同模型类型之间保持了一致性。

这意味着您可以使用熟悉的 Trainer API 不仅微调 Transformers 模型,还可以微调**任何 timm 模型**——将 timm 库中强大的模型引入 Hugging Face 生态系统,只需进行最少的调整。这极大地拓宽了您可以使用相同可信赖的工具和工作流程进行微调的模型范围。

模型示例
Food-101 上微调的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101

LoRA 微调以实现高效训练

LoRA (低秩自适应) 允许您通过仅训练少量额外参数而不是完整的模型权重来高效地微调大型模型。这使得微调更快,并允许使用消费级硬件。您可以使用 PEFT 库通过 LoRA 微调一个 timm 模型。

以下是您可以如何设置它

from peft import LoraConfig, get_peft_model

model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["qkv"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["head"],
)

# Wrap the model with PEFT
lora_model = get_peft_model(model, lora_config)

lora_model.print_trainable_parameters()

使用 LoRA 的可训练参数

trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.77%

模型示例
Food-101 上进行 LoRA 微调的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101

LoRA 只是您可以应用于 timm 模型的众多高效适配器微调方法中的一个例子。timm 与 🤗 生态系统的集成为您开启了各种**参数高效微调 (PEFT)** 技术的大门,让您可以选择最适合您应用场景的方法。

使用 LoRA 微调模型进行推理

一旦模型经过 LoRA 微调,我们仅将适配器权重推送到 Hugging Face Hub。本节将帮助您下载适配器权重,将适配器权重与基础模型合并,然后进行推理。

from peft import PeftModel, PeftConfig

repo_name = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101"
config = PeftConfig.from_pretrained(repo_name)

model = AutoModelForImageClassification.from_pretrained(
    config.base_model_name_or_path,
    label2id=label2id,
    num_labels=num_labels,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)
inference_model = PeftModel.from_pretrained(model, repo_name)

# Make prediction with the model

image of sushi with prediction from a fine tuned model

双向集成

Ross (timm的创建者) 最喜欢的一个功能是,这种集成保持了完整的“双向”兼容性。也就是说,使用包装器,人们可以使用 transformerTrainer 在新数据集上微调 timm 模型,将结果模型发布到 Hugging Face hub,然后再次使用 timm.create_model('hf-hub:my_org/my_fine_tuned_model', pretrained=True)timm 中加载微调后的模型。

让我们看看如何用 timm 加载我们微调过的模型 ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101

checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"

config = AutoConfig.from_pretrained(checkpoint)

model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timm
model = model.eval()

image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(image).unsqueeze(0))

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
    print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")

输出

Label: sushi                Score: 0.98%
Label: spring_rolls         Score: 0.01%
Label: sashimi              Score: 0.00%
Label: club_sandwich        Score: 0.00%
Label: cannoli              Score: 0.00%

Torch Compile:即时加速

在 PyTorch 2.0 中使用 torch.compile,您只需一行代码即可编译模型,从而实现更快的推理速度timm 集成完全兼容 torch.compile。以下是一个快速基准测试,用于比较使用 TimmWrapper 时有无 torch.compile 的推理时间。

# Load the model and input
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to(device)
processed_input = image_processor(image, return_tensors="pt").to(device)

# Benchmark function
def run_benchmark(model, input_data, warmup_runs=5, benchmark_runs=300):
    # Warm-up phase
    model.eval()
    with torch.no_grad():
        for _ in range(warmup_runs):
            _ = model(**input_data)

    # Benchmark phase
    times = []
    with torch.no_grad():
        for _ in range(benchmark_runs):
            start_time = time.perf_counter()
            _ = model(**input_data)
            if device.type == "cuda":
                torch.cuda.synchronize(device=device)  # Ensure synchronization for CUDA
            times.append(time.perf_counter() - start_time)

    avg_time = sum(times) / benchmark_runs
    return avg_time

# Run benchmarks
time_no_compile = run_benchmark(model, processed_input)
compiled_model = torch.compile(model).to(device)
time_compile = run_benchmark(compiled_model, processed_input)

# Results
print(f"Without torch.compile: {time_no_compile:.4f} s")
print(f"With torch.compile: {time_compile:.4f} s")

compile timing

总结

timm 与 transformers 的集成为利用最先进的视觉模型开辟了新的大门,且只需最少的努力。无论您是想进行微调、量化,还是仅仅运行推理,这种集成都提供了一个统一的 API 来简化您的工作流程。

立即开始探索,解锁计算机视觉的新可能!

致谢

我们要向在 Transformers PR #34564 中促成此次集成的各位同仁表示衷心的感谢。排名不分先后,衷心感谢 Pavel Iakubovskii、Ross Wightman、Lysandre Debut、Pablo Montalvo、Arthur Zucker 和 Amy Roberts 所做的杰出工作。你们的共同努力使这个想法变成了现实,让每个人都能享受到这个功能!

社区

文章作者
此评论已被隐藏

对此非常兴奋,谢谢!我们正准备切换到 timm,这使得它变得更容易了!

也许一个简单天真的问题,我正在尝试编写一个演示训练脚本,从以下位置加载基础模型
TimmWrapperForImageClassification.from_pretrained("timm/mobilenetv4_conv_medium.e500_r256_in1k").to("cuda")

但随后在 food101 数据集上进行训练,只是为了说明在新自定义数据集上进行训练。训练正常,但推理返回的是动物名称作为标签。

当我加载微调模型时,是否应该设置 label2id, num_labels, id2label, 等参数?它似乎在训练期间将数据存储在某个地方,但 TrainingArguments 不允许我设置 TypeError: TrainingArguments.__init__() got an unexpected keyword argument 'label2id'

@davidrs 嗯,可能是标签处理有问题,在发布前对集成做了一个改动,以保持标签与 timm 的使用兼容(保留 label_names 字段而不是 id2label/label2id),实际上这两种情况混合在一起了,而且 @ariG23498 的许多微调都有 id2label,尽管在发布时我被告知它应该生成 label_names...

你有一个我可以看的公开模型吗?你是在 Transformers 中使用示例图像分类脚本 (https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) 还是像上面的例子一样在自定义脚本/notebook 中直接使用 Trainer?

我的错,我错过了博客顶部的链接,那里有更完整的代码示例,也许这能帮到我 https://github.com/ariG23498/timm-wrapper-examples/blob/main/%2304_sft.ipynb

在自定义脚本中直接使用 Trainer,我做了一个 Colab notebook 来说明我目前正在尝试的东西,在最后的预测测试中,标签不是食物。
https://colab.research.google.com/drive/14jTpetYR61B6EVoJ6o8_B8gi6-SiizCA?usp=sharing

@davidrs 示例中用目标数据集重置分类器/标签的这部分很重要。

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

如果你不这样做,我相信它会用原始标签、一个 imagenet-1k (1000) 类分类器等被推送。不过,我猜新数据集中最低的 'n' 个类别会被目标微调(如果类别更少,如果新数据集有更多类别,它会崩溃)。

顺便说一下,我刚在 pipeline 中发现一个 bug,由于一个 bug (https://github.com/huggingface/transformers/pull/35848),它默认应用了 sigmoid 而不是 softmax,所以如果你想要 softmax 概率,请添加 function_to_apply='softmax'……这不特定于 timm 集成,而且看起来已经存在一段时间了。我确认了如果你像上面那样设置标签,微调后的 timm 模型将用正确的标签进行预测,并且也应该用这些标签推送到 hub...

如何与 Optimum 集成并加载模型的 onnx 版本

注册登录 以发表评论