Transformers 文档

FLAN-UL2

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

FLAN-UL2

概述

Flan-UL2 是基于 T5 架构的编码器-解码器模型。它使用与去年发布的 UL2 模型相同的配置。它使用“Flan”提示微调和数据集收集进行微调。与 Flan-T5 类似,可以直接使用 FLAN-UL2 权重,无需对模型进行微调。

根据原始博客,以下是值得注意的改进:

  • 原始 UL2 模型仅使用 512 的感受野进行训练,这使其不适合 N-shot 提示,其中 N 很大。
  • Flan-UL2 检查点使用 2048 的感受野,使其更适合少样本上下文学习。
  • 原始 UL2 模型还具有模式切换令牌,这对于获得良好的性能是强制性的。然而,它们有点麻烦,因为这通常需要在推理或微调期间进行一些更改。在此更新/更改中,我们在应用 Flan 指令微调之前,对 UL2 20B 进行了额外的 100k 步训练(小批次),以“忘记”模式令牌。此 Flan-UL2 检查点不再需要模式令牌。Google 发布了以下变体:

原始检查点可以在 这里 找到。

在资源有限的设备上运行

该模型非常庞大(半精度约 40GB),因此如果您只是想运行该模型,请确保以 8 位加载模型,并使用 device_map="auto" 以确保不会出现任何 OOM 问题!

>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-ul2", load_in_8bit=True, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")

>>> inputs = tokenizer("A step by step recipe to make bolognese pasta:", return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['In a large skillet, brown the ground beef and onion over medium heat. Add the garlic']

有关 API 参考、技巧、代码示例和笔记本,请参阅 T5 的文档页面

< > 更新 在 GitHub 上