TRL 文档

情感微调示例

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

情感微调示例

这些示例中的笔记本和脚本展示了如何使用情感分类器(例如 lvwerra/distilbert-imdb)对模型进行微调。

以下是 trl 存储库 中笔记本和脚本的概述

文件 描述
examples/scripts/ppo.py 在 Colab 中打开 此脚本展示了如何使用 PPOTrainer 使用 IMDB 数据集对情感分析模型进行微调
examples/notebooks/gpt2-sentiment.ipynb 此笔记本演示了如何在 Jupyter Notebook 上重现 GPT2 imdb 情感微调示例。
examples/notebooks/gpt2-control.ipynb 在 Colab 中打开 此笔记本演示了如何在 Jupyter Notebook 上重现 GPT2 情感控制示例。

用法

# 1. run directly
python examples/scripts/ppo.py
# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed)
accelerate config # will prompt you to define the training configuration
accelerate launch examples/scripts/ppo.py # launches training
# 3. get help text and documentation
python examples/scripts/ppo.py --help
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16

注意:如果您不想使用 wandb 记录,请在脚本/笔记本中删除 log_with="wandb"。您也可以将其替换为您喜欢的实验跟踪器,该跟踪器受 accelerate 支持

多 GPU 的一些说明

要在使用 DDP(分布式数据并行)的多 GPU 设置中运行,请将 device_map 值更改为 device_map={"": Accelerator().process_index},并确保使用 accelerate launch yourscript.py 运行您的脚本。如果您想应用简单的管道并行,可以使用 device_map="auto"

基准测试

以下是 examples/scripts/ppo.py 的一些基准测试结果。要本地复现,请查看下面的 --command 参数。

python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template

使用和不使用梯度累积

python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template

比较不同的模型 (gpt2, gpt2-xl, falcon, llama2)

python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template
python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template

使用和不使用PEFT

python benchmark/benchmark.py \
    --command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
    --num-seeds 5 \
    --start-seed 1 \
    --workers 10 \
    --slurm-nodes 1 \
    --slurm-gpus-per-task 1 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 12 \
    --slurm-template-path benchmark/trl.slurm_template

< > 在GitHub上更新