AWS Trainium & Inferentia 文档
为训练贡献自定义模型
并获得增强的文档体验
开始使用
为训练贡献自定义模型
本指南介绍了如何向 `optimum/neuron/models/training/` 目录中添加自定义模型实现。需要自定义模型来支持 AWS Trainium 设备上的分布式训练功能,例如张量并行、流水线并行和序列并行。
架构组件
1. NeuronModelMixin
`NeuronModelMixin` 类提供了核心功能:
- `from_pretrained()`:将常规的 Transformers 权重加载到自定义实现中
- `save_pretrained()`:保存带有合并元数据的分片检查点
- 通过 `PIPELINE_*` 属性支持流水线并行
2. 权重转换规范
转换规范处理权重在以下格式之间的转换:
- 原始 Transformers 格式 → 自定义并行格式(加载期间)
- 自定义并行格式 → 原始 Transformers 格式(检查点合并期间)
关键的转换规范类型:
- `FusedLinearsSpec`:处理融合的线性层(例如 `gate_up_proj`)
- `GQAQKVColumnParallelLinearSpec`:处理张量并行大小大于键值头数量时的分组查询注意力投影
有关所有转换规范和实用函数的完整 API 文档,请参阅模型权重转换规范 API 参考。
3. 并行层
使用 `neuronx_distributed` 中的这些并行层:
- `ColumnParallelLinear`:沿输出维度拆分权重矩阵
- `RowParallelLinear`:沿输入维度拆分权重矩阵
- `ParallelEmbedding`:在不同 rank 之间拆分嵌入表
- `GQAQKVColumnParallelLinear`:专门用于张量并行大小大于键值头数量时的分组查询注意力投影
实现步骤
步骤 1:创建模型结构
创建一个新目录:`optimum/neuron/models/training/your_model/`
__init__.py
from .modeling_your_model import YourModelForCausalLM, YourModel
__all__ = ["YourModelForCausalLM", "YourModel"]
步骤 2:实现模型构建块
modeling_your_model.py
导入和依赖
import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
ColumnParallelLinear,
RowParallelLinear,
ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig
from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
CustomModule,
FusedLinearsSpec,
GQAQKVColumnParallelLinearSpec,
ModelWeightTransformationSpecs,
)
嵌入层
class YourModelEmbeddings(nn.Module):
def __init__(self, config, trn_config):
super().__init__()
self.embed_tokens = ParallelEmbedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
)
带融合线性层的 MLP 层
重要提示:任何具有转换规范的模块都必须继承自 `CustomModule`,以确保正确处理权重转换,并且转换规范必须定义在 `self.specs` 属性中。
class YourModelMLP(nn.Module, CustomModule):
def __init__(self, config, trn_config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# Fused gate and up projections
self.gate_up_proj = ColumnParallelLinear(
self.hidden_size,
2 * self.intermediate_size,
stride=2, # Important for proper sharding
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# Define transformation specs
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
FusedLinearsSpec(
fused_linear_name="gate_up_proj",
linear_names=["gate_proj", "up_proj"],
bias=False,
fuse_axis="column", # Fuse along output dimension
original_dims=[self.intermediate_size, self.intermediate_size],
)
)
注意力层
注意力层的实现取决于模型的架构和张量并行配置。主要有三种变体:
1. 分离的 Q、K、V 投影(默认)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
# Separate projections for Q, K, V
self.q_proj = ColumnParallelLinear(
config.hidden_size,
self.num_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.k_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.v_proj = ColumnParallelLinear(
config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
config.hidden_size,
bias=False,
input_is_parallel=True,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# No transformation specs needed - regular parallel layers
self.specs = ModelWeightTransformationSpecs()
2. 融合的 QKV 投影(多头注意力)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
# ... (same setup as above)
tp_size = get_tensor_model_parallel_size()
# Only use fused QKV when num_heads == num_key_value_heads (no GQA)
if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
3 * self.num_heads * self.head_dim, # Q + K + V
stride=3, # Important for proper sharding
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
dtype=config.torch_dtype,
)
# Define transformation specs for fused QKV
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
FusedLinearsSpec(
fused_linear_name="qkv_proj",
linear_names=["q_proj", "k_proj", "v_proj"],
bias=False,
fuse_axis="column",
original_dims=[self.num_heads * self.head_dim] * 3,
)
)
self.split_size = self.num_heads * self.head_dim // tp_size
3. GQA QKV 投影(用于具有挑战性的 TP 配置)
class YourModelAttention(nn.Module, CustomModule):
def __init__(self, config, trn_config, layer_idx):
super().__init__()
# ... (same setup as above)
tp_size = get_tensor_model_parallel_size()
# Use GQA QKV when KV heads can't be evenly distributed across TP ranks
# This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
if self.qkv_linear:
# Calculate KV size multiplier to ensure even distribution
if trn_config.kv_size_multiplier is None:
self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
else:
self.kv_size_multiplier = trn_config.kv_size_multiplier
self.qkv_proj = GQAQKVColumnParallelLinear(
config.hidden_size,
[self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
bias=False,
gather_output=False,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
fuse_qkv=trn_config.fuse_qkv,
dtype=config.torch_dtype,
)
# Define transformation specs for GQA QKV
self.specs = ModelWeightTransformationSpecs()
self.specs.add_spec(
GQAQKVColumnParallelLinearSpec(
gqa_qkv_projection_name="qkv_proj",
query_projection_name="q_proj",
key_projection_name="k_proj",
value_projection_name="v_proj",
output_projection_name="o_proj",
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
kv_size_multiplier=self.kv_size_multiplier,
q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
fuse_qkv=trn_config.fuse_qkv,
)
)
何时使用每种变体
- 分离的 Q、K、V:默认方法,适用于所有配置,但效率可能较低
- 融合的 QKV:当 `num_heads == num_key_value_heads`(没有分组查询注意力)且 `fuse_qkv=True` 时使用
- GQA QKV:在使用分组查询注意力且 KV 头无法在 TP rank 之间均匀分布的具有挑战性的张量并行配置时需要
选择通常取决于:
tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv
步骤 3:实现主模型类
基础模型
class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
config_class = YourModelConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["YourModelDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
class YourModel(NeuronModelMixin, YourPreTrainedModel):
def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
YourPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.trn_config = trn_config
self.embed_tokens = ParallelEmbedding(...)
self.layers = nn.ModuleList([
YourModelDecoderLayer(config, trn_config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = YourModelRMSNorm(...)
self.post_init()
CausalLM 模型
class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
# Pipeline parallelism support
SUPPORTS_PIPELINE_PARALLELISM = True
PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
def __init__(self, config, trn_config):
super().__init__(config)
self.trn_config = trn_config
self.model = YourModel(config, trn_config)
self.vocab_size = config.vocab_size
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
dtype=config.torch_dtype,
)
self.post_init()
步骤 4:注册模型
更新 `optimum/neuron/models/training/__init__.py`
from .your_model import YourModelForCausalLM, YourModel
__all__ = [..., "YourModelForCausalLM", "YourModel"]
更新 `optimum/neuron/models/training/auto_models.py`
from .your_model.modeling_your_model import YourModelForCausalLM, YourModel
# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)
# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)
这里的 `"your_model"` 对应于模型配置类中的 `model_type` 属性。
最佳实践
1. 并行层配置
- 对中间层使用 `gather_output=False`
- 对接收并行输入的层设置 `input_is_parallel=True`
- 在所有层中一致地配置 `sequence_parallel_enabled`
- 使用适当的 `stride` 值以实现正确的权重分片
2. 权重转换规范
- 始终为使用融合或并行层的模块定义规范
- 对于任何具有转换规范的模块,使用 `CustomModule` mixin
- 确保规范参数名称与实际模块结构匹配
- 测试常规权重和 LoRA 权重转换
3. 流水线并行
- 对于支持的模型,设置 `SUPPORTS_PIPELINE_PARALLELISM = True`
- 将 `PIPELINE_TRANSFORMER_LAYER_CLS` 定义为你的解码器层类
- 在 `PIPELINE_INPUT_NAMES` 中列出所有输入名称
4. Flash Attention 支持
- 如果你的模型支持,设置 `_supports_flash_attn_2 = True`
- 实现 eager 和 flash attention 两种路径
- 使用适当的注意力函数分派
测试你的实现
`tests/training/` 中的训练测试提供了一个全面的测试框架,用于验证数值正确性、分布式训练场景和检查点兼容性。大多数测试并非为每个自定义模型实现而设计,而是为了验证 Optimum Neuron 训练基础设施的核心功能。考虑到这一点,以下是你需要为你的自定义模型实现的内容:
1. 自定义模型验证
`test_custom_modeling.py` 文件验证你的自定义实现是否与原始 Transformers 模型产生相同的输出。
更新 `tests/training/test_custom_modeling.py`
CUSTOM_MODELINGS_TO_TEST = [
# ... existing models ...
("YourModelForCausalLM", "your-org/your-model-name"),
]
重要提示:对于自定义模型验证测试,请使用小型/微型模型以确保 CI 效率。测试模型应具有:
- 小词汇量(例如,1000-8000 个词元)
- 少数层(例如,2-4 层)
- 小隐藏维度(例如,128-512)
- 最少的注意力头(例如,4-8 个头)
适合自定义模型验证的测试模型示例:
- `"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random"` - 4 层,4 个 KV 头
- `"michaelbenayoun/granite-tiny-4kv-heads-4layers-random"` - 微型 Granite 模型
- `"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"` - 微型 Qwen3 模型
你的模型必须通过的关键测试:
def test_custom_modeling_matches_original() # Output matching
- 数值正确性:确保自定义模型与 Transformers 的输出完全匹配
- 并行化支持:测试各种 QKV 实现(常规、融合、GQA)
2. 端到端训练验证
`test_overfit.py` 文件验证训练收敛性。要将你的模型包含在端到端训练验证中,你必须将其添加到参数化测试用例中。
更新 `tests/training/test_overfit.py`
@pytest.mark.parametrize(
"model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
[
# ... existing models ...
[
"YourModelForCausalLM",
"your-org/your-model-name",
1e-4,
0.03,
{},
True,
0.5,
2048,
50,
],
],
ids=[
# ... existing model IDs ...
"your-org/your-model-name",
],
)
此测试验证:
- 收敛性验证:确保模型可以在简单数据集上过拟合
你的模型将在以下方面进行测试:
def test_overfit_custom_modeling_causal_lm() # Basic training (your model included)
3. 自动模型加载
`test_modeling_auto.py` 文件验证你的模型可以使用 `NeuronModel` 和 `NeuronModelForCausalLM` 自动类加载。要将你的模型包含在这些测试中,你必须将其添加到测试用例中。
更新 `tests/training/test_modeling_auto.py`
@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
"your-org/your-model-name", # Add your model here
]:
# ... rest of test logic
@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
trn_config = TrainingNeuronConfig()
kwargs = {"torch_dtype": torch.bfloat16}
for model_name_or_path in [
"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
"michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
"your-org/your-model-name", # Add your model here
]:
# ... rest of test logic
此测试验证:
- 自动模型加载:测试 `NeuronModel.from_pretrained()` 和 `NeuronModel.from_config()` 是否正常工作
- 自动 CausalLM 加载:测试 `NeuronModelForCausalLM.from_pretrained()` 和 `NeuronModelForCausalLM.from_config()` 是否正常工作
4. 运行测试
测试需要 AWS Trainium 实例。运行特定的测试类别:
# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v
# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"
# Run end-to-end training validation
pytest tests/training/test_overfit.py -v
5. 测试要求
你的实现必须:
- 通过数值正确性测试,与原始 Transformers 实现对比
- 支持并行化策略(至少支持 DP 和 TP;推荐支持 PP)
- 处理各种 QKV 实现(常规、融合、GQA)
- 支持分布式训练的检查点合并
- 支持 LoRA 训练(如果适用)
- 通过过拟合测试证明收敛性
测试框架确保你的自定义模型与现有的 Optimum Neuron 训练基础设施保持兼容,同时提供预期的性能和正确性保证。
常见问题
- 权重形状不匹配:确保转换规范正确处理张量形状
- 流水线并行错误:检查所有必需的属性是否已设置
- 内存问题:考虑梯度检查点和激活重计算
- 注意力兼容性:验证注意力实现是否与你的模型架构兼容
其他资源
本指南为实现自定义模型提供了基础。有关完整示例和高级模式,请参考以下现有实现:
- LLaMA: `optimum/neuron/models/training/llama/modeling_llama.py` - 包含(常规、融合和 GQA 注意力)、融合 MLP 的完整实现
- Qwen3: `optimum/neuron/models/training/qwen3/modeling_qwen3.py` - 演示了如何调整 Llama 实现以适应 Qwen3 的 `q_norm` 和 `k_norm` 层
需要研究的关键文件:
- `optimum/neuron/models/training/modeling_utils.py` - `NeuronModelMixin` 基类
- `optimum/neuron/models/training/transformations_utils.py` - 权重转换规范
- `optimum/neuron/models/training/config.py` - 用于并行设置的 `TrainingNeuronConfig`