TRL 中的视觉语言模型对齐 ⚡️
引言
视觉语言模型 (VLM) 越来越强大,但将其与人类偏好进行*对齐*仍然很重要。在 TRL 中,我们已经展示了如何通过有监督微调 (SFT) 和直接偏好优化 (DPO) 对 VLM 进行后期训练。这一次,我们更进一步。
tl;dr 我们在 TRL 中增加了两种新的多模态对齐方法:组相对策略优化 (GRPO)、其变体组序列策略优化 (GSPO) 和混合偏好优化 (MPO)。所有这些方法都允许您超越成对 DPO,从偏好数据中提取更多信号,并更好地与现代 VLM 配合使用。我们发布了训练脚本和演示笔记本,以便轻松上手!
目录
视觉语言模型对齐
传统上,你会使用一个基础模型,应用 SFT 来遵循指令,然后应用 DPO 将其与偏好数据对齐。此前,我们已将此方法应用于视觉语言模型 (VLM) 并在 IDEFICS2 上进行了验证,显示模型响应有所改进。
DPO 通过使用对比损失优化模型响应对之间的偏好来工作:您有一个已选择和已拒绝的答案,并根据您想要和不想要的内容优化您的偏好。
但在过去一年中,新的多模态对齐方法 GRPO 和 MPO 越来越受欢迎,它们可以进一步提升 VLM 性能。在博客文章末尾,您可以找到一个表格,展示模型响应之间的差异。
混合偏好优化 (MPO)
使用 SFT 对齐多模态模型以执行推理任务会因分布偏移而不足。同时,使用 DPO 对齐的模型无法生成连贯的推理,并且可能会生成重复的响应。为了解决这个问题,有一种专门为多模态模型设计的名为混合偏好优化 (MPO) 的新技术。该方法本质上是 DPO 的扩展,具有多个损失:来自 DPO 的偏好损失(Sigmoid)、来自二分类器优化 (BCO) 的质量损失和来自 SFT 的生成损失。根据论文,仅仅切换到这种组合损失就能在 MathVista 中将性能提高 6.2 分!
由于这只修改了损失,我们为 TRL 的 DPOTrainer
类添加了组合损失支持。要使用它,您可以按如下方式初始化 DPOConfig
。
mpo_config = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
learning_rate=9e-1,
loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine, as used in the MPO paper
loss_weights=[0.8, 0.2, 1.0], # Corresponding weights, as used in the MPO paper
report_to="none",
bf16=False,
fp16=False,
use_cpu=True,
max_steps=1,
)
然后初始化 DPOTrainer
mpo_trainer = DPOTrainer(
model=model_id,
args=mpo_config,
processing_class=tokenizer,
train_dataset=dataset,
)
mpo_trainer.train()
就是这样!如果您想进一步探索,可以在此处找到一个完整的笔记本示例。
多模态组相对策略优化 (GRPO)
组相对策略优化 (GRPO) 是一种尖端对齐方法,最初在DeepSeek Math 论文中引入,后来集成到开创性的 LLM DeepSeek R1 中。它是 PPO 的一个补充,其中策略更新在组(表示对话如何展开的轨迹批次)上完成。此功能使其对奖励噪声更加鲁棒,因为噪声在组内平均。由于模型学习的是对良好响应的更广泛理解,而不是单一的高奖励样本,因此该方法也使模型具有高性能。
在 TRL 中,我们现在为视觉语言模型引入了 GRPO 支持。我们不会提供完整的训练脚本示例,因为您可以在笔记本中找到它。相反,我们将重点突出关键组件和概念。
为了使训练脚本有效工作,我们需要验证答案的格式是否正确以及解决方案本身是否接近已完成的部分,因此我们编写了两个奖励函数。为了真正看到后一个奖励的改进,您需要一个相当最大化的设置,即您拥有相对较大的模型、大量的生成以及高质量、多样化的数据集。
import re
from math_verify import LatexExtractionConfig, parse, verify
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
matches = [re.match(pattern, content) for content in completions]
rewards_list = [1.0 if match else 0.0 for match in matches]
rewards = [1.0 if match else 0.0 for match in matches]
print(completions)
print(rewards)
return rewards
def accuracy_reward(completions, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
solutions = kwargs['solution']
completion_contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, solution in zip(completion_contents, solutions):
gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
try:
rewards.append(float(verify(answer_parsed, gold_parsed)))
except Exception:
rewards.append(0.0)
else:
rewards.append(1.0)
return rewards
然后,您可以初始化 GRPOConfig 和 GRPOTrainer,传入我们上面定义的奖励函数,并调用 train() 开始训练。
from trl import GRPOConfig
training_args = GRPOConfig(
learning_rate=1e-5,
remove_unused_columns=False,
max_prompt_length=None,
.. # setup other params of choice here
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[format_reward, accuracy_reward],
args=training_args,
train_dataset=train_dataset,
processing_class=processor
)
trainer.train()
在此处探索完整的笔记本示例。
组序列策略优化 (GSPO)
组序列策略优化 (GSPO) 是 Qwen 最近发布的一种 RL 对齐算法,它克服了 GRPO 的一些局限性。它通过在序列级别而不是每标记计算重要性采样权重来实现更稳定的训练。它的优点在 MoE 风格模型中更具相关性。
最新版本的 TRL 也引入了对 GSPO 的支持,由于它是 GRPO 损失的一种变体,因此它支持多模态。创建训练器的方法与 GRPO 相同,但需要添加以下额外参数(值从论文中提取)。
from trl import GRPOConfig
training_args = GRPOConfig(
...
importance_sampling_level="sequence",
epsilon=3e-4,
epsilon_high=4e-4,
beta=0.0,
loss_type="grpo",
steps_per_generation=1,
steps_per_generation=4
)
比较
下表总结了使用上述技术微调后的 Qwen2.5VL-3B 模型输出。请注意,我们仅对数据集子集进行了少量运行,并且模型在不同数据集上进行了微调,因此此比较仅供参考。
在此处查看答案
问题: 在给定的图中,如果角 APO 为 25.0 度,则角 AOB 的度数是多少?选项:A:140° B:130° C:120° D:110°
数据集中答案: 为了求角 AOB 的度数,我们可以利用四边形内角和为 360 度的性质。应用此性质,我们可以计算角 AOB 为 360 - 角 OAP - 角 OBP - 角 P。由于角 OAP 和角 OBP 均为 90 度,且角 P 已知为 50 度,我们将这些值代入方程得到角 AOB = 360 - 90 - 90 - 50 = 130 度。因此,角 AOB 的度数为 130 度。因此,答案是选项 B。答案:B
基础模型输出: 为了找到角 AOB 的度数,我们需要使用圆中同一弧所对的角属性。具体来说,弧在圆心所对的角(角 AOB)是同一弧在圆周上任意一点所对的角(角 APO)的两倍。已知角 APO = 25°,我们可以如下计算角 AOB:由于角 AOB 是角 APO 的两倍,我们有:[ 角 AOB = 2 * 角 APO ][ 角 AOB = 2 * 25° ][ 角 AOB = 50° ] 然而,根据提供的选项,此计算似乎不正确。让我们使用正确的属性重新评估。正确的属性指出,弧在圆心所对的角是同一弧在圆周上任意一点所对的角的两倍。因此,正确的计算应该是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 但由于这与给定的选项不符,让我们考虑另一种方法。如果角 APO 是 25°,并且考虑到圆内接四边形的属性或圆内接四边形对角和为 180° 的事实,我们可能需要查看弧和弦形成的补角。假设正确的解释涉及找到给定角 APO 的补角,这将给出角 AOB 的度数。角 AOB = 180° - 角 APO = 180° - 25° = 155° 这仍然与选项不符。鉴于提供的选项,似乎可能存在误解或需要不同的解释。然而,根据标准几何原理,正确答案应该是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 如果情况并非如此,请提供额外背景信息或澄清问题陈述。
MPO 输出: 为了找到角 AOB 的度数,我们需要使用圆中同一弧所对的角属性。具体来说,弧在圆心所对的角(角 AOB)是同一弧在圆周上任意一点所对的角(角 APO)的两倍。已知角 APO = 25°,我们可以如下计算角 AOB:由于角 AOB 是角 APO 的两倍,我们有:[ 角 AOB = 2 * 角 APO ][ 角 AOB = 2 * 25° ][ 角 AOB = 50° ] 然而,根据提供的选项,此计算似乎不正确。让我们使用正确的属性重新评估。正确的属性指出,弧在圆心所对的角是同一弧在圆周上任意一点所对的角的两倍。因此,正确的计算应该是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 但由于这与给定的选项不符,让我们考虑另一种方法。如果我们看三角形 OAP,角 AOP 是 90°,因为它是一个直角(因为它内接于半圆)。现在,角 APO 给定为 25°,所以角 AOA 将是 65°(因为它与角 APO 互补)。因此,角 AOB 将是其一半,即 32.5°。然而,这也不符合选项。让我们考虑问题可能涉及不同的解释或不同的几何配置的可能性。鉴于这些选项,我们假设问题陈述或选项本身可能存在错误。如果我们要从给定选项中选择,根据此类问题的标准几何原理,似乎没有一个选项是正确的。
GRPO 输出:为了求角 AOB 的度数,我们需要理解角 AOB 是一个内接角,它截取与圆心角 AOP 相同的弧。根据内接角定理,内接角的度数是截取相同弧的圆心角度数的一半。已知:- 角 APO = 25.0 度 由于角 AOB 是一个内接角,它截取与圆心角 AOP 相同的弧,我们可以使用内接角定理来求角 AOB 的度数:角 AOB 的度数是角 AOP 度数的一半。因此,角 AOB 的度数为 130 度。
GSPO 输出: 为了求角 AOB 的度数,我们需要理解角 AOB 是一个内接角,它截取与圆心角 AOP 相同的弧。根据内接角定理,内接角的度数是截取相同弧的圆心角度数的一半。已知:- 角 APO = 25.0 度 由于角 AOB 是一个内接角,它截取与圆心角 AOP 相同的弧,我们可以使用内接角定理来求角 AOB 的度数:角 AOB 的度数是角 AOP 度数的一半。因此,答案是 B:130°。
TRL 中的 vLLM 集成
vLLM 已集成到 TRL 中,以支持需要在训练期间生成样本的在线对齐方法。运行以下示例脚本即可启用 vLLM
CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct … --log_completions --use_vllm --vllm_mode colocate
主要有两种模式:colocate
和 server
。colocate
在与训练循环相同的进程中运行 vLLM,在训练和生成之间共享同一 GPU,在 GRPOTrainer
中创建一个 vLLM LLM 实例。而 server
则要求您在不同的进程中单独运行 vLLM,您可以在其中访问服务器。您可以使用以下命令启动此服务器
trl vllm-serve --model Qwen/Qwen2.5-VL-3B-Instruct --tensor-parallel-size 1
然后您可以按如下方式运行脚本。
CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct … --log_completions --use_vllm --vllm_mode server
另一个提示:我们已添加了在 TRL 中使用 transformers 后端与 vLLM 的支持。您可以在使用 colocate
运行脚本或提供模型时通过传递 --vllm_model_impl transformers
标志来启用它。
您可以在此处阅读有关 TRL 中 vLLM 集成的更多信息。
有用资源
以下是探索 VLM 对齐的详细资源汇编。祝您阅读愉快!