文本反转(Textual Inversion)Token 是如何破坏提示词的?

社区文章 发布于2024年3月31日

本博客旨在探讨文本反转(Textual Inversion)Token 究竟是如何破坏提示词或在扩散模型中过度主导交叉注意力的。澄清一下,我认为我接近了答案,但我只能形成一个猜测。

背景

为了评估像 Dreambooth 和文本反转(Textual Inversion)这样的自定义生成方法,常用的评估指标是 CLIP 分数。CLIP 分数使用由 Open AI 创建的 CLIP 模型来衡量图像对齐度,即生成图像与训练图像的相似度,以及文本对齐度,即生成图像与我们提示词的符合度。

image/png

虽然文本反转可以生成相当高质量的图像,但在遵循提示方面表现不佳,如下面的自定义扩散论文所示,该论文表明文本反转的文本对齐度始终最低。

image/png

其原因是文本反转被训练为忽略提示词,只生成你用它训练的图像。例如,在 diffusers 中,我们告诉模型在给定以下所有提示词的情况下生成相同的图像:

imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

这促使我们的模型生成主体并忽略提示词的其余部分/用概念覆盖提示词。例如,当我们要求模型生成“A <cat-toy> next to a man with a friend”时,我们得到:

image/png

这似乎忽略了“man with a friend”部分,并将其替换为 <cat-toy>。这很有趣,因为每个 token 只占用一个 768 维的向量,这与扩散模型的其余部分相比非常小。此外,它只影响 clip 文本编码器中的一个词。

然而,首先,让我们确认存在问题。

问题

对于本博客,我只测试了一个示例,但如果我有时间,我可能会测试更多。我正在使用 daam,它可以显示每个 token 对输出的贡献,如下所示:

image/png 关于如何做到这一点,它来自于扩散模型中一个有趣的特性:特定 token 的交叉注意力图倾向于固定在该 token 的索引中,如下所示(图片取自 prompt-to-prompt 论文):

image/png 所以我们可以直接查看“bear”的交叉注意力图,就能看到“bear”这个 token 对输出的贡献!有关更多信息,请参阅此处

现在,如果我们查看提示词“A <cat-toy> next to a man with a friend”中每个 token 的贡献,当我们查看平均注意力图时,所有正常 token 的范数都在 10~70 左右。然而,<cat-toy> token 的注意力图范数始终在 200 左右。此外,<cat-toy> token 的注意力图似乎比其他 token 更清晰。例如,下面是 <cat-toy> 与 token next 的注意力图比较:

image/png

image/png 所以文本反转 token 确实有一些独特之处。

为了找出原因,我主要进行了 4 项测试:

  1. 范数
  2. 角度
  3. CLIP 注意力
  4. 检查与 SOS Token 的关系

精确地找出导致这种交叉注意力破坏的原因

范数

我在 LAION discord 服务器和文献 “Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models”(也称为 E4T)中发现,关于这种情况发生的主要理论是文本反转向量的范数导致了这种崩溃。为了防止范数变得过大,E4T 尝试用 l1 范数来限制它的增长。这些说法有一定的根据。token 的范数遵循如下分布:

image/png

其中 y 轴是频率,x 轴是范数。0 处的微小凸起表示从未训练过的 token。平均值约为 0.385。现在,对于上面提到的 <cat-toy> token,其范数为 3.1,大约是平均 token 的 8 倍。那么,这是否就是导致该 token 过度表示的原因呢?

对此的一个反驳是,如果你查看 clip 文本模型(即稳定扩散的文本编码器)的代码,我们会看到这行:

last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)

就在我们插入扩散模型之前。所以本质上,所有可能的比例信息都可能在这里丢失。然而,这种比例很有可能以某种方式混入 clip 侧的其他 token 中,以过度表示这个概念。如果我们假设这一点,那么如果我们把 token 的范数缩小到大约 0.385(平均 token 范数),我们应该会看到这个 token 的表示减少。然而,我们得到的是下面这个结果:

image/png

在范数为 0.385 时。如果我们将其缩放到 0.385*2、0.385*3,以此类推,直到原始比例,我们得到以下图像:

image/png image/png image/png image/png image/png image/png image/png

至少对我来说,范数似乎确实略微提高了质量,但降低它对提高提示对齐度的影响可以忽略不计。我发现这非常令人着迷,因为仅仅一个 768 维的向量就可以如此过度地主导提示。

角度

如果范数不是罪魁祸首,那会是角度吗?为了进行这项测试,我计算了 1000 个 token 之间的点积(而不是与自身)。出于计算原因,我没有对全部 49408 个 token 进行计算。点积的结果如下:

image/png

回顾一下,当我们计算点积时,我们还可以通过除以它们的范数来得到两个向量之间的余弦,如下所示:

cos(θ)=uvuv\cos(\theta) = \dfrac{u \cdot v}{\mid u \mid \mid v \mid}

余弦值如下:

image/png

因此,我认为我们可以相当自信地说,每个 token 的每个输入嵌入与其他 token 都相当不相似。由于它们的范数始终在 0.385 左右,我们可以想象每个向量在 token 空间中占据一个球体的一部分!

image/png

现在,让我们计算每个 token 与文本反转 token 嵌入的余弦。结果如下:

image/png

一个观察是,余弦的绝对值略小,这意味着文本反转 token 与其他常规 token 相比,与其余 token 稍微更正交/垂直,这可能会给我们一些线索。

CLIP 注意力

现在,我一开始提到“正常 token 的交叉注意力范数最多在 10~70 左右,而 <cat-toy> token 的范数始终在 200 左右。”

然而,有一个 token 我当时没有提到。那就是序列开始 token (SOS token)。SOS token 是用于启动每个提示词的 token。出于某种原因,它的交叉注意力图中范数竟然达到了 2000 左右,这确实超过了我所知道的所有 token 范数。所以我形成的一个假设是,也许文本反转 token 正在获取与 SOS token 相似的特性,通过非常高的注意力图来过度主导提示词。其次,文本反转 token 主要通过攻击文本编码器来实现这一点,而不是 Stable Diffusion UNet 本身。

所以让我们检查 CLIP 文本编码器层的 2 个注意力图。一个编码提示词“A <cat-toy> next to a man with a friend”,另一个编码“A cat toy next to a man with a friend”。

我们从普通提示词(不含文本反转 token)开始。对于 CLIP 文本转换器的每一层,我们都会得到这样的注意力图:

第一层

image/png

第二层

image/png

第三层

image/png

为了理解这些注意力图的含义,我们可以查看坐标轴上的数字。如果我们看 y 轴上的 1 和 x 轴上的 0。我们可以看到位置 1 的 token 对位置 0 的 token 的关注程度。因此,随着层级越来越深,所有 token 都只关注 SOS token。这是一个众所周知的事实,至少在大型语言模型中是这样。我最初是在 注意力流论文(Tom Aarsen 在这里发表了一篇关于它的博客文章)中了解到这一点的,作者们利用这一事实来扩展 LLMS 的上下文长度!

然而,有一个层似乎与其他层不同。这是下面的第一层,其中 y 轴上的 token 似乎正在关注自身、SOS token,有时还关注中间的 token。我的理解是,CLIP 模型的第一层负责掌握提示中的所有单词,并将其编码到起始 token 中。

那么现在让我们看看文本反转提示的第一层注意力图。

image/png

我们看到在索引为 2 的 token <cat-toy> 的位置,文本反转 token 似乎只关注自身。事实上,如果我们放大看:

image/png 我们发现它对起始 token 的关注非常少!因此,在 CLIP 的后续层中,我的假设是文本反转 token 通过跳过其索引处的起始 token 来过度主导提示的其余部分。这解释了为什么文本反转 token 的噪声相对较少。它完全控制了生成,而其他单词则被忽略了。具体来说,在索引 2, 2 处的文本反转 token 在注意力图中具有最高值(0.905),除了索引 0, 0 处的值为 1。我们从这里将索引 2, 2 处的值称为文本反转注意力。所以我目前的猜测是,SOS token 和我们的文本反转 token 之间一定存在某种关系,导致文本反转注意力具有如此高的值。

检查与 SOS Token 的关系

让我们看看如果我们将文本反转嵌入替换为 SOS token 嵌入会发生什么。有趣的是,如下方交叉注意力放大图所示,我们似乎不再关注我们的 token 了:

image/png

生成的图像如下:

image/png

这确实表明 CLIP 文本编码器注意力图中较低的注意力分数表示提示破坏程度较小,这与我们猜测的一致。另一个有趣的发现是,注意力图并未将更多注意力分配给与 SOS token 相似的 token。

当我们计算 token 与 SOS token 的余弦时,我们发现余弦为 -0.0821,这在与其他 SOS token 的余弦相比并不显著,如下所示:

image/png

但我形成的一个假设是,也许 CLIP 文本编码器更关注与 SOS token 不相似的 token。为了证实这一点,我尝试将文本反转 token 的输入嵌入设置为 SOS token 的负值。我得到的注意力图如下:

image/png

文本反转注意力值约为 0.88!为了进一步证实这一点,我在缩放到 0.385 范数的文本反转 token 和 SOS token 之间进行了球面线性插值 (SLERP)。我发现,如果我从 SOS token 旋转开,文本反转注意力至少保持在 0.87 左右,这始终高于其他 token。然而,如果我们旋转向文本反转 token,文本反转注意力会迅速下降,当达到 50% 时,所有猫玩具的迹象都消失了。在插值因子约为 0.184,文本反转注意力为 0.77 时,我得到了下面的图片:

image/png

以上是我从中得到的最佳图片,它表明一些图像保真度已经丧失。其中一个有趣的部分是,当我们把旋转后的 token 恢复到原始比例时,虽然注意力分数随着接近 SOS token 而降低的总体趋势仍然存在,但速度变慢了。此外,在某个点上,图像变得非常扭曲,这表明这个谜团中肯定缺少一些部分。

我目前的猜测是,在计算第一个注意力图时,存在某种操作来减去 token 的 SOS token 部分,并将减去的量分配给第 0 列等等。有点像 PCA,但不是正交的,因为负 SOS token 起作用了。然而,这可能是未来博客的主题。

结论与未来方向

上述小型研究主要旨在突出文本反转 token 中一个非常有趣的现象。一个唾手可得的成果是看看这是否可以用于其他 token,但我认为一个有趣的目标是看看我们是否能在不改变 token 任何东西的情况下保持保真度但仍使用文本反转 token。例如,如果我们手动将文本反转注意力设置为 0.7 左右而无需改变 token 的任何特性,这可能会很有趣。但现在,希望大家喜欢!

旁注 - 多主题生成

在多主题生成中,当我将两个独立的文本反转 token(一个是猫,一个是椅子)合并到一个提示中时,我得到了以下结果:

image/png

虽然它们单独生成得很好,如下所示:

image/png

image/png

我猜测这是因为每个 token 都试图垄断所有的交叉注意力,所以在生成过程中它们会相互破坏!

社区

注册登录 发表评论