使用推理端点进行解码的远程 VAE 🤗

发布于 2025 年 2 月 24 日
在 GitHub 上更新

当使用潜在空间扩散模型进行高分辨率图像和视频合成时,VAE 解码器可能会消耗相当多的内存。这使得用户很难在消费级 GPU 上运行这些模型,而无需牺牲延迟等。

例如,在卸载时,会产生设备传输开销,从而导致整体推理延迟。平铺是另一种解决方案,它允许我们对所谓的“输入块”进行操作。然而,它可能对最终图像的质量产生负面影响。

因此,我们希望与社区共同试行一个想法——将解码过程委托给远程端点。

不存储或跟踪任何数据,并且代码是开源的。我们对 huggingface-inference-toolkit 进行了一些更改,并使用 自定义处理程序

此实验功能由 Diffusers 🧨 开发

目录:

开始使用

下面,我们涵盖了我们认为远程 VAE 推理会受益的三个用例。

代码

首先,我们创建了一个用于与远程 VAE 交互的辅助方法。

从 `main` 安装 `diffusers` 以运行代码。 `pip install git+https://github.com/huggingface/diffusers@main`

代码

from diffusers.utils.remote_utils import remote_decode

基本示例

这里,我们展示了如何在随机张量上使用远程 VAE。

代码

image = remote_decode(
    endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
    scaling_factor=0.18215,
)

Flux 的用法略有不同。Flux 潜在向量是打包的,因此我们需要发送 `height` 和 `width`。

代码

image = remote_decode(
    endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
    height=1024,
    width=1024,
    scaling_factor=0.3611,
    shift_factor=0.1159,
)

最后,是 HunyuanVideo 的一个示例。

代码

video = remote_decode(
    endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
    output_type="mp4",
)
with open("video.mp4", "wb") as f:
    f.write(video)

生成

但我们希望在实际管道上使用 VAE 来获得实际图像,而不是随机噪声。下面的示例展示了如何使用 SD v1.5 来实现。

代码

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    variant="fp16",
    vae=None,
).to("cuda")

prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"

latent = pipe(
    prompt=prompt,
    output_type="latent",
).images
image = remote_decode(
    endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=latent,
    scaling_factor=0.18215,
)
image.save("test.jpg")

这是 Flux 的另一个例子。

代码

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
    vae=None,
).to("cuda")

prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"

latent = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    output_type="latent",
).images
image = remote_decode(
    endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=latent,
    height=1024,
    width=1024,
    scaling_factor=0.3611,
    shift_factor=0.1159,
)
image.save("test.jpg")

这是 HunyuanVideo 的一个例子。

代码

from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
    model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
).to("cuda")

latent = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    output_type="latent",
).frames

video = remote_decode(
    endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
    tensor=latent,
    output_type="mp4",
)

if isinstance(video, bytes):
    with open("video.mp4", "wb") as f:
        f.write(video)

排队

使用远程 VAE 的一大优点是我们可以将多个生成请求排队。当当前的潜在向量正在进行解码处理时,我们已经可以排队另一个请求。这有助于提高并发性。

代码

import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionPipeline

def decode_worker(q: queue.Queue):
    while True:
        item = q.get()
        if item is None:
            break
        image = remote_decode(
            endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
            tensor=item,
            scaling_factor=0.18215,
        )
        display(image)
        q.task_done()

q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()

def decode(latent: torch.Tensor):
    q.put(latent)

prompts = [
    "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
    "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
    "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
    "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
    "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
    "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
]

pipe = StableDiffusionPipeline.from_pretrained(
    "Lykon/dreamshaper-8",
    torch_dtype=torch.float16,
    vae=None,
).to("cuda")

pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

_ = pipe(
    prompt=prompts[0],
    output_type="latent",
)

for prompt in prompts:
    latent = pipe(
        prompt=prompt,
        output_type="latent",
    ).images
    decode(latent)

q.put(None)
thread.join()

可用 VAE

使用远程 VAE 的优势

这些表格展示了不同 GPU 的 VRAM 要求。内存使用百分比决定了某些 GPU 的用户是否需要卸载。卸载时间因 CPU、RAM 和 HDD/NVMe 而异。平铺解码会增加推理时间。

SD v1.5

GPU 分辨率 时间(秒) 内存 (%) 平铺时间(秒) 平铺内存 (%)
英伟达 GeForce RTX 4090 512x512 0.031 5.60% 0.031 (0%) 5.60%
英伟达 GeForce RTX 4090 1024x1024 0.148 20.00% 0.301 (+103%) 5.60%
英伟达 GeForce RTX 4080 512x512 0.05 8.40% 0.050 (0%) 8.40%
英伟达 GeForce RTX 4080 1024x1024 0.224 30.00% 0.356 (+59%) 8.40%
英伟达 GeForce RTX 4070 Ti 512x512 0.066 11.30% 0.066 (0%) 11.30%
英伟达 GeForce RTX 4070 Ti 1024x1024 0.284 40.50% 0.454 (+60%) 11.40%
英伟达 GeForce RTX 3090 512x512 0.062 5.20% 0.062 (0%) 5.20%
英伟达 GeForce RTX 3090 1024x1024 0.253 18.50% 0.464 (+83%) 5.20%
英伟达 GeForce RTX 3080 512x512 0.07 12.80% 0.070 (0%) 12.80%
英伟达 GeForce RTX 3080 1024x1024 0.286 45.30% 0.466 (+63%) 12.90%
英伟达 GeForce RTX 3070 512x512 0.102 15.90% 0.102 (0%) 15.90%
英伟达 GeForce RTX 3070 1024x1024 0.421 56.30% 0.746 (+77%) 16.00%

SDXL

GPU 分辨率 时间(秒) 内存消耗 (%) 平铺时间(秒) 平铺内存 (%)
英伟达 GeForce RTX 4090 512x512 0.057 10.00% 0.057 (0%) 10.00%
英伟达 GeForce RTX 4090 1024x1024 0.256 35.50% 0.257 (+0.4%) 35.50%
英伟达 GeForce RTX 4080 512x512 0.092 15.00% 0.092 (0%) 15.00%
英伟达 GeForce RTX 4080 1024x1024 0.406 53.30% 0.406 (0%) 53.30%
英伟达 GeForce RTX 4070 Ti 512x512 0.121 20.20% 0.120 (-0.8%) 20.20%
英伟达 GeForce RTX 4070 Ti 1024x1024 0.519 72.00% 0.519 (0%) 72.00%
英伟达 GeForce RTX 3090 512x512 0.107 10.50% 0.107 (0%) 10.50%
英伟达 GeForce RTX 3090 1024x1024 0.459 38.00% 0.460 (+0.2%) 38.00%
英伟达 GeForce RTX 3080 512x512 0.121 25.60% 0.121 (0%) 25.60%
英伟达 GeForce RTX 3080 1024x1024 0.524 93.00% 0.524 (0%) 93.00%
英伟达 GeForce RTX 3070 512x512 0.183 31.80% 0.183 (0%) 31.80%
英伟达 GeForce RTX 3070 1024x1024 0.794 96.40% 0.794 (0%) 96.40%

提供反馈

如果您喜欢这个想法和功能,请帮助我们提供反馈,说明我们如何做得更好,以及您是否有兴趣将此类功能更原生集成到 Hugging Face 生态系统中。如果这个试点进展顺利,我们计划为更多模型创建优化的 VAE 端点,包括那些可以生成高分辨率视频的模型!

步骤:

  1. 通过 此链接 在 Diffusers 上提出问题。
  2. 回答问题并提供您想要的任何额外信息。
  3. 点击提交!

社区

这篇文章有些地方不够清晰

  • “时间”是指“推理时间”还是“卸载时间”?
  • 平铺内存/平铺时间是什么意思?
  • VAE 平铺总是发生吗?
·
文章作者

我不确定这有什么不清楚的,但为了完整性,特此说明。

VAE 平铺总是发生吗?

我们从未提及这会发生。

平铺内存/平铺时间是什么意思?

平铺内存/时间意味着我们正在应用平铺。

“时间”是指“推理时间”还是“卸载时间”?

总往返时间。如果它意味着其他任何东西,它就会像其他一样被指定。

这需要专业账户吗?

文章作者

不。

你好。
非常棒的实现,运行完美,解码时间为 2-4 秒,具体取决于图像大小。

一个简单的问题是,是否有可能在另一台机器(非 HF)上创建本地端点并将其用作 vae-decode 机器?

例如,一个舒适的用户界面实现。

文章作者

舒适实现:https://github.com/kijai/ComfyUI-HFRemoteVae

如果输入和输出方案匹配,您可以在本地机器上托管端点并按博客文章中所示使用它。

·

好的,我再读一遍。
谢谢。

📻 🎙️ 嘿,我为这篇博文生成了一个 AI 播客,快来听听看吧!

此播客通过 ngxson/kokoro-podcast-generator 生成,使用 DeepSeek-R1Kokoro-TTS

·
文章作者

无音频。

非常棒的功能。

我们还能获得 Wan Video 2.1 VAE 用于远程解码吗?

注册登录 发表评论