PEFT 文档

提示微调

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

提示微调

提示微调 将特定于任务的提示添加到输入中,并且这些提示参数独立于冻结的预训练模型参数进行更新。

论文摘要如下

在这项工作中,我们探索了“提示微调”(prompt tuning),这是一种简单而有效的机制,用于学习“软提示”(soft prompts)来调节冻结的语言模型,以执行特定的下游任务。与 GPT-3 使用的离散文本提示不同,软提示是通过反向传播学习的,并且可以进行微调以整合来自任意数量标记示例的信号。我们的端到端学习方法在很大程度上优于 GPT-3 的“少样本”(few-shot)学习。更值得注意的是,通过使用 T5 对模型大小进行消融研究,我们发现提示微调随着规模的扩大而变得更具竞争力:当模型参数超过数十亿时,我们的方法“缩小了差距”,并与模型微调(其中所有模型权重都经过微调)的强大性能相匹配。这一发现尤其重要,因为大型模型的共享和服务成本很高,而能够重用一个冻结的模型来执行多个下游任务可以减轻这种负担。我们的方法可以看作是对 Li 和 Liang (2021) 最近提出的“前缀微调”(prefix tuning)的简化,我们提供了与这种方法和其他类似方法的比较。最后,我们表明,与完整模型微调相比,用软提示调节冻结模型可以提高对域转移的鲁棒性。.

PromptTuningConfig

peft.PromptTuningConfig

< >

( peft_type: Union = None auto_mapping: Optional = None base_model_name_or_path: Optional = None revision: Optional = None task_type: Union = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: Optional = None num_attention_heads: Optional = None num_layers: Optional = None prompt_tuning_init: Union = <PromptTuningInit.RANDOM: 'RANDOM'> prompt_tuning_init_text: Optional = None tokenizer_name_or_path: Optional = None tokenizer_kwargs: Optional = None )

参数

  • prompt_tuning_init (Union[PromptTuningInit, str]) — 提示嵌入的初始化。
  • prompt_tuning_init_text (str, 可选) — 用于初始化提示嵌入的文本。仅当 prompt_tuning_initTEXT 时使用。
  • tokenizer_name_or_path (str, 可选) — 分词器的名称或路径。仅当 prompt_tuning_initTEXT 时使用。
  • tokenizer_kwargs (dict, 可选) — 传递给 AutoTokenizer.from_pretrained 的关键字参数。仅当 prompt_tuning_initTEXT 时使用。

这是一个配置类,用于存储 PromptEmbedding 的配置。

PromptEmbedding

peft.PromptEmbedding

< >

( config word_embeddings )

参数

  • config (PromptTuningConfig) — 提示嵌入的配置。
  • word_embeddings (torch.nn.Module) — 基础Transformer模型的词嵌入。

将虚拟token编码为提示嵌入的模型。

属性:

  • embedding (torch.nn.Embedding) — 提示嵌入的嵌入层。

示例

>>> from peft import PromptEmbedding, PromptTuningConfig

>>> config = PromptTuningConfig(
...     peft_type="PROMPT_TUNING",
...     task_type="SEQ_2_SEQ_LM",
...     num_virtual_tokens=20,
...     token_dim=768,
...     num_transformer_submodules=1,
...     num_attention_heads=12,
...     num_layers=12,
...     prompt_tuning_init="TEXT",
...     prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
...     tokenizer_name_or_path="t5-base",
... )

>>> # t5_model.shared is the word embeddings of the base model
>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)

输入形状: (batch_size, total_virtual_tokens)

输出形状: (batch_size, total_virtual_tokens, token_dim)

< > 在GitHub上更新