Amazon SageMaker 文档

在 Amazon SageMaker 上训练和部署 Hugging Face

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

在 Amazon SageMaker 上训练和部署 Hugging Face

快速入门指南将向您展示如何在 Amazon SageMaker 上快速使用 Hugging Face。了解如何在 SageMaker 上微调和部署预训练的 🤗 Transformers 模型,以完成二元文本分类任务。

💡 如果您是 Hugging Face 的新手,我们建议您首先阅读 🤗 Transformers 快速入门

📓 打开 agemaker-notebook.ipynb 文件 并跟随操作!

安装与设置

首先安装必要的 Hugging Face 库和 SageMaker。如果您尚未安装,还需要安装 PyTorchTensorFlow

pip install "sagemaker>=2.140.0" "transformers==4.26.1" "datasets[s3]==2.10.1" --upgrade

如果您想在 SageMaker Studio 中运行此示例,请为 🤗 Datasets 库升级 ipywidgets 并重启内核

%%capture
import IPython
!conda install -c conda-forge ipywidgets -y
IPython.Application.instance().kernel.do_shutdown(True)

接下来,您应该设置您的环境:SageMaker 会话和 S3 存储桶。S3 存储桶将存储数据、模型和日志。您将需要访问具有所需权限的 IAM 执行角色

如果您计划在本地环境中使用 SageMaker,则需要自己提供 role。了解更多关于如何进行设置的信息 此处

⚠️ 执行角色仅在您在 SageMaker 内运行笔记本时可用。如果您尝试在非 SageMaker 的笔记本中运行 get_execution_role,您将收到区域错误。

import sagemaker

sess = sagemaker.Session()
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

预处理

🤗 Datasets 库使下载和预处理数据集以进行训练变得容易。下载并标记 IMDb 数据集

from datasets import load_dataset
from transformers import AutoTokenizer

# load dataset
train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

# create tokenization function
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)

# tokenize train and test datasets
train_dataset = train_dataset.map(tokenize, batched=True)
test_dataset = test_dataset.map(tokenize, batched=True)

# set dataset format for PyTorch
train_dataset =  train_dataset.rename_column("label", "labels")
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset = test_dataset.rename_column("label", "labels")
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

将数据集上传到 S3 存储桶

接下来,使用 🤗 Datasets S3 文件系统 实现将预处理的数据集上传到您的 S3 会话存储桶

# save train_dataset to s3
training_input_path = f's3://{sess.default_bucket()}/{s3_prefix}/train'
train_dataset.save_to_disk(training_input_path)

# save test_dataset to s3
test_input_path = f's3://{sess.default_bucket()}/{s3_prefix}/test'
test_dataset.save_to_disk(test_input_path)

启动训练作业

创建一个 Hugging Face Estimator 来处理端到端的 SageMaker 训练和部署。需要注意的最重要参数是

  • entry_point 指的是微调脚本,您可以在 train.py 文件中找到它。
  • instance_type 指的是将启动的 SageMaker 实例。请查看 此处 以获取实例类型的完整列表。
  • hyperparameters 指的是模型将使用其进行微调的训练超参数。
from sagemaker.huggingface import HuggingFace

hyperparameters={
    "epochs": 1,                                       # number of training epochs
    "train_batch_size": 32,                            # training batch size
    "model_name":"distilbert/distilbert-base-uncased"  # name of pretrained model
}

huggingface_estimator = HuggingFace(
    entry_point="train.py",                 # fine-tuning script to use in training job
    source_dir="./scripts",                 # directory where fine-tuning script is stored
    instance_type="ml.p3.2xlarge",          # instance type
    instance_count=1,                       # number of instances
    role=role,                              # IAM role used in training job to acccess AWS resources (S3)
    transformers_version="4.26",             # Transformers version
    pytorch_version="1.13",                  # PyTorch version
    py_version="py39",                      # Python version
    hyperparameters=hyperparameters         # hyperparameters to use in training job
)

用一行代码开始训练

huggingface_estimator.fit({"train": training_input_path, "test": test_input_path})

部署模型

训练作业完成后,通过调用 deploy() 以及实例数量和实例类型来部署微调后的模型

predictor = huggingface_estimator.deploy(initial_instance_count=1,"ml.g4dn.xlarge")

在您的数据上调用 predict()

sentiment_input = {"inputs": "It feels like a curtain closing...there was an elegance in the way they moved toward conclusion. No fan is going to watch and feel short-changed."}

predictor.predict(sentiment_input)

运行请求后,删除端点

predictor.delete_endpoint()

下一步是什么?

恭喜您,您刚刚在 SageMaker 上微调和部署了一个预训练的 🤗 Transformers 模型! 🎉

对于您的后续步骤,请继续阅读我们的文档,了解有关训练和部署的更多详细信息。还有许多有趣的功能,例如 分布式训练Spot 实例

< > 在 GitHub 上更新