Transformers 文档
训练脚本
并获得增强的文档体验
开始使用
训练脚本
Transformers 为深度学习框架(PyTorch、TensorFlow、Flax)和 transformers/examples 中的任务提供了许多示例训练脚本。 transformers/research projects 和 transformers/legacy 中还有其他脚本,但这些脚本未得到积极维护,并且需要特定版本的 Transformers。
示例脚本仅为示例,您可能需要根据您的用例调整脚本。为了帮助您做到这一点,大多数脚本在数据预处理方面都非常透明,允许您根据需要对其进行编辑。
对于您想在示例脚本中实现的任何功能,请在提交拉取请求之前在论坛或 issue 中讨论它。虽然我们欢迎贡献,但以可读性为代价添加更多功能的拉取请求不太可能被添加。
本指南将向您展示如何在 PyTorch 和 TensorFlow 中运行示例摘要训练脚本。
设置
从源代码在新虚拟环境中安装 Transformers,以运行最新版本的示例脚本。
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
运行以下命令以从特定或较旧版本的 Transformers 检出脚本。
git checkout tags/v3.5.1
在您设置了正确的版本后,导航到您选择的示例文件夹并安装示例特定的要求。
pip install -r requirements.txt
运行脚本
从较小的数据集开始,包括 max_train_samples
、max_eval_samples
和 max_predict_samples
参数,以将数据集截断为最大样本数。这有助于确保训练按预期工作,然后再提交到可能需要数小时才能完成的整个数据集。
并非所有示例脚本都支持 max_predict_samples
参数。 运行以下命令以检查脚本是否支持它。
examples/pytorch/summarization/run_summarization.py -h
以下示例在 CNN/DailyMail 数据集上微调 T5-small。 T5 需要额外的 source_prefix
参数来提示其进行摘要。
示例脚本下载并预处理数据集,然后使用 Trainer 和受支持的模型架构对其进行微调。
如果训练中断,从检查点恢复训练非常有用,因为您不必重新开始。 有两种方法可以从检查点恢复训练。
--output dir previous_output_dir
从output_dir
中存储的最新检查点恢复训练。 如果您使用此方法,请删除--overwrite_output_dir
参数。--resume_from_checkpoint path_to_specific_checkpoint
从特定的检查点文件夹恢复训练。
使用 --push_to_hub
参数在 Hub 上共享您的模型。 它创建一个存储库并将模型上传到 --output_dir
中指定的文件夹名称。 您也可以使用 --push_to_hub_model_id
参数来指定存储库名称。
python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path google-t5/t5-small \
# remove the `max_train_samples`, `max_eval_samples` and `max_predict_samples` if everything works
--max_train_samples 50 \
--max_eval_samples 50 \
--max_predict_samples 50 \
--do_train \
--do_eval \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--push_to_hub \
--push_to_hub_model_id finetuned-t5-cnn_dailymail \
# remove if using `output_dir previous_output_dir`
# --overwrite_output_dir \
--output_dir previous_output_dir \
# --resume_from_checkpoint path_to_specific_checkpoint \
--predict_with_generate \
对于混合精度和分布式训练,请包含以下参数并使用 torchrun 启动训练。
- 添加
fp16
或bf16
参数以启用混合精度训练。 XPU 设备仅支持bf16
。 - 添加
nproc_per_node
参数以设置用于训练的 GPU 数量。
torchrun \ --nproc_per_node 8 pytorch/summarization/run_summarization.py \ --fp16 \ ... ...
PyTorch 通过 PyTorch/XLA 包支持 TPU,这是一种旨在加速性能的硬件。 启动 xla_spawn.py
脚本并使用 num _cores
设置要用于训练的 TPU 核心数。
python xla_spawn.py --num_cores 8 pytorch/summarization/run_summarization.py \ --model_name_or_path google-t5/t5-small \ ... ...
Accelerate
Accelerate 旨在简化分布式训练,同时提供 PyTorch 训练循环的完整可见性。 如果您计划使用带有 Accelerate 的脚本进行训练,请使用 _no_trainer.py
版本的脚本。
从源代码安装 Accelerate 以确保您拥有最新版本。
pip install git+https://github.com/huggingface/accelerate
运行 accelerate config 命令,回答有关您的训练设置的几个问题。 这将创建并保存有关您系统的配置文件。
accelerate config
您可以使用 accelerate test 来确保您的系统配置正确。
accelerate test
运行 accelerate launch 以开始训练。
accelerate launch run_summarization_no_trainer.py \
--model_name_or_path google-t5/t5-small \
--dataset_name cnn_dailymail \
--dataset_config "3.0.0" \
--source_prefix "summarize: " \
--output_dir ~/tmp/tst-summarization \
自定义数据集
摘要脚本支持自定义数据集,只要它们是 CSV 或 JSONL 文件。 当使用您自己的数据集时,您需要指定以下附加参数。
train_file
和validation_file
指定您的训练和验证文件的路径。text_column
是要摘要的输入文本。summary_column
是要输出的目标文本。
下面显示了摘要自定义数据集的示例命令。
python examples/pytorch/summarization/run_summarization.py \
--model_name_or_path google-t5/t5-small \
--do_train \
--do_eval \
--train_file path_to_csv_or_jsonlines_file \
--validation_file path_to_csv_or_jsonlines_file \
--text_column text_column_name \
--summary_column summary_column_name \
--source_prefix "summarize: " \
--output_dir /tmp/tst-summarization \
--overwrite_output_dir \
--per_device_train_batch_size=4 \
--per_device_eval_batch_size=4 \
--predict_with_generate \