TRL 文档

多适配器强化学习 (MARL) - 一个适用于所有场景的单一基础模型

Hugging Face's logo
加入 Hugging Face 社区

并访问增强型文档体验

入门

多适配器强化学习 (MARL) - 一个适用于所有场景的单一基础模型

我们在此介绍一种方法,该方法使用单一基础模型来执行整个 PPO 算法 - 包括检索参考 logits、计算活动 logits 和奖励。此功能处于实验阶段,因为我们没有测试该方法的收敛性。我们鼓励社区告知我们是否遇到任何问题。

要求

您只需要安装 peft,如果您想使用 8 位基础模型进行更节省内存的微调,还可以选择安装 bitsandbytes

总结

您需要分三个阶段解决此方法,我们将其总结如下

1- 在目标领域(例如 IMDB 数据集)上训练一个基础模型 - 这是监督微调阶段 - 它可以利用 TRL 中的 SFTTrainer。 2- 使用 peft 训练奖励模型。这是为了在 RL 优化过程中(如下面的步骤 3)重新使用适配器。我们在 此示例 中展示了利用 TRL 中的 RewardTrainer 的示例。 3- 使用 PPO 和奖励适配器对基础模型上的新适配器进行微调。(“0 抽象 RL”)

确保在阶段 2 和 3 中使用相同的模型(即相同的架构和相同的权重)。

快速入门

假设您已经使用RewardTrainerllama-7b模型上训练了您的奖励适配器,并将权重推送到trl-lib/llama-7b-hh-rm-adapter下的中心。在进行PPO时,在将模型传递给PPOTrainer之前,请按如下方式创建您的模型

model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)

...

然后在您的PPO训练循环中,通过访问PPOTrainermodel属性来调用compute_reward_score方法。

rewards = trainer.model.compute_reward_score(**inputs)

高级用法

控制适配器名称

如果您熟悉peft库,您知道您可以在同一个模型中使用多个适配器。您可以做的是在同一个基础模型上训练多个适配器,以便对不同的策略进行微调。在这种情况下,您希望能够控制您想要激活的适配器名称,以便在检索奖励后将其激活回来。为此,只需在调用compute_reward_score时,将相应的adapter_name传递给ppo_adapter_name参数即可。

adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...

使用4位和8位基础模型

为了更有效地进行内存微调,您可以在将适配器保持在默认精度(float32)的同时,将您的基础模型加载到8位或4位。只需将相应的参数(即load_in_8bit=Trueload_in_4bit=True)传递给AutoModelForCausalLMWithValueHead.from_pretrained,如下所示(假设您已安装bitsandbytes

model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
    load_in_8bit=True,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)
...
< > 在GitHub上更新