AWS Trainium & Inferentia 文档

🚀 使用 LoRA 微调 Qwen3 8B

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

🚀 使用 LoRA 微调 Qwen3 8B

本教程展示了如何使用 optimum-neuron 在 AWS Trainium 加速器上微调 Qwen3 模型。

本教程基于 Qwen3 微调示例脚本

1. 🛠️ 设置 AWS 环境

我们将使用一个拥有 16 个 Trainium 加速器(32 个 Neuron 核心)的 `trn1.32xlarge` 实例以及 Hugging Face Neuron 深度学习 AMI。

Hugging Face AMI 包含了所有预装的必需库

  • datasets, transformers, optimum-neuron
  • Neuron SDK 包
  • 无需额外环境设置

要创建您的实例,请遵循此处的指南。

2. 📊 加载并准备数据集

我们将使用 simple recipes 数据集来微调我们的模型以生成食谱。

{
    'recipes': "- Preheat oven to 350 degrees\n- Butter two 9x5' loaf pans\n- Cream the sugar and the butter until light and whipped\n- Add the bananas, eggs, lemon juice, orange rind\n- Beat until blended uniformly\n- Be patient, and beat until the banana lumps are gone\n- Sift the dry ingredients together\n- Fold lightly and thoroughly into the banana mixture\n- Pour the batter into prepared loaf pans\n- Bake for 45 to 55 minutes, until the loaves are firm in the middle and the edges begin to pull away from the pans\n- Cool the loaves on racks for 30 minutes before removing from the pans\n- Freezes well",
    'names': 'Beat this banana bread'
}

我们使用 `datasets` 库中的 `load_dataset()` 方法来加载数据集。

from random import randrange

from datasets import load_dataset


# Load dataset from the hub
dataset_id = "tengomucho/simple_recipes"
recipes = load_dataset(dataset_id, split="train")

dataset_size = len(recipes)
print(f"dataset size: {dataset_size}")
print(recipes[randrange(dataset_size)])
# dataset size: 20000

为了调整我们的模型,我们需要将结构化的示例转换为带有给定上下文的引用集合,因此我们定义了分词函数,以便可以将其映射到数据集上。

数据集应以输入-输出对的形式构建,其中每个输入是提示,输出是模型的预期响应。我们将利用模型的分词器聊天模板并对数据集进行预处理,以便输入到训练器中。

# Preprocesses the dataset
def preprocess_dataset_with_eos(eos_token):
    def preprocess_function(examples):
        recipes = examples["recipes"]
        names = examples["names"]

        chats = []
        for recipe, name in zip(recipes, names):
            # Append the EOS token to the response
            recipe += eos_token

            chat = [
                {"role": "user", "content": f"How can I make {name}?"},
                {"role": "assistant", "content": recipe},
            ]

            chats.append(chat)
        return {"messages": chats}

    dataset = recipes.map(preprocess_function, batched=True, remove_columns=recipes.column_names)
    return dataset

# Structures the dataset into prompt-expected output pairs.
def formatting_function(examples):
    return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)

注意:这些函数引用了 `eos_token` 和 `tokenizer`,它们在用于运行本教程的 Python 脚本中已经明确定义。

3. 🎯 使用 NeuronSFTTrainer 和 PEFT 微调 Qwen3

对于标准的 PyTorch 微调,您通常会使用带 LoRA 适配器的 PEFT`SFTTrainer`

在 AWS Trainium 上,`optimum-neuron` 提供了 `NeuronSFTTrainer` 作为直接替代品。

在 Trainium 上进行分布式训练: 由于 Qwen3 无法在单个加速器上运行,我们使用分布式训练技术

  • 数据并行 (DDP)
  • 张量并行
  • 流水线并行

模型加载和 LoRA 配置与其他加速器类似。

将所有部分组合在一起,并假设数据集已经加载,我们可以编写以下代码在 AWS Trainium 上微调 Qwen3

model_id = "Qwen/Qwen3-8B"

# Define the training arguments
output_dir = "qwen3-finetuned-recipes"
training_args = NeuronTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    do_train=True,
    max_steps=-1,  # -1 means train until the end of the dataset
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=5e-4,
    bf16=True,  
    tensor_parallel_size=8,
    logging_steps=2,
    lr_scheduler_type="cosine",
    overwrite_output_dir=True,
)

# Load the model with the NeuronModelForCausalLM class.
# It will load the model with a custom modeling speficically designed for AWS Trainium.
trn_config = training_args.trn_config
dtype = torch.bfloat16 if training_args.bf16 else torch.float32
model = NeuronModelForCausalLM.from_pretrained(
    model_id,
    trn_config,
    torch_dtype=dtype,
    # Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
    use_flash_attention_2=True,
)

lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    lora_dropout=0.05,
    target_modules=[
        "embed_tokens",
        "q_proj",
        "v_proj",
        "o_proj",
        "k_proj",
        "up_proj",
        "down_proj",
        "gate_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

# Converting the NeuronTrainingArguments to a dictionary to feed them to the NeuronSFTConfig.
args = training_args.to_dict()

sft_config = NeuronSFTConfig(
    max_seq_length=4096,
    packing=True,
    **args,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = preprocess_dataset_with_eos(tokenizer.eos_token)

 def formatting_function(examples):
     return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)

 # The NeuronSFTTrainer will use `formatting_function` to format the dataset and `lora_config` to apply LoRA on the
 # model.
 trainer = NeuronSFTTrainer(
     args=sft_config,
     model=model,
     peft_config=lora_config,
     tokenizer=tokenizer,
     train_dataset=dataset,
     formatting_func=formatting_function,
 )
 trainer.train()

📝 完整脚本可用: 上述所有步骤都已整合到一个即用型脚本中 finetune_qwen3.py

要启动训练,只需在您的 AWS Trainium 实例中运行以下命令

# Flags for Neuron compilation
export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation"
export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime
export MALLOC_ARENA_MAX=64 # Host OOM mitigation

# Variables for training
PROCESSES_PER_NODE=32
NUM_EPOCHS=3
TP_DEGREE=8
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=2
MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned"
DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE"
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
    MAX_STEPS=5
else
    MAX_STEPS=-1
fi

torchrun --nproc_per_node $PROCESSES_PER_NODE finetune_qwen3.py \
  --model_id $MODEL_NAME \
  --num_train_epochs $NUM_EPOCHS \
  --do_train \
  --max_steps $MAX_STEPS \
  --per_device_train_batch_size $BS \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --learning_rate 8e-4 \
  --bf16 \
  --tensor_parallel_size $TP_DEGREE \
  --zero_1 \
  --async_save \
  --logging_steps $LOGGING_STEPS \
  --output_dir $OUTPUT_DIR \
  --lr_scheduler_type "cosine" \
  --overwrite_output_dir

🔧 单命令执行: 完整的 bash 训练脚本 finetune_qwen3.sh 可用

./finetune_qwen3.sh

4. 🔄 整合并测试微调后的模型

在分布式训练期间,Optimum Neuron 会分别保存模型分片。这些分片在使用前需要进行整合。

使用 Optimum CLI 进行整合

optimum-cli neuron consolidate Qwen3-8B-finetuned Qwen3-8B-finetuned/adapter_default

这将创建一个 `adapter_model.safetensors` 文件,即我们在上一步中训练的 LoRA 适配器权重。我们现在可以重新加载模型并将其合并,以便加载进行评估

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig


MODEL_NAME = "Qwen/Qwen3-8B"
ADAPTER_PATH = "Qwen3-8B-finetuned/adapter_default"
MERGED_MODEL_PATH = "Qwen3-8B-recipes"

# Load base model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load adapter configuration and model
adapter_config = PeftConfig.from_pretrained(ADAPTER_PATH)
finetuned_model = PeftModel.from_pretrained(model, ADAPTER_PATH, config=adapter_config)

print("Saving tokenizer")
tokenizer.save_pretrained(MERGED_MODEL_PATH)
print("Saving model")
finetuned_model = finetuned_model.merge_and_unload()
finetuned_model.save_pretrained(MERGED_MODEL_PATH)

完成此步骤后,就可以用新的提示来测试模型了。

您已成功从 Qwen3 创建了一个微调模型!

5. 🤗 推送到 Hugging Face Hub

通过将您微调的模型上传到 Hugging Face Hub,与社区分享。

第 1 步:身份验证

huggingface-cli login

第 2 步:上传您的模型

from transformers import AutoModelForCausalLM, AutoTokenizer

MERGED_MODEL_PATH = "Qwen3-8B-recipes"
HUB_MODEL_NAME = "your-username/qwen3-8b-recipes"

# Load and push tokenizer
tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH)
tokenizer.push_to_hub(HUB_MODEL_NAME)

# Load and push model
model = AutoModelForCausalLM.from_pretrained(MERGED_MODEL_PATH)
model.push_to_hub(HUB_MODEL_NAME)

🎉 您微调的 Qwen3 模型现在已在 Hub 上可供他人使用!