在本地运行您自定义的 LoRA 精调 MusicGen Large

社区文章 发布于 2024 年 12 月 6 日

使用 LoRA(低秩适应)微调模型(如 MusicGen Large)创建自定义 AI 生成音乐是利用 AI 创造潜力的强大方式。本指南将引导您在本地运行自己的微调模型,微调自己的 LoRA 模型,并将其部署到 API 以实现更广泛的访问。或者,您可以使用 Hugging Face 等平台上的预训练 LoRA 适配器。


1. 设置您的环境

在开始之前,请确保您具备以下条件:

  • CUDA 兼容 GPU 以进行加速。
  • 已安装 Python 3.8 或更高版本。
  • 所需的 Python 库:torchtransformerspeftsoundfilefastapiuvicorn

使用 pip 安装所需的库:

pip install torch transformers peft soundfile fastapi uvicorn

2. 微调您自己的 LoRA 模型

如果您想微调自己的 LoRA 适配器,请遵循以下步骤:

  1. 准备您的数据集:

    • 收集高质量音频及相应文本提示。
    • 预处理数据以对齐音频-文本对。
  2. 使用 PEFT 框架:

    • 使用 PEFT 库训练 LoRA 适配器。
    • 将适配器配置和权重保存在目录中。
  3. 测试您的微调模型:

    • 将微调后的适配器加载到基础模型中进行评估。

有关详细说明,请参阅 PEFT 文档


3. 使用 Hugging Face 上的预训练 LoRA

如果您更喜欢使用预训练的 LoRA 适配器:

  1. 下载适配器:访问 Hugging Face 并搜索与 MusicGen 兼容的 LoRA。下载或克隆存储库。

  2. 设置本地仓库路径:更新代码中的 local_repo_path,使其指向包含您的 LoRA 适配器文件的目录。


4. 在本地运行模型

下面是一个 Python 脚本,用于在本地运行您自定义的 LoRA 精调 MusicGen Large 模型并将其部署到 API。


代码

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForTextToWaveform, AutoProcessor
import soundfile as sf
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import os
import time
import gc
from contextlib import asynccontextmanager

class MusicRequest(BaseModel):
    prompt: str
    duration: int

# Configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
local_repo_path = "/path/to/your/lora"  # Update this path
base_model_name = "facebook/musicgen-large"
model, processor = None, None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, processor
    try:
        # Load LoRA configuration
        adapter_config_path = os.path.join(local_repo_path, "adapter_config.json")
        if os.path.exists(adapter_config_path):
            adapter_config = PeftConfig.from_pretrained(local_repo_path)
            base_model_name_or_path = adapter_config.base_model_name_or_path
            base_model = AutoModelForTextToWaveform.from_pretrained(
                base_model_name_or_path,
                torch_dtype=torch.float16,
                local_files_only=True
            )
            model = PeftModel.from_pretrained(base_model, local_repo_path, local_files_only=True).to(device)
        else:
            model = AutoModelForTextToWaveform.from_pretrained(
                base_model_name,
                torch_dtype=torch.float16,
                local_files_only=True
            ).to(device)

        processor = AutoProcessor.from_pretrained(
            local_repo_path if os.path.exists(adapter_config_path) else base_model_name
        )
        yield
    finally:
        del model
        torch.cuda.empty_cache()
        gc.collect()

app = FastAPI(lifespan=lifespan)

@app.post("/generate-music/")
async def generate_music(request: MusicRequest):
    global model, processor
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    if request.duration <= 0:
        raise HTTPException(status_code=400, detail="Invalid duration")
    
    # Preprocess input prompt
    processed_prompt = f"Genre: upbeat; Description: {request.prompt}"
    
    inputs = processor(text=[processed_prompt], return_tensors="pt").to(device)
    max_new_tokens = int(request.duration * 50)  # Approximate tokens per second
    audio_values = model.generate(
        **inputs,
        do_sample=True,
        guidance_scale=3,
        max_new_tokens=max_new_tokens
    )
    
    # Post-process audio: normalize and save
    sampling_rate = model.config.audio_encoder.sampling_rate
    output_path = f"song_{int(time.time())}.wav"
    audio_values_normalized = (audio_values[0].cpu().numpy() / abs(audio_values[0].cpu().numpy()).max()) * 0.9  # Normalize
    sf.write(output_path, audio_values_normalized, sampling_rate)
    return {"song": output_path}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

5. 部署您的 API

  1. 运行脚本:

    python script.py
    
  2. 测试 API:使用 Postman 或 curl 等工具。

    curl -X POST "https://:8000/generate-music/" -H "Content-Type: application/json" -d '{"prompt": "upbeat jazz", "duration": 10}'
    

6. 提高音乐质量的建议

1. 预处理提示

  • 使用结构化和描述性的提示来为模型提供清晰的指令(例如,包括流派、节奏或乐器)。

2. 后处理音频

  • 规范化音频输出以消除削波并平衡音量。
  • 使用 librosa 等库应用低通滤波器或去噪音频。

3. 使用高质量数据进行微调

  • 策划与您的目标音乐风格相符的数据集。
  • 微调时,侧重于高保真、低噪声的录音。

4. 试验参数

  • 调整 guidance_scale 以控制创造性与连贯性。
  • 增加 max_new_tokens 以获得更复杂的输出。

5. 验证音频质量

  • 使用信噪比 (SNR) 等音频指标,确保生成的音乐达到质量阈值。

6. 添加风格控制

  • 微调或使用 LoRA 适配器,以允许控制特定风格或乐器,从而获得更多样化的输出。

这些改进将确保更清晰、更精致的音乐生成,同时为用户提供更大的创作控制权。

社区

注册登录 发表评论