扩散模型课程文档

音频扩散模型

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

Open In Colab

音频扩散模型

在本笔记本中,我们将简要了解如何使用扩散模型生成音频。

你将学到什么:

  • 音频如何在计算机中表示
  • 原始音频数据和频谱图之间转换的方法
  • 如何使用自定义合并函数准备数据加载器,将音频片段转换为频谱图
  • 在特定音乐类型上微调现有的音频扩散模型
  • 将自定义管道上传到 Hugging Face Hub

注意:这主要用于教育目的 - 我们无法保证我们的模型声音效果良好 😉。

让我们开始吧!

设置和导入

%pip install -q datasets diffusers torchaudio accelerate
import torch, random
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from IPython.display import Audio
from matplotlib import pyplot as plt
from diffusers import DiffusionPipeline
from torchaudio import transforms as AT
from torchvision import transforms as IT

从预训练音频管道采样

让我们首先按照 音频扩散模型文档 加载现有的音频扩散模型管道

# Load a pre-trained audio diffusion pipeline
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-instrumental-hiphop-256").to(device)

与我们在前面单元中使用的管道一样,我们可以通过如下调用管道来创建样本

>>> # Sample from the pipeline and display the outputs
>>> output = pipe()
>>> display(output.images[0])
>>> display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))

这里,rate 参数指定音频的采样率;我们稍后将对此进行更深入的探讨。您还会注意到管道返回了多个内容。这里发生了什么?让我们仔细看看这两个输出。

第一个是数据数组,表示生成的音频

# The audio array
output.audios[0].shape

第二个看起来像是灰度图像

# The output image (spectrogram)
output.images[0].size

这为我们提供了一个关于此管道如何工作的提示。音频不是直接使用扩散生成的 - 相反,该管道与我们在 单元 1 中看到的无条件图像生成管道具有相同的 2D U-Net,用于生成频谱图,然后将其后处理成最终的音频。

该管道还有一个处理这些转换的额外组件,我们可以通过 pipe.mel 访问它

pipe.mel

从音频到图像再到音频

音频“波形”对时间上的原始音频样本进行编码 - 例如,这可能是从麦克风接收到的电信号。处理这种“时域”表示可能很棘手,因此通常的做法是将其转换为其他形式,通常称为频谱图。频谱图显示了不同频率(y 轴)与时间(x 轴)的强度

>>> # Calculate and show a spectrogram for our generated audio sample using torchaudio
>>> spec_transform = AT.Spectrogram(power=2)
>>> spectrogram = spec_transform(torch.tensor(output.audios[0]))
>>> print(spectrogram.min(), spectrogram.max())
>>> log_spectrogram = spectrogram.log()
>>> plt.imshow(log_spectrogram[0], cmap="gray")
tensor(0.) tensor(6.0842)

我们刚刚制作的频谱图的值介于 0.0000000000001 和 1 之间,大多数值接近该范围的低端。这对于可视化或建模来说并不理想 - 事实上,我们必须取这些值的对数才能获得显示任何细节的灰度图。出于这个原因,我们通常使用一种称为梅尔频谱图的特殊类型的频谱图,它通过对信号的不同频率分量应用一些变换来设计捕捉对人类听觉重要的信息。

torchaudio 文档图表 来自 torchaudio 文档 的一些音频变换

幸运的是,我们甚至不必过多地担心这些转换 - 管道的 mel 功能为我们处理了这些细节。使用它,我们可以将频谱图图像转换为音频,如下所示

a = pipe.mel.image_to_audio(output.images[0])
a.shape

我们可以通过首先加载原始音频数据,然后调用 audio_slice_to_image() 函数将音频数据数组转换为频谱图图像。较长的剪辑会自动切分成适当长度的块以生成 256x256 的频谱图图像

>>> pipe.mel.load_audio(raw_audio=a)
>>> im = pipe.mel.audio_slice_to_image(0)
>>> im

音频表示为一个长数字数组。为了将其大声播放,我们需要另一条关键信息:采样率。我们使用多少个样本(单个值)来表示一秒钟的音频?

我们可以使用以下方法查看在此管道训练期间使用的采样率

sample_rate_pipeline = pipe.mel.get_sample_rate()
sample_rate_pipeline

如果我们错误地指定了采样率,则会得到加速或减速的音频

display(Audio(output.audios[0], rate=44100))  # 2x speed

微调管道

现在我们已经对管道的运行方式有了大致的了解,让我们在一些新的音频数据上对其进行微调!

该数据集是不同类型的音频剪辑的集合,我们可以像这样从 Hub 加载它

from datasets import load_dataset

dataset = load_dataset("lewtun/music_genres", split="train")
dataset

您可以使用下面的代码查看数据集中不同的流派以及每个流派包含多少样本。

>>> for g in list(set(dataset["genre"])):
...     print(g, sum(x == g for x in dataset["genre"]))
Pop 945
Blues 58
Punk 2582
Old-Time / Historic 408
Experimental 1800
Folk 1214
Electronic 3071
Spoken 94
Classical 495
Country 142
Instrumental 1044
Chiptune / Glitch 1181
International 814
Ambient Electronic 796
Jazz 306
Soul-RnB 94
Hip-Hop 1757
Easy Listening 13
Rock 3095

该数据集以数组的形式存储音频。

>>> audio_array = dataset[0]["audio"]["array"]
>>> sample_rate_dataset = dataset[0]["audio"]["sampling_rate"]
>>> print("Audio array shape:", audio_array.shape)
>>> print("Sample rate:", sample_rate_dataset)
>>> display(Audio(audio_array, rate=sample_rate_dataset))
Audio array shape: (1323119,)
Sample rate: 44100

请注意,此音频的采样率较高 - 如果我们想使用现有的管道,则需要将其“重采样”以匹配。剪辑也比管道设置的剪辑更长。幸运的是,当我们使用pipe.mel加载音频时,它会自动将剪辑切成较小的片段。

>>> a = dataset[0]["audio"]["array"]  # Get the audio array
>>> pipe.mel.load_audio(raw_audio=a)  # Load it with pipe.mel
>>> pipe.mel.audio_slice_to_image(0)  # View the first 'slice' as a spectrogram

我们需要记住调整采样率,因为来自此数据集的数据每秒的样本数是原来的两倍。

sample_rate_dataset = dataset[0]["audio"]["sampling_rate"]
sample_rate_dataset

这里我们使用torchaudio的变换(导入为AT)进行重采样,使用管道的mel将音频转换为图像,并使用torchvision的变换(导入为IT)将图像转换为张量。这给了我们一个函数,可以将音频剪辑转换为我们可以用于训练的频谱图张量。

resampler = AT.Resample(sample_rate_dataset, sample_rate_pipeline, dtype=torch.float32)
to_t = IT.ToTensor()


def to_image(audio_array):
    audio_tensor = torch.tensor(audio_array).to(torch.float32)
    audio_tensor = resampler(audio_tensor)
    pipe.mel.load_audio(raw_audio=np.array(audio_tensor))
    num_slices = pipe.mel.get_number_of_slices()
    slice_idx = random.randint(0, num_slices - 1)  # Pic a random slice each time (excluding the last short slice)
    im = pipe.mel.audio_slice_to_image(slice_idx)
    return im

我们将使用我们的to_image()函数作为自定义collate函数的一部分,将我们的数据集转换为我们可以用于训练的数据加载器。collate函数定义了如何将数据集的一批示例转换为最终的、准备用于训练的数据批次。在本例中,我们将每个音频样本转换为频谱图图像,并将生成的张量堆叠在一起。

>>> def collate_fn(examples):
...     # to image -> to tensor -> rescale to (-1, 1) -> stack into batch
...     audio_ims = [to_t(to_image(x["audio"]["array"])) * 2 - 1 for x in examples]
...     return torch.stack(audio_ims)


>>> # Create a dataset with only the 'Chiptune / Glitch' genre of songs
>>> batch_size = 4  # 4 on colab, 12 on A100
>>> chosen_genre = "Electronic"  # <<< Try training on different genres <<<
>>> indexes = [i for i, g in enumerate(dataset["genre"]) if g == chosen_genre]
>>> filtered_dataset = dataset.select(indexes)
>>> dl = torch.utils.data.DataLoader(
...     filtered_dataset.shuffle(), batch_size=batch_size, collate_fn=collate_fn, shuffle=True
... )
>>> batch = next(iter(dl))
>>> print(batch.shape)
torch.Size([4, 1, 256, 256])

注意:除非您有足够的GPU vRAM可用,否则您需要使用较小的批次大小(例如,4)。

训练循环

这是一个简单的训练循环,它在数据加载器中运行几个周期以微调管道的UNet。您也可以跳过此单元格,并使用以下单元格中的代码加载管道。

epochs = 3
lr = 1e-4

pipe.unet.train()
pipe.scheduler.set_timesteps(1000)
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=lr)

for epoch in range(epochs):
    for step, batch in tqdm(enumerate(dl), total=len(dl)):

        # Prepare the input images
        clean_images = batch.to(device)
        bs = clean_images.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        noisy_images = pipe.scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction
        noise_pred = pipe.unet(noisy_images, timesteps, return_dict=False)[0]

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()
# OR: Load the version I trained earlier
pipe = DiffusionPipeline.from_pretrained("johnowhitaker/Electronic_test").to(device)
>>> output = pipe()
>>> display(output.images[0])
>>> display(Audio(output.audios[0], rate=22050))
>>> # Make a longer sample by passing in a starting noise tensor with a different shape
>>> noise = torch.randn(1, 1, pipe.unet.sample_size[0], pipe.unet.sample_size[1] * 4).to(device)
>>> output = pipe(noise=noise)
>>> display(output.images[0])
>>> display(Audio(output.audios[0], rate=22050))

输出的声音不是最棒的,但这只是一个开始:) 探索调整学习率和周期数,并在Discord上分享您的最佳结果,以便我们共同改进!

一些需要考虑的事情

  • 我们正在使用256px正方形的频谱图图像,这限制了我们的批次大小。您可以从128x128的频谱图中恢复足够质量的音频吗?
  • 代替随机图像增强,我们每次都选择音频剪辑的不同片段,但这在训练多个周期时可以通过一些不同的增强类型来改进吗?
  • 我们还可以如何利用它来生成更长的剪辑?也许您可以生成一个5秒的起始剪辑,然后使用受图像修复启发的想法来继续生成后续的音频片段……
  • 在此频谱图扩散上下文中,图像到图像的等价物是什么?

推送到Hub

当您对模型感到满意时,您可以保存它并将其推送到Hub,供其他人使用。

from huggingface_hub import get_full_repo_name, HfApi, create_repo, ModelCard
# Pick a name for the model
model_name = "audio-diffusion-electronic"
hub_model_id = get_full_repo_name(model_name)
# Save the pipeline locally
pipe.save_pretrained(model_name)
>>> # Inspect the folder contents
>>> !ls {model_name}
mel  model_index.json  scheduler  unet
# Create a repository
create_repo(hub_model_id)
# Upload the files
api = HfApi()
api.upload_folder(folder_path=f"{model_name}/scheduler", path_in_repo="scheduler", repo_id=hub_model_id)
api.upload_folder(folder_path=f"{model_name}/mel", path_in_repo="mel", repo_id=hub_model_id)
api.upload_folder(folder_path=f"{model_name}/unet", path_in_repo="unet", repo_id=hub_model_id)
api.upload_file(
    path_or_fileobj=f"{model_name}/model_index.json",
    path_in_repo="model_index.json",
    repo_id=hub_model_id,
)
# Push a model card
content = f"""
---
license: mit
tags:
- pytorch
- diffusers
- unconditional-audio-generation
- diffusion-models-class
---

# Model Card for Unit 4 of the [Diffusion Models Class 🧨](https://github.com/huggingface/diffusion-models-class)

This model is a diffusion model for unconditional audio generation of music in the genre {chosen_genre}

## Usage

<pre>
from IPython.display import Audio
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("{hub_model_id}")
output = pipe()
display(output.images[0])
display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate()))
</pre>
"""

card = ModelCard(content)
card.push_to_hub(hub_model_id)

结论

希望本笔记本让您对音频生成的潜力有了一点了解。查看本单元介绍中链接的一些参考文献,以了解一些更高级的方法以及它们可以创建的惊人样本!