PyTorch 在 Apple 芯片上的训练
之前,在 Mac 上训练模型只能使用 CPU。随着 PyTorch v1.12 的发布,您可以利用 Apple 芯片的 GPU 训练模型,从而实现显著更快的性能和训练速度。这得益于 PyTorch 集成了 Apple 的 Metal 性能着色器 (MPS) 作为后端。 MPS 后端 将 PyTorch 操作实现为自定义 Metal 着色器,并将这些模块放置在 mps
设备上。
一些 PyTorch 操作尚未在 MPS 中实现,会抛出错误。为了避免这种情况,您应该设置环境变量 PYTORCH_ENABLE_MPS_FALLBACK=1
,以使用 CPU 内核(您仍然会看到 UserWarning
)。
设置 mps
设备后,您可以
- 在本地训练更大的网络或批次大小
- 减少数据检索延迟,因为 GPU 的统一内存架构允许直接访问完整的内存存储
- 降低成本,因为您不需要在云 GPU 上进行训练,也不需要添加额外的本地 GPU
首先,确保您已安装 PyTorch。MPS 加速在 macOS 12.3 及更高版本上受支持。
pip install torch torchvision torchaudio
TrainingArguments 默认情况下使用 mps
设备(如果可用),这意味着您不需要显式设置设备。例如,您可以运行 run_glue.py 脚本,MPS 后端会自动启用,无需进行任何更改。
export TASK_NAME=mrpc
python examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
- --use_mps_device \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
mps
设备不支持 gloo
和 nccl
等 分布式设置 的后端,这意味着您只能使用 MPS 后端在单个 GPU 上进行训练。
您可以在 在 Mac 上介绍加速 PyTorch 训练 博客文章中了解更多关于 MPS 后端的信息。
< > 在 GitHub 上更新