AWS Trainium & Inferentia 文档

为训练贡献自定义模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

为训练贡献自定义模型

本指南介绍了如何向 `optimum/neuron/models/training/` 目录中添加自定义模型实现。需要自定义模型来支持 AWS Trainium 设备上的分布式训练功能,例如张量并行、流水线并行和序列并行。

架构组件

1. NeuronModelMixin

`NeuronModelMixin` 类提供了核心功能:

  • `from_pretrained()`:将常规的 Transformers 权重加载到自定义实现中
  • `save_pretrained()`:保存带有合并元数据的分片检查点
  • 通过 `PIPELINE_*` 属性支持流水线并行

2. 权重转换规范

转换规范处理权重在以下格式之间的转换:

  • 原始 Transformers 格式 → 自定义并行格式(加载期间)
  • 自定义并行格式 → 原始 Transformers 格式(检查点合并期间)

关键的转换规范类型:

  • `FusedLinearsSpec`:处理融合的线性层(例如 `gate_up_proj`)
  • `GQAQKVColumnParallelLinearSpec`:处理张量并行大小大于键值头数量时的分组查询注意力投影

有关所有转换规范和实用函数的完整 API 文档,请参阅模型权重转换规范 API 参考

3. 并行层

使用 `neuronx_distributed` 中的这些并行层:

  • `ColumnParallelLinear`:沿输出维度拆分权重矩阵
  • `RowParallelLinear`:沿输入维度拆分权重矩阵
  • `ParallelEmbedding`:在不同 rank 之间拆分嵌入表
  • `GQAQKVColumnParallelLinear`:专门用于张量并行大小大于键值头数量时的分组查询注意力投影

实现步骤

步骤 1:创建模型结构

创建一个新目录:`optimum/neuron/models/training/your_model/`

__init__.py

from .modeling_your_model import YourModelForCausalLM, YourModel

__all__ = ["YourModelForCausalLM", "YourModel"]

步骤 2:实现模型构建块

modeling_your_model.py

导入和依赖

import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig

from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
    CustomModule,
    FusedLinearsSpec,
    GQAQKVColumnParallelLinearSpec,
    ModelWeightTransformationSpecs,
)

嵌入层

class YourModelEmbeddings(nn.Module):
    def __init__(self, config, trn_config):
        super().__init__()
        self.embed_tokens = ParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            dtype=config.torch_dtype,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
        )

带融合线性层的 MLP 层

重要提示:任何具有转换规范的模块都必须继承自 `CustomModule`,以确保正确处理权重转换,并且转换规范必须定义在 `self.specs` 属性中。

class YourModelMLP(nn.Module, CustomModule):
    def __init__(self, config, trn_config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        # Fused gate and up projections
        self.gate_up_proj = ColumnParallelLinear(
            self.hidden_size,
            2 * self.intermediate_size,
            stride=2,  # Important for proper sharding
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # Define transformation specs
        self.specs = ModelWeightTransformationSpecs()
        self.specs.add_spec(
            FusedLinearsSpec(
                fused_linear_name="gate_up_proj",
                linear_names=["gate_proj", "up_proj"],
                bias=False,
                fuse_axis="column",  # Fuse along output dimension
                original_dims=[self.intermediate_size, self.intermediate_size],
            )
        )

注意力层

注意力层的实现取决于模型的架构和张量并行配置。主要有三种变体:

1. 分离的 Q、K、V 投影(默认)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // self.num_heads
        
        # Separate projections for Q, K, V
        self.q_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.k_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        self.v_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.torch_dtype,
        )
        
        # No transformation specs needed - regular parallel layers
        self.specs = ModelWeightTransformationSpecs()

2. 融合的 QKV 投影(多头注意力)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Only use fused QKV when num_heads == num_key_value_heads (no GQA)
        if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
            self.qkv_proj = ColumnParallelLinear(
                config.hidden_size,
                3 * self.num_heads * self.head_dim,  # Q + K + V
                stride=3,  # Important for proper sharding
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for fused QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                FusedLinearsSpec(
                    fused_linear_name="qkv_proj",
                    linear_names=["q_proj", "k_proj", "v_proj"],
                    bias=False,
                    fuse_axis="column",
                    original_dims=[self.num_heads * self.head_dim] * 3,
                )
            )
            self.split_size = self.num_heads * self.head_dim // tp_size

3. GQA QKV 投影(用于具有挑战性的 TP 配置)

class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Use GQA QKV when KV heads can't be evenly distributed across TP ranks
        # This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
        self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
        
        if self.qkv_linear:
            # Calculate KV size multiplier to ensure even distribution
            if trn_config.kv_size_multiplier is None:
                self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
            else:
                self.kv_size_multiplier = trn_config.kv_size_multiplier
                
            self.qkv_proj = GQAQKVColumnParallelLinear(
                config.hidden_size,
                [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                kv_size_multiplier=self.kv_size_multiplier,
                fuse_qkv=trn_config.fuse_qkv,
                dtype=config.torch_dtype,
            )
            
            # Define transformation specs for GQA QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                GQAQKVColumnParallelLinearSpec(
                    gqa_qkv_projection_name="qkv_proj",
                    query_projection_name="q_proj",
                    key_projection_name="k_proj", 
                    value_projection_name="v_proj",
                    output_projection_name="o_proj",
                    num_attention_heads=self.num_heads,
                    num_key_value_heads=self.num_key_value_heads,
                    kv_size_multiplier=self.kv_size_multiplier,
                    q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
                    kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
                    fuse_qkv=trn_config.fuse_qkv,
                )
            )

何时使用每种变体

  • 分离的 Q、K、V:默认方法,适用于所有配置,但效率可能较低
  • 融合的 QKV:当 `num_heads == num_key_value_heads`(没有分组查询注意力)且 `fuse_qkv=True` 时使用
  • GQA QKV:在使用分组查询注意力且 KV 头无法在 TP rank 之间均匀分布的具有挑战性的张量并行配置时需要

选择通常取决于:

tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv

步骤 3:实现主模型类

基础模型

class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
    config_class = YourModelConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["YourModelDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True


class YourModel(NeuronModelMixin, YourPreTrainedModel):
    def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
        YourPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.trn_config = trn_config
        
        self.embed_tokens = ParallelEmbedding(...)
        self.layers = nn.ModuleList([
            YourModelDecoderLayer(config, trn_config, layer_idx)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = YourModelRMSNorm(...)
        
        self.post_init()

CausalLM 模型

class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    
    # Pipeline parallelism support
    SUPPORTS_PIPELINE_PARALLELISM = True
    PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
    PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
    
    def __init__(self, config, trn_config):
        super().__init__(config)
        self.trn_config = trn_config
        self.model = YourModel(config, trn_config)
        self.vocab_size = config.vocab_size
        
        self.lm_head = ColumnParallelLinear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            gather_output=False,
            dtype=config.torch_dtype,
        )
        
        self.post_init()

步骤 4:注册模型

更新 `optimum/neuron/models/training/__init__.py`

from .your_model import YourModelForCausalLM, YourModel

__all__ = [..., "YourModelForCausalLM", "YourModel"]

更新 `optimum/neuron/models/training/auto_models.py`

from .your_model.modeling_your_model import YourModelForCausalLM, YourModel

# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)

# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)

这里的 `"your_model"` 对应于模型配置类中的 `model_type` 属性。

最佳实践

1. 并行层配置

  • 对中间层使用 `gather_output=False`
  • 对接收并行输入的层设置 `input_is_parallel=True`
  • 在所有层中一致地配置 `sequence_parallel_enabled`
  • 使用适当的 `stride` 值以实现正确的权重分片

2. 权重转换规范

  • 始终为使用融合或并行层的模块定义规范
  • 对于任何具有转换规范的模块,使用 `CustomModule` mixin
  • 确保规范参数名称与实际模块结构匹配
  • 测试常规权重和 LoRA 权重转换

3. 流水线并行

  • 对于支持的模型,设置 `SUPPORTS_PIPELINE_PARALLELISM = True`
  • 将 `PIPELINE_TRANSFORMER_LAYER_CLS` 定义为你的解码器层类
  • 在 `PIPELINE_INPUT_NAMES` 中列出所有输入名称

4. Flash Attention 支持

  • 如果你的模型支持,设置 `_supports_flash_attn_2 = True`
  • 实现 eager 和 flash attention 两种路径
  • 使用适当的注意力函数分派

测试你的实现

`tests/training/` 中的训练测试提供了一个全面的测试框架,用于验证数值正确性、分布式训练场景和检查点兼容性。大多数测试并非为每个自定义模型实现而设计,而是为了验证 Optimum Neuron 训练基础设施的核心功能。考虑到这一点,以下是你需要为你的自定义模型实现的内容:

1. 自定义模型验证

`test_custom_modeling.py` 文件验证你的自定义实现是否与原始 Transformers 模型产生相同的输出。

更新 `tests/training/test_custom_modeling.py`

CUSTOM_MODELINGS_TO_TEST = [
    # ... existing models ...
    ("YourModelForCausalLM", "your-org/your-model-name"),
]

重要提示:对于自定义模型验证测试,请使用小型/微型模型以确保 CI 效率。测试模型应具有:

  • 小词汇量(例如,1000-8000 个词元)
  • 少数层(例如,2-4 层)
  • 小隐藏维度(例如,128-512)
  • 最少的注意力头(例如,4-8 个头)

适合自定义模型验证的测试模型示例:

  • `"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random"` - 4 层,4 个 KV 头
  • `"michaelbenayoun/granite-tiny-4kv-heads-4layers-random"` - 微型 Granite 模型
  • `"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"` - 微型 Qwen3 模型

你的模型必须通过的关键测试:

def test_custom_modeling_matches_original()  # Output matching
  • 数值正确性:确保自定义模型与 Transformers 的输出完全匹配
  • 并行化支持:测试各种 QKV 实现(常规、融合、GQA)

2. 端到端训练验证

`test_overfit.py` 文件验证训练收敛性。要将你的模型包含在端到端训练验证中,你必须将其添加到参数化测试用例中。

更新 `tests/training/test_overfit.py`

@pytest.mark.parametrize(
    "model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
    [
        # ... existing models ...
        [
            "YourModelForCausalLM",
            "your-org/your-model-name",
            1e-4,
            0.03,
            {},
            True,
            0.5,
            2048,
            50,
        ],
    ],
    ids=[
        # ... existing model IDs ...
        "your-org/your-model-name",
    ],
)

此测试验证:

  • 收敛性验证:确保模型可以在简单数据集上过拟合

你的模型将在以下方面进行测试:

def test_overfit_custom_modeling_causal_lm()       # Basic training (your model included)

3. 自动模型加载

`test_modeling_auto.py` 文件验证你的模型可以使用 `NeuronModel` 和 `NeuronModelForCausalLM` 自动类加载。要将你的模型包含在这些测试中,你必须将其添加到测试用例中。

更新 `tests/training/test_modeling_auto.py`

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", 
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"torch_dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random", 
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

此测试验证:

  • 自动模型加载:测试 `NeuronModel.from_pretrained()` 和 `NeuronModel.from_config()` 是否正常工作
  • 自动 CausalLM 加载:测试 `NeuronModelForCausalLM.from_pretrained()` 和 `NeuronModelForCausalLM.from_config()` 是否正常工作

4. 运行测试

测试需要 AWS Trainium 实例。运行特定的测试类别:

# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v

# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"

# Run end-to-end training validation
pytest tests/training/test_overfit.py -v

5. 测试要求

你的实现必须:

  1. 通过数值正确性测试,与原始 Transformers 实现对比
  2. 支持并行化策略(至少支持 DP 和 TP;推荐支持 PP)
  3. 处理各种 QKV 实现(常规、融合、GQA)
  4. 支持分布式训练的检查点合并
  5. 支持 LoRA 训练(如果适用)
  6. 通过过拟合测试证明收敛性

测试框架确保你的自定义模型与现有的 Optimum Neuron 训练基础设施保持兼容,同时提供预期的性能和正确性保证。

常见问题

  • 权重形状不匹配:确保转换规范正确处理张量形状
  • 流水线并行错误:检查所有必需的属性是否已设置
  • 内存问题:考虑梯度检查点和激活重计算
  • 注意力兼容性:验证注意力实现是否与你的模型架构兼容

其他资源

本指南为实现自定义模型提供了基础。有关完整示例和高级模式,请参考以下现有实现:

  • LLaMA: `optimum/neuron/models/training/llama/modeling_llama.py` - 包含(常规、融合和 GQA 注意力)、融合 MLP 的完整实现
  • Qwen3: `optimum/neuron/models/training/qwen3/modeling_qwen3.py` - 演示了如何调整 Llama 实现以适应 Qwen3 的 `q_norm` 和 `k_norm` 层

需要研究的关键文件:

  • `optimum/neuron/models/training/modeling_utils.py` - `NeuronModelMixin` 基类
  • `optimum/neuron/models/training/transformations_utils.py` - 权重转换规范
  • `optimum/neuron/models/training/config.py` - 用于并行设置的 `TrainingNeuronConfig`