TRL 文档

使用 SFT 微调多模态模型(单图或多图数据集)

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

使用 SFT 微调多模态模型(单图或多图数据集)

VLM SFT training procedure

概览

本指南将引导你完成使用监督式微调 (Supervised Fine-Tuning, SFT) 来微调多模态语言模型(例如 Gemma 3)的过程。我们涵盖两种情况:

  • 单张图片 + 文本
  • 多张图片 + 文本

本指南是对现有 VLM SFT 脚本详细解读和补充。如果你已熟悉这些概念,可以直接使用该脚本。

我们使用两个数据集来演示微调过程,但这些原则同样适用于其他视觉语言模型 (Vision-Language Models, VLMs) 和数据集。

理解数据集

为了同时处理单张图片 + 文本多张图片 + 文本这两种场景,我们使用了两个非常适合此任务的数据集。

HuggingFaceH4/llava-instruct-mix-vsft 数据集(图片 + 文本)

此数据集是 LLaVA Instruct Mix 的重新格式化版本。它由对话组成,其中用户同时提供文本单张图片作为输入。

模型(被称为“助手”)会根据用户分享的视觉和文本信息进行回应。该数据集对于训练多模态模型以理解并生成基于图片和文本的响应特别有用。

FanqingM/MMIU-Benchmark 数据集(多张图片 + 文本)

FanqingM/MMIU-Benchmark 数据集包含:

  • 上下文:包含在系统提示中。
  • 问题:作为用户输入的一部分提供。
  • 一系列图片:与问题相关的多张图片。
  • 答案:模型的预期响应。

此数据集专为需要模型对多张图片进行推理,并根据视觉和文本输入生成明智回应的任务而设计。

为多模态 SFT 开发微调脚本

在本节中,我们将构建一个用于微调多模态模型的脚本,该脚本适用于单张图片 + 文本多张图片 + 文本两种用例。

设置环境

在微调之前,我们需要安装所需的依赖项。让我们从设置环境开始:

# Install the required libraries. Further details: https://huggingface.co/docs/trl/installation 
pip install -U -q trl bitsandbytes peft hf_xet tensorboard

所有依赖项安装完毕后,我们需要登录到 Hugging Face Hub。由于 Gemma 3 是一个受限模型,因此需要访问权限。

如果你尚未申请访问权限,请访问模型卡片并提交申请。

要登录,你需要从你的 Hugging Face 账户生成一个访问令牌

huggingface-cli login

加载数据

如前所述,我们将涵盖两种可能的用例。虽然具体流程可能因数据集而异,但核心原则保持一致。

本指南支持两种用例,请根据你的具体场景参考单张图片 + 文本多张图片 + 文本部分。

单张图片 + 文本

Single Image + Text

在这种情况下,批次中的每个样本都包含一张图片与文本配对。由于数据集已经格式化为监督式微调 (SFT) 格式,我们可以直接使用 load_dataset 加载它。

from datasets import load_dataset

dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"

# Load Dataset
dataset = load_dataset(dataset_name)

多张图片 + 文本(或交错)

Multi-Image + Text

Gemma 3 也支持多张图片 + 文本的场景,其中:

  • 模型接收一个图片列表以及一条用户消息。
  • 模型处理对话中交错的图片和文本

对于这个数据集,在训练前需要进行一些预处理。

from datasets import load_dataset

dataset_name = "FanqingM/MMIU-Benchmark"

# Load Dataset
dataset = load_dataset(dataset_name)

加载数据集后,我们需要将其预处理并格式化为对话结构。以下是数据可能的样子示例:

{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},

这里,images_list 是一个图片列表。

images_list = [
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
]

这种结构可以像这样转换成代码:

import os
import zipfile
import io
from datasets import DatasetDict
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image

dataset_train_split = "test"

def format_data(samples: dict[str, any]) -> dict[str, list]:
    formatted_samples = {"messages": []}
    for cont in range(len(samples["question"])):
        images = []
        for img_path in samples["input_image_path"][cont]:
            try:
                with open(img_path, "rb") as f:
                    img_bytes = f.read()
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                images.append({"type": "image", "image": image})
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")
                continue

        formatted_samples["messages"].append(
            [
                {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
                {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
                {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
            ]
        )
    return formatted_samples

# For multi-image example
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
    all_files = list_repo_files(dataset_name, repo_type="dataset")
    zip_files = [f for f in all_files if f.endswith(".zip")]

    for zip_filename in zip_files:
        zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
        extract_folder = zip_filename.replace(".zip", "")
        os.makedirs(extract_folder, exist_ok=True)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(extract_folder)

    dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
    return dataset

dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)

至此,你的多张图片 + 文本数据集已经准备好用于训练了。

准备训练

我们首先加载模型和处理器。在本例中,我们使用 google/gemma-3-4b-it,但同样的过程也适用于其其他变体和类似模型。

为了优化内存使用,我们配置 BitsAndBytes 来加载模型的量化版本。

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig

model_id = "google/gemma-3-4b-it"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
    attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"

接下来,我们设置量化低秩适配 (Quantized Low-Rank Adaptation, QLoRA),这是一种针对大型语言模型 (LLMs) 和视觉语言模型 (VLMs) 的高效微调技术。

from peft import LoraConfig, get_peft_model

# Configure QLoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

QLoRA 设置完成后,我们需要为 SFT 定义训练参数。SFTConfig 类简化了这一过程,提供了一种根据我们的具体需求轻松调整参数的方法。

from trl import SFTConfig

training_args = SFTConfig(
    output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft",     # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
    num_train_epochs=1,                                             # Set the number of epochs to train the model.
    per_device_train_batch_size=8,                                  # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
    gradient_accumulation_steps=4,                                  # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
    gradient_checkpointing=True,                                    # Enable gradient checkpointing to reduce memory usage during training.
    optim="adamw_torch_fused",                                      # Use the fused AdamW optimizer for better performance.
    save_strategy="epoch",                                          # Save checkpoints at the end of each epoch.
    learning_rate=2e-05,                                            # Learning rate for training.
    bf16=True,                                                      # Enable bfloat16 precision for training to save memory and speed up computations.
    push_to_hub=True,                                               # Automatically push the fine-tuned model to Hugging Face Hub after training.
    report_to="tensorboard",                                        # Automatically report metrics to tensorboard.
    gradient_checkpointing_kwargs={"use_reentrant": False},         # Set gradient checkpointing to non-reentrant to avoid issues.
    dataset_kwargs={"skip_prepare_dataset": True},                  # Skip dataset preparation to handle preprocessing manually.
    remove_unused_columns=False,                                    # Ensure unused columns are not removed in the collator (important for batch processing).
)

collate_fn 负责处理和准备单个样本以形成一个批次。

批次中的每个样本都会经历以下步骤:

  1. 聊天模板应用于文本。
  2. 处理器textsimages 进行分词,将它们编码成张量。
  3. 用于训练的标签被设置为样本的 input_ids
  4. 在损失计算过程中,某些特殊标记掩码(忽略)
    • pad_token_id
    • <image_token_id>
    • <image_soft_token>(对应 ID 262144

这个过程在不同类型的数据集中是相似的,只是在处理图片的方式上略有不同:

  • 单张图片 + 文本 → 一个图片列表被直接处理。
  • 多张图片 + 文本 → 使用一个由图片列表组成的列表,其中每个批次元素包含多张图片。
from PIL import Image

# For multi-image cases
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                if image is not None:
                    image = Image.open(io.BytesIO(image["bytes"]))
                    image_inputs.append(image.convert("RGB"))
    return image_inputs

def collate_fn(examples):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
    if "images" in examples[0]:  # single-image
        images = [
            [img.convert("RGB") for img in example["images"]]
            for example in examples
        ]
    else:  # multi-image
        images = [process_vision_info(example["messages"]) for example in examples]

    # Tokenize the texts and process the images
    batch = processor(
        images=images, text=texts, return_tensors="pt", padding=True
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch  # Return the prepared batch

训练模型

所有组件都设置好后,我们现在使用先前定义的设置来配置 SFTTrainer,并开始训练过程。

# Training
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
    processing_class=processor,
    peft_config=peft_config,
)

trainer.train()

# Save the final model
trainer.save_model()

我们将微调后的模型保存到 Hub,使其易于将来使用。此外,TRL 会根据所选配置,自动将训练结果记录到 Weights & Biases (Wandb)TensorBoard

结果

在训练期间和之后,我们可以使用 Weights & Biases (Wandb)TensorBoard 来检查结果。例如:

局限性

目前,微调 Gemma 存在一些已知的局限性。我们建议遵循本指南中概述的步骤以确保最佳结果。

参考文献

如需进一步阅读和补充资源,请查看以下内容:

< > 在 GitHub 上更新