PEFT 文档

上下文感知 Prompt Tuning:通过对抗方法改进 In-Context Learning

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

上下文感知 Prompt Tuning:通过对抗方法改进 In-Context Learning

CPT 结合了 In-Context Learning (ICL)、Prompt Tuning (PT) 和对抗优化,通过改进上下文嵌入来提升少样本学习。CPT 通过优化上下文和训练样本来更新上下文 tokens,并将它们封装到一个新颖的损失设计中,从而最大限度地减少过拟合,实现更有效的优化,并显著提升分类任务的性能。

论文摘要如下:

大型语言模型(LLM)可以使用基于优化的方法或 In-Context Learning (ICL) 来执行少样本学习。基于优化的方法通常容易过拟合,因为它们需要使用有限的数据更新大量参数。相比之下,ICL 避免了过拟合,但与基于优化的方法相比,性能通常较差,并且对演示示例的选择、顺序和格式高度敏感。为了克服这些挑战,我们引入了上下文感知 Prompt Tuning (CPT),这是一种受 ICL、Prompt Tuning (PT) 和对抗攻击启发的​​方法。CPT 基于 ICL 策略,即在输入之前连接示例,并通过结合类似 PT 的学习来扩展它,通过迭代优化来改进上下文嵌入,从训练示例中提取更深入的见解。我们的方法仔细修改特定的上下文 tokens,同时考虑上下文中示例的独特结构。除了使用类似 PT 的优化更新上下文外,CPT 还从对抗攻击中汲取灵感,根据上下文中存在的标签调整输入,同时保留用户提供数据的内在价值。为了确保优化过程中的鲁棒性和稳定性,我们采用了一种投影梯度下降算法,约束 token 嵌入保持接近其原始值,并保证上下文的质量。我们的方法在使用各种 LLM 模型的多个分类任务中展示了卓越的准确性,优于现有的基线,并有效地解决了少样本学习中的过拟合挑战。

请查看示例,获取关于如何使用 CPT 训练模型的逐步指南。

CPTConfig

class peft.CPTConfig

< >

( task_type: typing.Union[str, peft.utils.peft_types.TaskType, NoneType] = None peft_type: typing.Union[str, peft.utils.peft_types.PeftType, NoneType] = None auto_mapping: typing.Optional[dict] = None base_model_name_or_path: typing.Optional[str] = None revision: typing.Optional[str] = None inference_mode: bool = False num_virtual_tokens: int = None token_dim: int = None num_transformer_submodules: typing.Optional[int] = None num_attention_heads: typing.Optional[int] = None num_layers: typing.Optional[int] = None cpt_token_ids: typing.Optional[list[int]] = None cpt_mask: typing.Optional[list[int]] = None cpt_tokens_type_mask: typing.Optional[list[int]] = None opt_weighted_loss_type: typing.Optional[typing.Literal['none', 'decay']] = 'none' opt_loss_decay_factor: typing.Optional[float] = 1.0 opt_projection_epsilon: typing.Optional[float] = 0.1 opt_projection_format_epsilon: typing.Optional[float] = 0.1 tokenizer_name_or_path: typing.Optional[str] = None )

CPT 配置类,扩展了 PeftConfig,用于上下文感知 Prompt Tuning (CPT)。

此类引入了 CPT 所需的其他参数,例如

  • Token 类型掩码
  • Prompt tuning 初始化
  • 损失权重
  • 投影设置

有关更多详细信息,请参阅论文:https://arxiv.org/abs/2410.17222

CPTEmbedding

class peft.CPTEmbedding

< >

( config word_embeddings )

CPTEmbedding 是一个自定义嵌入层,专为 PEFT 中的上下文感知 Prompt Tuning (CPT) 而设计。它初始化嵌入,应用特定于 prompt 的投影,并使用标签掩码计算损失。

calculate_loss

  • base_model_output (ModelOutput) — Output from the base model containing logits.
  • labels (torch.Tensor) — Ground-truth labels for the input tokens.
  • cpt_type_mask (torch.Tensor) — Token type mask used for filtering valid loss terms.
  • config (Namespace) — Configuration object containing loss-related hyperparameters.

Returns

ModelOutput

The base model output with computed loss.

Computes the loss for CPT models with optional exponential decay.

forward

< >

( indices ) torch.Tensor

Parameters

  • indices (torch.Tensor) — Indices of the tokens to be embedded.

Returns

torch.Tensor

Sum of prompt embeddings and delta embeddings.

Computes the prompt embeddings and applies delta adjustments.

get_projection

< >

( )

Applies epsilon-based projection to the delta embeddings to control their norm.

set_updated_tokens

< >

( )

Sets up a backward hook to selectively update token gradients based on the CPT token type mask.

< > Update on GitHub