开源 SD-Small 和 SD-Tiny 的知识蒸馏代码和权重

发布日期:2023年8月1日
在 GitHub 上更新

近期,AI 社区见证了更大、性能更强的语言模型(如 Falcon 40B、LLaMa-2 70B、Falcon 40B、MPT 30B)以及图像领域模型(如 SD2.1 和 SDXL)的显著发展。这些进步无疑推动了 AI 能力的边界,实现了高度通用和最先进的图像生成和语言理解能力。然而,当我们惊叹于这些模型的强大和复杂性时,也必须认识到对 AI 模型进行小型化、高效化和更易于访问(尤其是通过开源)的需求日益增长。

Segmind,我们一直在致力于如何使生成式 AI 模型更快、更便宜。去年,我们开源了加速 SD-WebUI 库,名为 voltaML,它是一个基于 AITemplate/TensorRT 的推理加速库,已将推理速度提高了 4-6 倍。为了继续实现使生成式模型更快、更小、更便宜的目标,我们正在开源我们压缩的 SD 模型 SD-Small 和 SD-Tiny 的权重和训练代码。预训练检查点可在 Huggingface 🤗 上获取。

知识蒸馏

我们的新压缩模型已采用知识蒸馏(KD)技术进行训练,这项工作主要基于这篇论文。作者描述了一种块移除知识蒸馏方法,其中一些 UNet 层被移除,然后训练学生模型权重。利用论文中描述的 KD 方法,我们能够使用 🧨 diffusers 库训练出两个压缩模型:SmallTiny,它们比基础模型分别减少了 35% 和 55% 的参数,同时实现了与基础模型相当的图像保真度。我们已将蒸馏代码开源到此存储库,并将预训练检查点发布到 Huggingface 🤗

训练神经网络的知识蒸馏类似于老师一步步引导学生。一个大型教师模型在大数据上进行预训练,然后一个较小的模型在较小的数据集上进行训练,以模仿大型模型的输出,并结合数据集上的传统训练。

在这种特定的知识蒸馏类型中,学生模型被训练来执行从纯噪声中恢复图像的正常扩散任务,但同时,模型被要求匹配大型教师模型的输出。输出匹配发生在 U-net 的每个块中,因此模型质量得以大部分保留。因此,使用之前的类比,我们可以说,在这种蒸馏过程中,学生不仅会尝试从问题和答案中学习,还会从教师的答案以及获取答案的循序渐进的方法中学习。我们损失函数中有 3 个组成部分来实现这一点:首先是目标图像的潜在变量与生成图像的潜在变量之间的传统损失。其次是教师生成的图像的潜在变量与学生生成的图像的潜在变量之间的损失。最后,也是最重要的组成部分是特征级别损失,即教师和学生每个块的输出之间的损失。

所有这些结合起来构成了知识蒸馏训练。下面是论文中描述的 KD 中使用的块移除 UNet 的架构。

图片摘自 Shinkook 等人的论文《文本到图像扩散模型的架构压缩》(“On Architectural Compression of Text-to-Image Diffusion Models”)

我们以 Realistic-Vision 4.0 作为我们的基础教师模型,并在 LAION Art Aesthetic 数据集(图像评分高于 7.5 分,因其高质量图像描述)上进行训练。与论文不同,我们选择在 100 万张图像上分别对 Small 和 Tiny 模型训练 10 万步和 12.5 万步。蒸馏训练代码可以在这里找到。

模型使用

该模型可使用 🧨 diffusers 中的 DiffusionPipeline 进行使用。


from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("segmind/small-sd", torch_dtype=torch.float16)
prompt = "Portrait of a pretty girl"
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
image = pipeline(prompt, negative_prompt = negative_prompt).images[0]
image.save("my_image.png")

推理延迟速度

我们观察到,蒸馏模型比原始基础模型快达 100%。基准测试代码可在此处找到。

潜在限制

蒸馏模型尚处于早期阶段,输出质量可能尚未达到生产级别。这些模型可能不是最好的通用模型。它们最适合用于特定概念/风格的微调或 LoRA 训练。蒸馏模型在可组合性或多概念方面尚不理想。

在肖像数据集上微调 SD-tiny 模型

我们已在用 Realistic Vision v4.0 模型生成的肖像图像上微调了我们的 sd-tiny 模型。以下是使用的微调参数。

  • 步数:131000
  • 学习率:1e-4
  • 批量大小:32
  • 梯度累积步数:4
  • 图像分辨率:768
  • 数据集大小 - 7k 张图像
  • 混合精度:fp16

我们能够生成接近原始模型图像质量的图像,参数减少了近 40%,以下样本结果足以说明一切。

基础模型的微调代码可以在此处找到。

LoRA 训练

在蒸馏模型上进行 LoRA 训练的优势之一是训练速度更快。以下是我们对蒸馏模型进行 LoRA 训练的一些抽象概念图像。LoRA 训练的代码可以在这里找到。

结论

我们邀请开源社区帮助我们改进这些蒸馏 SD 模型并实现更广泛的采用。用户可以加入我们的 Discord 服务器,我们将在其中发布这些模型的最新更新,发布更多检查点和一些激动人心的新 LoRA。如果您喜欢我们的工作,请在我们的 Github 上给我们一颗星。

社区

注册登录 发表评论