optimum-tpu 文档

首次在 Google Cloud 上进行 TPU 训练

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

首次在 Google Cloud 上进行 TPU 训练

本教程将引导您完成使用 optimum-tpu 包在 TPU 上设置和运行模型训练的过程。

前提条件

在开始之前,请确保您有一个正在运行的 TPU 实例(请参阅TPU 设置指南

环境设置

首先,创建并激活一个虚拟环境

python -m venv .venv
source .venv/bin/activate

安装所需的软件包

# Install optimum-tpu with PyTorch/XLA support
pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html

# Install additional training dependencies
pip install transformers datasets accelerate trl peft evaluate

了解用于 TPU 训练的 FSDP

为了加速您在 TPU 上的训练,您可以依赖 Optimum TPU 与 FSDP(完全分片数据并行)的集成。在训练大型模型时,FSDP 会自动将您的模型分片(拆分)到所有可用的 TPU worker 上,从而提供以下几个关键优势

  1. 内存效率:每个 TPU worker 仅存储模型参数的一部分,从而降低了每个设备的内存需求
  2. 自动缩放:FSDP 处理分发模型和聚合梯度的复杂性
  3. 性能优化:Optimum TPU 的实现专门针对 TPU 硬件进行了调整

当您在训练设置中使用 fsdp_v2.get_fsdp_training_args(model) 配置时,这种分片会自动发生,从而可以轻松训练更大的模型,而这些模型无法在单个 TPU 设备上运行。

如何设置 FSDP

启用 FSDP 的关键修改仅仅是这几行代码

+from optimum.tpu import fsdp_v2
+fsdp_v2.use_fsdp_v2()
+fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

然后将这些参数包含在您的 trainer 配置中

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        ...
+       dataloader_drop_last=True,  # Required for FSDPv2
+       **fsdp_training_args,
    ),
    ...
)

完整示例

这是一个完整的可运行示例,演示了使用 FSDP 进行 TPU 训练

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from optimum.tpu import fsdp_v2

# Enable FSDPv2 for TPU
fsdp_v2.use_fsdp_v2()

# Load model and dataset
model_id = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")

# Get FSDP training arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

# Create trainer with minimal configuration
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        output_dir="./output",
        dataloader_drop_last=True,  # Required for FSDPv2
        **fsdp_training_args,
    ),
    peft_config=LoraConfig(
        r=8,
        target_modules=["k_proj", "v_proj"],
        task_type="CAUSAL_LM",
    ),
)

# Start training
trainer.train()

将此代码保存为 train.py 并运行它

python train.py

您现在应该看到损失在训练期间减少。当训练完成时,您将拥有一个微调模型。恭喜 - 您刚刚在 TPU 上训练了您的第一个模型!🙌

下一步

通过探索以下内容,继续您的 TPU 训练之旅