Bitsandbytes 文档

8 位优化器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

8 位优化器

使用 8 位优化器,大型模型可以使用减少 75% 的 GPU 内存进行微调,而与使用标准 32 位优化器训练相比,不会损失任何精度。 内存需求降低意味着 8 位优化器比标准优化器快 4 倍,并且不需要超参数调整。

本指南将向您展示如何使用 8 位优化器。

8 位优化器减少内存使用量并加速各种任务的优化。 然而,由于 8 位优化器仅按参数数量比例减少内存,因此使用大量激活内存的模型(例如卷积网络)实际上并没有从 8 位优化器中受益。 8 位优化器最有利于在高度内存受限的 GPU 上训练或微调具有大量参数的模型。

8 位优化器是常规优化器的直接替代品,这意味着它们也接受与常规优化器相同的参数。 对于 NLP 模型,建议使用 StableEmbedding 类以提高稳定性和结果。

import bitsandbytes as bnb

- adam = torch.optim.Adam(...)
+ adam = bnb.optim.Adam8bit(...)

# recommended for NLP models
- before: torch.nn.Embedding(...)
+ bnb.nn.StableEmbedding(...)

默认情况下,即使您使用 8 位优化器初始化这些参数,所有元素少于 4096 个的参数张量都将保持在 32 位。 这样做是因为小张量不会节省太多内存,并且通常包含高度可变的参数(偏差)或需要高精度的参数(批归一化、层归一化)。

您可以使用 min_8bit_size 参数更改此值。 例如,如果您只想在最小尺寸为 16384 个值时(建议使用 4096 的倍数)将参数优化为 8 位

import bitsandbytes as bnb

adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)

您可以配置的其他参数包括学习率 (lr)、衰减率 (betas)、优化器状态的位数 (optim_bits) 和百分位数裁剪 (percentile_clipping),这可以提高稳定性。 例如,要初始化具有第 5 个百分位数裁剪的 32 位 Adam 优化器

import bitsandbytes as bnb

adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5)

优化不稳定参数

要使用 32 位 Adam 优化一些不稳定参数,并使用 8 位 Adam 优化其他参数,请使用 GlobalOptimManager 类来覆盖特定层的特定超参数。 您需要

  1. 在参数位于 CPU 上时注册它们。
import torch
import bitsandbytes as bnb

mng = bnb.optim.GlobalOptimManager.get_instance()

model = MyModel()
mng.register_parameters(model.parameters())
  1. 使用新的所需超参数覆盖配置。 例如,让我们覆盖 model.fc1.weight 层以使用 32 位 Adam。

查看优化器 API 文档以获取有关您可以覆盖的其他超参数的更多信息。

model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)

# override the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, "optim_bits", 32)

您还可以通过将多个层作为列表传递并将新的超参数作为字典传递来一次性覆盖多个层。 例如,让我们覆盖 model.special.weightmodel.also_special.weight 层以使用稀疏优化和较低的学习率和衰减率。

mng.override_config([model.special.weight, model.also_special.weight],
                    key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})

对于特定层,我们建议在每个模块中本地覆盖。 将模块、参数及其属性名称传递给 GlobalOptimManager

class MyModule(torch.nn.Module):
  def __init__(d_in, d_out):
    super(MyModule, self).__init__()
    self.linear = torch.nn.Linear(d_in, d_out)
    # optimization will happen in 32-bit and
    # learning rate will be set to 0.0001 independent of the main learning rate
    config = {'optim_bits': 32, 'lr' : 0.0001}
    GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)

下一步

有关 8 位优化器的更多概念细节和解释,请查看 8 位优化器 指南。

< > GitHub 上更新