开源 AI 食谱文档
在消费级 GPU 上使用 TRL 微调 SmolVLM
并获得增强的文档体验
开始使用
在消费级 GPU 上使用 TRL 微调 SmolVLM
在本食谱中,我们将演示如何使用 Hugging Face 生态系统微调一个微小的 🤏 视觉语言模型 (VLM),利用强大的 Transformer 强化学习库 (TRL)。这份分步指南将使您能够为特定任务定制 VLM,即使在消费级 GPU 上也能实现。
🌟 模型和数据集概述
在本笔记本中,我们将使用 SmolVLM 模型和 ChartQA 数据集进行微调。SmolVLM 是一款高性能且内存高效的模型,使其成为此任务的理想选择。ChartQA 数据集包含各种图表类型的图像,并配有问答对,为增强模型的视觉问答 (VQA) 能力提供了宝贵的资源。这些技能对于一系列实际应用至关重要,包括数据分析、商业智能和教育工具。
💡 注意: 我们正在微调的 instruct 模型已经在这个数据集上进行了训练,因此它对数据很熟悉。但是,这可以作为理解微调技术的宝贵教育练习。有关用于训练此模型的完整数据集列表,请查看此文档。
📖 更多资源
通过以下资源扩展您对视觉语言模型和相关工具的知识
- 食谱中的多模态食谱: 探索多模态模型的实用食谱,包括 RAG 管道和微调。我们已经有使用 TRL 微调 VLM 的食谱,请参考它了解更多详情。
- TRL 社区教程: 大量教程可加深您对 TRL 及其应用的理解。
有了这些资源,您将能够更深入地探索 VLM 的世界,并突破它们可以实现的界限!
此笔记本使用 L4 GPU 进行测试。
1. 安装依赖项
让我们开始安装微调所需的必要库!🚀
!pip install -U -q transformers trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.46.3, trl==0.12.1, datasets==3.1.0, bitsandbytes==0.45.0, peft==0.13.2, accelerate==1.1.1
!pip install -q flash-attn --no-build-isolation
使用您的 Hugging Face 帐户进行身份验证,以便直接从该笔记本保存和共享您的模型 🗝️。
from huggingface_hub import notebook_login
notebook_login()
2. 加载数据集 📁
我们将加载 HuggingFaceM4/ChartQA 数据集,该数据集提供图表图像以及相应的问答对——非常适合微调视觉问答模型。
我们将创建一个系统消息,使 VLM 充当图表分析专家,提供关于图表图像的简明答案。
system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""
我们将数据集格式化为聊天机器人结构,其中包含每个交互的系统消息、图像、用户查询和答案。
💡有关使用此模型的更多技巧,请查看模型卡片。
def format_data(sample):
return [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": sample["image"],
},
{
"type": "text",
"text": sample["query"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["label"][0]}],
},
]
出于教育目的,我们将仅加载数据集中每个拆分的 10%。在实际应用场景中,您将加载整个数据集。
from datasets import load_dataset
dataset_id = "HuggingFaceM4/ChartQA"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=["train[:10%]", "val[:10%]", "test[:10%]"])
让我们看看数据集结构。它包括图像、查询、标签(答案)和我们将要丢弃的第四个特征。
train_dataset
现在,让我们使用聊天机器人结构格式化数据。这将为模型设置交互。
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]
train_dataset[200]
3. 加载模型并检查性能!🤔
现在我们已经加载了数据集,是时候加载 HuggingFaceTB/SmolVLM-Instruct,这是一个 2B 参数视觉语言模型 (VLM),它在内存使用方面高效的同时,还提供最先进 (SOTA) 的性能。
要更广泛地比较最先进的 VLM,请浏览 WildVision Arena 和 OpenVLM Leaderboard,您可以在其中找到各种基准测试中性能最佳的模型。
import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
model_id = "HuggingFaceTB/SmolVLM-Instruct"
接下来,我们将加载模型和 tokenizer 以准备进行推理。
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
为了评估模型的性能,我们将使用数据集中的一个样本。首先,让我们检查此样本的内部结构,以了解数据的组织方式。
train_dataset[1]
我们将使用不带系统消息的样本来评估 VLM 的原始理解能力。这是我们将使用的输入
train_dataset[1][1:2]
现在,让我们看一下与样本对应的图表。您可以根据视觉信息回答查询吗?
>>> train_dataset[1][1]["content"][0]["image"]
让我们创建一个方法,该方法接受模型、处理器和样本作为输入,以生成模型的答案。这将使我们能够简化推理过程并轻松评估 VLM 的性能。
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
sample[1:2], add_generation_prompt=True # Use the sample without the system message
)
image_inputs = []
image = sample[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
# Prepare the inputs for the model
model_inputs = processor(
# text=[text_input],
text=text_input,
images=image_inputs,
return_tensors="pt",
).to(
device
) # Move inputs to the specified device
# Generate text with the model
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] # Return the first decoded output text
output = generate_text_from_sample(model, processor, train_dataset[1])
output
模型似乎引用了错误的线,导致其失败。为了提高其性能,我们可以使用更多相关数据微调模型,以确保它更好地理解上下文并提供更准确的响应。
移除模型并清理 GPU
在下一节中继续训练模型之前,让我们清除当前变量并清理 GPU 以释放资源。
>>> import gc
>>> import time
>>> def clear_memory():
... # Delete variables if they exist in the current global scope
... if "inputs" in globals():
... del globals()["inputs"]
... if "model" in globals():
... del globals()["model"]
... if "processor" in globals():
... del globals()["processor"]
... if "trainer" in globals():
... del globals()["trainer"]
... if "peft_model" in globals():
... del globals()["peft_model"]
... if "bnb_config" in globals():
... del globals()["bnb_config"]
... time.sleep(2)
... # Garbage collection and clearing CUDA memory
... gc.collect()
... time.sleep(2)
... torch.cuda.empty_cache()
... torch.cuda.synchronize()
... time.sleep(2)
... gc.collect()
... time.sleep(2)
... print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
... print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
>>> clear_memory()
GPU allocated memory: 0.01 GB GPU reserved memory: 0.06 GB
4. 使用 TRL 微调模型
4.1 加载用于训练的量化模型 ⚙️
接下来,我们将使用 bitsandbytes 加载量化模型。如果您想了解有关量化的更多信息,请查看 这篇博文 或 这篇博文。
from transformers import BitsAndBytesConfig
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
4.2 设置 QLoRA 和 SFTConfig 🚀
接下来,我们将为我们的训练设置配置 QLoRA。QLoRA 通过减少内存占用量,可以高效地微调大型模型。与使用低秩近似的传统 LoRA 不同,QLoRA 进一步量化 LoRA 适配器权重,从而实现更低的内存使用率和更快的训练速度。
为了提高效率,我们还可以在 QLoRA 实施期间利用分页优化器或 8 位优化器。这种方法增强了内存效率并加快了计算速度,使其成为优化我们的模型而不牺牲性能的理想选择。
>>> from peft import LoraConfig, get_peft_model
>>> # Configure LoRA
>>> peft_config = LoraConfig(
... r=8,
... lora_alpha=8,
... lora_dropout=0.1,
... target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
... use_dora=True,
... init_lora_weights="gaussian",
... )
>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)
>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 11,269,248 || all params: 2,257,542,128 || trainable%: 0.4992
我们将使用监督式微调 (SFT) 来提高模型在特定任务上的性能。为了实现这一点,我们将使用 SFTConfig 类从 TRL 库定义训练参数。SFT 利用标记数据来帮助模型生成更准确的响应,使其适应任务。这种方法增强了模型理解和更有效地响应视觉查询的能力。
from trl import SFTConfig
# Configure training arguments using SFTConfig
training_args = SFTConfig(
output_dir="smolvlm-instruct-trl-sft-ChartQA",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=50,
learning_rate=1e-4,
weight_decay=0.01,
logging_steps=25,
save_strategy="steps",
save_steps=25,
save_total_limit=1,
optim="adamw_torch_fused",
bf16=True,
push_to_hub=True,
report_to="tensorboard",
remove_unused_columns=False,
gradient_checkpointing=True,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
)
4.3 训练模型 🏃
为了确保在训练期间数据结构正确,我们需要定义一个 collator 函数。此函数将处理数据集输入的格式化和批处理,确保数据正确对齐以进行训练。
👉 有关更多详细信息,请查看官方 TRL 示例脚本。
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
def collate_fn(examples):
texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
image_inputs = []
for example in examples:
image = example[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels
labels[labels == image_token_id] = -100 # Mask image token IDs in labels
batch["labels"] = labels
return batch
现在,我们将定义 SFTTrainer,它是 transformers.Trainer 类的包装器,并继承其属性和方法。当提供 PeftModel 对象时,此类通过正确初始化 PeftConfig 对象来简化微调过程。通过使用 SFTTrainer
,我们可以有效地管理训练工作流程,并确保我们的视觉语言模型获得流畅的微调体验。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collate_fn,
peft_config=peft_config,
tokenizer=processor.tokenizer,
)
模型训练时间到!🎉
trainer.train()
让我们保存结果 💾
trainer.save_model(training_args.output_dir)
5. 测试微调后的模型 🔍
现在我们的视觉语言模型 (VLM) 已经过微调,是时候评估其性能了!在本节中,我们将使用 ChartQA 数据集中的示例来测试模型,以评估其根据图表图像回答问题的准确程度。让我们深入了解结果,看看它的表现如何!🚀
让我们清理 GPU 内存以确保最佳性能 🧹
>>> clear_memory()
GPU allocated memory: 16.34 GB GPU reserved memory: 18.69 GB
我们将使用与之前相同的管道重新加载基础模型。
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)
我们将训练好的适配器附加到预训练模型。此适配器包含训练期间进行的微调调整,使基础模型能够利用新知识,同时保持其核心参数不变。通过集成适配器,我们增强了模型的功能,而无需更改其原始结构。
adapter_path = "sergiopaniego/smolvlm-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)
让我们在一个未见过的样本上评估模型。
test_dataset[20][:2]
>>> test_dataset[20][1]["content"][0]["image"]
output = generate_text_from_sample(model, processor, test_dataset[20])
output
模型已成功学习按照数据集中指定的查询进行响应。我们实现了目标!🎉✨
💻 我开发了一个示例应用程序来测试模型,您可以在此处找到。您可以轻松地将其与另一个具有预训练模型的 Space 进行比较,该 Space 在此处可用。
from IPython.display import IFrame
IFrame(src="https://sergiopaniego-smolvlm-trl-sft-chartqa.hf.space", width=1000, height=800)
6. 继续学习之旅 🧑🎓️
为了进一步提高您在多模态模型方面的技能,我建议查看本笔记本开头分享的资源,或重新访问 使用 Hugging Face 生态系统 (TRL) 微调视觉语言模型 (Qwen2-VL-7B) 中同名的章节。
这些资源将帮助您加深在多模态学习方面的知识和专业知识。
< > 在 GitHub 上更新