Safetensors 文档
Torch 共享张量
并获得增强的文档体验
开始使用
Torch 共享张量
TL;DR (太长不看)
使用特定的函数,这在大多数情况下应该对您有效。但这并非没有副作用。
from safetensors.torch import load_model, save_model
save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")
load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))
什么是共享张量?
Pytorch 在某些计算中使用共享张量。这对于总体减少内存使用非常有趣。
一个非常经典的用例是在 transformers 中,embeddings
与 lm_head
共享。通过使用相同的矩阵,模型使用的参数更少,并且梯度可以更好地流向 embeddings
(这是模型的开始,因此梯度不容易流到那里,而 lm_head
在模型的末尾,因此梯度在那里非常好,由于它们是相同的张量,因此两者都受益)
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a
def forward(self, x):
return self.b(self.a(x))
model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
为什么共享张量不保存在 safetensors 中?
有多种原因
并非所有框架都支持它们,例如
tensorflow
就不支持。因此,如果有人在 torch 中保存共享张量,则无法以类似的方式加载它们,因此我们无法保持相同的Dict[str, Tensor]
API。它使惰性加载变得非常快速。 惰性加载是仅加载给定文件的一些张量或部分张量的能力。在不共享张量的情况下,这很容易做到,但通过张量共享
with safe_open("model.safetensors", framework="pt") as f: a = f.get_tensor("a") b = f.get_tensor("b")
现在,对于给定的代码,在事后“重新共享”缓冲区是不可能的。一旦我们给出了
a
张量,当您请求b
时,我们就无法返回相同的内存。(在这个特定的示例中,我们可以跟踪给定的缓冲区,但这在一般情况下并非如此,因为您可以对a
进行任意操作,例如在请求b
之前将其发送到另一个设备)它可能导致文件比必要的更大。如果您正在保存的共享张量只是较大张量的一小部分,那么使用 pytorch 保存会导致保存整个缓冲区,而不是仅保存所需的内容。
a = torch.zeros((100, 100)) b = a[:1, :] torch.save({"b": b}, "model.bin") # File is 41k instead of the expected 400 bytes # In practice it could happen that you save several 10GB instead of 1GB.
现在,考虑到所有这些原因,其中没有任何内容是板上钉钉的。共享张量不会导致不安全或拒绝服务风险,因此如果当前解决方法不能令人满意,则可以重新考虑此决定。
它是如何工作的?
设计相当简单。我们将查找所有共享张量,然后查找覆盖整个缓冲区的所有张量(可以有多个这样的张量)。这为我们提供了多个可以保存的名称,我们只需选择第一个
在 load_model
期间,我们的加载方式有点像 load_state_dict
,不同的是我们正在查看模型本身,以检查共享缓冲区,并忽略由于缓冲区共享而实际覆盖的“遗漏键”(它们已正确加载,因为有一个缓冲区在后台加载)。所有其他错误都按原样引发
注意:这意味着我们正在删除文件中的一些键。这意味着如果您正在检查磁盘上保存的键,您将看到一些“丢失的张量”,或者如果您正在使用 load_state_dict
。除非我们开始直接在格式中支持共享张量,否则没有真正的方法可以解决它。