Safetensors 文档
Torch 共享张量
并获得增强的文档体验
开始使用
Torch 共享张量
总结
使用特定的函数,这在大多数情况下应该能满足您的需求。但这并非没有副作用。
from safetensors.torch import load_model, save_model
save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")
load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))
什么是共享张量?
Pytorch 使用共享张量进行一些计算。这对于普遍减少内存使用非常有帮助。
一个非常经典的用例是在 Transformers 中,embeddings
与 lm_head
共享。通过使用相同的矩阵,模型使用的参数更少,并且梯度能更好地流向 embeddings
(因为 embeddings
位于模型的起始部分,梯度不容易流到那里,而 lm_head
位于模型的末端,梯度在那里非常强。由于它们是同一个张量,两者都能受益)。
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a
def forward(self, x):
return self.b(self.a(x))
model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
为什么共享张量不保存在 safetensors 中?
原因有多种:
并非所有框架都支持它们,例如
tensorflow
就不支持。所以如果有人在 torch 中保存了共享张量,就无法以类似的方式在其他框架中加载它们,因此我们无法保持相同的Dict[str, Tensor]
API。这使得延迟加载变得非常快。延迟加载是指只加载文件中的部分张量或张量的部分。在没有张量共享的情况下,这很容易实现,但有了张量共享就不一样了。
with safe_open("model.safetensors", framework="pt") as f: a = f.get_tensor("a") b = f.get_tensor("b")
现在,使用给定的代码,事后“重新共享”缓冲区是不可能的。一旦我们给出了
a
张量,当您请求b
时,我们无法返回相同的内存。(在这个特定的例子中,我们可以跟踪已给出的缓冲区,但通常情况并非如此,因为您可能在请求b
之前对a
进行了任意操作,比如将其发送到另一个设备)。它可能导致文件比必要的大得多。如果您要保存的共享张量只是一个更大张量的一部分,那么用 PyTorch 保存会导致保存整个缓冲区,而不是只保存需要的部分。
a = torch.zeros((100, 100)) b = a[:1, :] torch.save({"b": b}, "model.bin") # File is 41k instead of the expected 400 bytes # In practice it could happen that you save several 10GB instead of 1GB.
尽管提到了所有这些原因,但目前还没有定论。共享张量不会导致不安全或潜在的拒绝服务攻击,因此如果当前的解决方案不尽人意,这个决定可能会被重新考虑。
它是如何工作的?
设计相当简单。我们将查找所有共享张量,然后查找所有覆盖整个缓冲区的张量(可能有多个这样的张量)。这样我们就得到了多个可以保存的名称,我们只需选择第一个。
在 load_model
期间,我们的加载方式有点像 load_state_dict
,但我们会检查模型本身,以查找共享缓冲区,并忽略那些因缓冲区共享而被覆盖的“缺失键”(它们实际上已正确加载,因为有一个缓冲区在底层加载了)。所有其他错误都将按原样引发。
注意事项:这意味着我们正在丢弃文件中的一些键。也就是说,如果您检查保存在磁盘上的键,您会看到一些“缺失的张量”,或者如果您正在使用 load_state_dict
。除非我们开始直接在格式中支持共享张量,否则没有真正的解决办法。