开源 AI 食谱文档
使用 TRL 在消费级 GPU 上微调 SmolVLM
并获得增强的文档体验
开始使用
使用 TRL 在消费级 GPU 上微调 SmolVLM
作者: Sergio Paniego
在本教程中,我们将演示如何利用强大的 Transformer 强化学习库 (TRL),使用 Hugging Face 生态系统微调一个 Vision Language Model (VLM)。本分步指南将使您能够自定义 VLM 以完成特定任务,即使是在消费级 GPU 上也能实现。
🌟 模型与数据集概述
在本笔记中,我们将使用 ChartQA 数据集对 SmolVLM 模型进行微调。SmolVLM 是一款高性能、内存高效的模型,是此任务的理想选择。ChartQA 数据集包含各种图表类型的图像以及问答对,为增强模型的视觉问答 (VQA) 能力提供了宝贵的资源。这些技能对于数据分析、商业智能和教育工具等一系列实际应用至关重要。
💡 注意:我们正在微调的 instruct 模型已经在此数据集上进行了训练,因此它熟悉数据。然而,这对于理解微调技术来说是一个宝贵的教育练习。有关用于训练此模型的数据集的完整列表,请查看此文档。
📖 附加资源
通过这些资源扩展您对视觉语言模型和相关工具的知识
- 食谱中的多模态教程: 探索多模态模型的实用教程,包括 RAG 管道和微调。我们已经有使用 TRL 微调 VLM 的教程,请参考它获取更多详细信息。
- TRL 社区教程: 包含大量教程,可加深您对 TRL 及其应用的理解。
有了这些资源,您将能够更深入地探索 VLM 的世界,并突破它们的极限!
本笔记使用 L4 GPU 进行测试。
1. 安装依赖项
让我们先安装微调所需的基本库吧!🚀
!pip install -U -q transformers trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.53.0.dev0, trl==0.20.0.dev0, datasets==3.6.0, bitsandbytes==0.46.0, peft==0.15.2, accelerate==1.8.1
!pip install -q flash-attn --no-build-isolation
使用您的 Hugging Face 账户进行身份验证,以便直接从本 Notebook 保存和分享您的模型 🗝️。
from huggingface_hub import notebook_login
notebook_login()
2. 加载数据集 📁
我们将加载 HuggingFaceM4/ChartQA 数据集,该数据集提供图表图像以及相应的问答对——非常适合微调视觉问答模型。
我们将创建一个系统消息,使 VLM 充当图表分析专家,提供关于图表图像的简洁答案。
system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""
我们将数据集格式化为聊天机器人结构,每次交互都包含系统消息、图像、用户查询和答案。
💡有关使用此模型的更多提示,请查看模型卡。
def format_data(sample):
return [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": sample["image"],
},
{
"type": "text",
"text": sample["query"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["label"][0]}],
},
]
出于教育目的,我们将只加载数据集中每个分割的 10%。在实际场景中,您将加载整个数据集。
from datasets import load_dataset
dataset_id = "HuggingFaceM4/ChartQA"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:10%]", "val[:10%]", "test[:10%]"])
让我们看看数据集的结构。它包括一个图像、一个查询、一个标签(答案)和一个我们将丢弃的第四个特征。
train_dataset
现在,让我们使用聊天机器人结构格式化数据。这将为模型设置交互。
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]
train_dataset[200]
3. 加载模型并检查性能!🤔
现在我们已经加载了数据集,是时候加载 HuggingFaceTB/SmolVLM-Instruct 模型了,这是一个 2B 参数的视觉语言模型 (VLM),它提供了最先进 (SOTA) 的性能,同时在内存使用方面效率很高。
为了更广泛地比较最先进的 VLM,请探索 WildVision Arena 和 OpenVLM 排行榜,在那里您可以找到在各种基准测试中表现最佳的模型。
import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
model_id = "HuggingFaceTB/SmolVLM-Instruct"
接下来,我们将加载模型和分词器,为推理做准备。
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
为了评估模型的性能,我们将使用数据集中的一个样本。首先,让我们检查此样本的内部结构,以了解数据的组织方式。
train_dataset[1]
我们将使用不带系统消息的样本来评估 VLM 的原始理解能力。这是我们将使用的输入:
train_dataset[1][1:2]
现在,让我们看一下与样本对应的图表。您能根据视觉信息回答查询吗?
>>> train_dataset[1][1]["content"][0]["image"]
让我们创建一个方法,该方法将模型、处理器和样本作为输入,以生成模型的答案。这将使我们能够简化推理过程并轻松评估 VLM 的性能。
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
sample[1:2], add_generation_prompt=True # Use the sample without the system message
)
image_inputs = []
image = sample[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
# Prepare the inputs for the model
model_inputs = processor(
# text=[text_input],
text=text_input,
images=image_inputs,
return_tensors="pt",
).to(
device
) # Move inputs to the specified device
# Generate text with the model
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] # Return the first decoded output text
output = generate_text_from_sample(model, processor, train_dataset[1])
output
看来模型引用了错误的行,导致它失败。为了提高其性能,我们可以使用更多相关数据对模型进行微调,以确保它更好地理解上下文并提供更准确的响应。
移除模型并清理 GPU
在下一节中进行模型训练之前,让我们清除当前变量并清理 GPU 以释放资源。
>>> import gc
>>> import time
>>> def clear_memory():
... # Delete variables if they exist in the current global scope
... if "inputs" in globals():
... del globals()["inputs"]
... if "model" in globals():
... del globals()["model"]
... if "processor" in globals():
... del globals()["processor"]
... if "trainer" in globals():
... del globals()["trainer"]
... if "peft_model" in globals():
... del globals()["peft_model"]
... if "bnb_config" in globals():
... del globals()["bnb_config"]
... time.sleep(2)
... # Garbage collection and clearing CUDA memory
... gc.collect()
... time.sleep(2)
... torch.cuda.empty_cache()
... torch.cuda.synchronize()
... time.sleep(2)
... gc.collect()
... time.sleep(2)
... print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
... print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
>>> clear_memory()
GPU allocated memory: 0.01 GB GPU reserved memory: 0.06 GB
4. 使用 TRL 微调模型
4.1 加载量化模型进行训练 ⚙️
接下来,我们将使用 bitsandbytes 加载量化模型。如果您想了解更多关于量化的信息,请查看这篇博客文章或这篇文章。
from transformers import BitsAndBytesConfig
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
4.2 设置 QLoRA 和 SFTConfig 🚀
接下来,我们将为我们的训练设置配置 QLoRA。QLoRA 通过减少内存占用,实现大型模型的高效微调。与使用低秩近似的传统 LoRA 不同,QLoRA 进一步量化了 LoRA 适配器权重,从而实现更低的内存使用和更快的训练。
为了提高效率,我们还可以在 QLoRA 实现期间利用分页优化器或8 位优化器。这种方法提高了内存效率并加快了计算速度,使其成为优化我们模型而又不牺牲性能的理想选择。
>>> from peft import LoraConfig, get_peft_model
>>> # Configure LoRA
>>> peft_config = LoraConfig(
... r=8,
... lora_alpha=8,
... lora_dropout=0.1,
... target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
... use_dora=True,
... init_lora_weights="gaussian",
... )
>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)
>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 11,269,248 || all params: 2,257,542,128 || trainable%: 0.4992
我们将使用监督微调 (SFT) 来提高模型在特定任务上的性能。为此,我们将使用 TRL 库中的 SFTConfig 类定义训练参数。SFT 利用标记数据帮助模型生成更准确的响应,使其适应任务。这种方法增强了模型理解和更有效地响应视觉查询的能力。
from trl import SFTConfig
# Configure training arguments using SFTConfig
training_args = SFTConfig(
output_dir="smolvlm-instruct-trl-sft-ChartQA",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=50,
learning_rate=1e-4,
weight_decay=0.01,
logging_steps=25,
save_strategy="steps",
save_steps=25,
save_total_limit=1,
optim="adamw_torch_fused",
bf16=True,
push_to_hub=True,
report_to="tensorboard",
remove_unused_columns=False,
gradient_checkpointing=True,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
)
4.3 训练模型 🏃
为确保数据在训练期间正确地为模型构建,我们需要定义一个 collator 函数。此函数将处理数据集输入的格式化和批处理,确保数据与训练正确对齐。
👉 更多详情,请查看官方 TRL 示例脚本。
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
def collate_fn(examples):
texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
image_inputs = []
for example in examples:
image = example[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
batch["labels"] = labels
return batch
现在,我们将定义 SFTTrainer,它是 transformers.Trainer 类的包装器,并继承其属性和方法。当提供 PeftConfig 对象时,此类别通过正确初始化 PeftModel 来简化微调过程。通过使用 SFTTrainer
,我们可以有效地管理训练流程,并确保我们的视觉语言模型获得流畅的微调体验。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collate_fn,
peft_config=peft_config,
processing_class=processor.tokenizer,
)
是时候训练模型了!🎉
trainer.train()
让我们保存结果 💾
trainer.save_model(training_args.output_dir)
5. 测试微调模型 🔍
现在我们的视觉语言模型 (VLM) 已经微调完成,是时候评估其性能了!在本节中,我们将使用 ChartQA 数据集中的示例来测试模型,以评估其根据图表图像回答问题的准确性。让我们深入了解结果,看看它的表现如何!🚀
让我们清理 GPU 内存以确保最佳性能 🧹
>>> clear_memory()
GPU allocated memory: 16.34 GB GPU reserved memory: 18.69 GB
我们将使用与之前相同的流程重新加载基础模型。
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
我们将把训练好的适配器附加到预训练模型。该适配器包含训练期间进行的微调调整,使基础模型能够利用新知识,同时保持其核心参数不变。通过集成适配器,我们在不改变模型原始结构的情况下增强了模型的功能。
adapter_path = "sergiopaniego/smolvlm-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)
让我们在一个未见过的样本上评估模型。
test_dataset[20][:2]
>>> test_dataset[20][1]["content"][0]["image"]
output = generate_text_from_sample(model, processor, test_dataset[20])
output
模型已成功学会按照数据集中指定的方式响应查询。我们已经实现了目标!🎉✨
💻 我已经开发了一个用于测试模型的示例应用程序,您可以在这里找到。您可以轻松将其与另一个包含预训练模型的 Space 进行比较,该 Space 可在这里获取。
from IPython.display import IFrame
IFrame(src="https://sergiopaniego-smolvlm-trl-sft-chartqa.hf.space", width=1000, height=800)
6. 继续学习之旅 🧑🎓️
为了进一步提升您使用多模态模型的技能,我建议您查看本笔记开头分享的资源,或重新访问使用 Hugging Face 生态系统 (TRL) 微调视觉语言模型 (Qwen2-VL-7B) 中同名的部分。
这些资源将帮助您加深对多模态学习的知识和专业技能。
< > 在 GitHub 上更新