Transformers 文档

训练脚本

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

训练脚本

Transformers 为深度学习框架(PyTorch、TensorFlow、Flax)和 transformers/examples 中的任务提供了许多示例训练脚本。 transformers/research projectstransformers/legacy 中还有其他脚本,但这些脚本未得到积极维护,并且需要特定版本的 Transformers。

示例脚本仅为示例,您可能需要根据您的用例调整脚本。为了帮助您做到这一点,大多数脚本在数据预处理方面都非常透明,允许您根据需要对其进行编辑。

对于您想在示例脚本中实现的任何功能,请在提交拉取请求之前在论坛issue 中讨论它。虽然我们欢迎贡献,但以可读性为代价添加更多功能的拉取请求不太可能被添加。

本指南将向您展示如何在 PyTorchTensorFlow 中运行示例摘要训练脚本。

设置

从源代码在新虚拟环境中安装 Transformers,以运行最新版本的示例脚本。

git clone https://github.com/huggingface/transformers
cd transformers
pip install .

运行以下命令以从特定或较旧版本的 Transformers 检出脚本。

git checkout tags/v3.5.1

在您设置了正确的版本后,导航到您选择的示例文件夹并安装示例特定的要求。

pip install -r requirements.txt

运行脚本

从较小的数据集开始,包括 max_train_samplesmax_eval_samplesmax_predict_samples 参数,以将数据集截断为最大样本数。这有助于确保训练按预期工作,然后再提交到可能需要数小时才能完成的整个数据集。

并非所有示例脚本都支持 max_predict_samples 参数。 运行以下命令以检查脚本是否支持它。

examples/pytorch/summarization/run_summarization.py -h

以下示例在 CNN/DailyMail 数据集上微调 T5-small。 T5 需要额外的 source_prefix 参数来提示其进行摘要。

PyTorch
TensorFlow

示例脚本下载并预处理数据集,然后使用 Trainer 和受支持的模型架构对其进行微调。

如果训练中断,从检查点恢复训练非常有用,因为您不必重新开始。 有两种方法可以从检查点恢复训练。

  • --output dir previous_output_diroutput_dir 中存储的最新检查点恢复训练。 如果您使用此方法,请删除 --overwrite_output_dir 参数。
  • --resume_from_checkpoint path_to_specific_checkpoint 从特定的检查点文件夹恢复训练。

使用 --push_to_hub 参数在 Hub 上共享您的模型。 它创建一个存储库并将模型上传到 --output_dir 中指定的文件夹名称。 您也可以使用 --push_to_hub_model_id 参数来指定存储库名称。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    # remove the `max_train_samples`, `max_eval_samples` and `max_predict_samples` if everything works
    --max_train_samples 50 \
    --max_eval_samples 50 \
    --max_predict_samples 50 \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --push_to_hub \
    --push_to_hub_model_id finetuned-t5-cnn_dailymail \
    # remove if using `output_dir previous_output_dir`
    # --overwrite_output_dir \
    --output_dir previous_output_dir \
    # --resume_from_checkpoint path_to_specific_checkpoint \
    --predict_with_generate \

对于混合精度和分布式训练,请包含以下参数并使用 torchrun 启动训练。

  • 添加 fp16bf16 参数以启用混合精度训练。 XPU 设备仅支持 bf16
  • 添加 nproc_per_node 参数以设置用于训练的 GPU 数量。
torchrun \
    --nproc_per_node 8 pytorch/summarization/run_summarization.py \
    --fp16 \
    ...
    ...

PyTorch 通过 PyTorch/XLA 包支持 TPU,这是一种旨在加速性能的硬件。 启动 xla_spawn.py 脚本并使用 num _cores 设置要用于训练的 TPU 核心数。

python xla_spawn.py --num_cores 8 pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    ...
    ...

Accelerate

Accelerate 旨在简化分布式训练,同时提供 PyTorch 训练循环的完整可见性。 如果您计划使用带有 Accelerate 的脚本进行训练,请使用 _no_trainer.py 版本的脚本。

从源代码安装 Accelerate 以确保您拥有最新版本。

pip install git+https://github.com/huggingface/accelerate

运行 accelerate config 命令,回答有关您的训练设置的几个问题。 这将创建并保存有关您系统的配置文件。

accelerate config

您可以使用 accelerate test 来确保您的系统配置正确。

accelerate test

运行 accelerate launch 以开始训练。

accelerate launch run_summarization_no_trainer.py \
    --model_name_or_path google-t5/t5-small \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ~/tmp/tst-summarization \

自定义数据集

摘要脚本支持自定义数据集,只要它们是 CSV 或 JSONL 文件。 当使用您自己的数据集时,您需要指定以下附加参数。

  • train_filevalidation_file 指定您的训练和验证文件的路径。
  • text_column 是要摘要的输入文本。
  • summary_column 是要输出的目标文本。

下面显示了摘要自定义数据集的示例命令。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --text_column text_column_name \
    --summary_column summary_column_name \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
< > 在 GitHub 上更新