Transformers 文档
Encoder Decoder Models
并访问增强的文档体验
开始使用
Encoder Decoder Models
概述
EncoderDecoderModel 可以用于初始化序列到序列模型,其中编码器可以是任何预训练的自编码模型,解码器可以是任何预训练的自回归模型。
Sascha Rothe、Shashi Narayan、Aliaksei Severyn 在 Leveraging Pre-trained Checkpoints for Sequence Generation Tasks 中展示了使用预训练检查点初始化序列到序列模型对于序列生成任务的有效性。
在 EncoderDecoderModel 经过训练/微调后,可以像任何其他模型一样保存/加载(有关更多信息,请参见示例)。
此架构的一个应用可能是利用两个预训练的 BertModel 作为编码器和解码器来构建摘要模型,如 Yang Liu 和 Mirella Lapata 在 Text Summarization with Pretrained Encoders 中所示。
从模型配置随机初始化 EncoderDecoderModel。
EncoderDecoderModel 可以从编码器和解码器配置中随机初始化。在以下示例中,我们展示了如何使用默认的 BertModel 配置作为编码器,以及默认的 BertForCausalLM
配置作为解码器来完成此操作。
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()
>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> model = EncoderDecoderModel(config=config)
从预训练的编码器和预训练的解码器初始化 EncoderDecoderModel。
EncoderDecoderModel 可以从预训练的编码器检查点和预训练的解码器检查点初始化。请注意,任何预训练的自编码模型,例如 BERT,都可以用作编码器,而预训练的自编码模型(例如 BERT)、预训练的因果语言模型(例如 GPT2)以及序列到序列模型的预训练解码器部分(例如 BART 的解码器)都可以用作解码器。根据您选择作为解码器的架构,交叉注意力层可能会随机初始化。从预训练的编码器和解码器检查点初始化 EncoderDecoderModel 需要在下游任务上对模型进行微调,正如 Warm-starting-encoder-decoder 博客文章 中所示。为此,EncoderDecoderModel
类提供了一个 EncoderDecoderModel.from_encoder_decoder_pretrained() 方法。
>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
加载现有的 EncoderDecoderModel 检查点并执行推理。
为了加载 EncoderDecoderModel
类的微调检查点,EncoderDecoderModel 提供了 from_pretrained(...)
方法,就像 Transformers 中的任何其他模型架构一样。
要执行推理,可以使用 generate
方法,该方法允许自回归地生成文本。此方法支持各种形式的解码,例如贪婪解码、束搜索和多项式采样。
>>> from transformers import AutoTokenizer, EncoderDecoderModel
>>> # load a fine-tuned seq2seq model and corresponding tokenizer
>>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
>>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")
>>> # let's perform inference on a long piece of text
>>> ARTICLE_TO_SUMMARIZE = (
... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
... )
>>> input_ids = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors="pt").input_ids
>>> # autoregressively generate summary (uses greedy decoding by default)
>>> generated_ids = model.generate(input_ids)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
nearly 800 thousand customers were affected by the shutoffs. the aim is to reduce the risk of wildfires. nearly 800, 000 customers were expected to be affected by high winds amid dry conditions. pg & e said it scheduled the blackouts to last through at least midday tomorrow.
将 PyTorch 检查点加载到 TFEncoderDecoderModel 中。
TFEncoderDecoderModel.from_pretrained() 当前不支持从 PyTorch 检查点初始化模型。将 from_pt=True
传递给此方法将引发异常。如果对于特定的 encoder-decoder 模型只有 PyTorch 检查点,则一种解决方法是
>>> # a workaround to load from pytorch checkpoint
>>> from transformers import EncoderDecoderModel, TFEncoderDecoderModel
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
训练
一旦创建了模型,就可以像 BART、T5 或任何其他 encoder-decoder 模型一样对其进行微调。如您所见,模型只需要 2 个输入即可计算损失:input_ids
(编码输入序列的 input_ids
)和 labels
(编码目标序列的 input_ids
)。
>>> from transformers import BertTokenizer, EncoderDecoderModel
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> input_ids = tokenizer(
... "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side.During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft).Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.",
... return_tensors="pt",
... ).input_ids
>>> labels = tokenizer(
... "the eiffel tower surpassed the washington monument to become the tallest structure in the world. it was the first structure to reach a height of 300 metres in paris in 1930. it is now taller than the chrysler building by 5. 2 metres ( 17 ft ) and is the second tallest free - standing structure in paris.",
... return_tensors="pt",
... ).input_ids
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_ids=input_ids, labels=labels).loss
有关训练的详细 colab。
此模型由 thomwolf 贡献。此模型的 TensorFlow 和 Flax 版本由 ydshieh 贡献。
EncoderDecoderConfig
class transformers.EncoderDecoderConfig
< source >( **kwargs )
参数
- kwargs (可选) — 关键字参数字典。 特别是:
- encoder (PretrainedConfig, 可选) — 定义编码器配置的配置对象实例。
- decoder (PretrainedConfig, 可选) — 定义解码器配置的配置对象实例。
EncoderDecoderConfig 是用于存储 EncoderDecoderModel 配置的配置类。 它用于根据指定的参数实例化 Encoder Decoder 模型,定义编码器和解码器配置。
配置对象继承自 PretrainedConfig,可用于控制模型输出。 有关更多信息,请阅读 PretrainedConfig 的文档。
示例
>>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
>>> # Initializing a BERT google-bert/bert-base-uncased style configuration
>>> config_encoder = BertConfig()
>>> config_decoder = BertConfig()
>>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations
>>> model = EncoderDecoderModel(config=config)
>>> # Accessing the model configuration
>>> config_encoder = model.config.encoder
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True
>>> # Saving the model, including its configuration
>>> model.save_pretrained("my-model")
>>> # loading model and config from pretrained folder
>>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model")
>>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
from_encoder_decoder_configs
< source >( encoder_config: PretrainedConfig decoder_config: PretrainedConfig **kwargs ) → EncoderDecoderConfig
从预训练的编码器模型配置和解码器模型配置实例化 EncoderDecoderConfig(或派生类)。
EncoderDecoderModel
class transformers.EncoderDecoderModel
< source >( config: typing.Optional[transformers.configuration_utils.PretrainedConfig] = None encoder: typing.Optional[transformers.modeling_utils.PreTrainedModel] = None decoder: typing.Optional[transformers.modeling_utils.PreTrainedModel] = None )
参数
- config (EncoderDecoderConfig) — 具有模型所有参数的模型配置类。 使用配置文件初始化不会加载与模型关联的权重,只会加载配置。 查看 from_pretrained() 方法以加载模型权重。
此类可用于使用任何预训练的自动编码模型作为编码器和任何预训练的自回归模型作为解码器来初始化序列到序列模型。 编码器通过 from_pretrained() 函数加载,解码器通过 from_pretrained() 函数加载。 交叉注意力层会自动添加到解码器,应在下游生成任务(如摘要)上进行微调。
Sascha Rothe、Shashi Narayan、Aliaksei Severyn. Michael Matena、Yanqi Zhou、Wei Li、Peter J. Liu 在 Leveraging Pre-trained Checkpoints for Sequence Generation Tasks 中展示了使用预训练检查点初始化序列到序列模型以进行序列生成任务的有效性。
在训练/微调这样的 Encoder Decoder 模型后,可以像任何其他模型一样保存/加载它(有关更多信息,请参见示例)。
此模型继承自 PreTrainedModel。 查看超类文档,了解库为所有模型实现的通用方法(例如下载或保存、调整输入嵌入大小、剪枝头等)。
此模型也是 PyTorch torch.nn.Module 子类。 将其用作常规 PyTorch 模块,并参阅 PyTorch 文档,了解与常规用法和行为相关的所有事项。
EncoderDecoderModel 是一个通用模型类,当使用编码器的 :meth~transformers.AutoModel.from_pretrained 类方法和解码器的 :meth~transformers.AutoModelForCausalLM.from_pretrained 类方法创建时,它将实例化为一个 transformer 架构,其中库的一个基本模型类作为编码器,另一个作为解码器。
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None decoder_input_ids: typing.Optional[torch.LongTensor] = None decoder_attention_mask: typing.Optional[torch.BoolTensor] = None encoder_outputs: typing.Optional[typing.Tuple[torch.FloatTensor]] = None past_key_values: typing.Tuple[typing.Tuple[torch.FloatTensor]] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None decoder_inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None **kwargs ) → transformers.modeling_outputs.Seq2SeqLMOutput 或 tuple(torch.FloatTensor)
参数
- input_ids (
torch.LongTensor
,形状为(batch_size, sequence_length)
) — 词汇表中输入序列标记的索引。可以使用 PreTrainedTokenizer 获取索引。 有关详细信息,请参见 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
- attention_mask (
torch.FloatTensor
,形状为(batch_size, sequence_length)
,可选) — 用于避免在 padding 标记索引上执行注意力的掩码。 在[0, 1]
中选择的掩码值:- 1 表示未掩码的标记,
- 0 表示掩码的标记。
- decoder_input_ids (
torch.LongTensor
,形状为(batch_size, target_sequence_length)
,可选) — 词汇表中解码器输入序列标记的索引。可以使用 PreTrainedTokenizer 获取索引。 有关详细信息,请参见 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
如果使用
past_key_values
,则可以选择仅输入最后一个decoder_input_ids
(请参见past_key_values
)。对于训练,
decoder_input_ids
由模型自动创建,方法是将labels
向右移动,将 -100 替换为pad_token_id
,并在其前面加上decoder_start_token_id
。 - decoder_attention_mask (
torch.BoolTensor
,形状为(batch_size, target_sequence_length)
,可选) — 默认行为:生成一个张量,该张量忽略decoder_input_ids
中的 padding 标记。 默认情况下,也会使用因果掩码。 - encoder_outputs (
tuple(torch.FloatTensor)
,可选) — 此元组必须由 (last_hidden_state
, 可选:hidden_states
, 可选:attentions
) 组成。last_hidden_state
(torch.FloatTensor
,形状为(batch_size, sequence_length, hidden_size)
) 是编码器最后一层输出的隐藏状态张量。 在解码器的交叉注意力中使用。 - past_key_values (
tuple(tuple(torch.FloatTensor))
,长度为config.n_layers
,每个元组有 4 个形状为(batch_size, num_heads, sequence_length - 1, embed_size_per_head)
的张量) — 包含注意力块的预计算键和值隐藏状态。 可用于加速解码。如果使用
past_key_values
,则用户可以选择仅输入最后一个decoder_input_ids
(那些没有将其过去的键值状态提供给此模型的)形状为(batch_size, 1)
,而不是所有形状为(batch_size, sequence_length)
的decoder_input_ids
。 - inputs_embeds (
torch.FloatTensor
,形状为(batch_size, sequence_length, hidden_size)
,可选) — (可选)您可以选择直接传递嵌入表示,而不是传递input_ids
。 如果您希望比模型的内部嵌入查找矩阵更好地控制如何将input_ids
索引转换为关联的向量,这将非常有用。 - decoder_inputs_embeds (
torch.FloatTensor
,形状为(batch_size, target_sequence_length, hidden_size)
,可选) — (可选) 您可以选择直接传递嵌入表示,而不是传递decoder_input_ids
。如果您希望比模型的内部嵌入查找矩阵更精细地控制如何将decoder_input_ids
索引转换为关联的向量,这将非常有用。 - labels (
torch.LongTensor
,形状为(batch_size, sequence_length)
,可选) — 用于计算解码器的掩码语言建模损失的标签。索引应在[-100, 0, ..., config.vocab_size]
中(请参阅input_ids
文档字符串)。索引设置为-100
的标记将被忽略(掩码),损失仅针对标签在[0, ..., config.vocab_size]
中的标记计算。 - use_cache (
bool
,可选) — 如果设置为True
,则返回past_key_values
键值状态,并可用于加速解码(请参阅past_key_values
)。 - output_attentions (
bool
,可选) — 是否返回所有注意力层的注意力张量。有关更多详细信息,请参阅返回张量下的attentions
。 - output_hidden_states (
bool
,可选) — 是否返回所有层的隐藏状态。有关更多详细信息,请参阅返回张量下的hidden_states
。 - return_dict (
bool
,可选) — 如果设置为True
,模型将返回一个~utils.Seq2SeqLMOutput
而不是纯元组。 - kwargs (可选) — 剩余的关键字参数字典。关键字参数有两种形式:
- 不带前缀的关键字参数将作为
**encoder_kwargs
输入到编码器前向函数。 - 带有 decoder_ 前缀的关键字参数将作为
**decoder_kwargs
输入到解码器前向函数。
- 不带前缀的关键字参数将作为
返回
transformers.modeling_outputs.Seq2SeqLMOutput 或 tuple(torch.FloatTensor)
一个 transformers.modeling_outputs.Seq2SeqLMOutput 或一个 torch.FloatTensor
元组(如果传递了 return_dict=False
或当 config.return_dict=False
时),其中包含各种元素,具体取决于配置 (EncoderDecoderConfig) 和输入。
-
loss (
torch.FloatTensor
,形状为(1,)
,可选,当提供labels
时返回) — 语言建模损失。 -
logits (
torch.FloatTensor
,形状为(batch_size, sequence_length, config.vocab_size)
) — 语言建模头的预测分数(SoftMax 之前的每个词汇表标记的分数)。 -
past_key_values (
tuple(tuple(torch.FloatTensor))
,可选,当传递use_cache=True
或当config.use_cache=True
时返回) — 长度为config.n_layers
的tuple(tuple(torch.FloatTensor))
元组,其中每个元组包含 2 个形状为(batch_size, num_heads, sequence_length, embed_size_per_head)
的张量和 2 个形状为(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
的额外张量。包含预先计算的隐藏状态(自注意力模块和交叉注意力模块中的键和值),可用于加速顺序解码(请参阅
past_key_values
输入)。 -
decoder_hidden_states (
tuple(torch.FloatTensor)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —torch.FloatTensor
元组(如果模型具有嵌入层,则为嵌入输出之一,加上每层输出之一),形状为(batch_size, sequence_length, hidden_size)
。解码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
decoder_attentions (
tuple(torch.FloatTensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —torch.FloatTensor
元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
-
cross_attentions (
tuple(torch.FloatTensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —torch.FloatTensor
元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的交叉注意力层的注意力权重,在注意力 softmax 之后,用于计算交叉注意力头中的加权平均值。
-
encoder_last_hidden_state (
torch.FloatTensor
,形状为(batch_size, sequence_length, hidden_size)
,可选) — 模型编码器最后一层输出处的隐藏状态序列。 -
encoder_hidden_states (
tuple(torch.FloatTensor)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —torch.FloatTensor
元组(如果模型具有嵌入层,则为嵌入输出之一,加上每层输出之一),形状为(batch_size, sequence_length, hidden_size)
。编码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
encoder_attentions (
tuple(torch.FloatTensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —torch.FloatTensor
元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。编码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
EncoderDecoderModel 前向方法,覆盖了 __call__
特殊方法。
虽然前向传递的配方需要在该函数中定义,但应在此之后调用 Module
实例,而不是调用此函数,因为前者负责运行预处理和后处理步骤,而后者会静默地忽略它们。
示例
>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased"
... ) # initialize Bert2Bert from pre-trained checkpoints
>>> # training
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> model.config.vocab_size = model.config.decoder.vocab_size
>>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
>>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("bert2bert")
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
>>> # generation
>>> generated = model.generate(input_ids)
from_encoder_decoder_pretrained
< source >( encoder_pretrained_model_name_or_path: str = None decoder_pretrained_model_name_or_path: str = None *model_args **kwargs )
参数
- encoder_pretrained_model_name_or_path (
str
,可选) — 用于初始化编码器的必要信息。可以是:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- decoder_pretrained_model_name_or_path (
str
,可选,默认为None
) — 用于初始化解码器的必要信息。可以是:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (剩余的位置参数,可选) — 所有剩余的位置参数将传递给底层模型的
__init__
方法。 - kwargs (剩余的关键字参数字典,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。- 要更新编码器配置,请为每个配置参数使用前缀 encoder_。
- 要更新解码器配置,请为每个配置参数使用前缀 decoder_。
- 要更新父模型配置,请不要为每个配置参数使用前缀。
行为取决于是否提供
config
或自动加载config
。
从库的一个或两个基类实例化编码器和解码器,从预训练模型检查点实例化。
默认情况下,模型使用 model.eval()
设置为评估模式(Dropout 模块被禁用)。要训练模型,您需要先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import EncoderDecoderModel
>>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert2bert")
>>> # load fine-tuned model
>>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
TFEncoderDecoderModel
class transformers.TFEncoderDecoderModel
< source >( config: Optional[PretrainedConfig] = None encoder: Optional[TFPreTrainedModel] = None decoder: Optional[TFPreTrainedModel] = None )
参数
- config (EncoderDecoderConfig) — 模型配置类,包含模型的所有参数。使用配置文件初始化不会加载与模型关联的权重,仅加载配置。查看 from_pretrained() 方法以加载模型权重。
此类可用于使用任何预训练的自动编码模型作为编码器和任何预训练的自回归模型作为解码器来初始化序列到序列模型。 编码器通过 from_pretrained() 函数加载,解码器通过 from_pretrained() 函数加载。 交叉注意力层会自动添加到解码器,应在下游生成任务(如摘要)上进行微调。
Sascha Rothe、Shashi Narayan、Aliaksei Severyn. Michael Matena、Yanqi Zhou、Wei Li、Peter J. Liu 在 Leveraging Pre-trained Checkpoints for Sequence Generation Tasks 中展示了使用预训练检查点初始化序列到序列模型以进行序列生成任务的有效性。
在训练/微调这样的 Encoder Decoder 模型后,可以像任何其他模型一样保存/加载它(有关更多信息,请参见示例)。
此模型继承自 TFPreTrainedModel。查看超类文档,了解库为其所有模型实现的通用方法(例如,下载或保存、调整输入嵌入大小、剪枝头等)。
此模型也是 keras.Model 子类。将其用作常规 TF 2.0 Keras 模型,并参阅 TF 2.0 文档,了解与常规用法和行为相关的所有事项。
TFEncoderDecoderModel 是一个通用模型类,当使用编码器的 from_pretrained() 类方法和解码器的 from_pretrained() 类方法创建时,它将实例化为一个 Transformer 架构,其中库的一个基模型类作为编码器,另一个作为解码器。
call
< source >( input_ids: TFModelInputType | None = None attention_mask: np.ndarray | tf.Tensor | None = None decoder_input_ids: np.ndarray | tf.Tensor | None = None decoder_attention_mask: np.ndarray | tf.Tensor | None = None encoder_outputs: np.ndarray | tf.Tensor | None = None past_key_values: Tuple[Tuple[tf.Tensor]] | None = None inputs_embeds: np.ndarray | tf.Tensor | None = None decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None labels: np.ndarray | tf.Tensor | None = None use_cache: Optional[bool] = None output_attentions: Optional[bool] = None output_hidden_states: Optional[bool] = None return_dict: Optional[bool] = None training: bool = False **kwargs ) → transformers.modeling_tf_outputs.TFSeq2SeqLMOutput 或 tuple(tf.Tensor)
参数
- input_ids (
np.ndarray
,tf.Tensor
,List[tf.Tensor]
、Dict[str, tf.Tensor]
或Dict[str, np.ndarray]
,并且每个示例都必须具有形状(batch_size, sequence_length)
) — 词汇表中输入序列标记的索引。可以使用 PreTrainedTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
- attention_mask (
np.ndarray
或tf.Tensor
,形状为(batch_size, sequence_length)
,可选) — 用于避免对填充标记索引执行注意力的掩码。掩码值在[0, 1]
中选择:- 1 表示未掩码的标记,
- 0 表示已掩码的标记。
- decoder_input_ids (
np.ndarray
或tf.Tensor
,形状为(batch_size, target_sequence_length)
,可选) — 词汇表中解码器输入序列标记的索引。可以使用 PreTrainedTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
如果使用
past_key_values
,则可以选择仅输入最后一个decoder_input_ids
(请参阅past_key_values
)。为解码器提供序列到序列训练。可以使用 PreTrainedTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
- decoder_attention_mask (
np.ndarray
或tf.Tensor
,形状为(batch_size, target_sequence_length)
,可选) — 默认行为:生成一个张量,忽略decoder_input_ids
中的填充 token。默认情况下也会使用因果掩码。 - encoder_outputs (
tuple(tuple(tf.Tensor)
, 可选) — 此元组必须包含 (last_hidden_state
, 可选:hidden_states
, 可选:attentions
)last_hidden_state
(tf.Tensor
,形状为(batch_size, sequence_length, hidden_size)
) 是编码器最后一层的输出处的隐藏状态张量。在解码器的交叉注意力中使用。 - past_key_values (
tuple(tuple(tf.Tensor))
,长度为config.n_layers
,每个元组有 4 个形状为(batch_size, num_heads, sequence_length - 1, embed_size_per_head)
的张量) — 包含注意力块的预计算的键和值隐藏状态。可用于加速解码。如果使用
past_key_values
,用户可以选择仅输入最后一次的decoder_input_ids
(那些没有将其过去的键值状态提供给此模型的),形状为(batch_size, 1)
,而不是所有形状为(batch_size, sequence_length)
的decoder_input_ids
。 - inputs_embeds (
np.ndarray
或tf.Tensor
,形状为(batch_size, sequence_length, hidden_size)
,可选) — 可选地,您可以选择直接传递嵌入表示,而不是传递input_ids
。如果您希望比模型的内部嵌入查找矩阵更精细地控制如何将input_ids
索引转换为关联的向量,这将非常有用。 - decoder_inputs_embeds (
np.ndarray
或tf.Tensor
,形状为(batch_size, target_sequence_length, hidden_size)
,可选) — 可选地,您可以选择直接传递嵌入表示,而不是传递decoder_input_ids
。如果您希望比模型的内部嵌入查找矩阵更精细地控制如何将decoder_input_ids
索引转换为关联的向量,这将非常有用。 - labels (
np.ndarray
或tf.Tensor
,形状为(batch_size, sequence_length)
,可选) — 用于计算解码器的掩码语言建模损失的标签。索引应在[-100, 0, ..., config.vocab_size]
中(请参阅input_ids
文档字符串)。索引设置为-100
的 token 将被忽略(掩码),损失仅针对标签在[0, ..., config.vocab_size]
中的 token 计算。 - use_cache (
bool
,可选) — 如果设置为True
,则返回past_key_values
键值状态,并可用于加速解码(请参阅past_key_values
)。 - output_attentions (
bool
,可选) — 是否返回所有注意力层的注意力张量。 有关更多详细信息,请参阅返回张量下的attentions
。 - output_hidden_states (
bool
,可选) — 是否返回所有层的隐藏状态。 有关更多详细信息,请参阅返回张量下的hidden_states
。 - return_dict (
bool
,可选) — 如果设置为True
,模型将返回一个~utils.Seq2SeqLMOutput
而不是纯元组。 - training (
bool
,可选,默认为False
) — 是否在训练模式下使用模型(某些模块,如 dropout 模块,在训练和评估之间具有不同的行为)。 - kwargs (可选) — 剩余的关键字参数字典。关键字参数有两种形式:
- 没有前缀的关键字参数将作为
**encoder_kwargs
输入到编码器前向函数。 - 带有 decoder_ 前缀的关键字参数将作为
**decoder_kwargs“
输入到解码器前向函数。
- 没有前缀的关键字参数将作为
返回
transformers.modeling_tf_outputs.TFSeq2SeqLMOutput 或 tuple(tf.Tensor)
一个 transformers.modeling_tf_outputs.TFSeq2SeqLMOutput 或 tf.Tensor
的元组 (如果传递了 return_dict=False
或当 config.return_dict=False
时),包含各种元素,具体取决于配置 (EncoderDecoderConfig) 和输入。
-
loss (
tf.Tensor
,形状为(n,)
,可选,当提供labels
时返回) — 语言建模损失。 -
logits (
tf.Tensor
,形状为(batch_size, sequence_length, config.vocab_size)
) — 语言建模头的预测分数(SoftMax 之前每个词汇 token 的分数)。 -
past_key_values (
List[tf.Tensor]
,可选,当传递use_cache=True
或当config.use_cache=True
时返回) — 长度为config.n_layers
的tf.Tensor
列表,每个张量的形状为(2, batch_size, num_heads, sequence_length, embed_size_per_head)
)。包含解码器的预计算隐藏状态(注意力块中的键和值),可用于(请参阅
past_key_values
输入)加速顺序解码。 -
decoder_hidden_states (
tuple(tf.Tensor)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —tf.Tensor
的元组(每个嵌入输出 + 每个层的输出一个),形状为(batch_size, sequence_length, hidden_size)
。解码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
decoder_attentions (
tuple(tf.Tensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —tf.Tensor
的元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
-
cross_attentions (
tuple(tf.Tensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —tf.Tensor
的元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的交叉注意力层的注意力权重,在注意力 softmax 之后,用于计算交叉注意力头中的加权平均值。
-
encoder_last_hidden_state (
tf.Tensor
,形状为(batch_size, sequence_length, hidden_size)
,可选) — 模型编码器最后一层输出处的隐藏状态序列。 -
encoder_hidden_states (
tuple(tf.Tensor)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —tf.Tensor
的元组(每个嵌入输出 + 每个层的输出一个),形状为(batch_size, sequence_length, hidden_size)
。编码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
encoder_attentions (
tuple(tf.Tensor)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —tf.Tensor
的元组(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。编码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
TFEncoderDecoderModel 前向方法,覆盖了 __call__
特殊方法。
虽然前向传递的配方需要在该函数中定义,但应在此之后调用 Module
实例,而不是调用此函数,因为前者负责运行预处理和后处理步骤,而后者会静默地忽略它们。
示例
>>> from transformers import TFEncoderDecoderModel, BertTokenizer
>>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
>>> # forward
>>> input_ids = tokenizer.encode(
... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf"
... ) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("bert2gpt2")
>>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2")
>>> # generation
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id)
from_encoder_decoder_pretrained
< source >( encoder_pretrained_model_name_or_path: str = None decoder_pretrained_model_name_or_path: str = None *model_args **kwargs )
参数
- encoder_pretrained_model_name_or_path (
str
,可选) — 初始化编码器所需的信息。可以是:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - pytorch 索引检查点文件 的路径或 URL(例如,
./pt_model/
)。在这种情况下,应将encoder_from_pt
设置为True
。
- decoder_pretrained_model_name_or_path (
str
,可选,默认为None
) — 初始化解码器所需的信息。可以是:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - pytorch 检查点文件 的路径或 URL(例如,
./pt_model/
)。在这种情况下,应将decoder_from_pt
设置为True
。
- model_args (剩余的位置参数,可选) — 所有剩余的位置参数将传递给底层模型的
__init__
方法。 - kwargs (剩余的关键字参数字典,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。- 要更新编码器配置,请为每个配置参数使用前缀 encoder_。
- 要更新解码器配置,请为每个配置参数使用前缀 decoder_。
- 要更新父模型配置,请不要为每个配置参数使用前缀。
根据是否提供
config
或自动加载config
,行为会有所不同。
从库的一个或两个基类实例化编码器和解码器,从预训练模型检查点实例化。
示例
>>> from transformers import TFEncoderDecoderModel
>>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2")
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert2gpt2")
>>> # load fine-tuned model
>>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2")
FlaxEncoderDecoderModel
class transformers.FlaxEncoderDecoderModel
< source >( config: EncoderDecoderConfig input_shape: typing.Optional[typing.Tuple] = None seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )
参数
- config (EncoderDecoderConfig) — 具有模型所有参数的模型配置类。使用配置文件初始化不会加载与模型关联的权重,仅加载配置。查看 from_pretrained() 方法以加载模型权重。
- dtype (
jax.numpy.dtype
,可选,默认为jax.numpy.float32
) — 计算的数据类型。可以是jax.numpy.float32
、jax.numpy.float16
(在 GPU 上) 和jax.numpy.bfloat16
(在 TPU 上) 之一。这可以用于在 GPU 或 TPU 上启用混合精度训练或半精度推理。如果指定,所有计算将以给定的
dtype
执行。请注意,这仅指定计算的 dtype,而不影响模型参数的 dtype。
此类可用于使用任何预训练的自动编码模型作为编码器和任何预训练的自回归模型作为解码器来初始化序列到序列模型。 编码器通过 from_pretrained() 函数加载,解码器通过 from_pretrained() 函数加载。 交叉注意力层会自动添加到解码器,应在下游生成任务(如摘要)上进行微调。
Sascha Rothe、Shashi Narayan、Aliaksei Severyn. Michael Matena、Yanqi Zhou、Wei Li、Peter J. Liu 在 Leveraging Pre-trained Checkpoints for Sequence Generation Tasks 中展示了使用预训练检查点初始化序列到序列模型以进行序列生成任务的有效性。
在训练/微调这样的 Encoder Decoder 模型后,可以像任何其他模型一样保存/加载它(有关更多信息,请参见示例)。
此模型继承自 FlaxPreTrainedModel。查看超类文档,了解库为其所有模型实现的通用方法(例如下载或保存、调整输入嵌入大小、剪枝头等)。
此模型也是 Flax Linen flax.nn.Module 子类。将其用作常规 Flax 模块,并参考 Flax 文档以获取与常规用法和行为相关的所有事项。
FlaxEncoderDecoderModel 是一个通用模型类,当使用 :meth~transformers.FlaxAutoModel.from_pretrained 类方法创建编码器和使用 :meth~transformers.FlaxAutoModelForCausalLM.from_pretrained 类方法创建解码器时,它将被实例化为一个 transformer 架构,其中模块 (flax.nn.Module) 由库中一个基础模型类的编码器模块和另一个基础模型类的解码器模块组成。
__call__
< source >( input_ids: Array attention_mask: typing.Optional[jax.Array] = None decoder_input_ids: typing.Optional[jax.Array] = None decoder_attention_mask: typing.Optional[jax.Array] = None position_ids: typing.Optional[jax.Array] = None decoder_position_ids: typing.Optional[jax.Array] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None train: bool = False params: dict = None dropout_rng: <function PRNGKey at 0x7f787eb14310> = None ) → transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput 或 tuple(torch.FloatTensor)
参数
- input_ids (
jnp.ndarray
,形状为(batch_size, sequence_length)
) — 词汇表中输入序列 tokens 的索引。如果您提供 padding,默认情况下将被忽略。索引可以使用 PreTrainedTokenizer 获得。 有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
- attention_mask (
jnp.ndarray
,形状为(batch_size, sequence_length)
,可选) — 用于避免在 padding token 索引上执行 attention 的 Mask。 Mask 值在[0, 1]
中选择:- 1 代表 未被 Mask 的 tokens,
- 0 代表 被 Mask 的 tokens。
- decoder_input_ids (
jnp.ndarray
,形状为(batch_size, target_sequence_length)
,可选) — 词汇表中解码器输入序列 tokens 的索引。索引可以使用 PreTrainedTokenizer 获得。 有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
对于序列到序列的训练,应提供
decoder_input_ids
。decoder_input_ids
应该在模型外部创建,方法是将labels
向右移动,将 -100 替换为pad_token_id
,并在其前面加上decoder_start_token_id
。 - decoder_attention_mask (
jnp.ndarray
,形状为(batch_size, target_sequence_length)
,可选) — 默认行为:生成一个 tensor,该 tensor 忽略decoder_input_ids
中的 pad tokens。 默认情况下,也会使用因果 Mask。 - position_ids (
numpy.ndarray
,形状为(batch_size, sequence_length)
,可选) — 位置 embeddings 中每个输入序列 tokens 的位置索引。在范围[0, config.encoder.max_position_embeddings - 1]
中选择。 - decoder_position_ids (
numpy.ndarray
,形状为(batch_size, sequence_length)
,可选) — 位置 embeddings 中每个解码器输入序列 tokens 的位置索引。在范围[0, config.decoder.max_position_embeddings - 1]
中选择。 - output_attentions (
bool
,可选) — 是否返回所有 attention 层的 attentions tensors。 有关更多详细信息,请参阅返回的 tensors 下的attentions
。 - output_hidden_states (
bool
,可选) — 是否返回所有层的 hidden states。 有关更多详细信息,请参阅返回的 tensors 下的hidden_states
。 - return_dict (
bool
,可选) — 如果设置为True
,模型将返回一个~utils.FlaxSeq2SeqLMOutput
而不是一个普通的 tuple。
返回
transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput 或 tuple(torch.FloatTensor)
一个 transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput 或一个 torch.FloatTensor
的 tuple(如果传递了 return_dict=False
或当 config.return_dict=False
时),包含各种元素,具体取决于配置 (EncoderDecoderConfig) 和输入。
-
logits (
jnp.ndarray
,形状为(batch_size, sequence_length, config.vocab_size)
) — 语言建模头的预测分数(SoftMax 之前每个词汇表 token 的分数)。 -
past_key_values (
tuple(tuple(jnp.ndarray))
,可选,当传递use_cache=True
或当config.use_cache=True
时返回) — 长度为config.n_layers
的tuple(jnp.ndarray)
的 tuple,其中每个 tuple 有 2 个形状为(batch_size, num_heads, sequence_length, embed_size_per_head)
) 的 tensors 和 2 个形状为(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)
的额外 tensors。包含预先计算的隐藏状态(自注意力模块和交叉注意力模块中的键和值),可用于加速顺序解码(请参阅
past_key_values
输入)。 -
decoder_hidden_states (
tuple(jnp.ndarray)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —jnp.ndarray
的 tuple(embedding 输出一个,每层输出一个),形状为(batch_size, sequence_length, hidden_size)
。解码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
decoder_attentions (
tuple(jnp.ndarray)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —jnp.ndarray
的 tuple(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
-
cross_attentions (
tuple(jnp.ndarray)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —jnp.ndarray
的 tuple(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。解码器的交叉注意力层的注意力权重,在注意力 softmax 之后,用于计算交叉注意力头中的加权平均值。
-
encoder_last_hidden_state (
jnp.ndarray
,形状为(batch_size, sequence_length, hidden_size)
,可选) — 模型编码器最后一层输出的 hidden-states 序列。 -
encoder_hidden_states (
tuple(jnp.ndarray)
,可选,当传递output_hidden_states=True
或当config.output_hidden_states=True
时返回) —jnp.ndarray
的 tuple(embedding 输出一个,每层输出一个),形状为(batch_size, sequence_length, hidden_size)
。编码器在每层输出处的隐藏状态,加上初始嵌入输出。
-
encoder_attentions (
tuple(jnp.ndarray)
,可选,当传递output_attentions=True
或当config.output_attentions=True
时返回) —jnp.ndarray
的 tuple(每层一个),形状为(batch_size, num_heads, sequence_length, sequence_length)
。编码器的注意力权重,在注意力 softmax 之后,用于计算自注意力头中的加权平均值。
FlaxEncoderDecoderModel 的 forward 方法,覆盖了 __call__
特殊方法。
虽然前向传递的配方需要在该函数中定义,但应在此之后调用 Module
实例,而不是调用此函数,因为前者负责运行预处理和后处理步骤,而后者会静默地忽略它们。
示例
>>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
>>> # load a fine-tuned bert2gpt2 model
>>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
>>> # load input & output tokenizer
>>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
>>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
>>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
>>> singing a racist chant. SAE's national chapter suspended the students,
>>> but University of Oklahoma President David Boren took it a step further,
>>> saying the university's affiliation with the fraternity is permanently done.'''
>>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids
>>> # use GPT2's eos_token as the pad as well as eos token
>>> model.config.eos_token_id = model.config.decoder.eos_token_id
>>> model.config.pad_token_id = model.config.eos_token_id
>>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
>>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
>>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
from_encoder_decoder_pretrained
< source >( encoder_pretrained_model_name_or_path: typing.Union[str, os.PathLike, NoneType] = None decoder_pretrained_model_name_or_path: typing.Union[str, os.PathLike, NoneType] = None *model_args **kwargs )
参数
- encoder_pretrained_model_name_or_path (
Union[str, os.PathLike]
,可选) — 用于初始化编码器的必要信息。 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。
- decoder_pretrained_model_name_or_path (
Union[str, os.PathLike]
,可选,默认为None
) — 用于初始化解码器的必要信息。 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。
- model_args (剩余的位置参数,可选) — 所有剩余的位置参数将传递给底层模型的
__init__
方法。 - kwargs (剩余的关键字参数字典,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。- 要更新编码器配置,请为每个配置参数使用前缀 encoder_。
- 要更新解码器配置,请为每个配置参数使用前缀 decoder_。
- 要更新父模型配置,请不要为每个配置参数使用前缀。
行为方式取决于是否提供了
config
或自动加载了config
。
从库的一个或两个基类实例化编码器和解码器,从预训练模型检查点实例化。
示例
>>> from transformers import FlaxEncoderDecoderModel
>>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert2gpt2")
>>> # load fine-tuned model
>>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")