Bitsandbytes 文档
8 位优化器
并获得增强的文档体验
开始使用
8 位优化器
使用 8 位优化器,微调大型模型可以减少 75% 的 GPU 内存使用,且与使用标准 32 位优化器训练相比不会损失任何精度。减少内存需求意味着 8 位优化器比标准优化器快 4 倍,并且无需进行超参数调整。
本指南将向您展示如何使用 8 位优化器。
8 位优化器可在多种任务中减少内存使用并加速优化。然而,由于 8 位优化器减少的内存与参数数量成正比,因此对于使用大量激活内存的模型(如卷积网络)来说,它们并不能带来真正的益处。8 位优化器对于在内存极其受限的 GPU 上训练或微调具有大量参数的模型最为有利。
8 位优化器是常规优化器的直接替代品,这意味着它们也接受与常规优化器相同的参数。对于 NLP 模型,建议使用 StableEmbedding 类来提高稳定性和结果。
import bitsandbytes as bnb
- adam = torch.optim.Adam(...)
+ adam = bnb.optim.Adam8bit(...)
# recommended for NLP models
- before: torch.nn.Embedding(...)
+ bnb.nn.StableEmbedding(...)
默认情况下,即使您使用 8 位优化器初始化参数,所有元素数量少于 4096 的参数张量仍将保持 32 位。这样做是因为小张量节省的内存不多,并且通常包含高度可变的参数(偏置)或需要高精度的参数(批量归一化、层归一化)。
您可以使用 min_8bit_size
参数更改此值。例如,如果您希望仅在最小大小为 16384 个值时才将参数优化为 8 位(建议使用 4096 的倍数)
import bitsandbytes as bnb
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
您可以配置的其他参数包括学习率 (lr
)、衰减率 (betas
)、优化器状态的位数 (optim_bits
) 和百分位裁剪 (percentile_clipping
),后者可以增加稳定性。例如,要初始化一个带有第 5 百分位裁剪的 32 位 Adam 优化器
import bitsandbytes as bnb
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5)
优化不稳定参数
要使用 32 位 Adam 优化一些不稳定参数,同时使用 8 位 Adam 优化其他参数,请使用 GlobalOptimManager 类为特定层覆盖特定的超参数。您需要:
- 在参数仍在 CPU 上时注册它们。
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters())
- 使用新的期望超参数覆盖配置。例如,让我们覆盖
model.fc1.weight
层以使用 32 位 Adam。
有关您可以覆盖的其他超参数的更多信息,请查阅优化器 API 文档。
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# override the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, "optim_bits", 32)
您还可以通过将多个层作为列表传递,并将新的超参数作为字典传递来一次性覆盖多个层。例如,让我们覆盖 model.special.weight
和 model.also_special.weight
层以使用稀疏优化以及更低的学习率和衰减率。
mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
对于特定层,我们建议在每个模块中进行局部覆盖。将模块、参数及其属性名称传递给 GlobalOptimManager
class MyModule(torch.nn.Module):
def __init__(d_in, d_out):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(d_in, d_out)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
后续步骤
有关 8 位优化器的更多概念细节和解释,请参阅 8 位优化器 指南。
< > 在 GitHub 上更新