TRL 文档
多适配器强化学习 (MARL) - 适用于一切的单一基础模型
并获得增强的文档体验
开始使用
多适配器强化学习 (MARL) - 适用于一切的单一基础模型
我们在此介绍一种方法,该方法为整个 PPO 算法使用单一基础模型 - 包括检索参考 logits、计算活动 logits 和奖励。此功能是实验性的,因为我们尚未测试该方法的收敛性。我们鼓励社区告知我们他们可能遇到的问题。
要求
您只需安装 peft
,如果想要使用 8 位基础模型以获得更节省内存的微调,还可以选择安装 bitsandbytes
。
总结
您需要分三个阶段处理此方法,我们总结如下:
1- 在目标领域(例如 IMDB 数据集)上训练基础模型 - 这是监督式微调阶段 - 可以利用 TRL 中的 SFTTrainer
。 2- 使用 peft
训练奖励模型。 这是为了在 RL 优化过程(下面的步骤 3)中重复使用适配器所必需的。 我们在这个示例中展示了如何利用 TRL 中的 RewardTrainer
: 这个示例 3- 使用 PPO 和奖励适配器在基础模型上微调新的适配器。(“0 抽象 RL”)
确保在阶段 2 和 3 中使用相同的模型(即相同的架构和相同的权重)。
快速入门
假设您已使用 RewardTrainer
在 llama-7b
模型上训练了奖励适配器,并将权重推送到了 hub 上的 trl-lib/llama-7b-hh-rm-adapter
下。 在执行 PPO 时,在将模型传递给 PPOTrainer
之前,请按如下方式创建您的模型:
model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
)
...
trainer = PPOTrainer(
model=model,
...
)
...
然后在您的 PPO 训练循环中,通过访问 PPOTrainer
中的 model
属性来调用 compute_reward_score
方法。
rewards = trainer.model.compute_reward_score(**inputs)
高级用法
控制适配器名称
如果您熟悉 peft
库,您就会知道可以在同一模型内部使用多个适配器。 您可以做的是在同一基础模型上训练多个适配器,以针对不同的策略进行微调。 在这种情况下,您希望能够在检索奖励后控制要激活的适配器名称。 为此,只需在调用 compute_reward_score
时将适当的 adapter_name
传递给 ppo_adapter_name
参数。
adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...
使用 4 位和 8 位基础模型
为了更节省内存的微调,您可以将基础模型加载为 8 位或 4 位,同时保持适配器为默认精度 (float32)。 只需将适当的参数(即 load_in_8bit=True
或 load_in_4bit=True
)传递给 AutoModelForCausalLMWithValueHead.from_pretrained
,如下所示(假设您已安装 bitsandbytes
):
model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
load_in_8bit=True,
)
...
trainer = PPOTrainer(
model=model,
...
)
...