Transformers 文档

自定义模型组件

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

自定义模型组件

自定义模型的另一种方法是修改其组件,而不是完全编写新模型,从而允许您根据特定用例定制模型。例如,您可以添加新层或优化架构的注意力机制。自定义直接应用于 Transformers 模型,以便您可以继续使用诸如 TrainerPreTrainedModelPEFT 库等功能。

本指南将向您展示如何自定义模型的注意力机制,以便应用 低秩自适应 (LoRA)

当您迭代地修改和开发模型代码时,clear_import_cache 实用程序非常有用。它会删除所有缓存的 Transformers 模块,并允许 Python 重新加载修改后的代码,而无需不断重启您的环境。

from transformers import AutoModel
from transformers.utils.import_utils import clear_import_cache

model = AutoModel.from_pretrained("bert-base-uncased")
# modifications to model code
# clear cache to reload modified code
clear_import_cache()
# re-import to use updated code
model = AutoModel.from_pretrained("bert-base-uncased")

注意力类

Segment Anything 是一个图像分割模型,它在其注意力机制中结合了查询-键-值 (qkv) 投影。为了减少可训练参数的数量和计算开销,您可以将 LoRA 应用于 qkv 投影。这需要拆分 qkv 投影,以便您可以分别使用 LoRA 定位 qv

  1. 通过子类化原始 SamVisionAttention 类来创建一个自定义注意力类 SamVisionAttentionSplit。在 __init__ 中,删除组合的 qkv,并为 qkv 创建单独的线性层。
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        # remove combined qkv
        del self.qkv
        # separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
  1. _split_qkv_load_hook 函数在加载模型时将预训练的 qkv 权重拆分为单独的 qkv 权重,以确保与任何预训练模型兼容。
    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # mark the old qkv key for deletion
                keys_to_delete.append(key)
        
        # remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]
  1. forward 传递中,qkv 是单独计算的,而注意力机制的其余部分保持不变。
    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

将自定义 SamVisionAttentionSplit 类分配给原始模型的 SamVisionAttention 模块以替换它。模型中所有 SamVisionAttention 的实例都将替换为拆分注意力版本。

使用 from_pretrained() 加载模型。

from transformers import SamModel
from transformers.models.sam import modeling_sam

# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit

# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")

LoRA

使用单独的 qkv 投影,将 LoRA 应用于 qv

创建一个 LoraConfig 并指定秩 rlora_alphalora_dropouttask_type,以及最重要的,要定位的模块。

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    # apply LoRA to q and v
    target_modules=["q", "v"],
    lora_dropout=0.1,
    task_type="mask-generation"
)

将模型和 LoraConfig 传递给 get_peft_model 以将 LoRA 应用于模型。

model = get_peft_model(model, config)

调用 print_trainable_parameters 以查看您正在训练的参数数量与参数总数。

model.print_trainable_parameters()
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
< > 在 GitHub 上更新