AWS Trainium & Inferentia 文档
🚀 使用 LoRA 微调 Qwen3 8B
并获得增强的文档体验
开始使用
🚀 使用 LoRA 微调 Qwen3 8B
本教程展示了如何使用 optimum-neuron 在 AWS Trainium 加速器上微调 Qwen3 模型。
本教程基于 Qwen3 微调示例脚本。
1. 🛠️ 设置 AWS 环境
我们将使用一个拥有 16 个 Trainium 加速器(32 个 Neuron 核心)的 `trn1.32xlarge` 实例以及 Hugging Face Neuron 深度学习 AMI。
Hugging Face AMI 包含了所有预装的必需库
datasets
,transformers
,optimum-neuron
- Neuron SDK 包
- 无需额外环境设置
要创建您的实例,请遵循此处的指南。
2. 📊 加载并准备数据集
我们将使用 simple recipes 数据集来微调我们的模型以生成食谱。
{
'recipes': "- Preheat oven to 350 degrees\n- Butter two 9x5' loaf pans\n- Cream the sugar and the butter until light and whipped\n- Add the bananas, eggs, lemon juice, orange rind\n- Beat until blended uniformly\n- Be patient, and beat until the banana lumps are gone\n- Sift the dry ingredients together\n- Fold lightly and thoroughly into the banana mixture\n- Pour the batter into prepared loaf pans\n- Bake for 45 to 55 minutes, until the loaves are firm in the middle and the edges begin to pull away from the pans\n- Cool the loaves on racks for 30 minutes before removing from the pans\n- Freezes well",
'names': 'Beat this banana bread'
}
我们使用 `datasets` 库中的 `load_dataset()` 方法来加载数据集。
from random import randrange
from datasets import load_dataset
# Load dataset from the hub
dataset_id = "tengomucho/simple_recipes"
recipes = load_dataset(dataset_id, split="train")
dataset_size = len(recipes)
print(f"dataset size: {dataset_size}")
print(recipes[randrange(dataset_size)])
# dataset size: 20000
为了调整我们的模型,我们需要将结构化的示例转换为带有给定上下文的引用集合,因此我们定义了分词函数,以便可以将其映射到数据集上。
数据集应以输入-输出对的形式构建,其中每个输入是提示,输出是模型的预期响应。我们将利用模型的分词器聊天模板并对数据集进行预处理,以便输入到训练器中。
# Preprocesses the dataset
def preprocess_dataset_with_eos(eos_token):
def preprocess_function(examples):
recipes = examples["recipes"]
names = examples["names"]
chats = []
for recipe, name in zip(recipes, names):
# Append the EOS token to the response
recipe += eos_token
chat = [
{"role": "user", "content": f"How can I make {name}?"},
{"role": "assistant", "content": recipe},
]
chats.append(chat)
return {"messages": chats}
dataset = recipes.map(preprocess_function, batched=True, remove_columns=recipes.column_names)
return dataset
# Structures the dataset into prompt-expected output pairs.
def formatting_function(examples):
return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)
注意:这些函数引用了 `eos_token` 和 `tokenizer`,它们在用于运行本教程的 Python 脚本中已经明确定义。
3. 🎯 使用 NeuronSFTTrainer 和 PEFT 微调 Qwen3
对于标准的 PyTorch 微调,您通常会使用带 LoRA 适配器的 PEFT 和 `SFTTrainer`。
在 AWS Trainium 上,`optimum-neuron` 提供了 `NeuronSFTTrainer` 作为直接替代品。
在 Trainium 上进行分布式训练: 由于 Qwen3 无法在单个加速器上运行,我们使用分布式训练技术
- 数据并行 (DDP)
- 张量并行
- 流水线并行
模型加载和 LoRA 配置与其他加速器类似。
将所有部分组合在一起,并假设数据集已经加载,我们可以编写以下代码在 AWS Trainium 上微调 Qwen3
model_id = "Qwen/Qwen3-8B"
# Define the training arguments
output_dir = "qwen3-finetuned-recipes"
training_args = NeuronTrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
do_train=True,
max_steps=-1, # -1 means train until the end of the dataset
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=5e-4,
bf16=True,
tensor_parallel_size=8,
logging_steps=2,
lr_scheduler_type="cosine",
overwrite_output_dir=True,
)
# Load the model with the NeuronModelForCausalLM class.
# It will load the model with a custom modeling speficically designed for AWS Trainium.
trn_config = training_args.trn_config
dtype = torch.bfloat16 if training_args.bf16 else torch.float32
model = NeuronModelForCausalLM.from_pretrained(
model_id,
trn_config,
torch_dtype=dtype,
# Use FlashAttention2 for better performance and to be able to use larger sequence lengths.
use_flash_attention_2=True,
)
lora_config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.05,
target_modules=[
"embed_tokens",
"q_proj",
"v_proj",
"o_proj",
"k_proj",
"up_proj",
"down_proj",
"gate_proj",
],
bias="none",
task_type="CAUSAL_LM",
)
# Converting the NeuronTrainingArguments to a dictionary to feed them to the NeuronSFTConfig.
args = training_args.to_dict()
sft_config = NeuronSFTConfig(
max_seq_length=4096,
packing=True,
**args,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = preprocess_dataset_with_eos(tokenizer.eos_token)
def formatting_function(examples):
return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)
# The NeuronSFTTrainer will use `formatting_function` to format the dataset and `lora_config` to apply LoRA on the
# model.
trainer = NeuronSFTTrainer(
args=sft_config,
model=model,
peft_config=lora_config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=formatting_function,
)
trainer.train()
📝 完整脚本可用: 上述所有步骤都已整合到一个即用型脚本中 finetune_qwen3.py。
要启动训练,只需在您的 AWS Trainium 实例中运行以下命令
# Flags for Neuron compilation
export NEURON_CC_FLAGS="--model-type transformer --retry_failed_compilation"
export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 # Async Runtime
export MALLOC_ARENA_MAX=64 # Host OOM mitigation
# Variables for training
PROCESSES_PER_NODE=32
NUM_EPOCHS=3
TP_DEGREE=8
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=2
MODEL_NAME="Qwen/Qwen3-8B" # Change this to the desired model name
OUTPUT_DIR="$(echo $MODEL_NAME | cut -d'/' -f2)-finetuned"
DISTRIBUTED_ARGS="--nproc_per_node $PROCESSES_PER_NODE"
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
MAX_STEPS=5
else
MAX_STEPS=-1
fi
torchrun --nproc_per_node $PROCESSES_PER_NODE finetune_qwen3.py \
--model_id $MODEL_NAME \
--num_train_epochs $NUM_EPOCHS \
--do_train \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BS \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--learning_rate 8e-4 \
--bf16 \
--tensor_parallel_size $TP_DEGREE \
--zero_1 \
--async_save \
--logging_steps $LOGGING_STEPS \
--output_dir $OUTPUT_DIR \
--lr_scheduler_type "cosine" \
--overwrite_output_dir
🔧 单命令执行: 完整的 bash 训练脚本 finetune_qwen3.sh 可用
./finetune_qwen3.sh
4. 🔄 整合并测试微调后的模型
在分布式训练期间,Optimum Neuron 会分别保存模型分片。这些分片在使用前需要进行整合。
使用 Optimum CLI 进行整合
optimum-cli neuron consolidate Qwen3-8B-finetuned Qwen3-8B-finetuned/adapter_default
这将创建一个 `adapter_model.safetensors` 文件,即我们在上一步中训练的 LoRA 适配器权重。我们现在可以重新加载模型并将其合并,以便加载进行评估
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
MODEL_NAME = "Qwen/Qwen3-8B"
ADAPTER_PATH = "Qwen3-8B-finetuned/adapter_default"
MERGED_MODEL_PATH = "Qwen3-8B-recipes"
# Load base model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load adapter configuration and model
adapter_config = PeftConfig.from_pretrained(ADAPTER_PATH)
finetuned_model = PeftModel.from_pretrained(model, ADAPTER_PATH, config=adapter_config)
print("Saving tokenizer")
tokenizer.save_pretrained(MERGED_MODEL_PATH)
print("Saving model")
finetuned_model = finetuned_model.merge_and_unload()
finetuned_model.save_pretrained(MERGED_MODEL_PATH)
完成此步骤后,就可以用新的提示来测试模型了。
您已成功从 Qwen3 创建了一个微调模型!
5. 🤗 推送到 Hugging Face Hub
通过将您微调的模型上传到 Hugging Face Hub,与社区分享。
第 1 步:身份验证
huggingface-cli login
第 2 步:上传您的模型
from transformers import AutoModelForCausalLM, AutoTokenizer
MERGED_MODEL_PATH = "Qwen3-8B-recipes"
HUB_MODEL_NAME = "your-username/qwen3-8b-recipes"
# Load and push tokenizer
tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH)
tokenizer.push_to_hub(HUB_MODEL_NAME)
# Load and push model
model = AutoModelForCausalLM.from_pretrained(MERGED_MODEL_PATH)
model.push_to_hub(HUB_MODEL_NAME)
🎉 您微调的 Qwen3 模型现在已在 Hub 上可供他人使用!