Safetensors 文档

Torch 共享张量

您正在查看的是需要从源码安装。如果您想使用常规的 pip 安装,请查看最新的稳定版本 (v0.5.0-rc.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Torch 共享张量

总结

使用特定的函数,这在大多数情况下应该能满足您的需求。但这并非没有副作用。

from safetensors.torch import load_model, save_model

save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")

load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))

什么是共享张量?

Pytorch 使用共享张量进行一些计算。这对于普遍减少内存使用非常有帮助。

一个非常经典的用例是在 Transformers 中,embeddingslm_head 共享。通过使用相同的矩阵,模型使用的参数更少,并且梯度能更好地流向 embeddings(因为 embeddings 位于模型的起始部分,梯度不容易流到那里,而 lm_head 位于模型的末端,梯度在那里非常强。由于它们是同一个张量,两者都能受益)。

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Linear(100, 100)
        self.b = self.a

    def forward(self, x):
        return self.b(self.a(x))


model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer

为什么共享张量不保存在 safetensors 中?

原因有多种:

  • 并非所有框架都支持它们,例如 tensorflow 就不支持。所以如果有人在 torch 中保存了共享张量,就无法以类似的方式在其他框架中加载它们,因此我们无法保持相同的 Dict[str, Tensor] API。

  • 这使得延迟加载变得非常快。延迟加载是指只加载文件中的部分张量或张量的部分。在没有张量共享的情况下,这很容易实现,但有了张量共享就不一样了。

    with safe_open("model.safetensors", framework="pt") as f:
        a = f.get_tensor("a")
        b = f.get_tensor("b")

    现在,使用给定的代码,事后“重新共享”缓冲区是不可能的。一旦我们给出了 a 张量,当您请求 b 时,我们无法返回相同的内存。(在这个特定的例子中,我们可以跟踪已给出的缓冲区,但通常情况并非如此,因为您可能在请求 b 之前对 a 进行了任意操作,比如将其发送到另一个设备)。

  • 它可能导致文件比必要的大得多。如果您要保存的共享张量只是一个更大张量的一部分,那么用 PyTorch 保存会导致保存整个缓冲区,而不是只保存需要的部分。

    a = torch.zeros((100, 100))
    b = a[:1, :]
    torch.save({"b": b}, "model.bin")
    # File is 41k instead of the expected 400 bytes
    # In practice it could happen that you save several 10GB instead of 1GB.

尽管提到了所有这些原因,但目前还没有定论。共享张量不会导致不安全或潜在的拒绝服务攻击,因此如果当前的解决方案不尽人意,这个决定可能会被重新考虑。

它是如何工作的?

设计相当简单。我们将查找所有共享张量,然后查找所有覆盖整个缓冲区的张量(可能有多个这样的张量)。这样我们就得到了多个可以保存的名称,我们只需选择第一个。

load_model 期间,我们的加载方式有点像 load_state_dict,但我们会检查模型本身,以查找共享缓冲区,并忽略那些因缓冲区共享而被覆盖的“缺失键”(它们实际上已正确加载,因为有一个缓冲区在底层加载了)。所有其他错误都将按原样引发。

注意事项:这意味着我们正在丢弃文件中的一些键。也就是说,如果您检查保存在磁盘上的键,您会看到一些“缺失的张量”,或者如果您正在使用 load_state_dict。除非我们开始直接在格式中支持共享张量,否则没有真正的解决办法。

< > 在 GitHub 上更新