PEFT 文档
自定义模型
并获得增强的文档体验
开始使用
自定义模型
一些微调技术,如提示调优,是针对语言模型的。这意味着在 🤗 PEFT 中,假定使用的是 🤗 Transformers 模型。然而,其他微调技术——如 LoRA——并不局限于特定的模型类型。
在本指南中,我们将看到如何将 LoRA 应用于多层感知机、来自 timm 库的计算机视觉模型或新的 🤗 Transformers 架构。
多层感知机
假设我们想要用 LoRA 微调一个多层感知机。下面是定义:
from torch import nn
class MLP(nn.Module):
def __init__(self, num_units_hidden=2000):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(20, num_units_hidden),
nn.ReLU(),
nn.Linear(num_units_hidden, num_units_hidden),
nn.ReLU(),
nn.Linear(num_units_hidden, 2),
nn.LogSoftmax(dim=-1),
)
def forward(self, X):
return self.seq(X)
这是一个简单的多层感知机,包含一个输入层、一个隐藏层和一个输出层。
在这个玩具示例中,我们选择了一个非常大的隐藏单元数量来突出 PEFT 带来的效率提升,但这些提升与更现实的示例是一致的。
在这个模型中有几个线性层可以用 LoRA 进行调优。当使用常见的 🤗 Transformers 模型时,PEFT 会知道要对哪些层应用 LoRA,但在这种情况下,由我们用户来选择层。要确定要调优的层名称:
print([(n, type(m)) for n, m in MLP().named_modules()])
这应该会打印出:
[('', __main__.MLP),
('seq', torch.nn.modules.container.Sequential),
('seq.0', torch.nn.modules.linear.Linear),
('seq.1', torch.nn.modules.activation.ReLU),
('seq.2', torch.nn.modules.linear.Linear),
('seq.3', torch.nn.modules.activation.ReLU),
('seq.4', torch.nn.modules.linear.Linear),
('seq.5', torch.nn.modules.activation.LogSoftmax)]
假设我们想对输入层和隐藏层应用 LoRA,它们是 'seq.0'
和 'seq.2'
。此外,假设我们想在不使用 LoRA 的情况下更新输出层,即 'seq.4'
。相应的配置将是:
from peft import LoraConfig
config = LoraConfig(
target_modules=["seq.0", "seq.2"],
modules_to_save=["seq.4"],
)
有了这个配置,我们就可以创建我们的 PEFT 模型并检查训练参数的比例:
from peft import get_peft_model
model = MLP()
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
# prints trainable params: 56,164 || all params: 4,100,164 || trainable%: 1.369798866581922
最后,我们可以使用任何我们喜欢的训练框架,或者编写我们自己的拟合循环来训练 peft_model
。
有关完整示例,请查看此笔记本。
timm 模型
timm 库包含了大量预训练的计算机视觉模型。这些模型也可以使用 PEFT 进行微调。让我们看看这在实践中是如何工作的。
首先,确保在 Python 环境中安装了 timm:
python -m pip install -U timm
接下来,我们加载一个用于图像分类任务的 timm 模型:
import timm
num_classes = ...
model_id = "timm/poolformer_m36.sail_in1k"
model = timm.create_model(model_id, pretrained=True, num_classes=num_classes)
同样,我们需要决定在哪些层上应用 LoRA。由于 LoRA 支持 2D 卷积层,并且这些层是该模型的主要构建块,我们应该在 2D 卷积层上应用 LoRA。为了识别这些层的名称,让我们查看所有层的名称:
print([(n, type(m)) for n, m in model.named_modules()])
这将打印一个非常长的列表,我们只显示前几个:
[('', timm.models.metaformer.MetaFormer),
('stem', timm.models.metaformer.Stem),
('stem.conv', torch.nn.modules.conv.Conv2d),
('stem.norm', torch.nn.modules.linear.Identity),
('stages', torch.nn.modules.container.Sequential),
('stages.0', timm.models.metaformer.MetaFormerStage),
('stages.0.downsample', torch.nn.modules.linear.Identity),
('stages.0.blocks', torch.nn.modules.container.Sequential),
('stages.0.blocks.0', timm.models.metaformer.MetaFormerBlock),
('stages.0.blocks.0.norm1', timm.layers.norm.GroupNorm1),
('stages.0.blocks.0.token_mixer', timm.models.metaformer.Pooling),
('stages.0.blocks.0.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d),
('stages.0.blocks.0.drop_path1', torch.nn.modules.linear.Identity),
('stages.0.blocks.0.layer_scale1', timm.models.metaformer.Scale),
('stages.0.blocks.0.res_scale1', torch.nn.modules.linear.Identity),
('stages.0.blocks.0.norm2', timm.layers.norm.GroupNorm1),
('stages.0.blocks.0.mlp', timm.layers.mlp.Mlp),
('stages.0.blocks.0.mlp.fc1', torch.nn.modules.conv.Conv2d),
('stages.0.blocks.0.mlp.act', torch.nn.modules.activation.GELU),
('stages.0.blocks.0.mlp.drop1', torch.nn.modules.dropout.Dropout),
('stages.0.blocks.0.mlp.norm', torch.nn.modules.linear.Identity),
('stages.0.blocks.0.mlp.fc2', torch.nn.modules.conv.Conv2d),
('stages.0.blocks.0.mlp.drop2', torch.nn.modules.dropout.Dropout),
('stages.0.blocks.0.drop_path2', torch.nn.modules.linear.Identity),
('stages.0.blocks.0.layer_scale2', timm.models.metaformer.Scale),
('stages.0.blocks.0.res_scale2', torch.nn.modules.linear.Identity),
('stages.0.blocks.1', timm.models.metaformer.MetaFormerBlock),
('stages.0.blocks.1.norm1', timm.layers.norm.GroupNorm1),
('stages.0.blocks.1.token_mixer', timm.models.metaformer.Pooling),
('stages.0.blocks.1.token_mixer.pool', torch.nn.modules.pooling.AvgPool2d),
...
('head.global_pool.flatten', torch.nn.modules.linear.Identity),
('head.norm', timm.layers.norm.LayerNorm2d),
('head.flatten', torch.nn.modules.flatten.Flatten),
('head.drop', torch.nn.modules.linear.Identity),
('head.fc', torch.nn.modules.linear.Linear)]
]
仔细观察后,我们发现 2D 卷积层的名称类似于 "stages.0.blocks.0.mlp.fc1"
和 "stages.0.blocks.0.mlp.fc2"
。我们如何专门匹配这些层名呢?你可以编写正则表达式来匹配层名。对于我们的情况,正则表达式 r".*\.mlp\.fc\d"
应该可以完成任务。
此外,与第一个示例一样,我们应该确保输出层,在这里是分类头,也得到更新。查看上面打印列表的末尾,我们可以看到它的名称是 'head.fc'
。考虑到这一点,这是我们的 LoRA 配置:
config = LoraConfig(target_modules=r".*\.mlp\.fc\d", modules_to_save=["head.fc"])
然后我们只需将我们的基础模型和配置传递给 get_peft_model
来创建 PEFT 模型:
peft_model = get_peft_model(model, config)
peft_model.print_trainable_parameters()
# prints trainable params: 1,064,454 || all params: 56,467,974 || trainable%: 1.88505789139876
这表明我们只需要训练不到 2% 的参数,这是一个巨大的效率提升。
有关完整示例,请查看此笔记本。
新的 transformers 架构
当新的流行 transformers 架构发布时,我们会尽力将其快速添加到 PEFT 中。如果你遇到一个开箱即用不支持的 transformers 模型,别担心,如果配置设置正确,它很可能仍然可以工作。具体来说,你必须识别应该被适配的层,并在初始化相应的配置类(例如 LoraConfig
)时正确设置它们。以下是一些有助于此的提示。
作为第一步,查看现有模型以获取灵感是一个好主意。你可以在 PEFT 仓库的 constants.py 文件中找到它们。通常,你会找到一个使用相同名称的类似架构。例如,如果新模型架构是“mistral”模型的变体,并且你想应用 LoRA,你可以看到 TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
中“mistral”的条目包含 ["q_proj", "v_proj"]
。这告诉你对于“mistral”模型,LoRA 的 target_modules
应该是 ["q_proj", "v_proj"]
。
from peft import LoraConfig, get_peft_model
my_mistral_model = ...
config = LoraConfig(
target_modules=["q_proj", "v_proj"],
..., # other LoRA arguments
)
peft_model = get_peft_model(my_mistral_model, config)
如果这没有帮助,请使用 named_modules
方法检查你的模型架构中现有的模块,并尝试识别注意力层,特别是键、查询和值层。这些层通常会有名为 c_attn
、query
、q_proj
等。键层并不总是被适配,理想情况下,你应该检查包含它是否能带来更好的性能。
此外,线性层是常见的适配目标(例如,在 QLoRA 论文中,作者建议也适配它们)。它们的名称通常会包含字符串 fc
或 dense
。
如果你想向 PEFT 添加一个新模型,请在 constants.py 中创建一个条目,并在 仓库 上发起一个拉取请求。别忘了同时更新 README 文件。
验证参数和层
你可以通过几种方式来验证你是否已正确地将 PEFT 方法应用于你的模型。
- 使用 print_trainable_parameters() 方法检查可训练参数的比例。如果这个数字低于或高于预期,请通过打印模型来检查模型的
repr
。这会显示模型中所有层类型的名称。确保只有预期的目标层被适配器层替换。例如,如果将 LoRA 应用于nn.Linear
层,那么你应该只看到lora.Linear
层被使用。
peft_model.print_trainable_parameters()
- 另一种查看已适配层的方法是使用
targeted_module_names
属性来列出每个被适配模块的名称。
print(peft_model.targeted_module_names)
不支持的模块类型
像 LoRA 这样的方法只有在目标模块被 PEFT 支持时才有效。例如,可以将 LoRA 应用于 nn.Linear
和 nn.Conv2d
层,但不能应用于 nn.LSTM
。如果你发现想要应用 PEFT 的层类不受支持,你可以:
- 定义一个自定义映射,以在 LoRA 中动态分派自定义模块
- 提交一个 issue 并请求该功能,维护者将会实现它,或者如果对此模块类型的需求足够高,他们会指导你如何自己实现
实验性支持 LoRA 中自定义模块的动态分派
此功能是实验性的,可能会根据社区的反馈而改变。如果有显著的需求,我们将引入一个公开且稳定的 API。
PEFT 为 LoRA 的自定义模块类型提供了一个实验性 API。假设你有一个用于 LSTMs 的 LoRA 实现。通常情况下,你无法告诉 PEFT 使用它,即使它理论上可以与 PEFT 一起工作。然而,通过自定义层的动态分派,这是可能的。
实验性 API 目前如下所示:
class MyLoraLSTMLayer:
...
base_model = ... # load the base model that uses LSTMs
# add the LSTM layer names to target_modules
config = LoraConfig(..., target_modules=["lstm"])
# define a mapping from base layer type to LoRA layer type
custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer}
# register the new mapping
config._register_custom_module(custom_module_mapping)
# after registration, create the PEFT model
peft_model = get_peft_model(base_model, config)
# do training
当你调用 get_peft_model() 时,你会看到一个警告,因为 PEFT 无法识别目标模块类型。在这种情况下,你可以忽略这个警告。
通过提供自定义映射,PEFT 首先检查基础模型的层与自定义映射,如果匹配,则分派到自定义 LoRA 层类型。如果没有匹配,PEFT 会检查内置的 LoRA 层类型以寻找匹配项。
因此,此功能也可用于覆盖现有的分派逻辑,例如,如果你想为 nn.Linear
使用自己的 LoRA 层而不是 PEFT 提供的层。
创建自定义 LoRA 模块时,请遵循与现有 LoRA 模块相同的规则。需要考虑的一些重要约束:
- 自定义模块应继承自
nn.Module
和peft.tuners.lora.layer.LoraLayer
。 - 自定义模块的
__init__
方法应具有位置参数base_layer
和adapter_name
。之后,你可以自由使用或忽略额外的**kwargs
。 - 可学习参数应存储在
nn.ModuleDict
或nn.ParameterDict
中,其中键对应于特定适配器的名称(请记住,一个模型可以同时有多个适配器)。 - 这些可学习参数属性的名称应以
"lora_"
开头,例如self.lora_new_param = ...
。 - 一些方法是可选的,例如,只有在你想要支持权重合并时才需要实现
merge
和unmerge
。
目前,保存模型时不会保留有关自定义模块的信息。加载模型时,你必须再次注册自定义模块。
# saving works as always and includes the parameters of the custom modules
peft_model.save_pretrained(<model-path>)
# loading the model later:
base_model = ...
# load the LoRA config that you saved earlier
config = LoraConfig.from_pretrained(<model-path>)
# register the custom module again, the same way as the first time
custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer}
config._register_custom_module(custom_module_mapping)
# pass the config instance to from_pretrained:
peft_model = PeftModel.from_pretrained(model, tmp_path / "lora-custom-module", config=config)
如果你使用此功能并觉得有用,或者遇到问题,请通过在 GitHub 上创建 issue 或讨论来告知我们。这使我们能够评估对此功能的需求,并在需求足够高时添加公共 API。
< > 在 GitHub 上更新