TRL 文档

多适配器强化学习 (MARL) - 一个基础模型搞定一切

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

多适配器强化学习 (MARL) - 一个基础模型搞定一切

这里我们提出一种方法,它使用单个基础模型来完成整个 PPO 算法——包括检索参考 logits、计算活动 logits 以及计算奖励。此功能尚处于实验阶段,因为我们尚未测试该方法的收敛性。我们鼓励社区成员在遇到潜在问题时告知我们。

环境要求

您只需安装 peft,如果想使用 8 位基础模型以实现更高效的内存微调,还可以选择安装 bitsandbytes

概要

您需要分三个阶段来实施此方法,我们总结如下:

1- 在目标领域(例如 IMDB 数据集)上训练一个基础模型——这是监督式微调(SFT)阶段——可以利用 TRL 中的 SFTTrainer。 2- 使用 peft 训练一个奖励模型。这是为了在强化学习优化过程(下面的步骤 3)中复用适配器所必需的。我们在这个例子中展示了如何利用 TRL 中的 RewardTrainer。 3- 使用 PPO 和奖励适配器在基础模型上微调新的适配器。(“零抽象强化学习”)

请确保在第 2 阶段和第 3 阶段使用相同的模型(即相同的架构和相同的权重)。

快速入门

假设您已经使用 RewardTrainerllama-7b 模型上训练了奖励适配器,并将其权重推送到了 Hub 上的 trl-lib/llama-7b-hh-rm-adapter。在进行 PPO 训练时,在将模型传递给 PPOTrainer 之前,请按如下方式创建您的模型:

model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)

...

然后在您的 PPO 训练循环中,通过访问 PPOTrainermodel 属性来调用 compute_reward_score 方法。

rewards = trainer.model.compute_reward_score(**inputs)

高级用法

控制适配器名称

如果您熟悉 peft 库,您会知道可以在同一个模型中使用多个适配器。您可以做的是,在同一个基础模型上训练多个适配器,以针对不同的策略进行微调。在这种情况下,您希望能够在检索到奖励后,控制要重新激活的适配器名称。为此,在调用 compute_reward_score 时,只需将相应的 adapter_name 传递给 ppo_adapter_name 参数即可。

adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...

使用 4 位和 8 位基础模型

为了实现更高效的内存微调,您可以将基础模型加载为 8 位或 4 位,同时保持适配器为默认精度(float32)。只需将适当的参数(即 load_in_8bit=Trueload_in_4bit=True)传递给 AutoModelForCausalLMWithValueHead.from_pretrained 即可,如下所示(假设您已安装 bitsandbytes):

model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
    load_in_8bit=True,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)
...
< > 在 GitHub 上更新