Transformers 文档

模块化Transformers

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

模块化Transformers

模块化Transformers通过允许导入和继承,降低了贡献模型的门槛,并显著减少了添加模型所需的代码量。

Transformers的核心设计特性之一是单模型、单文件策略。模型组件(如注意力层)在许多文件中重复,并且任何独立的实现都可能随着对代码特定部分的修复和更改而出现分歧。

# Copied from 语句可以防止代码出现分歧,并且它通过我们的持续集成测试和本地命令强制执行。缺点是这种方法很繁琐,并显著增加了代码行数,其中大部分是样板代码。

动机

模块化Transformers通过在模型文件夹中添加一个*模块化*文件来解决这些问题。模块化文件可以从其他模型导入代码并从其他类继承代码,这与传统的模型和处理文件不同。

模块化Transformers并非旨在取代建模代码,如果你的模型不是基于现有模型,你需要手动添加一个`modeling.py`文件。同样,如果配置、标记化或处理文件不能轻易地从类似文件中继承,你可以直接添加该文件。

模块化文件包含模型、处理器和配置类代码,这些代码在“单一模型,单一文件”策略下本应在单独的文件中。

模型用户仍然可以导入和使用他们已经熟悉的单文件接口。通过这样做,我们希望在坚持我们理念的同时,实现更简单的贡献。

创建一个 `modeling.py` 文件

一个Linter将模块化文件“展开”为`modeling.py`文件,以保留单一模型、单一文件目录结构(建模、处理器等)。继承被扁平化为仅**单一**级别。

运行以下命令以自动从模块化文件生成`modeling.py`文件。

python utils/modular_model_converter.py --files-to-parse src/transformers/models/<your_model>/modular_<your_model>.py

例如:

  • 如果一个配置类继承自另一个类,但又添加和删除了一个参数,那么如果添加了参数,生成的文件会直接引用它;如果删除了参数,则会完全删除它。
  • 如果一个类继承自另一个类,例如`GemmaModel(LlamaModel)`,则会自动推断依赖关系。所有子模块也会自动从超类中推断出来。
  • 如果在模块化文件中定义了新函数并在类中使用,Linter也会自动推断这些函数。

您应该能够在一个模块中编写所有内容(分词器、图像处理器、模型、配置等),然后生成对应的单一文件。

运行下面的命令,确保生成的内容与`modular_<your_model>.py`匹配。

python utils/check_modular_conversion.py --files src/transformers/models/<your_model>/modular_<your_model>.py

下面的例子演示了如何使用模块化Transformers,以显著减少代码行数的方式添加模型。

BERT 和 RoBERTa

BERT 和 RoBERTa 是两个非常相似的模型,它们唯一的区别在于嵌入层的实现方式。

与其完全重新定义模型,不如考虑下面所示的`modular_roberta.py`文件,它包含了建模和配置类(本例中未显示分词器)。

from torch import nn
from ..bert.configuration_bert import BertConfig
from ..bert.modeling_bert import (
    BertModel,
    BertEmbeddings,
    BertForMaskedLM
)

# RoBERTa and BERT config is identical
class RobertaConfig(BertConfig):
  model_type = 'roberta'

# Redefine the embeddings to highlight the padding id difference, and redefine the position embeddings
class RobertaEmbeddings(BertEmbeddings):
    def __init__(self, config):
        super().__init__(config())

        self.padding_idx = config.pad_token_id
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
        )

# RoBERTa and BERT model is identical except for the embedding layer, which is defined above, so no need for additional changes here
class RobertaModel(BertModel):
  def __init__(self, config):
    super().__init__(config)
    self.embeddings = RobertaEmbeddings(config)


# The model heads now only need to redefine the model inside to `RobertaModel`
class RobertaForMaskedLM(BertForMaskedLM):
  def __init__(self, config):
    super().__init__(config)
    self.model = RobertaModel(config)

如果您不使用定义的依赖项,您将收到以下错误。

ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used when you define `BertModel`, as it is one of it's direct dependencies. Make sure you use it in the `__init__` function.

实现模块化文件

最简单的开始方式是浏览Transformers中与您的模型相似的模型,以便从中继承。一些好的起点包括MistralQwen2CohereCohere以及Llama。请参阅下表,了解您的模型可能使用的组件以及可以从何处继承。

组件 模型
专家混合 SwitchTransformers 或 Mixtral
交错(和/或部分)旋转嵌入 GLM, Phi
状态空间模型 Jamba, Bamba, Zamba, Mamba2
循环隐藏状态 Gemma2
每层滑动窗口注意力/全注意力模式 Gemma2, Cohere2
QKV 裁剪 Olmo
QK 归一化 Olmo2, Cohere
融合 QKV (不推荐) Phi3

本节将引导您了解如何使用模块化 Transformers 实现 Olmo2,从 Olmo 开始(您可以参考原始的 modeling.py 文件)。

配置

模块化的`Olmo2Config`如下所示。

from ..olmo.configuration_olmo import OlmoConfig

class Olmo2Config(OlmoConfig):
    r"""
    This is the configuration class to store the configuration of a [Olmo2Model](/docs/transformers/main/en/model_doc/olmo2#transformers.Olmo2Model).
    """

    def __init__(
        self,
        vocab_size=50304,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        hidden_act="silu",
        max_position_embeddings=2048,
        initializer_range=0.02,
        use_cache=True,
        pad_token_id=1,
        bos_token_id=None,
        eos_token_id=50279,
        tie_word_embeddings=False,
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        rms_norm_eps=1e-5,
        **kwargs,
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            attention_bias=attention_bias,
            attention_dropout=attention_dropout,
            **kwargs,
        )

        self.rms_norm_eps = rms_norm_eps
        del self.clip_qkv

`Olmo2Config`与原始`OlmoConfig`有三个不同点。

  1. 大多数参数的默认值都已更改。
  2. 新增一个参数,`rms_norm_eps`。
  3. `clip_qkv`参数不再使用。

对于新的默认值和参数,用新的默认值覆盖 `__init__` 函数并添加 `rms_norm_eps`。在 `__init__` 函数体中将 `rms_norm_eps` 赋值给 `self`。对于 `clip_qkv` 参数,使用 `del self.clip_qkv` 删除在展开的代码中(经 linter 转换后)此属性的赋值。

请注意`super().__init__(...)`的使用方式。通常,它会调用父`__init__`。

但在模块化Transformers中,如果存在像`super().my_function(...)`这样的调用,linter会将父类中`my_function`的主体展开到`super().my_function(...)`调用发生的位置。`del self.clip_qkv`语句会删除展开主体中对`self.clip_qkv`的引用。

`del self.` 和 `super().my_function(..)` 协同工作,并且它应该始终放置在 `super().my_function(...)` 之后。您可以在调用 `super()` *之前*添加任何您想要的内容,它将放置在父主体之前。

范数

from ..llama.modeling_llama import LlamaRMSNorm

class Olmo2RMSNorm(LlamaRMSNorm):
    pass

`LlamaRMSNorm` 中无需修改。linter 会将 `LlamaRMSNorm` 的确切内容展开到 `Olmo2RMSNorm` 中。文档字符串、类型提示和注释中对 Llama 的引用也会更改为 Olmo2。

注意力

模块化`Olmo2Attention`如下所示。

from ..llama.modeling_llama import eager_attention_forward
from ..olmo.modeling_olmo import OlmoAttention, apply_rotary_pos_emb


# Olmo2 attention is identical to OLMo attention except:
# - Norm is applied to attention queries and keys.
# - No qkv clipping.
class Olmo2Attention(OlmoAttention):
    def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx=layer_idx)
        self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
        self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_norm(self.q_proj(hidden_states))
        key_states = self.k_norm(self.k_proj(hidden_states))
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(hidden_shape).transpose(1, 2)
        key_states = key_states.view(hidden_shape).transpose(1, 2)
        value_states = value_states.view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

`super().__init__(...)`复制了父定义并从`Olmo2RMSNorm`添加了2个新层。前向传播需要被覆盖以使用这2个新层。在用`q_proj`和`k_proj`进行投影之前,添加了一个带有范数层的通道。为了简化,`eager_attention_forward`函数直接从Llama导入,而`apply_rotary_pos_emb`从Olmo导入。

Linter通过从源文件复制它们的定义,自动将这些导入的函数添加到最终的`modeling_olmo2.py`文件中。`rotate_half`和`repeat_kv`函数也添加了,因为它们在`apply_rotary_pos_emb`和`eager_attention_forward`内部使用。

`Attention`类必须重新定义,因为没有任何现有模型包含`RMSNorm`层的`Attention`层。

解码器层

模块化的`DecoderLayer`如下所示。

from ..olmo.modeling_olmo import OlmoDecoderLayer

# The OLMo2 layers are identical to those of the OLMo model except:
# - RMSNorm is used instead of standard layer norm.
# - Norm is applied after attention/feedforward rather than before.
class Olmo2DecoderLayer(OlmoDecoderLayer):
    def __init__(self, config: Olmo2Config, layer_idx: int):
        super().__init__(config, layer_idx=layer_idx)
        self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
        del self.input_layernorm

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs

在调用`super().__init__(...)`之后,通过覆盖`self.post_attention_layernorm`来切换`__init__`中的范数类型。删除`self.input_layernorm`属性并将其替换为`self.post_feedforward_layernorm`,因为它在Olmo2中应用在后面。前向方法被覆盖以反映此更改。

如果你只将`self.post_feedforward_layernorm`和`self.input_layernorm`从`LayerNorm`切换到`RMSNorm`,而没有同时更改`self.input_layernorm`的名称和逻辑,那么你就不需要重写forward方法。

模型

模块化的`Olmo2Model`类如下所示。

from ..olmo.modeling_olmo import OlmoModel

# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of
# standard layer norm for the output norm.
class Olmo2Model(OlmoModel):
    def __init__(self, config: Olmo2Config):
        super().__init__(config)
        self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.layers = nn.ModuleList(
            [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

您只需将`self.norm`属性的*类型*更改为使用`RMSNorm`而不是`LayerNorm`。此更改不影响前向方法中的逻辑(层名称和用法与父类相同),因此您无需覆盖它。Linter会自动展开它。

模型头部

模块化因果建模头如下所示。

from ..olmo.modeling_olmo import OlmoForCausalLM

class Olmo2ForCausalLM(OlmoForCausalLM):
    pass

逻辑与`OlmoForCausalLM`相同,这意味着您无需在此处进行任何更改。

其他类

由linter生成的`modeling_olmo2.py`还包含一些在`modular_olmo2.py`中未明确定义的类(`Olmo2MLP`、`Olmo2RotaryEmbedding`、`Olmo2PreTrainedModel`)。

作为继承类的依赖项但未明确定义的类,会自动作为依赖项跟踪的一部分添加。这与某些函数添加到`Attention`类而无需直接导入它们的方式类似。

例如,`OlmoDecoderLayer`有一个属性定义为`self.mlp = OlmoMLP(config)`。这个类从未在`Olmo2MLP`中明确重新定义,因此linter会自动创建一个类似于`OlmoMLP`的`Olmo2MLP`类。如果它在`modular_olmo2.py`中明确写入,则与下面的代码相同。

from ..olmo.modeling_olmo import OlmoMLP

class Olmo2MLP(OlmoMLP):
    pass

然而,有必要重写`Olmo2RMSNorm`,因为在`Attention`和`DecoderLayer`类中需要重新定义层归一化。同样,这就是为什么您不需要创建`Olmo2PreTrainedModel`和`Olmo2RotaryEmbedding`类。

未重写的类将从继承模块首次使用它们的文件中复制。这意味着如果您希望`Olmo2MLP`继承自`MistralMLP`,则需要更明确,如下所示。

# switch to mistral definition
from ..mistral.modeling_mistral import MistralMLP

class Olmo2MLP(MistralMLP):
    pass

删除属性

您可以在使用`super().__init__()`之后使用`del`删除父类中定义的属性。但是,如果该属性也在其他地方使用,则不起作用,如下所示。它只抑制赋值。`self.attribute = config.attribute`行被删除,但`if`语句仍然存在并引用该属性。

class DummyModel(nn.Module):

  def __init__(self, config: DummyConfig):
    super().__init__()
    self.attribute = config.attribute
    if self.attribute:
      # do more stuff with `self.attribute` here
      ...

class MyNewDummyModel(DummyModel):

  def __init__(self, config: MyNewDummyConfig):
    super().__init__(config)
    del self.attribute

显式 `super()` 调用

如果您仍然想从 `DummyModel` 继承,但又不想删除 `self.attribute`,请明确指定您正在调用哪个类的 `super()`。下面的示例演示了如何调用 `nn.Module` 的 `super()`(展开的代码显示在右侧)

class MyNewDummyModel(DummyModel, nn.Module):        |     class MyNewDummyModel(nn.Module):
                                                     |
  def __init__(self, config: MyNewDummyConfig):      |       def __init__(self, config: MyNewDummyConfig):
    nn.Module.__init__(config)                       |         super().__init__()
    self.foo = config.foo                            |         self.foo = config.foo
    ...                                              |         ...

删除未使用的函数

通过将其覆盖为`raise AttributeError("")`语句来删除属性,以模仿您在Python中删除父函数时想要的行为。下面的示例删除了展开代码中的方法。

class GemmaTokenizer(LlamaTokenizer):
    ...

    def get_spm_processor(self):
        raise AttributeError("Not needed for Gemma")

    def unk_token_length(self):
        raise AttributeError("Not needed for Gemma")

定义新函数

默认情况下,如果您继承一个类并使用父方法中的一个或多个装饰器覆盖一个方法,则这些装饰器也会添加到展开的代码中*,仅当您没有自己添加任何装饰器时*。否则,将使用重新定义的装饰器。

例如,如果您有一个如下所示的父类并对其进行覆盖,则会保留父装饰器。

class DummyModel(nn.Module):
  ...

  @decorator(...)
  def forward(...)
    # do stuff here

模块化代码显示在左侧,展开的代码显示在右侧。

class NewModel(DummyModel):       |   class NewModel(nn.Module):
  ...                             |     ...
                                  |
  def forward(...):               |     @decorator(...)
    ...                           |     def forward(...):
                                  |       ...

但是,如果您添加一个新装饰器,则会使用您的新装饰器。

class NewModel(DummyModel):       |   class NewModel(nn.Module):
  ...                             |     ...
                                  |
  @my_new_decorator(...)          |     @my_new_decorator(...)
  def forward(...):               |     def forward(...):
    ...                           |       ...

super_kwargs

在某个前向方法很长且您想切换装饰器的情况下,您不需要重新定义所有内容并复制/粘贴该函数。您可以使用`super().forward(...)`来展开父方法体。当函数签名中有许多参数时,请在重写的签名中使用特殊的`**super_kwargs`语法。

此语法指示linter在此处展开所有父签名参数。下面是AutoModelForCausalLM模型中的一个示例签名,包含许多参数。

class LlamaForCausalLM(nn.Module):
  ...

  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  def forward(
      self,
      input_ids: torch.LongTensor = None,
      attention_mask: Optional[torch.Tensor] = None,
      position_ids: Optional[torch.LongTensor] = None,
      past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
      inputs_embeds: Optional[torch.FloatTensor] = None,
      labels: Optional[torch.LongTensor] = None,
      use_cache: Optional[bool] = None,
      output_attentions: Optional[bool] = None,
      output_hidden_states: Optional[bool] = None,
      return_dict: Optional[bool] = None,
      cache_position: Optional[torch.LongTensor] = None,
      num_logits_to_keep: int = 0,
      **kwargs: Unpack[KwargsForCausalLM],
  ) -> Union[Tuple, CausalLMOutputWithPast]:
    ...

与其重写并复制/粘贴所有这些参数,不如使用`super().forward(**super_kwargs)`语句(模块化代码显示在左侧,展开代码显示在右侧)。

class NewModelForCausalLM(LlamaForCausalLM):    |    class LlamaForCausalLM(nn.Module):
  ...                                           |      ...
                                                |
  @my_new_decorator                             |     @my_new_decorator
  def forward(self, **super_kwargs):            |     def forward(
    super().forward(**super_kwargs)             |         self,
                                                |         input_ids: torch.LongTensor = None,
                                                |         attention_mask: Optional[torch.Tensor] = None,
                                                |         position_ids: Optional[torch.LongTensor] = None,
                                                |         past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = |None,
                                                |         inputs_embeds: Optional[torch.FloatTensor] = None,
                                                |         labels: Optional[torch.LongTensor] = None,
                                                |         use_cache: Optional[bool] = None,
                                                |         output_attentions: Optional[bool] = None,
                                                |         output_hidden_states: Optional[bool] = None,
                                                |         return_dict: Optional[bool] = None,
                                                |         cache_position: Optional[torch.LongTensor] = None,
                                                |         num_logits_to_keep: int = 0,
                                                |         **kwargs: Unpack[KwargsForCausalLM],
                                                |     ) -> Union[Tuple, CausalLMOutputWithPast]:
                                                |       ...

这使得切换装饰器变得非常容易,并且明确表明您想要应用的唯一更改是装饰器。

然而,不应使用`**super_kwargs`来避免在重新定义方法时显得过于明确。如果您重写一个方法,您应该像往常一样明确编写签名。`**super_kwargs`语法是切换装饰器和其他一些特殊情况的快捷方式。

文档字符串变量

如果模块化文件和它继承的建模文件中都定义了对象,则模块化定义具有优先权,但包含`DOCSTRING`模式的赋值除外。这些变量通常用于建模文件中的`MODEL_START_DOCSTRING`和`MODEL_INPUT_DOCSTRING`。它们是大的文档字符串块,linter会在所有地方重写这些名称。因此,包含`DOCSTRING`变量的赋值可以使用源文件中找到的定义,而无需复制整个文档字符串,只需在模块化文件中将变量设置为`None`即可。

如果您需要在某个地方引用变量但又不想用总是相同的文档字符串来使模块化文件变得混乱,这非常有用。下面的示例代码允许您自动使用Mistral中与Starcoder2相同的文档字符串。

STARCODER2_INPUTS_DOCSTRING = None  # will be automatically redefined

class Starcoder2Model(MistralModel):
    ...

    @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
    def forward(...)
        ...

将变量设置为`None`以外的任何值都将覆盖文档字符串,以便您可以在需要时自定义文档字符串。

特殊命名

Linter在从类继承时会自动重命名所有内容。为了保持一致性,当从同一文件的不同类继承时,您应该始终使用相同的类名前缀。

不推荐以下示例。它违反了库中的标准,使用了`MyModelIncredibleMLP`而不是`LlamaMLP`,因为linter不知道如何重命名潜在的高阶依赖(`MyModelIncredible`或仅仅`MyModel`)。

class MyModelIncredibleMLP(LlamaMLP):
    ...

class MyModelDecoderLayer(LlamaDecoderLayer):
    ...

但是,如果没有隐式依赖项,则可以局部重命名单个类。请确保您仍然使用新的命名模式明确重新定义类的所有其他提及。例如,所有`LlamaMLP`的提及都应重命名为`MyModelIncredibleMLP`,否则linter可能会添加一个新的且不需要的`MyModelMLP`类。

如果检测到模糊情况,linter会发出警告。它会解释正在发生的事情以及默认用于获取依赖项的前缀。这些警告和重命名模式的复杂性通常只在定义多模态模型时出现。例如,在多模态模型中向类名添加`Text`以明确其指的是哪种模态。

We detected multiple prefix names when inheriting from transformers.models.llama.modeling_llama: ('Emu3Text', 'Emu3'). We will only use the most used 'Emu3' prefix when grabbing args and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different from 'Emu3') or use a single prefix in all the modular (best).

如果存在带有前缀的自动依赖项,但您想要另一个,请使用`pass`类在本地显式重命名类,如下所示。

class Emu3TextMLP(LlamaMLP):
    pass

配置文档字符串

当继承`Config`类或添加和删除属性时,您可能只想重新定义文档字符串中的新属性。但是,linter尚不支持此功能。您需要直接在类定义下的模块化文件中添加整个文档字符串。

< > 在 GitHub 上更新