Transformers 文档

训练脚本

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

训练脚本

Transformers 为深度学习框架(PyTorch、TensorFlow、Flax)和任务在 transformers/examples 中提供了许多示例训练脚本。在 transformers/research projectstransformers/legacy 中还有其他脚本,但这些脚本没有积极维护,并且需要特定版本的 Transformers。

示例脚本只是示例,你可能需要根据你的用例调整脚本。为了帮助你,大多数脚本在数据预处理方面都非常透明,允许你根据需要进行编辑。

对于你想在示例脚本中实现的任何功能,请在提交拉取请求之前在 论坛issue 中讨论。虽然我们欢迎贡献,但增加更多功能但牺牲可读性的拉取请求不太可能被接受。

本指南将向你展示如何在 PyTorchTensorFlow 中运行示例摘要训练脚本。

设置

在新虚拟环境中从源代码安装 Transformers,以运行最新版本的示例脚本。

git clone https://github.com/huggingface/transformers
cd transformers
pip install .

运行以下命令以从特定或旧版本 Transformers 中检出脚本。

git checkout tags/v3.5.1

设置正确版本后,导航到你选择的示例文件夹并安装示例特定的依赖项。

pip install -r requirements.txt

运行脚本

通过包含 `max_train_samples`、`max_eval_samples` 和 `max_predict_samples` 参数来截断数据集到最大样本数,从而使用较小的数据集开始。这有助于确保训练按预期进行,然后再提交整个数据集,这可能需要数小时才能完成。

并非所有示例脚本都支持 `max_predict_samples` 参数。运行以下命令以检查脚本是否支持它。

examples/pytorch/summarization/run_summarization.py -h

以下示例在 CNN/DailyMail 数据集上对 T5-small 进行微调。T5 需要一个额外的 `source_prefix` 参数来提示它进行摘要。

PyTorch
TensorFlow

示例脚本下载并预处理数据集,然后使用 Trainer 和受支持的模型架构对其进行微调。

如果训练中断,从检查点恢复训练非常有用,因为你无需从头开始。有两种方法可以从检查点恢复训练。

  • `--output dir previous_output_dir` 从存储在 `output_dir` 中的最新检查点恢复训练。如果使用此方法,请删除 `--overwrite_output_dir` 参数。
  • `--resume_from_checkpoint path_to_specific_checkpoint` 从特定的检查点文件夹恢复训练。

使用 `--push_to_hub` 参数在 Hub 上分享你的模型。它会创建一个仓库并将模型上传到 `--output_dir` 中指定的文件夹名称。你也可以使用 `--push_to_hub_model_id` 参数来指定仓库名称。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    # remove the `max_train_samples`, `max_eval_samples` and `max_predict_samples` if everything works
    --max_train_samples 50 \
    --max_eval_samples 50 \
    --max_predict_samples 50 \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --push_to_hub \
    --push_to_hub_model_id finetuned-t5-cnn_dailymail \
    # remove if using `output_dir previous_output_dir`
    # --overwrite_output_dir \
    --output_dir previous_output_dir \
    # --resume_from_checkpoint path_to_specific_checkpoint \
    --predict_with_generate \

对于混合精度和分布式训练,请包含以下参数并使用 torchrun 启动训练。

  • 添加 `fp16` 或 `bf16` 参数以启用混合精度训练。XPU 设备仅支持 `bf16`。
  • 添加 `nproc_per_node` 参数以设置要训练的 GPU 数量。
torchrun \
    --nproc_per_node 8 pytorch/summarization/run_summarization.py \
    --fp16 \
    ...
    ...

PyTorch 通过 PyTorch/XLA 包支持 TPU,这是一种旨在加速性能的硬件。启动 `xla_spawn.py` 脚本并使用 `num_cores` 设置要训练的 TPU 核心数量。

python xla_spawn.py --num_cores 8 pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    ...
    ...

Accelerate

Accelerate 旨在简化分布式训练,同时提供对 PyTorch 训练循环的完全可见性。如果你计划使用 Accelerate 训练脚本,请使用脚本的 `_no_trainer.py` 版本。

从源代码安装 Accelerate,以确保你拥有最新版本。

pip install git+https://github.com/huggingface/accelerate

运行 accelerate config 命令,回答有关你的训练设置的几个问题。这将创建并保存一个关于你系统的配置文件。

accelerate config

你可以使用 accelerate test 确保你的系统已正确配置。

accelerate test

运行 accelerate launch 以开始训练。

accelerate launch run_summarization_no_trainer.py \
    --model_name_or_path google-t5/t5-small \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ~/tmp/tst-summarization \

自定义数据集

摘要脚本支持自定义数据集,只要它们是 CSV 或 JSONL 文件。使用你自己的数据集时,你需要指定以下附加参数。

  • `train_file` 和 `validation_file` 指定训练和验证文件的路径。
  • `text_column` 是要摘要的输入文本。
  • `summary_column` 是要输出的目标文本。

下面显示了摘要自定义数据集的示例命令。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --text_column text_column_name \
    --summary_column summary_column_name \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
< > 在 GitHub 上更新