PyTorch 共享张量
概述
使用特定函数,这些函数在大多数情况下应该可以满足您的需求。但这并非没有副作用。
from safetensors.torch import load_model, save_model
save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")
load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))
什么是共享张量?
Pytorch 在某些计算中使用共享张量。这对于减少整体内存使用非常有用。
一个非常经典的用例是在 transformers 中,embeddings
与 lm_head
共享。通过使用相同的矩阵,模型使用更少的参数,并且梯度可以更好地流向 embeddings
(它是模型的起点,因此梯度不易流向那里,而 lm_head
是模型的尾部,因此梯度非常好,因为它们是相同的张量,所以两者都受益)。
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a
def forward(self, x):
return self.b(self.a(x))
model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
为什么 safetensors 中不保存共享张量?
有多个原因。
并非所有框架都支持它们,例如
tensorflow
不支持。因此,如果有人在 torch 中保存共享张量,则无法以类似的方式加载它们,因此我们无法保持相同的Dict[str, Tensor]
API。这使得延迟加载变得非常快。 延迟加载是指仅加载特定张量或特定文件部分张量的能力。在没有共享张量的情况下,这很容易实现,但在共享张量的情况下,
with safe_open("model.safetensors", framework="pt") as f: a = f.get_tensor("a") b = f.get_tensor("b")
现在,使用当前代码无法在之后“重新共享”缓冲区。一旦我们提供
a
张量,当您请求b
时,我们就无法将相同的内存返回(在此特定示例中,我们可以跟踪已提供的缓冲区,但通常情况并非如此,因为您可以对a
执行任意操作,例如在请求b
之前将其发送到另一个设备)。这会导致文件比必要的大得多。 如果您要保存一个仅占较大张量一部分的共享张量,那么使用 pytorch 保存它会导致保存整个缓冲区,而不是仅保存所需的部分。
a = torch.zeros((100, 100)) b = a[:1, :] torch.save({"b": b}, "model.bin") # File is 41k instead of the expected 400 bytes # In practice it could happen that you save several 10GB instead of 1GB.
现在,考虑到所有这些原因,这里没有一成不变的规定。共享张量不会导致不安全或拒绝服务漏洞,因此如果当前的解决方法无法令人满意,可以重新考虑此决定。
它是如何工作的?
设计相当简单。我们将查找所有共享张量,然后查找覆盖整个缓冲区的张量(可能有多个此类张量)。这将提供多个名称,可以保存这些名称,我们只需选择第一个名称。
在 load_model
期间,我们加载的方式类似于 load_state_dict
的加载方式,不同之处在于我们查看模型本身以检查共享缓冲区,并忽略实际上通过缓冲区共享而被覆盖的“缺少键”(由于缓冲区在幕后加载,因此它们已正确加载)。所有其他错误都会照常引发。
注意: 这意味着我们正在删除文件中的某些键。也就是说,如果要检查磁盘上保存的键,您将看到一些“缺少的张量”,或者如果您使用的是 load_state_dict
。除非我们开始直接在格式中支持共享张量,否则别无选择。