PEFT 文档
P-tuning
并获得增强的文档体验
开始使用
P-tuning
P-tuning 将可训练的 prompt 嵌入添加到输入中,这些嵌入由 prompt 编码器优化,以找到更好的 prompt,从而消除了手动设计 prompt 的需要。prompt 标记可以添加到输入序列中的任何位置,并且 p-tuning 还引入了锚标记以提高性能。
论文摘要如下
虽然使用传统微调的 GPT 在自然语言理解 (NLU) 方面未能取得出色的结果,但我们表明,通过一种新颖的方法 P-tuning(它采用可训练的连续 prompt 嵌入),GPT 在 NLU 任务上可以优于或与类似大小的 BERT 相媲美。在知识探测 (LAMA) 基准测试中,最佳 GPT 在测试时无需提供任何额外文本的情况下,恢复了 64\% (P@1) 的世界知识,这大大提高了之前最佳水平 20 多个百分点。在 SuperGlue 基准测试中,GPT 在监督学习中实现了与类似大小的 BERT 相当甚至更好的性能。重要的是,我们发现 P-tuning 还提高了 BERT 在少样本和监督设置中的性能,同时大大减少了 prompt 工程的需求。因此,P-tuning 在少样本 SuperGlue 基准测试中优于最先进的方法。.
PromptEncoderConfig
class peft.PromptEncoderConfig
< source >( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None encoder_reparameterization_type: typing.Union[str, peft.tuners.p_tuning.config.PromptEncoderReparameterizationType] = <PromptEncoderReparameterizationType.MLP: 'MLP'> encoder_hidden_size: int = None encoder_num_layers: int = 2 encoder_dropout: float = 0.0 )
这是用于存储 PromptEncoder 配置的配置类。
PromptEncoder
prompt 编码器网络,用于生成 p-tuning 的虚拟标记嵌入。
示例
>>> from peft import PromptEncoder, PromptEncoderConfig
>>> config = PromptEncoderConfig(
... peft_type="P_TUNING",
... task_type="SEQ_2_SEQ_LM",
... num_virtual_tokens=20,
... token_dim=768,
... num_transformer_submodules=1,
... num_attention_heads=12,
... num_layers=12,
... encoder_reparameterization_type="MLP",
... encoder_hidden_size=768,
... )
>>> prompt_encoder = PromptEncoder(config)
属性:
- embedding (
torch.nn.Embedding
) — prompt 编码器的嵌入层。 - mlp_head (
torch.nn.Sequential
) — 如果inference_mode=False
,则为 prompt 编码器的 MLP 头。 - lstm_head (
torch.nn.LSTM
) — 如果inference_mode=False
且encoder_reparameterization_type="LSTM"
,则为 prompt 编码器的 LSTM 头。 - token_dim (
int
) — 基础 Transformer 模型的隐藏嵌入维度。 - input_size (
int
) — prompt 编码器的输入大小。 - output_size (
int
) — prompt 编码器的输出大小。 - hidden_size (
int
) — prompt 编码器的隐藏大小。 - total_virtual_tokens (
int
):prompt 编码器的虚拟标记总数。 - encoder_type (Union[
PromptEncoderReparameterizationType
,str
]):prompt 编码器的编码器类型。
输入形状:(batch_size
, total_virtual_tokens
)
输出形状:(batch_size
, total_virtual_tokens
, token_dim
)