optimum-tpu 文档

在 Google Cloud 上首次进行 TPU 训练

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

在 Google Cloud 上首次进行 TPU 训练

本教程将引导您使用 optimum-tpu 包在 TPU 上设置和运行模型训练。

先决条件

开始之前,请确保您有一个正在运行的 TPU 实例(参见TPU 设置指南

环境设置

首先,创建并激活虚拟环境

python -m venv .venv
source .venv/bin/activate

安装所需包

# Install optimum-tpu with PyTorch/XLA support
pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html

# Install additional training dependencies
pip install transformers datasets accelerate trl peft evaluate

了解用于 TPU 训练的 FSDP

为了加快 TPU 上的训练速度,您可以依赖 Optimum TPU 与 FSDP(Fully Sharded Data Parallel,完全分片数据并行)的集成。在训练大型模型时,FSDP 会自动将模型分片(拆分)到所有可用的 TPU 工作器上,提供几个关键优势:

  1. 内存效率:每个 TPU 工作器只存储模型参数的一部分,从而降低了每个设备的内存需求。
  2. 自动扩展:FSDP 处理模型分发和梯度聚合的复杂性。
  3. 性能优化:Optimum TPU 的实现专门针对 TPU 硬件进行了调整。

当您在训练设置中使用 fsdp_v2.get_fsdp_training_args(model) 配置时,这种分片会自动发生,从而可以轻松训练无法在单个 TPU 设备上容纳的更大模型。

如何设置 FSDP

启用 FSDP 的关键修改只有这几行

+from optimum.tpu import fsdp_v2
+fsdp_v2.use_fsdp_v2()
+fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

然后将这些参数包含在您的训练器配置中

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        ...
+       dataloader_drop_last=True,  # Required for FSDPv2
+       **fsdp_training_args,
    ),
    ...
)

完整示例

这是一个演示使用 FSDP 进行 TPU 训练的完整工作示例

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from optimum.tpu import fsdp_v2

# Enable FSDPv2 for TPU
fsdp_v2.use_fsdp_v2()

# Load model and dataset
model_id = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
dataset = load_dataset("tatsu-lab/alpaca", split="train[:1000]")

# Get FSDP training arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

# Create trainer with minimal configuration
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        output_dir="./output",
        dataloader_drop_last=True,  # Required for FSDPv2
        **fsdp_training_args,
    ),
    peft_config=LoraConfig(
        r=8,
        target_modules=["k_proj", "v_proj"],
        task_type="CAUSAL_LM",
    ),
)

# Start training
trainer.train()

将此代码保存为 train.py 并运行

python train.py

您现在应该会看到训练期间损失值下降。训练完成后,您将获得一个微调模型。恭喜 - 您刚刚在 TPU 上训练了您的第一个模型!🙌

后续步骤

通过探索以下内容继续您的 TPU 训练之旅: