Transformers 文档

文本生成

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

文本生成

文本生成是大型语言模型(LLM)最受欢迎的应用。LLM 经过训练,可以根据一些初始文本(提示)以及其自身生成的输出,生成下一个单词(token),直到达到预定义的长度或达到序列结束(EOS)token。

在 Transformers 中,generate() API 处理文本生成,它适用于所有具有生成能力的模型。

本指南将向您展示使用 generate() 进行文本生成的基础知识,以及一些需要避免的常见陷阱。

默认生成

在开始之前,建议安装 bitsandbytes,以量化非常大的模型,从而减少其内存使用量。

!pip install -U transformers bitsandbytes

除了基于 CUDA 的 GPU 之外,Bitsandbytes 还支持多个后端。请参阅多后端安装指南以了解更多信息。

使用 from_pretrained() 加载 LLM,并添加以下两个参数以减少内存需求。

  • device_map="auto" 启用 Accelerate 的 Big Model Inference 功能,用于自动初始化模型骨架,并在所有可用设备(从最快的设备 (GPU) 开始)上加载和分派模型权重。
  • quantization_config 是一个配置对象,用于定义量化设置。此示例使用 bitsandbytes 作为量化后端(有关更多可用的后端,请参阅量化部分),并在 4 位中加载模型。
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", quantization_config=quantization_config)

对您的输入进行分词,并将 padding_side() 参数设置为 "left",因为 LLM 没有经过训练以从填充 token 继续生成。分词器返回输入 ID 和注意力掩码。

通过将字符串列表传递给分词器,一次处理多个提示。批量处理输入可以提高吞吐量,但会略微增加延迟和内存成本。

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")

将输入传递给 generate() 以生成 token,并使用 batch_decode() 将生成的 token 解码回文本。

generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
"A list of colors: red, blue, green, yellow, orange, purple, pink,"

生成配置

所有生成设置都包含在 GenerationConfig 中。在上面的示例中,生成设置来源于 mistralai/Mistral-7B-v0.1generation_config.json 文件。当模型没有保存任何配置时,将使用默认的解码策略。

通过 generation_config 属性检查配置。它仅显示与默认配置不同的值,在本例中为 bos_token_ideos_token_id

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
model.generation_config
GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}

您可以通过覆盖 GenerationConfig 中的参数和值来自定义 generate()。一些最常调整的参数是 max_new_tokensnum_beamsdo_samplenum_return_sequences

# enable beam search sampling strategy
model.generate(**inputs, num_beams=4, do_sample=True)

generate() 还可以使用外部库或自定义代码进行扩展。logits_processor 参数接受自定义的 LogitsProcessor 实例,用于操作下一个 token 的概率分布。stopping_criteria 支持自定义的 StoppingCriteria 以停止文本生成。查看 logits-processor-zoo 以获取更多与 generate() 兼容的扩展示例。

请参阅生成策略指南,以了解有关搜索、采样和解码策略的更多信息。

保存

创建 GenerationConfig 的实例,并指定您想要的解码参数。

from transformers import AutoModelForCausalLM, GenerationConfig

model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
generation_config = GenerationConfig(
    max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)

使用 save_pretrained() 保存特定的生成配置,并将 push_to_hub 参数设置为 True 以将其上传到 Hub。

generation_config.save_pretrained("my_account/my_model", push_to_hub=True)

config_file_name 参数留空。当在单个目录中存储多个生成配置时,应使用此参数。它为您提供了一种指定要加载哪个生成配置的方法。您可以为不同的生成任务(使用采样的创意文本生成,使用集束搜索的摘要)创建不同的配置,以用于单个模型。

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")

translation_generation_config = GenerationConfig(
    num_beams=4,
    early_stopping=True,
    decoder_start_token_id=0,
    eos_token_id=model.config.eos_token_id,
    pad_token=model.config.pad_token_id,
)

translation_generation_config.save_pretrained("/tmp", config_file_name="translation_generation_config.json", push_to_hub=True)

generation_config = GenerationConfig.from_pretrained("/tmp", config_file_name="translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

陷阱

以下部分介绍您在文本生成期间可能遇到的一些常见问题以及如何解决它们。

输出长度

默认情况下,generate() 最多返回 20 个 token,除非在模型的 GenerationConfig 中另有指定。强烈建议手动使用 max_new_tokens 参数设置生成的 token 数量,以控制输出长度。仅解码器模型会返回初始提示以及生成的 token。

model_inputs = tokenizer(["A sequence of numbers: 1, 2"], return_tensors="pt").to("cuda")
默认长度
max_new_tokens
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'A sequence of numbers: 1, 2, 3, 4, 5'

解码策略

除非在模型的 GenerationConfig 中另有指定,否则 generate() 中的默认解码策略是贪婪搜索,它选择下一个最有可能的 token。虽然这种解码策略对于输入相关的任务(转录、翻译)效果良好,但对于更具创造性的用例(故事写作、聊天应用程序)而言,它并非最佳选择。

例如,启用 多项式采样策略以生成更多样化的输出。有关更多解码策略,请参阅生成策略指南。

model_inputs = tokenizer(["I am a cat."], return_tensors="pt").to("cuda")
贪婪搜索
多项式采样
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

填充侧

如果输入长度不同,则需要进行填充。但是 LLM 没有经过训练以从填充 token 继续生成,这意味着 padding_side() 参数需要设置为输入的左侧。

右侧填充
左侧填充
model_inputs = tokenizer(
    ["1, 2, 3", "A, B, C, D, E"], padding=True, return_tensors="pt"
).to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'1, 2, 33333333333'

提示格式

某些模型和任务需要特定的输入提示格式,如果格式不正确,模型将返回次优的输出。您可以在提示工程指南中了解有关提示的更多信息。

例如,聊天模型期望输入为聊天模板。您的提示应包括 rolecontent 以指示谁参与了对话。如果您尝试将提示作为单个字符串传递,则模型并不总是返回预期的输出。

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/zephyr-7b-alpha", device_map="auto", load_in_4bit=True
)
无格式
聊天模板
prompt = """How many cats does it take to change a light bulb? Reply as a pirate."""
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
input_length = model_inputs.input_ids.shape[1]
generated_ids = model.generate(**model_inputs, max_new_tokens=50)
print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])
"Aye, matey! 'Tis a simple task for a cat with a keen eye and nimble paws. First, the cat will climb up the ladder, carefully avoiding the rickety rungs. Then, with"

资源

查看下面的一些更具体和专业的文本生成库。

  • Optimum:Transformers 的扩展,专注于优化特定硬件设备上的训练和推理
  • Outlines:用于约束文本生成的库(例如,生成 JSON 文件)。
  • SynCode:用于上下文无关文法引导生成的库(JSON、SQL、Python)。
  • Text Generation Inference:用于 LLM 的生产就绪服务器。
  • Text generation web UI:用于文本生成的 Gradio Web UI。
  • logits-processor-zoo:用于控制文本生成的其他 logits 处理器。
< > 在 GitHub 上更新