利用预训练语言模型检查点进行编码器-解码器模型训练

发布于 2020 年 11 月 9 日
在 GitHub 上更新
Open In Colab

Vaswani 等人(2017)提出了基于 Transformer 的编码器-解码器模型,最近受到了广泛关注,例如 Lewis 等人(2019)Raffel 等人(2019)Zhang 等人(2020)Zaheer 等人(2020)Yan 等人(2020)

与 BERT 和 GPT2 类似,大型预训练编码器-解码器模型已显示出在各种序列到序列任务中显著提升性能,例如 Lewis 等人(2019)Raffel 等人(2019)。然而,由于预训练编码器-解码器模型的计算成本巨大,此类模型的开发主要局限于大型公司和机构。

《Leveraging Pre-trained Checkpoints for Sequence Generation Tasks》(2020)中,Sascha Rothe、Shashi Narayan 和 Aliaksei Severyn 利用预训练的仅编码器和/或仅解码器检查点(例如 BERT、GPT2)来初始化编码器-解码器模型,以跳过昂贵的预训练过程。作者表明,此类温启动编码器-解码器模型在多种序列到序列任务上,仅需一小部分训练成本即可获得与大型预训练编码器-解码器模型(如T5Pegasus)相媲美的结果。

在本笔记本中,我们将详细解释如何温启动编码器-解码器模型,根据Rothe 等人(2020)的论文提供实用技巧,最后通过一个完整的代码示例,展示如何使用 🤗Transformers 温启动编码器-解码器模型。

本笔记本分为 4 个部分

  • 引言 - 对 NLP 中预训练语言模型以及温启动编码器-解码器模型的需求进行简要总结。
  • 编码器-解码器模型的温启动(理论) - 图解说明编码器-解码器模型是如何进行温启动的?
  • 编码器-解码器模型的温启动(分析) - 《Leveraging Pre-trained Checkpoints for Sequence Generation Tasks》(2020)摘要 - 哪些模型组合对温启动编码器-解码器模型有效;它在不同任务之间有何差异?
  • 使用 🤗Transformers 温启动编码器-解码器模型(实践) - 详细展示如何使用 EncoderDecoderModel 框架温启动基于 Transformer 的编码器-解码器模型的完整代码示例。

强烈建议(甚至可能是必要)阅读这篇博文,了解基于 Transformer 的编码器-解码器模型。

我们首先介绍温启动编码器-解码器模型的背景。

引言

最近,预训练语言模型1{}^1彻底改变了自然语言处理(NLP)领域。

最早的预训练语言模型是基于循环神经网络(RNN),由Dai 等人(2015)提出。Dai 等人表明,在未标注数据上预训练基于 RNN 的模型,然后针对特定任务进行微调2{}^2,比直接在任务上训练随机初始化的模型效果更好。然而,直到 2018 年,预训练语言模型才在 NLP 领域被广泛接受。Peters 等人提出的 ELMOHoward 等人提出的 ULMFit是首批显著提升一系列自然语言理解(NLU)任务最新技术的预训练语言模型。仅仅几个月后,OpenAI 和 Google 发布了基于 Transformer 的预训练语言模型,分别命名为Radford 等人提出的 GPTDevlin 等人提出的 BERT。基于 Transformer 的语言模型在效率上优于 RNN,使得 GPT2 和 BERT 能够在大量未标注文本数据上进行预训练。一旦预训练完成,BERT 和 GPT 被证明只需要很少的微调即可在十多个 NLU 任务上打破最新技术水平3{}^3

预训练语言模型将任务无关知识有效地迁移到任务特定知识的能力,极大地推动了 NLU 的发展。工程师和研究人员以前需要从头开始训练语言模型,而现在,公开可用的、大型预训练语言模型的检查点可以在极短的时间内以极低的成本进行微调。这在工业界可以节省数百万美元,在研究中则可以实现更快的原型开发和更好的基准测试。

预训练语言模型已经在 NLU 任务上确立了新的性能水平,越来越多的研究都建立在利用这些预训练语言模型来改进 NLU 系统上。然而,独立的 BERT 和 GPT 模型在序列到序列任务上(例如,文本摘要机器翻译句子改写等)的表现则不尽如人意。

序列到序列任务定义为将输入序列 X1:n\mathbf{X}_{1:n} 映射到输出序列 Y1:m\mathbf{Y}_{1:m},其中输出长度 mm 是预先未知的。因此,序列到序列模型应该定义输出序列 Y1:m\mathbf{Y}_{1:m} 关于输入序列 X1:n\mathbf{X}_{1:n} 的条件概率分布

pθmodel(Y1:mX1:n). p_{\theta_{\text{model}}}(\mathbf{Y}_{1:m} | \mathbf{X}_{1:n}).

不失一般性,包含 nn 个词的输入词序列表示为向量序列 X1:n=x1,,xn\mathbf{X}_{1:n} = \mathbf{x}_1, \ldots, \mathbf{x}_n;包含 mm 个词的输出序列表示为 Y1:m=y1,,ym\mathbf{Y}_{1:m} = \mathbf{y}_1, \ldots, \mathbf{y}_m

让我们看看 BERT 和 GPT2 如何适应序列到序列任务。

BERT

BERT 是一个仅编码器模型,它将输入序列 X1:n\mathbf{X}_{1:n} 映射到语境化的编码序列 X1:n\mathbf{\overline{X}}_{1:n}

fθBERT:X1:nX1:n. f_{\theta_{\text{BERT}}}: \mathbf{X}_{1:n} \to \mathbf{\overline{X}}_{1:n}.

BERT 的语境化编码序列 X1:n\mathbf{\overline{X}}_{1:n} 随后可以由分类层进一步处理,以用于 NLU 分类任务,例如情感分析自然语言推理等。为此,分类层(通常是池化层后接前馈层)作为最后一层添加到 BERT 的顶部,将语境化编码序列 X1:n\mathbf{\overline{X}}_{1:n} 映射到类别 cc

fθp,c:X1:nc. f_{\theta{\text{p,c}}}: \mathbf{\overline{X}}_{1:n} \to c.

已经证明,在预训练的 BERT 模型 θBERT\theta_{\text{BERT}} 顶部添加定义为 θp,c\theta_{\text{p,c}} 的池化层和分类层,然后微调整个模型 {θp,c,θBERT}\{\theta_{\text{p,c}}, \theta_{\text{BERT}}\} 可以在各种 NLU 任务上获得最先进的性能,参见Devlin 等人提出的 BERT

让我们来可视化 BERT。

texte du
lien

BERT 模型显示为灰色。该模型堆叠了多个BERT 块,每个块由双向自注意力层(红色框下部所示)和两个前馈层(红色框上部所示)组成。

每个 BERT 块都利用**双向**自注意力来处理输入序列 x1,,xn\mathbf{x'}_1, \ldots, \mathbf{x'}_n(浅灰色所示),以生成更“精炼”的语境化输出序列 x1,,xn\mathbf{x''}_1, \ldots, \mathbf{x''}_n(略深灰色所示)4{}^4。最后一个 BERT 块的语境化输出序列,即 X1:n\mathbf{\overline{X}}_{1:n},可以通过添加一个任务特定的分类层(橙色所示),如上所述,映射到单个输出类别 cc

仅编码器模型只能将输入序列映射到先验已知输出长度的输出序列。总之,输出维度不依赖于输入序列,这使得使用仅编码器模型进行序列到序列任务具有不利和不切实际的缺点。

正如所有仅编码器模型一样,BERT 的架构与编码器-解码器笔记本中“编码器”部分所示的基于 Transformer 的编码器-解码器模型的编码器部分架构完全对应。

GPT2

GPT2 是一个仅解码器模型,它利用单向(即“因果”)自注意力,将输入序列 Y0:m1\mathbf{Y}_{0: m - 1} 1{}^1 映射到“下一个词”的 logits 向量序列 L1:m\mathbf{L}_{1:m}

fθGPT2:Y0:m1L1:m. f_{\theta_{\text{GPT2}}}: \mathbf{Y}_{0: m - 1} \to \mathbf{L}_{1:m}.

通过对 logits 向量 L1:m\mathbf{L}_{1:m} 应用 softmax 操作,模型可以定义词序列 Y1:m\mathbf{Y}_{1:m} 的概率分布。具体来说,词序列 Y1:m\mathbf{Y}_{1:m} 的概率分布可以分解为 m1m-1 个条件“下一个词”分布

pθGPT2(Y1:m)=i=1mpθGPT2(yiY0:i1). p_{\theta_{\text{GPT2}}}(\mathbf{Y}_{1:m}) = \prod_{i=1}^{m} p_{\theta_{\text{GPT2}}}(\mathbf{y}_i | \mathbf{Y}_{0:i-1}). pθGPT2(yiY0:i1)p_{\theta_{\text{GPT2}}}(\mathbf{y}_i | \mathbf{Y}_{0:i-1}) 表示给定所有先前的词 y0,,yi1\mathbf{y}_0, \ldots, \mathbf{y}_{i-1} 3{}^3 后,下一个词 yi\mathbf{y}_i 的概率分布,并定义为对 logits 向量 li\mathbf{l}_i 应用 softmax 操作。总而言之,以下等式成立。

pθgpt2(yiY0:i1)=Softmax(li)=Softmax(fθGPT2(Y0:i1)). p_{\theta_{\text{gpt2}}}(\mathbf{y}_i | \mathbf{Y}_{0:i-1}) = \textbf{Softmax}(\mathbf{l}_i) = \textbf{Softmax}(f_{\theta_{\text{GPT2}}}(\mathbf{Y}_{0: i - 1})).

有关更多详细信息,请参阅编码器-解码器博客文章的解码器部分。

现在也让我们可视化 GPT2。

texte du
lien

与BERT类似,GPT2由一系列*GPT2块*组成。与BERT块不同,GPT2块利用**单向**自注意力机制处理一些输入向量y0,,ym1\mathbf{y'}_0, \ldots, \mathbf{y'}_{m-1} (右下方浅蓝色所示),将其处理成输出向量序列y0,,ym1\mathbf{y''}_0, \ldots, \mathbf{y''}_{m-1} (右上方深蓝色所示)。除了GPT2块堆栈外,该模型还有一个线性层,称为*LM Head*,它将最后一个GPT2块的输出向量映射到logits向量l1,,lm\mathbf{l}_1, \ldots, \mathbf{l}_m。如前所述,logits向量li\mathbf{l}_i可以用于采样新的输入向量yi\mathbf{y}_i5{}^5

GPT2主要用于*开放域*文本生成。首先,将输入提示Y0:i1\mathbf{Y}_{0:i-1}输入到模型中,以获得条件分布pθgpt2(yY0:i1)p_{\theta_{\text{gpt2}}}(\mathbf{y} | \mathbf{Y}_{0:i-1})。然后从该分布中采样下一个词yi\mathbf{y}_i(上图中灰色箭头所示),并将其添加到输入中。以自回归的方式,词yi+1\mathbf{y}_{i+1}可以从pθgpt2(yY0:i)p_{\theta_{\text{gpt2}}}(\mathbf{y} | \mathbf{Y}_{0:i})中采样,以此类推。

因此,GPT2非常适合*语言生成*,但不太适合*条件*生成。通过将输入提示Y0:i1\mathbf{Y}_{0: i-1}设置为等于序列输入X1:n\mathbf{X}_{1:n},GPT2可以很好地用于条件生成。然而,与编码器-解码器架构相比,该模型架构存在根本性缺陷,如Raffel et al. (2019)第17页所解释。简而言之,单向自注意力强制模型对序列输入X1:n\mathbf{X}_{1:n}的表示受到不必要的限制,因为xi\mathbf{x}_i不能依赖于xi+1,i{1,,n}\mathbf{x}_{i+1}, \forall i \in \{1,\ldots, n\}

编码器-解码器

由于*仅编码器*模型需要*预先*知道输出长度,因此它们似乎不适用于序列到序列任务。*仅解码器*模型可以很好地用于序列到序列任务,但也如上所述存在某些架构限制。

当前处理*序列到序列*任务的主要方法是*基于Transformer*的**编码器-解码器**模型——通常也称为*seq2seq Transformer*模型。编码器-解码器模型由Vaswani et al. (2017)引入,此后已被证明在*序列到序列*任务上的表现优于独立语言模型(即仅解码器模型),例如Raffel et al. (2020)。本质上,编码器-解码器模型是*独立*编码器(如BERT)和*独立*解码器模型(如GPT2)的组合。有关基于Transformer的编码器-解码器模型的具体架构的更多详细信息,请参阅这篇博客文章

现在,我们知道大型预训练*独立*编码器和解码器模型的自由可用检查点(例如*BERT*和*GPT*)可以提高许多NLU任务的性能并降低训练成本。我们也知道编码器-解码器模型本质上是*独立*编码器和解码器模型的组合。这自然引出了一个问题:如何利用独立模型检查点用于编码器-解码器模型,以及哪些模型组合在某些*序列到序列*任务上表现最佳。

2020年,Sascha Rothe、Shashi Narayan和Aliaksei Severyn在他们的论文**利用预训练检查点进行序列生成任务**中精确地研究了这个问题。该论文对不同的编码器-解码器模型组合和微调技术进行了精彩分析,我们将在后面更详细地研究。

从预训练的独立模型检查点构建编码器-解码器模型被定义为对编码器-解码器模型进行*热启动*。以下各节将展示热启动编码器-解码器模型的理论工作原理、如何使用🤗Transformers将理论付诸实践,并提供提高性能的实用技巧。


1{}^1 *预训练语言模型*定义为神经网络

  • 在*未标记*文本数据上进行训练,即以任务无关的无监督方式进行训练,并且
  • 将输入词序列处理成*上下文相关*的嵌入。例如,Mikolov et al. (2013)的*连续词袋*和*跳字图*模型不被视为预训练语言模型,因为它们的嵌入是上下文无关的。

2{}^2 *微调*被定义为使用预训练语言模型的权重初始化的模型进行*任务特定*训练。

3{}^3 输入向量y0\mathbf{y}_0对应于预测第一个输出词y1\mathbf{y}_1所需的BOS\text{BOS}嵌入向量。

4{}^4 在不失一般性的前提下,我们排除归一化层,以免使方程和插图过于繁琐。

5{}^5 有关单向自注意力如何用于“仅解码器”模型(如GPT2)以及采样如何精确工作的更多详细信息,请参阅编码器-解码器博客文章的解码器部分。

热启动编码器-解码器模型(理论)

阅读了引言后,我们现在熟悉了*仅编码器*和*仅解码器*模型。我们注意到编码器-解码器模型架构本质上是*独立*编码器模型和*独立*解码器模型的组合,这引导我们思考如何从*独立*模型检查点*热启动*编码器-解码器模型。

有多种可能性可以热启动编码器-解码器模型。可以

  1. 从*仅编码器*模型检查点初始化编码器和解码器部分,例如BERT,
  2. 从*仅编码器*模型检查点初始化编码器部分,例如BERT,并从*仅解码器*检查点初始化解码器部分,例如GPT2,
  3. 仅使用*仅编码器*模型检查点初始化编码器部分,或
  4. 仅使用*仅解码器*模型检查点初始化解码器部分。

在下文中,我们将重点放在可能性1和2上。在理解了前两种可能性之后,可能性3和4就变得微不足道了。

编码器-解码器模型回顾

首先,让我们快速回顾一下编码器-解码器架构。

texte du
lien

编码器(绿色所示)是*编码器块*的堆栈。每个编码器块由一个*双向自注意力*层和两个前馈层1{}^1组成。解码器(橙色所示)是*解码器块*的堆栈,后面是一个称为*LM Head*的密集层。每个解码器块由一个*单向自注意力*层、一个*交叉注意力*层和两个前馈层组成。

编码器将输入序列X1:n\mathbf{X}_{1:n}映射到上下文编码序列X1:n\mathbf{\overline{X}}_{1:n},其方式与BERT完全相同。然后,解码器将上下文编码序列X1:n\mathbf{\overline{X}}_{1:n}和目标序列Y0:m1\mathbf{Y}_{0:m-1}映射到logits向量L1:m\mathbf{L}_{1:m}。与GPT2类似,logits然后通过*softmax*操作用于定义目标序列Y1:m\mathbf{Y}_{1:m}以输入序列X1:n\mathbf{X}_{1:n}为条件的分布。

用数学术语来说,首先,条件分布通过贝叶斯规则分解为m1m - 1个下一个词yi\mathbf{y}_i的条件分布。

pθenc, dec(Y1:mX1:n)=pθdec(Y1:mX1:n)=i=1mpθdec(yiY0:i1,X1:n), with X1:n=fθenc(X1:n). p_{\theta_{\text{enc, dec}}}(\mathbf{Y}_{1:m} | \mathbf{X}_{1:n}) = p_{\theta_{\text{dec}}}(\mathbf{Y}_{1:m} | \mathbf{\overline{X}}_{1:n}) = \prod_{i=1}^m p_{\theta_{\text{dec}}}(\mathbf{y}_i | \mathbf{Y}_{0: i -1}, \mathbf{\overline{X}}_{1:n}), \text{ with } \mathbf{\overline{X}}_{1:n} = f_{\theta_{\text{enc}}}(\mathbf{X}_{1:n}).

每个“下一个词”的条件分布由logits向量的*softmax*定义,如下所示。

pθdec(yiY0:i1,X1:n)=Softmax(li). p_{\theta_{\text{dec}}}(\mathbf{y}_i | \mathbf{Y}_{0: i -1}, \mathbf{\overline{X}}_{1:n}) = \textbf{Softmax}(\mathbf{l}_i).

更多详情请参阅编码器-解码器notebook

使用BERT热启动编码器-解码器

现在,让我们演示如何使用预训练的BERT模型热启动编码器-解码器模型。BERT的预训练权重参数用于初始化编码器的权重参数和解码器的权重参数。为此,BERT的架构与编码器的架构进行比较,编码器中所有在BERT中也存在的层都将使用相应BERT层的预训练权重参数进行初始化。编码器中所有在BERT中不存在的层将简单地随机初始化其权重参数。

让我们进行可视化。

texte du
lien

我们可以看到编码器架构与BERT的架构一一对应。**所有**编码器块的*双向自注意力层*和两个*前馈层*的权重参数都使用相应BERT块的权重参数进行初始化。这在第二个编码器块(底部红色框)中得到了示例说明,其权重参数θencself-attn,2\theta_{\text{enc}}^{\text{self-attn}, 2}θencfeed-forward,2\theta_{\text{enc}}^{\text{feed-forward}, 2}在初始化时分别设置为BERT的权重参数θBERTfeed-forward,2\theta_{\text{BERT}}^{\text{feed-forward}, 2}θBERTself-attn,2\theta_{\text{BERT}}^{\text{self-attn}, 2}

在微调之前,编码器因此表现得与预训练的 BERT 模型完全一样。假设传递给编码器的输入序列 x1,,xn\mathbf{x}_1, \ldots, \mathbf{x}_n(绿色所示)等于传递给 BERT 的输入序列 x1BERT,,xnBERT\mathbf{x}_1^{\text{BERT}}, \ldots, \mathbf{x}_n^{\text{BERT}}(灰色所示),这意味着相应的输出向量序列 x1,,xn\mathbf{\overline{x}}_1, \ldots, \mathbf{\overline{x}}_n(深绿色所示)和 x1BERT,,xnBERT\mathbf{\overline{x}}_1^{\text{BERT}}, \ldots, \mathbf{\overline{x}}_n^{\text{BERT}}(深灰色所示)也必须相等。

接下来,我们来演示解码器是如何进行热启动的。

texte du
lien

解码器的架构与 BERT 的架构有三个不同之处。

  1. 首先,解码器必须通过交叉注意力层以上下文编码序列 X1:n\mathbf{\overline{X}}_{1:n} 为条件。因此,在每个 BERT 块的自注意力层和两个前馈层之间添加了随机初始化的交叉注意力层。这在第二个块中以 +θdeccross-attention, 2+\theta_{\text{dec}}^{\text{cross-attention, 2}} 为例进行表示,并在右下方红色框中以新增的红色完全连接图表示。这必然会改变每个修改过的 BERT 块的行为,使得输入向量(例如 y0\mathbf{y'}_0)现在会产生随机输出向量 y0\mathbf{y''}_0(由输出向量 y0\mathbf{y''}_0 周围的红色边框突出显示)。

  2. 其次,BERT 的*双向*自注意力层必须改为*单向*自注意力层,以符合自回归生成的要求。由于双向和单向自注意力层都基于相同的*键*、*查询*和*值*投影权重,因此解码器的自注意力层权重可以用 BERT 的自注意力层权重进行初始化。例如,解码器的单向自注意力层的查询、键和值权重参数用 BERT 双向自注意力层的相应参数进行初始化:θBERTself-attn,2={WBERT,kself-attn,2,WBERT,vself-attn,2,WBERT,qself-attn,2}θdecself-attn,2={Wdec,kself-attn,2,Wdec,vself-attn,2,Wdec,qself-attn,2}.\theta_{\text{BERT}}^{\text{self-attn}, 2} = \{\mathbf{W}_{\text{BERT}, k}^{\text{self-attn}, 2}, \mathbf{W}_{\text{BERT}, v}^{\text{self-attn}, 2}, \mathbf{W}_{\text{BERT}, q}^{\text{self-attn}, 2} \} \to \theta_{\text{dec}}^{\text{self-attn}, 2} = \{\mathbf{W}_{\text{dec}, k}^{\text{self-attn}, 2}, \mathbf{W}_{\text{dec}, v}^{\text{self-attn}, 2}, \mathbf{W}_{\text{dec}, q}^{\text{self-attn}, 2} \}. 然而,在*单向*自注意力中,每个 token 只关注所有先前的 token,因此即使解码器的自注意力层共享相同的权重,它们也会产生与 BERT 自注意力层不同的输出向量。例如,比较右侧框中解码器的因果连接图与左侧框中 BERT 的完全连接图。

  3. 第三,解码器输出一个 logit 向量序列 L1:m\mathbf{L}_{1:m},以便定义条件概率分布 pθdec(Y1:nX)p_{\theta_{\text{dec}}}(\mathbf{Y}_{1:n} | \mathbf{\overline{X}})。因此,在最后一个解码器块的顶部添加了一个*LM Head*层。*LM Head*层的权重参数通常与词嵌入 Wemb\mathbf{W}_{\text{emb}} 的权重参数相对应,因此不是随机初始化的。这在顶部通过初始化 θBERTword-embθdeclm-head\theta_{\text{BERT}}^{\text{word-emb}} \to \theta_{\text{dec}}^{\text{lm-head}}θdeclm-head 进行说明。

总而言之,当从预训练的 BERT 模型热启动解码器时,只有交叉注意力层权重是随机初始化的。所有其他权重,包括自注意力层和 LM Head 的权重,都用 BERT 的预训练权重参数进行初始化。

在热启动编码器-解码器模型后,权重将在*序列到序列*的下游任务(如摘要)上进行微调。

使用 BERT 和 GPT2 热启动编码器-解码器

我们可以不使用 BERT 检查点热启动编码器和解码器,而是利用 BERT 检查点用于编码器,并利用 GPT2 检查点用于解码器。乍一看,一个仅包含解码器的 GPT2 检查点似乎更适合热启动解码器,因为它已经过因果语言建模训练,并且使用*单向*自注意力层。

让我们演示如何使用 GPT2 检查点来热启动解码器。

texte du
lien

我们可以看到,解码器与 GPT2 的相似度高于其与 BERT 的相似度。解码器的*LM Head*的权重参数可以直接用 GPT2 的*LM Head*权重参数进行初始化,例如 θGPT2lm-headθdeclm-head\theta_{\text{GPT2}}^{\text{lm-head}} \to \theta_{\text{dec}}^{\text{lm-head}}。此外,解码器和 GPT2 的块都使用*单向*自注意力,因此假设输入向量相同,解码器自注意力层的输出向量与 GPT2 的输出向量等效,例如 y0GPT2=y0\mathbf{y'}_0^{\text{GPT2}} = \mathbf{y'}_0y0。与 BERT 初始化的解码器相反,GPT2 初始化的解码器因此保留了自注意力层的因果连接图,如下方红色框所示。

然而,GPT2 初始化的解码器也必须以 X1:n\mathbf{\overline{X}}_{1:n} 为条件。因此,与 BERT 初始化的解码器类似,为每个解码器块添加了随机初始化的交叉注意力层权重参数。这以 +θdeccross-attention, 2+\theta_{\text{dec}}^{\text{cross-attention, 2}} 为例进行说明。

尽管 GPT2 比 BERT 更像编码器-解码器模型中的解码器部分,但由于每个解码器块中随机初始化的交叉注意力层,GPT2 初始化的解码器在不进行微调的情况下也会产生随机的 logit 向量 L1:m\mathbf{L}_{1:m}。研究 GPT2 初始化的解码器是否能产生更好的结果或更有效地进行微调将是很有趣的。

编码器-解码器权重共享

Raffel et al. (2020) 中,作者表明,一个随机初始化的编码器-解码器模型,如果将编码器的权重与解码器共享,从而将内存占用减少一半,其性能仅比其“不共享”版本略差。将编码器的权重与解码器共享意味着解码器中与编码器中相同位置的所有层共享相同的权重参数,即网络计算图中的相同节点。
例如,第三个编码器块中自注意力层的查询、键和值投影矩阵,定义为 WEnc,kself-attn,3\mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, k}WEnc,vself-attn,3\mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, v}WEnc,qself-attn,3\mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, q},与第三个解码器块中自注意力层的相应查询、键和值投影矩阵 2{}^2 相同

Wkself-attn,3=Wenc,kself-attn,3Wdec,kself-attn,3, \mathbf{W}^{\text{self-attn}, 3}_{k} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, k} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, k}, =Wenc,kself-attn,3Wdec,kself-attn,3, Wqself-attn,3=Wenc,qself-attn,3Wdec,qself-attn,3, \mathbf{W}^{\text{self-attn}, 3}_{q} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, q} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, q}, =Wenc,qself-attn,3Wdec,qself-attn,3, Wvself-attn,3=Wenc,vself-attn,3Wdec,vself-attn,3, \mathbf{W}^{\text{self-attn}, 3}_{v} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, v} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, v}, =Wenc,vself-attn,3Wdec,vself-attn,3,

因此,键投影权重 Wkself-attn,3,Wvself-attn,3,Wqself-attn,3\mathbf{W}^{\text{self-attn}, 3}_{k}, \mathbf{W}^{\text{self-attn}, 3}_{v}, \mathbf{W}^{\text{self-attn}, 3}_{q},Wvself-attn,3,Wqself-attn,3 在每次反向传播过程中更新两次——一次当梯度通过第三个解码器块反向传播时,另一次当梯度通过第三个编码器块反向传播时。

以同样的方式,我们可以通过共享编码器权重与解码器来热启动编码器-解码器模型。为了在编码器和解码器之间共享权重,解码器架构(不包括交叉注意力权重)需要与编码器架构相同。因此,*编码器-解码器权重共享*仅在编码器-解码器模型从单个*仅编码器*预训练检查点热启动时才相关。

太棒了!这就是关于热启动编码器-解码器模型的理论。现在让我们看看一些结果。


1{}^1 不失一般性,我们排除了归一化层,以免混淆方程式和插图。 2{}^2 有关自注意力层如何运作的更多详细信息,请参阅变压器编码器-解码器模型博客文章的此部分(编码器部分)和此部分(解码器部分)。

热启动编码器-解码器模型(分析)

在本节中,我们将总结 Sascha Rothe、Shashi Narayan 和 Aliaksei Severyn 的《Leveraging Pre-trained Checkpoints for Sequence Generation Tasks》中提出的热启动编码器-解码器模型的研究结果。作者比较了热启动编码器-解码器模型与随机初始化编码器-解码器模型在多个*序列到序列*任务(特别是*摘要*、*翻译*、*句子拆分*和*句子合并*)上的性能。

更确切地说,公共可用的预训练检查点 BERTRoBERTaGPT2 以不同方式用于热启动编码器-解码器模型。例如,BERT 初始化的编码器与 BERT 初始化的解码器配对,生成 BERT2BERT 模型;或者 RoBERTa 初始化的编码器与 GPT2 初始化的解码器配对,生成 RoBERTa2GPT2 模型。此外,还研究了 RoBERTa 的编码器和解码器权重共享(如前一节所述)的效果,即 RoBERTaShare,以及 BERT 的效果,即 BERTShare。随机或部分随机初始化的编码器-解码器模型用作基线,例如完全随机初始化的编码器-解码器模型(称为 Rnd2Rnd)或 BERT 初始化的解码器与随机初始化的编码器配对(定义为 Rnd2BERT)。

下表显示了所有研究模型变体的完整列表,包括随机初始化的权重数量(即“随机”)和从各自预训练检查点初始化的权重数量(即“利用”)。所有模型均基于 12 层架构,隐藏尺寸嵌入为 768 维,对应于 🤗Transformers 模型中心的 `bert-base-cased`、`bert-base-uncased`、`roberta-base` 和 `gpt2` 检查点。

模型 随机 利用 总计
Rnd2Rnd 2.21 亿 0 2.21 亿
Rnd2BERT 1.12 亿 1.09 亿 2.21 亿
BERT2Rnd 1.12 亿 1.09 亿 2.21 亿
Rnd2GPT2 1.14 亿 1.25 亿 2.38 亿
BERT2BERT 2600 万 1.95 亿 2.21 亿
BERTShare 2600 万 1.09 亿 1.35 亿
RoBERTaShare 2600 万 1.26 亿 1.52 亿
BERT2GPT2 2600 万 2.34 亿 2.60 亿
RoBERTa2GPT2 2600 万 2.50 亿 2.76 亿

基于 BERT2BERT 架构的模型*Rnd2Rnd*包含 2.21 亿权重参数,所有这些参数都是随机初始化的。另外两个“基于 BERT”的基线*Rnd2BERT*和*BERT2Rnd*大约有一半的权重(即 1.12 亿参数)是随机初始化的。其余的 1.09 亿权重参数分别从预训练的 `bert-base-uncased` 检查点中提取,用于编码器或解码器部分。模型*BERT2BERT*、*BERT2GPT2*和*RoBERTa2GPT2*的所有编码器权重参数都得到了利用(分别来自 `bert-base-uncased`、`roberta-base`),并且大部分解码器权重参数也得到了利用(分别来自 `gpt2`、`bert-base-uncased`)。其中,2600 万个解码器权重参数(对应于 12 个交叉注意力层)是随机初始化的。RoBERTa2GPT2 和 BERT2GPT2 与*Rnd2GPT2*基线进行了比较。此外,需要注意的是,共享模型变体*BERTShare*和*RoBERTaShare*的参数数量显著减少,因为所有编码器权重参数都与相应的解码器权重参数共享。

实验

上述模型在四个复杂程度递增的序列到序列任务上进行了训练和评估:句子级合并、句子级拆分、翻译和抽象摘要。下表显示了每个任务使用的数据集。

序列到序列任务 数据集 论文 🤗数据集
句子合并 DiscoFuse Geva et al. (2019) 链接
句子拆分 WikiSplit Botha et al. (2018) -
翻译 WMT14 英语 => 德语 Bojar et al. (2014) 链接
WMT14 德语 => 英语 Bojar et al. (2014) 链接
抽象摘要 CNN/Dailymail Hermann et al. (2015) 链接
BBC XSum Narayan et al. (2018a) 链接
Gigaword Napoles et al. (2012) 链接

根据任务的不同,使用了略有不同的训练方案。例如,根据数据集的大小和具体任务,训练步数范围为 20 万到 50 万,批处理大小设置为 128 或 256,输入长度范围为 128 到 512,输出长度在 32 到 128 之间变化。然而,需要强调的是,在每个任务中,所有模型都使用相同的超参数进行训练和评估,以确保公平比较。有关任务特定超参数设置的更多信息,建议读者参阅论文的*实验*部分。

现在我们将简要概述每个任务的结果。

句子合并和拆分(DiscoFuse、WikiSplit)

句子合并是将多个句子组合成一个连贯句子的任务。例如,以下两个句子:

作为一名跑动阻挡者,蔡特勒的移动相对不错。 蔡特勒在空间接触点上经常挣扎。

应该用一个合适的*连接词*连接起来,例如:

作为一名跑动阻挡者,蔡特勒的移动相对不错。然而在空间接触点上经常挣扎。

可以看出,“然而”这个连接词为第一个句子到第二个句子提供了连贯的过渡。一个能够生成这种连接词的模型可以说已经学会推断出上述两个句子是相互对比的。

逆任务称为句子拆分,包括将一个复杂的句子拆分成多个更简单的句子,这些句子共同保留相同的含义。句子拆分被认为是文本简化中的一项重要任务,参见Botha et al. (2018)

例如,句子

《街头霸王》是1989年为PC和Commodore 64发布的系列两款游戏中的第一款

可以简化为

《街头霸王》是系列两款游戏中的第一款于1989年为PC和Commodore 64发布

可以看出,长句试图传达两个重要的信息。一是这款游戏是为PC发布的系列两款游戏中的第一款,二是它发布的年份。因此,句子拆分要求模型理解句子的哪个部分应该被分成两个句子,这使得这项任务比句子合并更困难。

评估模型在句子合并和拆分任务上性能的常用指标是 SARI (Wu et al. (2016),它大致基于标签和模型输出的 F1 分数。

让我们看看模型在句子合并和拆分上的表现。

模型 100% DiscoFuse (SARI) 10% DiscoFuse (SARI) 100% WikiSplit (SARI)
Rnd2Rnd 86.9 81.5 61.7
Rnd2BERT 87.6 82.1 61.8
BERT2Rnd 89.3 86.1 63.1
Rnd2GPT2 86.5 81.4 61.3
BERT2BERT 89.3 86.1 63.2
BERTShare 89.2 86.0 63.5
RoBERTaShare 89.7 86.0 63.4
BERT2GPT2 88.4 84.1 62.4
RoBERTa2GPT2 89.9 87.1 63.2
--- --- --- ---
RoBERTaShare (大型) 90.3 87.7 63.8

前两列显示了编码器-解码器模型在 DiscoFuse 评估数据上的性能。第一列显示了在所有 (100%) 训练数据上训练的编码器-解码器模型的结果,而第二列显示了仅在 10% 训练数据上训练的模型的结果。我们观察到,热启动模型比随机初始化的基线模型 *Rnd2Rnd*、*Rnd2Bert* 和 *Rnd2GPT2* 表现显著更好。仅在 10% 训练数据上训练的热启动 *RoBERTa2GPT2* 模型与在 100% 训练数据上训练的 *Rnd2Rnd* 模型性能相当。有趣的是,*Bert2Rnd* 基线模型的表现与完全热启动的 *Bert2Bert* 模型一样好,这表明热启动编码器部分比热启动解码器部分更有效。最好的结果是由 *RoBERTa2GPT2* 获得,其次是 *RobertaShare*。共享编码器和解码器权重参数似乎确实略微提高了模型的性能。

在更困难的句子拆分任务中,也出现了类似的模式。热启动编码器-解码器模型的性能显著优于编码器随机初始化模型,并且具有共享权重参数的编码器-解码器模型比具有非耦合权重参数的模型产生更好的结果。在句子拆分任务上,BertShare 模型表现最佳,紧随其后的是 RobertaShare

除了12层模型变体,作者还训练和评估了一个24层*RobertaShare (large)*模型,其性能显著优于所有12层模型。

机器翻译 (WMT14)

接下来,作者在机器翻译 (MT) 中可能最常见的基准上评估了热启动的编码器-解码器模型——即 En \to DeDe \to En WMT14 数据集。在本 Notebook 中,我们展示了 newstest2014 评估数据集的结果。因为该基准要求模型理解英语和德语词汇,所以 BERT 初始化的编码器-解码器模型是从多语言预训练检查点 bert-base-multilingual-cased 热启动的。由于没有公开可用的多语言 RoBERTa 检查点,因此 MT 中排除了 RoBERTa 初始化的编码器-解码器模型。GPT2 初始化的模型像之前的实验一样从 gpt2 预训练检查点初始化。翻译结果使用 BLUE-4 分数指标报告 1{}^1

模型 \to 德 (BLEU-4) \to 英 (BLEU-4)
Rnd2Rnd 26.0 29.1
Rnd2BERT 27.2 30.4
BERT2Rnd 30.1 32.7
Rnd2GPT2 19.6 23.2
BERT2BERT 30.1 32.7
BERTShare 29.6 32.6
BERT2GPT2 23.2 31.4
--- --- ---
BERT2Rnd (大型,自定义) 31.7 34.2
BERTShare (大型,自定义) 30.5 33.8

我们再次观察到,通过热启动编码器部分,性能得到了显著提升,其中 BERT2RndBERT2BERTEn \to DeDe \to En 任务上都取得了最佳结果。GPT2 初始化的模型在 En \to De 上的表现甚至明显差于 Rnd2Rnd 基线。考虑到 gpt2 检查点仅在英文文本上训练,BERT2GPT2Rnd2GPT2 模型在生成德语翻译时遇到困难并不令人惊讶。这一假设得到了 BERT2GPT2De \to En 任务上具有竞争力的结果(例如,31.4 对 32.7)的支持,因为 GPT2 的词汇表适合英文输出格式。与句子融合和句子拆分获得的结果相反,共享编码器和解码器权重参数并未在机器翻译中带来性能提升。作者指出的可能原因包括

  • MT 中编码器-解码器模型容量是一个重要因素,以及
  • 编码器和解码器必须处理不同的语法和词汇

由于 bert-base-multilingual-cased 检查点在超过 100 种语言上进行了训练,其词汇量对于 En \to DeDe \to En MT 可能不理想地过大。因此,作者在维基百科转储的英文和德文子集上预训练了一个大型 BERT 仅编码器检查点,并随后用它来热启动 BERT2RndBERTShare 编码器-解码器模型。由于词汇表的改进,观察到了另一个显著的性能提升,其中 BERT2Rnd (大型,自定义) 显著优于所有其他模型。

摘要 (CNN/Dailymail, BBC XSum, Gigaword)

最后,编码器-解码器模型在可以说是最具挑战性的序列到序列任务——摘要上进行了评估。作者选择了三个具有不同特征的摘要数据集进行评估:Gigaword (标题生成)、BBC XSum (极端摘要) 和 CNN/Dailymail (抽象摘要)。

Gigaword 数据集包含句子级别的抽象摘要,要求模型学习句子级别的理解、抽象和最终的转述。Gigaword 中的典型数据样本,例如

"*委内瑞拉总统乌戈·查韦斯周四表示,他已下令调查一起涉嫌涉及现役和退役军官的政变阴谋。*",

将有一个相应的标题作为其标签,例如

"查韦斯下令调查涉嫌政变阴谋"。

BBC XSum 数据集包含更长的文章式文本输入,其标签大多是单句摘要。该数据集要求模型不仅要学习文档级别的推理,还要学习高水平的抽象转述。BBC XSUM 数据集的一些数据样本可以在此处查看。

对于 CNN/Dailymail 数据集,与 BBC XSum 数据集长度相似的文档必须摘要为要点式故事摘要。因此,标签通常包含多个句子。除了文档级别的理解之外,CNN/Dailymail 数据集还要求模型善于复制最突出的信息。一些示例可以在此处查看。

模型使用 Rouge 指标进行评估,其中 Rouge-2 分数如下所示。

好的,让我们看看结果。

模型 CNN/Dailymail (Rouge-2) BBC XSum (Rouge-2) Gigaword (Rouge-2)
Rnd2Rnd 14.00 10.23 18.71
Rnd2BERT 15.55 11.52 18.91
BERT2Rnd 17.76 15.83 19.26
Rnd2GPT2 8.81 8.77 18.39
BERT2BERT 17.84 15.24 19.68
BERTShare 18.10 16.12 19.81
RoBERTaShare 18.95 17.50 19.70
BERT2GPT2 4.96 8.37 18.23
RoBERTa2GPT2 14.72 5.20 19.21
--- --- --- ---
RoBERTaShare (大型) 18.91 18.79 19.78

我们再次观察到,热启动编码器部分与随机初始化的编码器模型相比有显著的性能提升,这在文档级抽象任务(即 CNN/Dailymail 和 BBC XSum)中尤为明显。这表明,需要高水平抽象的任务比仅需要句子级抽象的任务更能从预训练的编码器部分中受益。除了 Gigaword,基于 GPT2 的编码器-解码器模型似乎不适合摘要。

此外,共享编码器-解码器模型在摘要任务中表现最佳。RoBERTaShareBERTShare 在所有数据集上都表现最佳,其中在 BBC XSum 数据集上的优势尤为显著,在该数据集上,RoBERTaShare (大型)BERT2BERTBERT2Rnd 高出约 3 个 Rouge-2 点,比 Rnd2Rnd 高出 8 个 Rouge-2 点以上。正如作者所说,“这可能是因为 BBC 摘要句子的分布与文档中句子的分布相似,而 Gigaword 标题和 CNN/DailyMail 要点摘要则不一定如此”。直观地说,这意味着在 BBC XSum 中,编码器处理的输入句子与解码器处理的单句摘要在结构上非常相似,即长度相同、词语选择相似、语法相似。

结论

好的,让我们得出结论并尝试提出一些实用技巧。

  • 我们已经观察到,在所有任务中,与随机初始化编码器的编码器-解码器模型相比,热启动编码器部分能显著提高性能。另一方面,热启动解码器似乎不那么重要,在大多数任务中,BERT2BERTBERT2Rnd 相当。一个直观的原因是,由于 BERT 或 RoBERTa 初始化的编码器部分没有任何随机初始化的权重参数,因此编码器可以充分利用 BERT 或 RoBERTa 预训练检查点所获得的知识。相比之下,热启动的解码器始终有部分权重参数是随机初始化的,这可能使得有效利用用于初始化解码器的检查点所获得的知识变得更加困难。

  • 接下来,我们注意到共享编码器和解码器权重通常是有益的,特别是当目标分布与输入分布相似时(例如 BBC XSum)。然而,对于目标数据分布与输入数据分布差异更大的数据集,以及已知模型容量 2{}^2 在其中扮演重要角色的数据集,例如 WMT14,编码器-解码器权重共享似乎是不利的。

  • 最后,我们看到预训练的“独立”检查点的词汇表与解决序列到序列任务所需的词汇表非常重要。例如,一个热启动的 BERT2GPT2 编码器-解码器在 En \to De MT 上的表现会很差,因为 GPT2 是在英语上预训练的,而目标语言是德语。与 BERT2BERTBERTSharedRoBERTaShared 相比,BERT2GPT2Rnd2GPT2RoBERTa2GPT2 的整体表现不佳表明共享词汇表更有效。此外,这表明用预训练的 GPT2 检查点初始化解码器部分并不比用预训练的 BERT 检查点初始化它更有效,尽管 GPT2 在其架构上与解码器更相似。

对于上述每个任务,性能最佳的模型已移植到 🤗Transformers,可在此处访问


1{}^1 为了获取 BLEU-4 分数,使用了 Tensorflow 官方 Transformer 实现 https://github.com/tensorflow/models/tree master/official/nlp/transformer 中的脚本。请注意,与 Vaswani 等人 (2017) 使用的 tensor2tensor/utils/ get_ende_bleu.sh 不同,该脚本不拆分名词复合词,但在注意到预处理的训练集只包含 ascii 引号后,将 utf-8 引号标准化为 ascii 引号。

2{}^2 模型容量是模型对复杂模式建模能力的非正式定义。有时也定义为模型从越来越多数据中学习的能力。模型容量通常通过可训练参数的数量来衡量——参数越多,模型容量越高。

使用 🤗Transformers 热启动编码器-解码器模型(实践)

我们已经解释了热启动编码器-解码器模型的理论,分析了多个数据集上的实证结果,并得出了实际结论。现在,我们将通过一个完整的代码示例来演示如何热启动 BERT2BERT 模型,并将其在 CNN/Dailymail 摘要任务上进行微调。我们将利用 🤗datasets 和 🤗Transformers 库。

此外,以下列表提供了本 Notebook 和其他关于热启动其他编码器-解码器模型组合的 Notebook 的精简版本。

  • 关于 CNN/Dailymail 上的 BERT2BERT(本 Notebook 的精简版),请点击此处
  • 关于 BBC XSum 上的 RoBERTaShare,请点击此处
  • 关于 WMT14 En \to De 上的 BERT2Rnd,请点击此处
  • 关于 DiscoFuse 上的 RoBERTa2GPT2,请点击此处

注意:本 Notebook 仅使用少量训练、验证和测试数据样本进行演示。要对完整训练数据进行编码器-解码器模型微调,用户应根据注释中突出显示的内容相应地更改训练和数据预处理参数。

数据预处理

本节将展示如何对数据进行预处理以进行训练。更重要的是,我们试图让读者对如何决定预处理数据的过程有一些了解。

我们将需要安装 datasets 和 transformers。

!pip install datasets==1.0.2
!pip install transformers==4.2.1

让我们首先下载 CNN/Dailymail 数据集。

import datasets
train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")

好的,让我们先对数据集有一个初步印象。另外,数据集也可以使用优秀的在线 datasets viewer 进行可视化。

train_data.info.description

我们的输入称为 article,我们的标签称为 highlights。现在让我们打印出训练数据的第一个示例,以便对数据有一个感觉。

import pandas as pd
from IPython.display import display, HTML
from datasets import ClassLabel

df = pd.DataFrame(train_data[:1])
del df["id"]
for column, typ in train_data.features.items():
      if isinstance(typ, ClassLabel):
          df[column] = df[column].transform(lambda i: typ.names[i])
display(HTML(df.to_html()))
OUTPUT:
-------
Article:
"""It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It's a step that is set to turn an international crisis into a fierce domestic political battle. There are key questions looming over the debate: What did U.N. weapons inspectors find in Syria? What happens if Congress votes no? And how will the Syrian government react? In a televised address from the White House Rose Garden earlier Saturday, the president said he would take his case to Congress, not because he has to -- but because he wants to. "While I believe I have the authority to carry out this military action without specific congressional authorization, I know that the country will be stronger if we take this course, and our actions will be even more effective," he said. "We should have this debate, because the issues are too big for business as usual." Obama said top congressional leaders had agreed to schedule a debate when the body returns to Washington on September 9. The Senate Foreign Relations Committee will hold a hearing over the matter on Tuesday, Sen. Robert Menendez said. Transcript: Read Obama's full remarks . Syrian crisis: Latest developments . U.N. inspectors leave Syria . Obama's remarks came shortly after U.N. inspectors left Syria, carrying evidence that will determine whether chemical weapons were used in an attack early last week in a Damascus suburb. "The aim of the game here, the mandate, is very clear -- and that is to ascertain whether chemical weapons were used -- and not by whom," U.N. spokesman Martin Nesirky told reporters on Saturday. But who used the weapons in the reported toxic gas attack in a Damascus suburb on August 21 has been a key point of global debate over the Syrian crisis. Top U.S. officials have said there's no doubt that the Syrian government was behind it, while Syrian officials have denied responsibility and blamed jihadists fighting with the rebels. British and U.S. intelligence reports say the attack involved chemical weapons, but U.N. officials have stressed the importance of waiting for an official report from inspectors. The inspectors will share their findings with U.N. Secretary-General Ban Ki-moon Ban, who has said he wants to wait until the U.N. team's final report is completed before presenting it to the U.N. Security Council. The Organization for the Prohibition of Chemical Weapons, which nine of the inspectors belong to, said Saturday that it could take up to three weeks to analyze the evidence they collected. "It needs time to be able to analyze the information and the samples," Nesirky said. He noted that Ban has repeatedly said there is no alternative to a political solution to the crisis in Syria, and that "a military solution is not an option." Bergen:  Syria is a problem from hell for the U.S. Obama: 'This menace must be confronted' Obama's senior advisers have debated the next steps to take, and the president's comments Saturday came amid mounting political pressure over the situation in Syria. Some U.S. lawmakers have called for immediate action while others warn of stepping into what could become a quagmire. Some global leaders have expressed support, but the British Parliament's vote against military action earlier this week was a blow to Obama's hopes of getting strong backing from key NATO allies. On Saturday, Obama proposed what he said would be a limited military action against Syrian President Bashar al-Assad. Any military attack would not be open-ended or include U.S. ground forces, he said. Syria's alleged use of chemical weapons earlier this month "is an assault on human dignity," the president said. A failure to respond with force, Obama argued,  "could lead to escalating use of chemical weapons or their proliferation to terrorist groups who would do our people harm. In a world with many dangers, this menace must be confronted." Syria missile strike: What would happen next? Map: U.S. and allied assets around Syria . Obama decision came Friday night . On Friday night, the president made a last-minute decision to consult lawmakers. What will happen if they vote no? It's unclear. A senior administration official told CNN that Obama has the authority to act without Congress -- even if Congress rejects his request for authorization to use force. Obama on Saturday continued to shore up support for a strike on the al-Assad government. He spoke by phone with French President Francois Hollande before his Rose Garden speech. "The two leaders agreed that the international community must deliver a resolute message to the Assad regime -- and others who would consider using chemical weapons -- that these crimes are unacceptable and those who violate this international norm will be held accountable by the world," the White House said. Meanwhile, as uncertainty loomed over how Congress would weigh in, U.S. military officials said they remained at the ready. 5 key assertions: U.S. intelligence report on Syria . Syria: Who wants what after chemical weapons horror . Reactions mixed to Obama's speech . A spokesman for the Syrian National Coalition said that the opposition group was disappointed by Obama's announcement. "Our fear now is that the lack of action could embolden the regime and they repeat his attacks in a more serious way," said spokesman Louay Safi. "So we are quite concerned." Some members of Congress applauded Obama's decision. House Speaker John Boehner, Majority Leader Eric Cantor, Majority Whip Kevin McCarthy and Conference Chair Cathy McMorris Rodgers issued a statement Saturday praising the president. "Under the Constitution, the responsibility to declare war lies with Congress," the Republican lawmakers said. "We are glad the president is seeking authorization for any military action in Syria in response to serious, substantive questions being raised." More than 160 legislators, including 63 of Obama's fellow Democrats, had signed letters calling for either a vote or at least a "full debate" before any U.S. action. British Prime Minister David Cameron, whose own attempt to get lawmakers in his country to support military action in Syria failed earlier this week, responded to Obama's speech in a Twitter post Saturday. "I understand and support Barack Obama's position on Syria," Cameron said. An influential lawmaker in Russia -- which has stood by Syria and criticized the United States -- had his own theory. "The main reason Obama is turning to the Congress:  the military operation did not get enough support either in the world, among allies of the US or in the United States itself," Alexei Pushkov, chairman of the international-affairs committee of the Russian State Duma, said in a Twitter post. In the United States, scattered groups of anti-war protesters around the country took to the streets Saturday. "Like many other Americans...we're just tired of the United States getting involved and invading and bombing other countries," said Robin Rosecrans, who was among hundreds at a Los Angeles demonstration. What do Syria's neighbors think? Why Russia, China, Iran stand by Assad . Syria's government unfazed . After Obama's speech, a military and political analyst on Syrian state TV said Obama is "embarrassed" that Russia opposes military action against Syria, is "crying for help" for someone to come to his rescue and is facing two defeats -- on the political and military levels. Syria's prime minister appeared unfazed by the saber-rattling. "The Syrian Army's status is on maximum readiness and fingers are on the trigger to confront all challenges," Wael Nader al-Halqi said during a meeting with a delegation of Syrian expatriates from Italy, according to a banner on Syria State TV that was broadcast prior to Obama's address. An anchor on Syrian state television said Obama "appeared to be preparing for an aggression on Syria based on repeated lies." A top Syrian diplomat told the state television network that Obama was facing pressure to take military action from Israel, Turkey, some Arabs and right-wing extremists in the United States. "I think he has done well by doing what Cameron did in terms of taking the issue to Parliament," said Bashar Jaafari, Syria's ambassador to the United Nations. Both Obama and Cameron, he said, "climbed to the top of the tree and don't know how to get down." The Syrian government has denied that it used chemical weapons in the August 21 attack, saying that jihadists fighting with the rebels used them in an effort to turn global sentiments against it. British intelligence had put the number of people killed in the attack at more than 350. On Saturday, Obama said "all told, well over 1,000 people were murdered." U.S. Secretary of State John Kerry on Friday cited a death toll of 1,429, more than 400 of them children. No explanation was offered for the discrepancy. Iran: U.S. military action in Syria would spark 'disaster' Opinion: Why strikes in Syria are a bad idea ."""
Summary:
"""Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"\nObama sends a letter to the heads of the House and Senate .\nObama to seek congressional approval on military action against Syria .\nAim is to determine whether CW were used, not by whom, says U.N. spokesman"""

输入数据似乎由短新闻文章组成。有趣的是,标签似乎是项目符号式的摘要。此时,应该查看其他几个示例,以便更好地了解数据。

这里还应该注意到文本是区分大小写的。这意味着如果我们想使用不区分大小写的模型,我们必须小心。由于 CNN/Dailymail 是一个摘要数据集,模型将使用 ROUGE 指标进行评估。检查 🤗datasets 中 ROUGE 的描述(参见此处),我们可以看到该指标是不区分大小写的,这意味着在评估期间大写字母将被规范化为小写字母。因此,我们可以安全地利用不带大小写的检查点,例如 bert-base-uncased

太棒了!接下来,让我们了解一下输入数据和标签的长度。

由于模型以token 长度计算长度,我们将使用 bert-base-uncased 分词器来计算文章和摘要的长度。

首先,我们加载分词器。

from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

接下来,我们使用 .map() 来计算文章及其摘要的长度。由于我们知道 bert-base-uncased 可以处理的最大长度为 512,因此我们也对输入样本超过最大长度的百分比感兴趣。同样,我们计算摘要长度分别超过 64 和 128 的百分比。

我们可以将 .map() 函数定义如下。

# map article and summary len to dict as well as if sample is longer than 512 tokens
def map_to_length(x):
  x["article_len"] = len(tokenizer(x["article"]).input_ids)
  x["article_longer_512"] = int(x["article_len"] > 512)
  x["summary_len"] = len(tokenizer(x["highlights"]).input_ids)
  x["summary_longer_64"] = int(x["summary_len"] > 64)
  x["summary_longer_128"] = int(x["summary_len"] > 128)
  return x

查看前 10000 个样本就足够了。我们可以通过使用 num_proc=4 的多个进程来加速映射。

sample_size = 10000
data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)

计算出前 10000 个样本的长度后,我们现在应该将它们平均。为此,我们可以使用 .map() 函数,并设置 batched=Truebatch_size=-1,以便在 .map() 函数中访问所有 10000 个样本。

def compute_and_print_stats(x):
  if len(x["article_len"]) == sample_size:
    print(
        "Article Mean: {}, %-Articles > 512:{}, Summary Mean:{}, %-Summary > 64:{}, %-Summary > 128:{}".format(
            sum(x["article_len"]) / sample_size,
            sum(x["article_longer_512"]) / sample_size, 
            sum(x["summary_len"]) / sample_size,
            sum(x["summary_longer_64"]) / sample_size,
            sum(x["summary_longer_128"]) / sample_size,
        )
    )

output = data_stats.map(
  compute_and_print_stats, 
  batched=True,
  batch_size=-1,
)
    OUTPUT:
    -------
    Article Mean: 847.6216, %-Articles > 512:0.7355, Summary Mean:57.7742, %-Summary > 64:0.3185, %-Summary > 128:0.0

我们可以看到,一篇文章平均包含 848 个 token,其中约四分之三的文章长度超过了模型的 max_length 512。摘要平均长度为 57 个 token。我们 10000 个样本的摘要中有超过 30% 的长度超过 64 个 token,但没有一个长度超过 128 个 token。

bert-base-cased 限于 512 个 token,这意味着我们可能需要从文章中裁剪重要的信息。由于大部分重要信息通常出现在文章开头,并且我们希望计算效率高,因此本 Notebook 决定坚持使用 bert-base-cased,其 max_length 为 512。这个选择并非最优,但已在 CNN/Dailymail 上显示出良好效果。或者,可以使用长程序列模型(如 Longformer)作为编码器。

关于摘要长度,我们可以看到长度为 128 已经包含了所有摘要标签。128 很容易在 bert-base-cased 的限制范围内,因此我们决定将生成限制在 128。

我们再次使用 .map() 函数,这次是将每个训练批次转换为模型输入批次。

“article”和“highlights”被分词并分别准备为编码器的“input_ids”和解码器的“decoder_input_ids”。

“标签”会自动向左移动,用于语言建模训练。

最后,非常重要的是要记住忽略填充标签的损失。在 🤗Transformers 中,可以通过将标签设置为 -100 来完成此操作。好的,现在让我们写下映射函数。

encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

在本 Notebook 中,我们仅使用少量训练示例来训练和评估模型,并将 batch_size 设置为 4,以防止内存不足问题。

以下行将训练数据减少到仅前 32 个示例。该单元格可以注释掉或不运行,以进行完整的训练运行。使用 16 的 batch_size 获得了良好结果。

train_data = train_data.select(range(32))

好的,让我们准备训练数据。

# batch_size = 16
batch_size=4

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

查看处理后的训练数据集,我们可以看到列名 articlehighlightsid 已被 EncoderDecoderModel 所需的参数替换。

train_data
OUTPUT:
-------
Dataset(features: {'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}, num_rows: 32)

到目前为止,数据是使用 Python 的 List 格式进行操作的。让我们将数据转换为 PyTorch 张量,以便在 GPU 上进行训练。

train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"],
)

太棒了,训练数据的数据处理已经完成。类似地,我们可以对验证数据做同样的操作。

首先,我们加载 10% 的验证数据集

val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")

为了演示目的,验证数据随后减少到只有 8 个样本,

val_data = val_data.select(range(8))

应用映射函数,

val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

最后,验证数据也被转换为 PyTorch 张量。

val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"],
)

太好了!现在我们可以继续热启动 EncoderDecoderModel

热启动编码器-解码器模型

本节介绍如何使用 bert-base-cased 检查点热启动编码器-解码器模型。

让我们从导入 EncoderDecoderModel 开始。有关 EncoderDecoderModel 类的更详细信息,建议读者查阅文档

from transformers import EncoderDecoderModel

与 🤗Transformers 中的其他模型类不同,EncoderDecoderModel 类有两种加载预训练权重的方法,即

  1. “标准” .from_pretrained(...) 方法源自通用的 PretrainedModel.from_pretrained(...) 方法,因此与所有其他模型类完全相同。该函数需要一个模型标识符,例如 .from_pretrained("google/bert2bert_L-24_wmt_de_en"),并将单个 .pt 检查点文件加载到 EncoderDecoderModel 类中。

  2. 一个特殊的 .from_encoder_decoder_pretrained(...) 方法,可用于从两个模型标识符(一个用于编码器,一个用于解码器)热启动编码器-解码器模型。第一个模型标识符用于通过 AutoModel.from_pretrained(...) 加载编码器(参见文档此处),第二个模型标识符用于通过 AutoModelForCausalLM 加载解码器(参见文档此处)。

好的,让我们热启动我们的 BERT2BERT 模型。如前所述,我们将使用 "bert-base-cased" 检查点热启动编码器和解码器。

bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
OUTPUT:
-------
"""Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
    - This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
    - This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.2.crossattention.self.query.bias', 'bert.encoder.layer.2.crossattention.self.key.weight', 'bert.encoder.layer.2.crossattention.self.key.bias', 'bert.encoder.layer.2.crossattention.self.value.weight', 'bert.encoder.layer.2.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.output.dense.weight', 'bert.encoder.layer.2.crossattention.output.dense.bias', 'bert.encoder.layer.2.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.2.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.3.crossattention.self.query.weight', 'bert.encoder.layer.3.crossattention.self.query.bias', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.key.bias', 'bert.encoder.layer.3.crossattention.self.value.weight', 'bert.encoder.layer.3.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.output.dense.weight', 'bert.encoder.layer.3.crossattention.output.dense.bias', 'bert.encoder.layer.3.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.3.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.4.crossattention.self.query.weight', 'bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.4.crossattention.self.key.bias', 'bert.encoder.layer.4.crossattention.self.value.weight', 'bert.encoder.layer.4.crossattention.self.value.bias', 'bert.encoder.layer.4.crossattention.output.dense.weight', 'bert.encoder.layer.4.crossattention.output.dense.bias', 'bert.encoder.layer.4.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.4.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.5.crossattention.self.query.weight', 'bert.encoder.layer.5.crossattention.self.query.bias', 'bert.encoder.layer.5.crossattention.self.key.weight', 'bert.encoder.layer.5.crossattention.self.key.bias', 'bert.encoder.layer.5.crossattention.self.value.weight', 'bert.encoder.layer.5.crossattention.self.value.bias', 'bert.encoder.layer.5.crossattention.output.dense.weight', 'bert.encoder.layer.5.crossattention.output.dense.bias', 'bert.encoder.layer.5.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.5.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.6.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.query.bias', 'bert.encoder.layer.6.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.key.bias', 'bert.encoder.layer.6.crossattention.self.value.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.6.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.6.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.7.crossattention.self.query.weight', 'bert.encoder.layer.7.crossattention.self.query.bias', 'bert.encoder.layer.7.crossattention.self.key.weight', 'bert.encoder.layer.7.crossattention.self.key.bias', 'bert.encoder.layer.7.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.7.crossattention.output.dense.weight', 'bert.encoder.layer.7.crossattention.output.dense.bias', 'bert.encoder.layer.7.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.7.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.8.crossattention.self.query.weight', 'bert.encoder.layer.8.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.self.key.weight', 'bert.encoder.layer.8.crossattention.self.key.bias', 'bert.encoder.layer.8.crossattention.self.value.weight', 'bert.encoder.layer.8.crossattention.self.value.bias', 'bert.encoder.layer.8.crossattention.output.dense.weight', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.8.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.8.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.9.crossattention.self.query.weight', 'bert.encoder.layer.9.crossattention.self.query.bias', 'bert.encoder.layer.9.crossattention.self.key.weight', 'bert.encoder.layer.9.crossattention.self.key.bias', 'bert.encoder.layer.9.crossattention.self.value.weight', 'bert.encoder.layer.9.crossattention.self.value.bias', 'bert.encoder.layer.9.crossattention.output.dense.weight', 'bert.encoder.layer.9.crossattention.output.dense.bias', 'bert.encoder.layer.9.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.9.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.10.crossattention.self.query.weight', 'bert.encoder.layer.10.crossattention.self.query.bias', 'bert.encoder.layer.10.crossattention.self.key.weight', 'bert.encoder.layer.10.crossattention.self.key.bias', 'bert.encoder.layer.10.crossattention.self.value.weight', 'bert.encoder.layer.10.crossattention.self.value.bias', 'bert.encoder.layer.10.crossattention.output.dense.weight', 'bert.encoder.layer.10.crossattention.output.dense.bias', 'bert.encoder.layer.10.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.10.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.11.crossattention.self.query.weight', 'bert.encoder.layer.11.crossattention.self.query.bias', 'bert.encoder.layer.11.crossattention.self.key.weight', 'bert.encoder.layer.11.crossattention.self.key.bias', 'bert.encoder.layer.11.crossattention.self.value.weight', 'bert.encoder.layer.11.crossattention.self.value.bias', 'bert.encoder.layer.11.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.output.dense.bias', 'bert.encoder.layer.11.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.11.crossattention.output.LayerNorm.bias']"""
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."""

我们应该仔细查看这里的警告。我们可以看到两个与 "cls" 层对应的权重没有被使用。这不应该是一个问题,因为我们不需要 BERT 的 CLS 层用于序列到序列任务。此外,我们注意到许多权重是“新”或随机初始化的。仔细查看这些权重,它们都对应于交叉注意力层,这正是我们阅读了上述理论后所预期的。

让我们仔细看看模型。

bert2bert
OUTPUT:
-------
    EncoderDecoderModel(
      (encoder): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            ),
                        ...
                        ,
            (11): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (activation): Tanh()
        )
      )
      (decoder): BertLMHeadModel(
        (bert): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0): BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (crossattention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): BertIntermediate(
                  (dense): Linear(in_features=768, out_features=3072, bias=True)
                )
                (output): BertOutput(
                  (dense): Linear(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              ),
                            ...,
              (11): BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (crossattention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): BertIntermediate(
                  (dense): Linear(in_features=768, out_features=3072, bias=True)
                )
                (output): BertOutput(
                  (dense): Linear(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
        )
        (cls): BertOnlyMLMHead(
          (predictions): BertLMPredictionHead(
            (transform): BertPredictionHeadTransform(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            )
            (decoder): Linear(in_features=768, out_features=30522, bias=True)
          )
        )
      )
    )

我们看到 bert2bert.encoderBertModel 的实例,而 bert2bert.decoderBertLMHeadModel 的实例。然而,这两个实例现在被组合成一个单一的 torch.nn.Module,因此可以保存为一个单一的 .pt 检查点文件。

让我们尝试使用标准 .save_pretrained(...) 方法。

bert2bert.save_pretrained("bert2bert")

同样,模型可以使用标准的 .from_pretrained(...) 方法重新加载。

bert2bert = EncoderDecoderModel.from_pretrained("bert2bert")

太棒了。我们还要检查配置。

bert2bert.config
OUTPUT:
-------
    EncoderDecoderConfig {
      "_name_or_path": "bert2bert",
      "architectures": [
        "EncoderDecoderModel"
      ],
      "decoder": {
        "_name_or_path": "bert-base-uncased",
        "add_cross_attention": true,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "do_sample": false,
        "early_stopping": false,
        "eos_token_id": null,
        "finetuning_task": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": true,
        "is_encoder_decoder": false,
        "label2id": {
          "LABEL_0": 0,
          "LABEL_1": 1
        },
        "layer_norm_eps": 1e-12,
        "length_penalty": 1.0,
        "max_length": 20,
        "max_position_embeddings": 512,
        "min_length": 0,
        "model_type": "bert",
        "no_repeat_ngram_size": 0,
        "num_attention_heads": 12,
        "num_beams": 1,
        "num_hidden_layers": 12,
        "num_return_sequences": 1,
        "output_attentions": false,
        "output_hidden_states": false,
        "pad_token_id": 0,
        "prefix": null,
        "pruned_heads": {},
        "repetition_penalty": 1.0,
        "return_dict": false,
        "sep_token_id": null,
        "task_specific_params": null,
        "temperature": 1.0,
        "tie_encoder_decoder": false,
        "tie_word_embeddings": true,
        "tokenizer_class": null,
        "top_k": 50,
        "top_p": 1.0,
        "torchscript": false,
        "type_vocab_size": 2,
        "use_bfloat16": false,
        "use_cache": true,
        "vocab_size": 30522,
        "xla_device": null
      },
      "encoder": {
        "_name_or_path": "bert-base-uncased",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "do_sample": false,
        "early_stopping": false,
        "eos_token_id": null,
        "finetuning_task": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false,
        "is_encoder_decoder": false,
        "label2id": {
          "LABEL_0": 0,
          "LABEL_1": 1
        },
        "layer_norm_eps": 1e-12,
        "length_penalty": 1.0,
        "max_length": 20,
        "max_position_embeddings": 512,
        "min_length": 0,
        "model_type": "bert",
        "no_repeat_ngram_size": 0,
        "num_attention_heads": 12,
        "num_beams": 1,
        "num_hidden_layers": 12,
        "num_return_sequences": 1,
        "output_attentions": false,
        "output_hidden_states": false,
        "pad_token_id": 0,
        "prefix": null,
        "pruned_heads": {},
        "repetition_penalty": 1.0,
        "return_dict": false,
        "sep_token_id": null,
        "task_specific_params": null,
        "temperature": 1.0,
        "tie_encoder_decoder": false,
        "tie_word_embeddings": true,
        "tokenizer_class": null,
        "top_k": 50,
        "top_p": 1.0,
        "torchscript": false,
        "type_vocab_size": 2,
        "use_bfloat16": false,
        "use_cache": true,
        "vocab_size": 30522,
        "xla_device": null
      },
      "is_encoder_decoder": true,
      "model_type": "encoder_decoder"
    }

该配置同样由一个编码器配置和一个解码器配置组成,在我们的例子中,它们都是 BertConfig 的实例。然而,整个配置是 EncoderDecoderConfig 类型,因此它被保存为一个单一的 .json 文件。

总而言之,我们应该记住,一旦实例化了 EncoderDecoderModel 对象,它就提供了与 🤗Transformers 中任何其他编码器-解码器模型(例如 BARTT5ProphetNet 等)相同的功能。唯一的区别是 EncoderDecoderModel 提供了额外的 from_encoder_decoder_pretrained(...) 函数,允许模型类从任意两个编码器和解码器检查点进行热启动。

此外,如果想创建一个共享的编码器-解码器模型,可以额外传递参数 tie_encoder_decoder=True,如下所示

shared_bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased", tie_encoder_decoder=True)

作为比较,我们可以看到,正如预期的那样,共享模型的参数要少得多。

print(f"\n\nNum Params. Shared: {shared_bert2bert.num_parameters()}, Non-Shared: {bert2bert.num_parameters()}")
OUTPUT:
-------
Num Params. Shared: 137298244, Non-Shared: 247363386

在本 Notebook 中,我们仍将训练一个非共享的 Bert2Bert 模型,因此我们继续使用 bert2bert 而不是 shared_bert2bert

# free memory
del shared_bert2bert

我们已经热启动了一个 bert2bert 模型,但我们还没有定义所有与束搜索解码相关的参数。

让我们先设置特殊 token。bert-base-cased 没有 decoder_start_token_ideos_token_id,所以我们将分别使用它的 cls_token_idsep_token_id。此外,我们应该在配置上定义一个 pad_token_id,并确保设置了正确的 vocab_size

bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size

接下来,让我们定义所有与束搜索解码相关的参数。由于 bart-large-cnn 在 CNN/Dailymail 上表现良好,我们将直接复制其束搜索解码参数。

有关这些参数的更多详细信息,请参阅博客文章或文档

bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

好的,现在让我们开始微调热启动的 BERT2BERT 模型。

微调热启动编码器-解码器模型

本节将演示如何使用 Seq2SeqTrainer 微调热启动的编码器-解码器模型。

我们首先导入 Seq2SeqTrainer 及其训练参数 Seq2SeqTrainingArguments

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

此外,我们需要一些 Python 包来使 Seq2SeqTrainer 工作。

!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu

Seq2SeqTrainer 扩展了 🤗Transformer 的 Trainer,用于编码器-解码器模型。简而言之,它允许在评估期间使用 generate(...) 函数,这对于验证编码器-解码器模型在大多数序列到序列任务(如摘要)上的性能是必需的。

有关 Trainer 的更多信息,请阅读简短教程。

让我们从配置 Seq2SeqTrainingArguments 开始。

参数 predict_with_generate 应设置为 True,这样 Seq2SeqTrainer 就会在验证数据上运行 generate(...) 函数,并将生成的输出作为 predictions 传递给我们稍后将定义的 compute_metric(...) 函数。额外的参数派生自 TrainingArguments,可以在此处阅读。对于完整的训练运行,应根据需要更改这些参数。下面已注释掉了一些不错的默认值。

有关 Seq2SeqTrainer 的更多信息,建议读者查阅代码

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True, 
    output_dir="./",
    logging_steps=2,
    save_steps=10,
    eval_steps=4,
    # logging_steps=1000,
    # save_steps=500,
    # eval_steps=7500,
    # warmup_steps=2000,
    # save_total_limit=3,
)

此外,我们需要定义一个函数来正确计算验证期间的 ROUGE 分数。由于我们激活了 predict_with_generate,因此 compute_metrics(...) 函数需要使用 generate(...) 函数获得的 predictions。与大多数摘要任务一样,CNN/Dailymail 通常使用 ROUGE 分数进行评估。

首先,我们使用 🤗datasets 库加载 ROUGE 指标。

rouge = datasets.load_metric("rouge")

接下来,我们将定义 compute_metrics(...) 函数。rouge 指标从两个字符串列表计算分数。因此,我们解码 predictionslabels,确保 -100 被正确替换为 pad_token_id,并通过设置 skip_special_tokens=True 移除所有特殊字符。

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

太棒了,现在我们可以将所有参数传递给 Seq2SeqTrainer 并开始微调。执行以下单元格将需要大约 10 分钟 ☕。

在完整的 CNN/Dailymail 训练数据上微调 BERT2BERT 模型大约需要一台 TITAN RTX GPU 8 小时。

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)
trainer.train()

太棒了,我们现在应该完全能够微调热启动的编码器-解码器模型了。为了检查微调结果,让我们看看保存的检查点。

!ls
OUTPUT:
-------
    bert2bert      checkpoint-20  runs	   seq2seq_trainer.py
    checkpoint-10  __pycache__    sample_data  seq2seq_training_args.py

最后,我们可以像往常一样通过 EncoderDecoderModel.from_pretrained(...) 方法加载检查点。

dummy_bert2bert = EncoderDecoderModel.from_pretrained("./checkpoint-20")

评估

最后一步,我们可能希望在测试数据上评估 BERT2BERT 模型。

首先,我们不加载虚拟模型,而是加载一个在完整训练数据集上微调过的 BERT2BERT 模型。此外,我们加载它的分词器,它只是 bert-base-cased 分词器的一个副本。

from transformers import BertTokenizer

bert2bert = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail").to("cuda")
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

接下来,我们只加载 CNN/Dailymail 测试数据中的 2%。对于完整评估,显然应该使用 100% 的数据。

test_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="test[:2%]")

现在,我们可以再次利用 🤗dataset 的便捷 map() 函数为每个测试样本生成摘要。

对于每个数据样本,我们

  • 首先,对 "article" 进行分词,
  • 其次,生成输出 token IDs,
  • 第三,解码输出 token IDs 以获得我们预测的摘要。
def generate_summary(batch):
    # cut off at BERT max length 512
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    outputs = bert2bert.generate(input_ids, attention_mask=attention_mask)

    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred_summary"] = output_str

    return batch

让我们运行 map 函数来获取 results 字典,其中存储了模型的每个样本的预测摘要。执行以下单元格可能需要约 10 分钟 ☕。

batch_size = 16  # change to 64 for full evaluation

results = test_data.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])

最后,我们计算 ROUGE 分数。

rouge.compute(predictions=results["pred_summary"], references=results["highlights"], rouge_types=["rouge2"])["rouge2"].mid
OUTPUT:
-------
    Score(precision=0.10389454113300968, recall=0.1564771201053348, fmeasure=0.12175271663717585)

就是这样。我们已经展示了如何热启动 BERT2BERT 模型并在 CNN/Dailymail 数据集上进行微调/评估。

完整训练的 BERT2BERT 模型已上传至 🤗模型中心,地址为 patrickvonplaten/bert2bert_cnn_daily_mail

该模型在完整评估数据上取得了 18.22 的 ROUGE-2 分数,甚至比论文中报告的还要好一些。

有关一些摘要示例,建议读者使用模型的在线推理 API,此处

非常感谢 Google Research 的 Sascha Rothe、Shashi Narayan 和 Aliaksei Severyn,以及 🤗Hugging Face 的 Victor Sanh、Sylvain Gugger 和 Thomas Wolf 的校对和宝贵的反馈。

社区

注册登录以评论