开源 AI 食谱文档
使用 TRL 微调 Granite Vision 3.1 2B
并获得增强的文档体验
开始使用
使用 TRL 微调 Granite Vision 3.1 2B
作者:Eli Schwartz
改编自 Sergio Paniego 的 Notebook
本食谱将帮助您微调 IBM 的 Granite Vision 3.1 2B 模型。它是一个轻量级但功能强大的模型,通过微调 Granite 语言模型 并结合图像和文本模态进行训练。我们将使用 Hugging Face 生态系统,利用强大的 Transformer 强化学习库 (TRL)。这份循序渐进的指南将帮助您为特定任务微调 Granite Vision,即使在消费级 GPU 上也能实现。
🌟 模型和数据集概览
在本 notebook 中,我们将使用 Granite Vision 模型和 Geometric Perception 数据集进行微调和评估,该数据集包含模型最初未训练的任务。Granite Vision 是一款高性能且内存高效的模型,非常适合针对新任务进行微调。Geometric Perception 数据集提供了各种几何图形的图像,这些图像编译自高中教科书,并配有问答对。
本 notebook 使用 A100 GPU 进行测试。
1. 安装依赖项
让我们开始安装微调所需的必要库!🚀
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -U -q trl datasets bitsandbytes peft accelerate
# Tested with transformers==4.49.0.dev0, trl==0.14.0, datasets==3.2.0, bitsandbytes==0.45.2, peft==0.14.0, accelerate==1.3.0
>>> !pip install -q flash-attn --no-build-isolation
>>> try:
... import flash_attn
... print("FlashAttention is installed")
... USE_FLASH_ATTENTION = True
>>> except ImportError:
... print("FlashAttention is not installed")
... USE_FLASH_ATTENTION = False
FlashAttention is not installed
2. 加载数据集 📁
我们将加载 Geometric Perception 数据集,该数据集提供各种几何图形的图像,这些图像编译自流行的高中教科书,并配有问答对。
我们将使用模型训练期间使用的原始系统提示。
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
出于教育目的,我们将仅在线段长度比较任务上进行训练和评估,该任务在数据集的 “predicate” 字段中指定。
from datasets import load_dataset
dataset_id = "euclid-multimodal/Geoperception"
dataset = load_dataset(dataset_id)
dataset_LineComparison = dataset["train"].filter(lambda x: x["predicate"] == "LineComparison")
train_test = dataset_LineComparison.train_test_split(test_size=0.5, seed=42)
让我们看一下数据集结构。它包括图像、问题、答案和 “predicate”,我们使用它来过滤数据集。
train_test
我们将数据集格式化为聊天机器人结构,每个交互都包含系统消息、图像、用户查询和答案。
💡有关使用此模型进行推理的更多技巧,请查看 模型卡片。
def format_data(sample):
return [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": sample["image"],
},
{
"type": "text",
"text": sample["question"],
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["answer"]}],
},
]
现在,让我们使用聊天机器人结构格式化数据。这将为模型设置交互。
train_dataset = [format_data(x) for x in train_test["train"]]
test_dataset = [format_data(x) for x in train_test["test"]]
train_dataset[200]
3. 加载模型并检查性能!🤔
现在我们已经加载了数据集,是时候加载 IBM 的 Granite Vision 模型 了,这是一个 20 亿参数的视觉语言模型 (VLM),它在提供最先进 (SOTA) 性能的同时,还具有高效的内存使用率。
要更广泛地比较最先进的 VLM,请浏览 WildVision Arena 和 OpenVLM Leaderboard,您可以在其中找到各种基准测试中性能最佳的模型。
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
model_id = "ibm-granite/granite-vision-3.1-2b-preview"
接下来,我们将加载模型和分词器,为推理做准备。
model = AutoModelForVision2Seq.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None,
)
processor = AutoProcessor.from_pretrained(model_id)
为了评估模型的性能,我们将使用数据集中的一个样本。首先,让我们检查一下此样本的内部结构,以了解数据的组织方式。
test_idx = 20
sample = test_dataset[test_idx]
sample
现在,让我们看一下与样本对应的图像。您可以根据视觉信息回答查询吗?
>>> sample[1]["content"][0]["image"]
让我们创建一个方法,该方法接受模型、处理器和样本作为输入,以生成模型的答案。这将使我们能够简化推理过程并轻松评估 VLM 的性能。
def generate_text_from_sample(model, processor, sample, max_new_tokens=100, device="cuda"):
# Prepare the text input by applying the chat template
text_input = processor.apply_chat_template(
sample[:2], add_generation_prompt=True # Use the sample without the assistant response
)
image_inputs = []
image = sample[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
# Prepare the inputs for the model
model_inputs = processor(
# text=[text_input],
text=text_input,
images=image_inputs,
return_tensors="pt",
).to(
device
) # Move inputs to the specified device
# Generate text with the model
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
# Trim the generated ids to remove the input ids
trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]
# Decode the output text
output_text = processor.batch_decode(
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] # Return the first decoded output text
output = generate_text_from_sample(model, processor, sample) output
看来该模型无法比较未明确指定的线段长度。为了提高其性能,我们可以使用更多相关数据微调模型,以确保模型更好地理解上下文并提供更准确的响应。
移除模型并清理 GPU
在下一节中继续训练模型之前,让我们清除当前变量并清理 GPU 以释放资源。
>>> import gc
>>> import time
>>> def clear_memory():
... # Delete variables if they exist in the current global scope
... if "inputs" in globals():
... del globals()["inputs"]
... if "model" in globals():
... del globals()["model"]
... if "processor" in globals():
... del globals()["processor"]
... if "trainer" in globals():
... del globals()["trainer"]
... if "peft_model" in globals():
... del globals()["peft_model"]
... if "bnb_config" in globals():
... del globals()["bnb_config"]
... time.sleep(2)
... # Garbage collection and clearing CUDA memory
... gc.collect()
... time.sleep(2)
... torch.cuda.empty_cache()
... torch.cuda.synchronize()
... time.sleep(2)
... gc.collect()
... time.sleep(2)
... print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
... print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
>>> clear_memory()
GPU allocated memory: 0.01 GB GPU reserved memory: 0.02 GB
4. 使用 TRL 微调模型
4.1 加载用于训练的量化模型 ⚙️
接下来,我们将使用 bitsandbytes 加载量化模型。如果您想了解有关量化的更多信息,请查看 这篇博客文章 或 这篇。
from transformers import BitsAndBytesConfig
USE_QLORA = True
USE_LORA = True
if USE_QLORA:
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_skip_modules=["vision_tower", "lm_head"], # Skip problematic modules
llm_int8_enable_fp32_cpu_offload=True,
)
else:
bnb_config = None
# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
_attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None,
)
processor = AutoProcessor.from_pretrained(model_id)
4.2 设置 QLoRA 和 SFTConfig 🚀
接下来,我们将为我们的训练设置配置 QLoRA。QLoRA 通过减少内存占用来实现大型模型的高效微调。与使用低秩近似的传统 LoRA 不同,QLoRA 进一步量化 LoRA 适配器权重,从而实现更低的内存使用率和更快的训练速度。
为了提高效率,我们还可以在 QLoRA 实施期间利用分页优化器或 8 位优化器。这种方法提高了内存效率并加快了计算速度,使其成为优化模型而不牺牲性能的理想选择。
if USE_LORA:
from peft import LoraConfig, get_peft_model
# Configure LoRA
peft_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=[name for name, _ in model.named_modules() if "language_model" in name and "_proj" in name],
use_dora=True,
init_lora_weights="gaussian",
)
# Apply PEFT model adaptation
# model = get_peft_model(model, peft_config)
model.add_adapter(peft_config)
model.enable_adapters()
model = get_peft_model(model, peft_config)
# Print trainable parameters
model.print_trainable_parameters()
else:
peft_config = None
我们将使用监督式微调 (SFT) 来提高模型在特定任务上的性能。为了实现这一点,我们将使用 TRL 库 中的 SFTConfig 类定义训练参数。SFT 利用标记数据来帮助模型生成更准确的响应,使其适应任务。这种方法增强了模型理解和更有效地响应视觉查询的能力。
from trl import SFTConfig
# Configure training arguments using SFTConfig
training_args = SFTConfig(
output_dir="./checkpoints/geoperception",
num_train_epochs=1,
# max_steps=30,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
warmup_steps=10,
learning_rate=1e-4,
weight_decay=0.01,
logging_steps=10,
save_strategy="steps",
save_steps=20,
save_total_limit=1,
optim="adamw_torch_fused",
bf16=True,
push_to_hub=False,
report_to="none",
remove_unused_columns=False,
gradient_checkpointing=True,
dataset_text_field="",
dataset_kwargs={"skip_prepare_dataset": True},
)
4.3 训练模型 🏃
为了确保在训练期间数据结构正确,我们需要定义一个 collator 函数。此函数将处理数据集输入的格式化和批处理,确保数据正确对齐以进行训练。
👉 有关更多详细信息,请查看官方 TRL 示例脚本。
def collate_fn(examples):
texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
image_inputs = []
for example in examples:
image = example[1]["content"][0]["image"]
if image.mode != "RGB":
image = image.convert("RGB")
image_inputs.append([image])
batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
assistant_tokens = processor.tokenizer("<|assistant|>", return_tensors="pt")["input_ids"][0]
eos_token = processor.tokenizer("<|end_of_text|>", return_tensors="pt")["input_ids"][0]
for i in range(batch["input_ids"].shape[0]):
apply_loss = False
for j in range(batch["input_ids"].shape[1]):
if not apply_loss:
labels[i][j] = -100
if (j >= len(assistant_tokens) + 1) and torch.all(
batch["input_ids"][i][j + 1 - len(assistant_tokens) : j + 1] == assistant_tokens
):
apply_loss = True
if batch["input_ids"][i][j] == eos_token:
apply_loss = False
batch["labels"] = labels
return batch
现在,我们将定义 SFTTrainer,它是 transformers.Trainer 类的包装器,并继承其属性和方法。当提供 PeftConfig 对象时,此类通过正确初始化 PeftModel 来简化微调过程。通过使用 SFTTrainer
,我们可以有效地管理训练工作流程,并确保视觉语言模型的微调体验顺畅。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=collate_fn,
peft_config=peft_config,
tokenizer=processor.tokenizer,
)
开始训练模型!🎉
trainer.train()
让我们保存结果 💾
trainer.save_model(training_args.output_dir)
5. 测试微调模型 🔍
现在我们的视觉语言模型 (VLM) 已经过微调,是时候评估其性能了!在本节中,我们将使用 ChartQA 数据集中的示例来测试模型,以评估其根据图表图像回答问题的准确性。让我们深入了解结果,看看它的表现如何!🚀
让我们清理 GPU 内存以确保最佳性能 🧹
>>> clear_memory()
GPU allocated memory: 0.02 GB GPU reserved memory: 0.19 GB
我们将使用与之前相同的管道重新加载基础模型。
model = AutoModelForVision2Seq.from_pretrained(
training_args.output_dir,
device_map="auto",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None,
)
processor = AutoProcessor.from_pretrained(model_id)
如果我们正在使用 LORA 适配器,我们将合并它们。
if USE_LORA:
from peft import PeftModel
model = PeftModel.from_pretrained(model, training_args.output_dir)
让我们在一个未见过的样本上评估模型。
test_idx = 20
sample = test_dataset[test_idx]
sample[1:]
>>> sample[1]["content"][0]["image"]
output = generate_text_from_sample(model, processor, sample) output