Transformers 文档
文本生成
并获得增强的文档体验
开始使用
文本生成
文本生成是大型语言模型(LLM)最受欢迎的应用。LLM 经过训练,可以根据初始文本(提示)以及自身生成的输出,生成下一个单词(标记),直到达到预定义长度或遇到序列结束(EOS
)标记。
在 Transformers 中,generate() API 处理文本生成,它适用于所有具有生成能力的模型。本指南将向您展示使用 generate() 进行文本生成的基础知识以及一些常见的常见陷阱。
您还可以直接从命令行与模型聊天。(参考)
transformers chat Qwen/Qwen2.5-0.5B-Instruct
默认生成
在开始之前,安装 bitsandbytes 以量化大型模型以减少其内存使用量会很有帮助。
!pip install -U transformers bitsandbytes
Bitsandbytes 除了基于 CUDA 的 GPU 外,还支持多种后端。请参阅多后端安装指南以了解更多信息。
使用 from_pretrained() 加载 LLM 并添加以下两个参数以减少内存需求。
device_map="auto"
启用 Accelerate 的大型模型推理功能,用于自动初始化模型骨架并在所有可用设备上加载和分派模型权重,从最快的设备(GPU)开始。quantization_config
是一个配置对象,定义了量化设置。此示例使用 bitsandbytes 作为量化后端(有关更多可用后端,请参阅量化部分),它以4 位加载模型。
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", quantization_config=quantization_config)
对您的输入进行标记化,并将 padding_side()
参数设置为 "left"
,因为 LLM 未经训练以从填充标记继续生成。标记器返回输入 ID 和注意力掩码。
通过向标记器传递字符串列表,一次处理多个提示。批量输入以在延迟和内存方面付出少量代价来提高吞吐量。
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
将输入传递给 generate() 以生成标记,并使用 batch_decode() 将生成的标记解码回文本。
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
"A list of colors: red, blue, green, yellow, orange, purple, pink,"
生成配置
所有生成设置都包含在 GenerationConfig 中。在上面的示例中,生成设置来自 mistralai/Mistral-7B-v0.1 的 generation_config.json
文件。如果模型没有保存配置,则使用默认解码策略。
通过 generation_config
属性检查配置。它只显示与默认配置不同的值,在此示例中是 bos_token_id
和 eos_token_id
。
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
model.generation_config
GenerationConfig {
"bos_token_id": 1,
"eos_token_id": 2
}
您可以通过覆盖 GenerationConfig 中的参数和值来自定义 generate()。有关常用调整参数,请参阅下面的此部分。
# enable beam search sampling strategy
model.generate(**inputs, num_beams=4, do_sample=True)
generate() 还可以通过外部库或自定义代码进行扩展
logits_processor
参数接受自定义 LogitsProcessor 实例,用于操纵下一个标记概率分布;stopping_criteria
参数支持自定义 StoppingCriteria 以停止文本生成;- 其他自定义生成方法可以通过
custom_generate
标志加载(文档)。
请参阅生成策略指南,了解有关搜索、采样和解码策略的更多信息。
保存
创建 GenerationConfig 实例并指定所需的解码参数。
from transformers import AutoModelForCausalLM, GenerationConfig
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
generation_config = GenerationConfig(
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
使用 save_pretrained() 保存特定的生成配置,并将 push_to_hub
参数设置为 True
以将其上传到 Hub。
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
将 config_file_name
参数留空。当在一个目录中存储多个生成配置时,应使用此参数。它为您提供了一种指定要加载哪个生成配置的方式。您可以为不同的生成任务(使用采样的创意文本生成,使用束搜索的摘要)创建不同的配置,以供单个模型使用。
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
translation_generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
decoder_start_token_id=0,
eos_token_id=model.config.eos_token_id,
pad_token=model.config.pad_token_id,
)
translation_generation_config.save_pretrained("/tmp", config_file_name="translation_generation_config.json", push_to_hub=True)
generation_config = GenerationConfig.from_pretrained("/tmp", config_file_name="translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
常用选项
generate() 是一个功能强大的工具,可以高度自定义。这对于新用户来说可能令人生畏。本节列出了 Transformers 中大多数文本生成工具中可以定义的流行生成选项:generate()、GenerationConfig、pipelines
、chat
CLI 等。
选项名称 | 类型 | 简化描述 |
---|---|---|
max_new_tokens | int | 控制最大生成长度。务必定义它,因为它通常默认为一个较小的值。 |
do_sample | 布尔值 | 定义生成是采样下一个标记(True ),还是贪婪生成(False )。大多数用例应将此标志设置为 True 。有关更多信息,请查看此指南。 |
temperature | 浮点数 | 下一个选定标记的不可预测性。高值(>0.8 )适用于创意任务,低值(例如 <0.4 )适用于需要“思考”的任务。需要 do_sample=True 。 |
num_beams | int | 当设置为 >1 时,激活束搜索算法。束搜索在基于输入的任务中效果良好。有关更多信息,请查看此指南。 |
repetition_penalty | 浮点数 | 如果模型经常重复自身,请将其设置为 >1.0 。值越大,惩罚越大。 |
eos_token_id | list[int] | 将导致生成停止的标记。默认值通常很好,但您可以指定不同的标记。 |
常见陷阱
以下部分介绍了一些您在文本生成过程中可能遇到的常见问题以及如何解决它们。
输出长度
除非在模型的 GenerationConfig 中另有指定,否则 generate() 默认最多返回 20 个标记。强烈建议使用 max_new_tokens
参数手动设置生成的标记数量以控制输出长度。仅解码器模型返回初始提示以及生成的标记。
model_inputs = tokenizer(["A sequence of numbers: 1, 2"], return_tensors="pt").to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'A sequence of numbers: 1, 2, 3, 4, 5'
解码策略
除非在模型的 GenerationConfig 中另有指定,否则 generate() 中的默认解码策略是*贪婪搜索*,它选择下一个最可能的标记。虽然此解码策略适用于基于输入的任务(转录、翻译),但它不适用于更具创意性的用例(故事创作、聊天应用程序)。
例如,启用多项式采样策略以生成更多样化的输出。请参阅生成策略指南,了解更多解码策略。
model_inputs = tokenizer(["I am a cat."], return_tensors="pt").to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
填充侧
如果输入长度不一致,则需要填充。但是 LLM 未经训练以从填充标记继续生成,这意味着 padding_side()
参数需要设置为输入的左侧。
model_inputs = tokenizer(
["1, 2, 3", "A, B, C, D, E"], padding=True, return_tensors="pt"
).to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'1, 2, 33333333333'
提示格式
某些模型和任务需要特定的输入提示格式,如果格式不正确,模型将返回次优输出。您可以在提示工程指南中了解有关提示的更多信息。
例如,聊天模型将输入视为聊天模板。您的提示应包含 role
和 content
以指示对话的参与者。如果您尝试将提示作为单个字符串传递,模型不总是返回预期的输出。
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-alpha", device_map="auto", load_in_4bit=True
)
prompt = """How many cats does it take to change a light bulb? Reply as a pirate."""
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
input_length = model_inputs.input_ids.shape[1]
generated_ids = model.generate(**model_inputs, max_new_tokens=50)
print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])
"Aye, matey! 'Tis a simple task for a cat with a keen eye and nimble paws. First, the cat will climb up the ladder, carefully avoiding the rickety rungs. Then, with"
资源
请看下面一些更具体和专业的文本生成库。
- Optimum:Transformers 的扩展,专注于优化特定硬件设备上的训练和推理
- Outlines:一个用于受限文本生成的库(例如生成 JSON 文件)。
- SynCode:一个用于上下文无关语法引导生成的库(JSON、SQL、Python)。
- Text Generation Inference:一个用于 LLM 的生产就绪服务器。
- Text generation web UI:一个用于文本生成的 Gradio Web UI。
- logits-processor-zoo:用于控制文本生成的附加 logits 处理器。