开源 AI 食谱文档
使用 TRL 微调视觉语言模型 (VLM) 以进行目标检测定位
并获得增强的文档体验
开始使用
使用 TRL 对 VLM 进行微调,以实现目标检测基础
作者: Sergio Paniego
🚨 警告:本笔记本资源密集,需要大量的计算能力。如果在 Colab 中运行,它将使用 A100 GPU。
🔍 您将学到什么
在本教程中,我们将演示如何使用 TRL 对 视觉语言模型 (VLM) 进行微调,以实现目标检测基础。
传统上,目标检测涉及在图像中识别预定义的一组类别(例如,“汽车”、“人物”、“狗”)。然而,随着 Grounding DINO、GLIP 或 OWL-ViT 等模型的出现,这一范式发生了转变,引入了开放式目标检测——使模型能够检测自然语言中描述的任何类别。
基础化更进一步,增加了上下文理解。除了仅仅检测“汽车”之外,基础化检测可以定位“左侧的汽车”或“树后的红色汽车”。这为目标检测提供了一种更细致、更强大的方法。
在本教程中,我们将逐步讲解如何为此任务微调 VLM。具体来说,我们将使用 Google 开发的 PaliGemma 2,这是一种开箱即支持目标检测的视觉语言模型。虽然并非所有 VLM 都默认提供检测功能,但本笔记本中的概念和步骤也可以适用于没有内置目标检测功能的模型。
为了训练我们的模型,我们将使用 RefCOCO,它是流行 COCO 数据集的扩展,专门为引用表达式理解而设计——即通过自然语言将目标检测与基础化相结合。
本教程还基于我最近发布的 这个 Space,它允许您在目标检测、关键点检测等目标理解任务上比较不同的 VLM。
📚 附加资源
在本笔记本的末尾,如果您有兴趣进一步探索该主题,您会找到更多资源。
1. 安装依赖项
让我们首先安装所需的依赖项
!pip install -Uq transformers datasets trl supervision albumentations
我们将登录我们的 Hugging Face 帐户,以访问受限模型并保存我们训练过的检查点。
您将需要一个访问 令牌 🗝️。
from huggingface_hub import notebook_login
notebook_login()
2. 📁 加载数据集
在此示例中,我们将使用 RefCOCO,这是一个包含基础目标检测注释的数据集——可实现更稳健、更具上下文意识的检测。
为了简单高效,我们将使用数据集的一个子集。
from datasets import load_dataset
refcoco_dataset = load_dataset("jxu124/refcoco", split="train[:5%]")
加载后,我们来看看里面有什么
refcoco_dataset
我们可以看到数据集包含有用的信息,例如 `bbox` 和 `captions` 列。在这种情况下,bboxes 遵循 `xyxy` 格式。
但是,图像本身无法直接从这些字段访问。有关图像源的更多详细信息,我们可以检查 `raw_image_info` 列。
refcoco_dataset[13]["raw_image_info"]
2.1 🖼️ 向数据集添加图像
虽然我们可以将每个示例链接到 COCO 数据集 中对应的图像,但我们将通过直接从 Flickr 下载图像来简化该过程。
但是,这种方法可能会导致某些图像缺失,因此我们需要相应地处理这些情况。
import json
import requests
from PIL import Image
from io import BytesIO
def add_image(example):
try:
raw_info = json.loads(example["raw_image_info"])
url = raw_info.get("flickr_url", None)
if url:
response = requests.get(url, timeout=10)
image = Image.open(BytesIO(response.content)).convert("RGB")
example["image"] = image
else:
example["image"] = None
except Exception as e:
print(f"Error loading image: {e}")
example["image"] = None
return example
refcoco_dataset_with_images = refcoco_dataset.map(add_image, desc="Adding image from flickr", num_proc=16)
太棒了!我们的图片现在已下载并准备就绪。
refcoco_dataset_with_images
接下来,让我们过滤数据集,只包含带有相关图像的样本。
filtered_dataset = refcoco_dataset_with_images.filter(
lambda example: example["image"] is not None, desc="Removing failed image downloads"
)
2.2 删除不需要的列
filtered_dataset
数据集包含许多我们在此任务中不需要的列。
让我们通过只保留 `bbox`、`captions` 和 `image` 列来简化它。
filtered_dataset = filtered_dataset.remove_columns(
[
"sent_ids",
"file_name",
"ann_id",
"ref_id",
"image_id",
"split",
"sentences",
"category_id",
"raw_anns",
"raw_image_info",
"raw_sentences",
"image_path",
"global_image_id",
"anns_id",
]
)
现在看起来好多了!
filtered_dataset
2.3 将标题分离成独立的样本
最后一步:每个样本目前都有多个标题。为了简化数据集,我们将它们分开,使每个标题成为一个独立的样本。
def separate_captions_into_unique_samples(batch):
new_images = []
new_bboxes = []
new_captions = []
for image, bbox, captions in zip(batch["image"], batch["bbox"], batch["captions"]):
for caption in captions:
new_images.append(image)
new_bboxes.append(bbox)
new_captions.append(caption)
return {
"image": new_images,
"bbox": new_bboxes,
"caption": new_captions,
}
filtered_dataset = filtered_dataset.map(
separate_captions_into_unique_samples,
batched=True,
batch_size=100,
num_proc=4,
remove_columns=filtered_dataset.column_names,
)
现在一切都准备好了,我们来看一个例子吧!
filtered_dataset[20]["caption"]
filtered_dataset[20]["bbox"]
>>> filtered_dataset[20]["image"]
2.4 显示带有边界框的样本
我们的数据集准备工作已经完成。现在,让我们在一个样本的图像上可视化边界框。
为此,我们将创建一个辅助函数,可以在整个食谱中重复使用。
我们将使用 supervision 库来协助显示边界框。
labels = [(filtered_dataset[20]["caption"], filtered_dataset[20]["bbox"])]
>>> import supervision as sv
>>> import numpy as np
>>> def get_annotated_image(image, parsed_labels):
... if not parsed_labels:
... return image
... xyxys = []
... labels = []
... for label, bbox in parsed_labels:
... xyxys.append(bbox)
... labels.append(label)
... detections = sv.Detections(xyxy=np.array(xyxys))
... bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
... label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
... annotated_image = bounding_box_annotator.annotate(scene=image, detections=detections)
... annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
... return annotated_image
>>> annotated_image = get_annotated_image(filtered_dataset[20]["image"], labels)
>>> annotated_image
太棒了!我们现在可以看到与每个边界框关联的基础标题。
2.5 划分数据集
我们的数据集已准备就绪,但在继续之前,让我们将其划分为训练集和验证集,以便进行适当的模型评估。
split_dataset = filtered_dataset.train_test_split(test_size=0.2, seed=42, shuffle=False)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
train_dataset, val_dataset
3. 使用数据集检查预训练模型
如前所述,我们将使用 PaliGemma 2 作为我们的模型,因为它已经包含目标检测功能,这简化了我们的工作流程。
如果我们要使用没有内置目标检测功能的视觉语言模型 (VLM),我们可能需要先对其进行训练才能获得这些功能。
有关此内容的更多信息,请查看我们关于“微调 Gemma 3 以进行目标检测”的项目,其中详细介绍了此训练过程。
现在,让我们加载模型和处理器。我们将使用预训练模型 google/paligemma2-3b-pt-448,它未经对话任务微调。
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import torch
model_id = "google/paligemma2-3b-pt-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)
3.1 在一个样本上进行推理
让我们评估模型在单个图像和标题上的当前性能。
image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]
由于我们的模型不是指令模型,因此输入应按如下方式格式化
<image>detect [CAPTION]
这里,`<image>` 表示图像 token,后跟关键词 `detect` 以指定目标检测任务,然后是描述要检测内容的标题。
如我们接下来将看到的,此格式将产生特定输出。
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
... generation = generation[0][input_len:]
... output = processor.decode(generation, skip_special_tokens=True)
... print(output)
middle vase ; middle vase ; middle vase
我们可以看到模型以 `<locXXXX>...` 这样的特殊格式生成位置标记,后面跟着检测到的类别。每个检测都由 `;` 分隔。
这些位置标记遵循 PaliGemma 格式,该格式特定于模型并相对于输入大小——在这种情况下,如模型名称所示,为 `448x448`。
为了正确显示检测结果,我们需要将这些标记转换回可用格式。让我们创建一个辅助函数来处理这种转换。
import re
# https://github.com/ariG23498/gemma3-object-detection/blob/main/utils.py#L17 thanks to Aritra Roy Gosthipaty
def parse_paligemma_labels(label, width, height):
predictions = label.strip().split(";")
results = []
for pred in predictions:
pred = pred.strip()
if not pred:
continue
loc_pattern = r"<loc(\d{4})>"
locations = [int(loc) for loc in re.findall(loc_pattern, pred)]
if len(locations) != 4:
continue
category = pred.split(">")[-1].strip()
y1_norm, x1_norm, y2_norm, x2_norm = locations
x1 = (x1_norm / 1024) * width
y1 = (y1_norm / 1024) * height
x2 = (x2_norm / 1024) * width
y2 = (y2_norm / 1024) * height
results.append((category, [x1, y1, x2, y2]))
return results
现在,我们可以使用此函数将 PaliGemma 标签解析为常见的 COCO 格式。
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
接下来,我们可以使用之前的函数来检索图像。
让我们将其与解析后的边界框一起显示!
annotated_image = get_annotated_image(image, parsed_labels)
>>> annotated_image
我们可以看到,该模型在目标检测方面表现良好,但在基础化方面有些挣扎。
例如,它将所有三个花瓶标记为“中间的花瓶”,而不是只有一个。
我们来努力改进这一点!🙂
4. 使用 LoRA 和 TRL 微调模型
为了微调视觉语言模型 (VLM),我们将利用 LoRA 和 TRL。
让我们从配置 LoRA 开始
>>> from peft import LoraConfig, get_peft_model
>>> target_modules = ["q_proj", "v_proj", "fc1", "fc2", "linear", "gate_proj", "up_proj", "down_proj"]
>>> # Configure LoRA
>>> peft_config = LoraConfig(
... lora_alpha=16,
... lora_dropout=0.05,
... r=8,
... bias="none",
... target_modules=target_modules,
... task_type="CAUSAL_LM",
... )
>>> # Apply PEFT model adaptation
>>> peft_model = get_peft_model(model, peft_config)
>>> # Print trainable parameters
>>> peft_model.print_trainable_parameters()
trainable params: 12,165,888 || all params: 3,045,293,040 || trainable%: 0.3995
接下来,让我们配置 TRL 中的 SFT 训练 管道。
这个管道通过抽象大部分底层复杂性并为我们管理它来简化训练过程。
from trl import SFTConfig
training_args = SFTConfig(
output_dir="paligemma2-3b-pt-448-od-grounding",
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
gradient_checkpointing=False,
learning_rate=1e-05,
num_train_epochs=2,
logging_steps=10,
eval_steps=100,
eval_strategy="steps",
save_steps=10,
bf16=True,
report_to=["tensorboard"],
dataset_kwargs={"skip_prepare_dataset": True},
remove_unused_columns=False,
push_to_hub=True,
dataloader_pin_memory=False,
label_names=["labels"],
)
我们差不多准备好了!
接下来,我们将定义几个辅助函数来处理整理器中的目标检测。
这些功能直观且不言自明。
def coco_to_xyxy(coco_bbox):
x, y, width, height = coco_bbox
x1, y1 = x, y
x2, y2 = x + width, y + height
return [x1, y1, x2, y2]
def convert_to_detection_string(bboxs, image_width, image_height, category):
def format_location(value, max_value):
return f"<loc{int(round(value * 1024 / max_value)):04}>"
detection_strings = []
for bbox in bboxs:
x1, y1, x2, y2 = coco_to_xyxy(bbox)
locs = [
format_location(y1, image_height),
format_location(x1, image_width),
format_location(y2, image_height),
format_location(x2, image_width),
]
detection_string = "".join(locs) + f" {category}"
detection_strings.append(detection_string)
return " ; ".join(detection_strings)
def format_objects(example):
height = example["height"]
width = example["width"]
bboxs = example["bbox"]
category = example["caption"][0]
formatted_objects = convert_to_detection_string(bboxs, width, height, category)
return {"label_for_paligemma": formatted_objects}
由于我们正在微调 VLM,我们还可以引入数据增强。
在我们的案例中,我们将处理图像大小调整——这是强制性的,以确保一致的输入大小——因为模型期望的图像大小为 `448x448`。
为了参考,我们已经包含了几个可能的增强(注释掉)。
import albumentations as A
resize_size = 448
augmentations = A.Compose(
[
A.Resize(height=resize_size, width=resize_size),
# A.HorizontalFlip(p=0.5),
# A.ColorJitter(p=0.2),
],
bbox_params=A.BboxParams(format="coco", label_fields=["category_ids"], filter_invalid_bboxes=True),
)
现在,让我们创建 collate 函数,它为 VLM 输入准备批次。
在这一步中,我们需要仔细处理数据增强过程,以确保一致性和正确性。
from functools import partial
# Create a data collator to encode text and image pairs
def collate_fn(examples, transform=None):
images = []
prompts = []
suffixes = []
for sample in examples:
if transform:
transformed = transform(
image=np.array(sample["image"]), bboxes=[sample["bbox"]], category_ids=[sample["caption"]]
)
sample["image"] = transformed["image"]
sample["bbox"] = transformed["bboxes"]
sample["caption"] = transformed["category_ids"]
sample["height"] = sample["image"].shape[0]
sample["width"] = sample["image"].shape[1]
sample["label_for_paligemma"] = format_objects(sample)["label_for_paligemma"]
images.append([sample["image"]])
prompts.append(f"<image>Detect {sample['caption']}.")
suffixes.append(sample["label_for_paligemma"])
batch = processor(images=images, text=prompts, suffix=suffixes, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
batch["pixel_values"] = batch["pixel_values"].to(model.device)
return batch
train_collate_fn = partial(collate_fn, transform=augmentations)
最后,我们可以实例化 `SFTTrainer` 并开始训练我们的模型!
from trl import SFTTrainer
trainer = SFTTrainer(
model=peft_model,
args=training_args,
data_collator=train_collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
让我们将训练好的模型和结果保存到 Hugging Face Hub。
processor.save_pretrained(training_args.output_dir) trainer.save_model(training_args.output_dir) trainer.push_to_hub()
5. 测试微调模型
我们已经对模型进行了微调,以实现基础目标检测。最后一步,让我们在测试集中的一个样本上测试其功能。
模型:sergiopaniego/paligemma2-3b-pt-448-od-grounding
让我们使用微调后的检查点实例化我们的模型
trained_model_id = "sergiopaniego/paligemma2-3b-pt-448-od-grounding"
model_id = "google/paligemma2-3b-pt-448"
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from peft import PeftModel
import torch
base_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
trained_model = PeftModel.from_pretrained(base_model, trained_model_id).eval()
trained_processor = PaliGemmaProcessor.from_pretrained(model_id, use_fast=True)
5.1 在训练样本上测试
让我们从测试一张训练图片开始。
这能让我们初步了解训练情况,但请记住,由于模型在训练期间已经见过此样本,因此这可能会有点误导。
对于这个测试,我们将使用我们前面介绍的示例来检查模型现在是否可以正确执行推理。
image = train_dataset[20]["image"]
caption = train_dataset[20]["caption"]
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
... trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
... generation = generation[0][input_len:]
... output = trained_processor.decode(generation, skip_special_tokens=True)
... print(output)
middle vase
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)
让我们看看微调是否成功… 🥁
>>> annotated_image
太好了!模型现在能够正确识别“中间的花瓶”。
5.3 对比验证样本进行测试
最后,让我们在验证样本上评估模型的能力,以正确评估它是否已学习基础化和目标检测。
image = val_dataset[13]["image"]
caption = val_dataset[13]["caption"]
caption
>>> prompt = f"<image>detect {caption}"
>>> model_inputs = (
... trained_processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(trained_model.device)
... )
>>> input_len = model_inputs["input_ids"].shape[-1]
>>> with torch.inference_mode():
... generation = trained_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
... generation = generation[0][input_len:]
... output = trained_processor.decode(generation, skip_special_tokens=True)
... print(output)
darker bear
width, height = image.size parsed_labels = parse_paligemma_labels(output, width, height) parsed_labels
annotated_image = get_annotated_image(image, parsed_labels)
我们来看看… 🥁
>>> annotated_image
成功了!我们的模型能够正确识别图像中的“深色熊”,并避免为每只熊生成多个检测。
请记住,我们的训练是轻量级的——只使用了数据集的一个子集——并且训练配置可以进一步优化。我们将这些改进留给您探索!
6. 继续学习之旅 🧑🎓️
为了进一步加深您的理解和技能,请查看以下宝贵资源
- 微调 Grounding DINO — LearnOpenCV
- RefCOCO 数据集 — Papers with Code
- 微调 PaliGemma — GitHub
- 微调 Gemma 3 进行目标检测 — GitHub
- VLM 对象理解 — Hugging Face Space
- GPT-4o 对视觉的理解程度如何?评估多模态基础模型在标准计算机视觉任务上的表现
- 视觉语言模型 (更好、更快、更强) 博客
- 查看 HF 开源 AI 食谱中的其他多模态食谱
欢迎探索这些资源,加深您的知识并不断突破界限!
< > 在 GitHub 上更新