PEFT 文档
上下文感知 Prompt Tuning:通过对抗方法改进 In-Context Learning
并获得增强的文档体验
开始使用
上下文感知 Prompt Tuning:通过对抗方法改进 In-Context Learning
CPT 结合了 In-Context Learning (ICL)、Prompt Tuning (PT) 和对抗优化,通过改进上下文嵌入来提升少样本学习。CPT 通过优化上下文和训练样本来更新上下文 tokens,并将它们封装到一个新颖的损失设计中,从而最大限度地减少过拟合,实现更有效的优化,并显著提升分类任务的性能。
论文摘要如下:
大型语言模型(LLM)可以使用基于优化的方法或 In-Context Learning (ICL) 来执行少样本学习。基于优化的方法通常容易过拟合,因为它们需要使用有限的数据更新大量参数。相比之下,ICL 避免了过拟合,但与基于优化的方法相比,性能通常较差,并且对演示示例的选择、顺序和格式高度敏感。为了克服这些挑战,我们引入了上下文感知 Prompt Tuning (CPT),这是一种受 ICL、Prompt Tuning (PT) 和对抗攻击启发的方法。CPT 基于 ICL 策略,即在输入之前连接示例,并通过结合类似 PT 的学习来扩展它,通过迭代优化来改进上下文嵌入,从训练示例中提取更深入的见解。我们的方法仔细修改特定的上下文 tokens,同时考虑上下文中示例的独特结构。除了使用类似 PT 的优化更新上下文外,CPT 还从对抗攻击中汲取灵感,根据上下文中存在的标签调整输入,同时保留用户提供数据的内在价值。为了确保优化过程中的鲁棒性和稳定性,我们采用了一种投影梯度下降算法,约束 token 嵌入保持接近其原始值,并保证上下文的质量。我们的方法在使用各种 LLM 模型的多个分类任务中展示了卓越的准确性,优于现有的基线,并有效地解决了少样本学习中的过拟合挑战。
请查看示例,获取关于如何使用 CPT 训练模型的逐步指南。
CPTConfig
class peft.CPTConfig
< source >( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None cpt_token_ids: typing.Optional[list[int]] = None cpt_mask: typing.Optional[list[int]] = None cpt_tokens_type_mask: typing.Optional[list[int]] = None opt_weighted_loss_type: typing.Optional[typing.Literal['none', 'decay']] = 'none' opt_loss_decay_factor: typing.Optional[float] = 1.0 opt_projection_epsilon: typing.Optional[float] = 0.1 opt_projection_format_epsilon: typing.Optional[float] = 0.1 tokenizer_name_or_path: typing.Optional[str] = None )
CPT 配置类,扩展了 PeftConfig,用于上下文感知 Prompt Tuning (CPT)。
此类引入了 CPT 所需的其他参数,例如
- Token 类型掩码
- Prompt tuning 初始化
- 损失权重
- 投影设置
有关更多详细信息,请参阅论文:https://arxiv.org/abs/2410.17222
CPTEmbedding
CPTEmbedding 是一个自定义嵌入层,专为 PEFT 中的上下文感知 Prompt Tuning (CPT) 而设计。它初始化嵌入,应用特定于 prompt 的投影,并使用标签掩码计算损失。
forward
< source > ( indices ) → torch.Tensor
Computes the prompt embeddings and applies delta adjustments.
Applies epsilon-based projection to the delta embeddings to control their norm.
Sets up a backward hook to selectively update token gradients based on the CPT token type mask.