提示微调
提示微调 将特定于任务的提示添加到输入中,并且这些提示参数独立于冻结的预训练模型参数进行更新。
论文摘要如下
在这项工作中,我们探索了“提示微调”(prompt tuning),这是一种简单而有效的机制,用于学习“软提示”(soft prompts)来调节冻结的语言模型,以执行特定的下游任务。与 GPT-3 使用的离散文本提示不同,软提示是通过反向传播学习的,并且可以进行微调以整合来自任意数量标记示例的信号。我们的端到端学习方法在很大程度上优于 GPT-3 的“少样本”(few-shot)学习。更值得注意的是,通过使用 T5 对模型大小进行消融研究,我们发现提示微调随着规模的扩大而变得更具竞争力:当模型参数超过数十亿时,我们的方法“缩小了差距”,并与模型微调(其中所有模型权重都经过微调)的强大性能相匹配。这一发现尤其重要,因为大型模型的共享和服务成本很高,而能够重用一个冻结的模型来执行多个下游任务可以减轻这种负担。我们的方法可以看作是对 Li 和 Liang (2021) 最近提出的“前缀微调”(prefix tuning)的简化,我们提供了与这种方法和其他类似方法的比较。最后,我们表明,与完整模型微调相比,用软提示调节冻结模型可以提高对域转移的鲁棒性。.
PromptTuningConfig
类 peft.PromptTuningConfig
< 源代码 >( peft_type: Union = None auto_mapping: Optional = None base_model_name_or_path: Optional = None revision: Optional = None task_type: Union = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: Optional = None num_attention_heads: Optional = None num_layers: Optional = None prompt_tuning_init: Union = <PromptTuningInit.RANDOM: 'RANDOM'> prompt_tuning_init_text: Optional = None tokenizer_name_or_path: Optional = None tokenizer_kwargs: Optional = None )
参数
- prompt_tuning_init (Union[
PromptTuningInit
,str
]) — 提示嵌入的初始化。 - prompt_tuning_init_text (
str
, 可选) — 用于初始化提示嵌入的文本。仅当prompt_tuning_init
为TEXT
时使用。 - tokenizer_name_or_path (
str
, 可选) — 分词器的名称或路径。仅当prompt_tuning_init
为TEXT
时使用。 - tokenizer_kwargs (
dict
, 可选) — 传递给AutoTokenizer.from_pretrained
的关键字参数。仅当prompt_tuning_init
为TEXT
时使用。
这是一个配置类,用于存储 PromptEmbedding 的配置。
PromptEmbedding
类 peft.PromptEmbedding
< 源代码 >( config word_embeddings )
参数
- config (PromptTuningConfig) — 提示嵌入的配置。
- word_embeddings (
torch.nn.Module
) — 基础Transformer模型的词嵌入。
将虚拟token编码为提示嵌入的模型。
属性:
- embedding (
torch.nn.Embedding
) — 提示嵌入的嵌入层。
示例
>>> from peft import PromptEmbedding, PromptTuningConfig
>>> config = PromptTuningConfig(
... peft_type="PROMPT_TUNING",
... task_type="SEQ_2_SEQ_LM",
... num_virtual_tokens=20,
... token_dim=768,
... num_transformer_submodules=1,
... num_attention_heads=12,
... num_layers=12,
... prompt_tuning_init="TEXT",
... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
... tokenizer_name_or_path="t5-base",
... )
>>> # t5_model.shared is the word embeddings of the base model
>>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
输入形状: (batch_size
, total_virtual_tokens
)
输出形状: (batch_size
, total_virtual_tokens
, token_dim
)