8 位优化器
使用 8 位优化器,大型模型可以使用比标准 32 位优化器少 75% 的 GPU 内存进行微调,而不会损失任何准确性。降低的内存需求意味着 8 位优化器比标准优化器快 4 倍,并且不需要任何超参数调整。
本指南将向您展示如何使用 8 位优化器。
8 位优化器减少了内存使用量并加速了各种任务的优化。但是,由于 8 位优化器仅减少与参数数量成比例的内存,因此使用大量激活内存的模型(例如卷积网络)实际上并不能从 8 位优化器中受益。对于在高度内存受限的 GPU 上训练或微调具有大量参数的模型,8 位优化器最有效。
8 位优化器是常规优化器的直接替代品,这意味着它们也接受与常规优化器相同的参数。对于 NLP 模型,建议使用 StableEmbedding 类来提高稳定性和结果。
import bitsandbytes as bnb
- adam = torch.optim.Adam(...)
+ adam = bnb.optim.Adam8bit(...)
# recommended for NLP models
- before: torch.nn.Embedding(...)
+ bnb.nn.StableEmbedding(...)
默认情况下,即使您使用 8 位优化器初始化这些参数,所有小于 4096 个元素的参数张量也将保留为 32 位。这样做是因为小型张量不会节省太多内存,并且通常包含高度可变的参数(偏差)或需要高精度的参数(批归一化、层归一化)。
您可以使用 min_8bit_size
参数更改此值。例如,如果您希望仅当最小大小为 16384 个值时才将参数优化为 8 位(建议使用 4096 的倍数)
import bitsandbytes as bnb
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
您可以配置的其他参数包括学习率 (lr
)、衰减率 (betas
)、优化器状态的位数 (optim_bits
) 和百分位数裁剪 (percentile_clipping
),这可以提高稳定性。例如,要使用 5% 百分位数裁剪初始化一个 32 位 Adam 优化器
import bitsandbytes as bnb
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5)
优化不稳定的参数
要使用 32 位 Adam 优化某些不稳定的参数,并使用 8 位 Adam 优化其他参数,请使用 GlobalOptimManager 类覆盖特定层的特定超参数。您需要
- 在 CPU 上注册参数。
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters())
- 使用新的所需超参数覆盖配置。例如,让我们覆盖
model.fc1.weight
层以使用 32 位 Adam。
查看优化器 API 文档以获取有关您可以覆盖的其他超参数的更多信息。
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# override the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, "optim_bits", 32)
您还可以通过将它们作为列表传递并将新的超参数作为字典传递来一次覆盖多个层。例如,让我们覆盖 model.special.weight
和 model.also_special.weight
层以使用稀疏优化和更低的学习率和衰减率。
mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
对于特定层,我们建议在每个模块中本地覆盖。将模块、参数及其属性名称传递给 GlobalOptimManager
class MyModule(torch.nn.Module):
def __init__(d_in, d_out):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(d_in, d_out)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
下一步
有关 8 位优化器的更多概念细节和解释,请查看 8 位优化器 指南。
< > 在 GitHub 上更新