Safetensors 文档

Torch 共享张量

您正在查看 main 版本,该版本需要从源代码安装。如果您想要常规 pip 安装,请查看最新的稳定版本 (v0.5.0-rc.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Torch 共享张量

TL;DR (太长不看)

使用特定的函数,这在大多数情况下应该对您有效。但这并非没有副作用。

from safetensors.torch import load_model, save_model

save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")

load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))

什么是共享张量?

Pytorch 在某些计算中使用共享张量。这对于总体减少内存使用非常有趣。

一个非常经典的用例是在 transformers 中,embeddingslm_head 共享。通过使用相同的矩阵,模型使用的参数更少,并且梯度可以更好地流向 embeddings(这是模型的开始,因此梯度不容易流到那里,而 lm_head 在模型的末尾,因此梯度在那里非常好,由于它们是相同的张量,因此两者都受益)

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Linear(100, 100)
        self.b = self.a

    def forward(self, x):
        return self.b(self.a(x))


model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer

为什么共享张量不保存在 safetensors 中?

有多种原因

  • 并非所有框架都支持它们,例如 tensorflow 就不支持。因此,如果有人在 torch 中保存共享张量,则无法以类似的方式加载它们,因此我们无法保持相同的 Dict[str, Tensor] API。

  • 它使惰性加载变得非常快速。 惰性加载是仅加载给定文件的一些张量或部分张量的能力。在不共享张量的情况下,这很容易做到,但通过张量共享

    with safe_open("model.safetensors", framework="pt") as f:
        a = f.get_tensor("a")
        b = f.get_tensor("b")

    现在,对于给定的代码,在事后“重新共享”缓冲区是不可能的。一旦我们给出了 a 张量,当您请求 b 时,我们就无法返回相同的内存。(在这个特定的示例中,我们可以跟踪给定的缓冲区,但这在一般情况下并非如此,因为您可以对 a 进行任意操作,例如在请求 b 之前将其发送到另一个设备)

  • 它可能导致文件比必要的更大。如果您正在保存的共享张量只是较大张量的一小部分,那么使用 pytorch 保存会导致保存整个缓冲区,而不是仅保存所需的内容。

    a = torch.zeros((100, 100))
    b = a[:1, :]
    torch.save({"b": b}, "model.bin")
    # File is 41k instead of the expected 400 bytes
    # In practice it could happen that you save several 10GB instead of 1GB.

现在,考虑到所有这些原因,其中没有任何内容是板上钉钉的。共享张量不会导致不安全或拒绝服务风险,因此如果当前解决方法不能令人满意,则可以重新考虑此决定。

它是如何工作的?

设计相当简单。我们将查找所有共享张量,然后查找覆盖整个缓冲区的所有张量(可以有多个这样的张量)。这为我们提供了多个可以保存的名称,我们只需选择第一个

load_model 期间,我们的加载方式有点像 load_state_dict,不同的是我们正在查看模型本身,以检查共享缓冲区,并忽略由于缓冲区共享而实际覆盖的“遗漏键”(它们已正确加载,因为有一个缓冲区在后台加载)。所有其他错误都按原样引发

注意:这意味着我们正在删除文件中的一些键。这意味着如果您正在检查磁盘上保存的键,您将看到一些“丢失的张量”,或者如果您正在使用 load_state_dict。除非我们开始直接在格式中支持共享张量,否则没有真正的方法可以解决它。

< > 在 GitHub 上更新