Transformers.js 文档

utils/generation

您正在查看 main 版本,该版本需要从源代码安装。如果您想要常规的 npm 安装,请查看最新的稳定版本 (v3.0.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

utils/generation

用于生成的类、函数和实用工具。

待办事项

  • 描述如何创建自定义 GenerationConfig

utils/generation.LogitsProcessorList ⇐ Callable

表示 logits 处理器列表的类。logits 处理器是修改语言模型 logits 输出的函数。此类提供了添加新处理器和将所有处理器应用于一批 logits 的方法。

类型utils/generation 的静态类
继承自Callable


new LogitsProcessorList()

构造 LogitsProcessorList 的新实例。


logitsProcessorList.push(item)

向列表添加新的 logits 处理器。

类型LogitsProcessorList 的实例方法

参数类型描述
itemLogitsProcessor

要添加的 logits 处理器函数。


logitsProcessorList.extend(items)

向列表添加多个 logits 处理器。

类型LogitsProcessorList 的实例方法

参数类型描述
itemsArray.<LogitsProcessor>

要添加的 logits 处理器函数。


logitsProcessorList._call(input_ids, batchedLogits)

将列表中的所有 logits 处理器应用于一批 logits,就地修改它们。

类型LogitsProcessorList 的实例方法

参数类型描述
input_idsArray.<number>

语言模型的输入 ID。

batchedLogitsArray.<Array<number>>

logits 的二维数组,其中每行对应于批处理中的单个输入序列。


utils/generation.LogitsProcessor ⇐ Callable

用于处理 logits 的基类。

类型utils/generation 的静态类
继承自Callable


logitsProcessor._call(input_ids, logits)

将处理器应用于输入 logits。

类型LogitsProcessor 的实例抽象方法
抛出:

  • Error 如果子类中未实现 `_call`,则抛出错误。
参数类型描述
input_idsArray

输入 id。

logitsTensor

要处理的 logits。


utils/generation.ForceTokensLogitsProcessor ⇐ LogitsProcessor

一种 logits 处理器,它强制解码器生成特定的 token。

类型utils/generation 的静态类
继承自LogitsProcessor


new ForceTokensLogitsProcessor(forced_decoder_ids)

构造 ForceTokensLogitsProcessor 的新实例。

参数类型描述
forced_decoder_idsArray

应该强制执行的 token 的 ID。


forceTokensLogitsProcessor._call(input_ids, logits) ⇒ Tensor

将处理器应用于输入 logits。

类型ForceTokensLogitsProcessor 的实例方法
返回Tensor - 处理后的 logits。

参数类型描述
input_idsArray

输入 id。

logitsTensor

要处理的 logits。


utils/generation.ForcedBOSTokenLogitsProcessor ⇐ LogitsProcessor

一种 LogitsProcessor,它强制在生成的序列的开头添加 BOS token。

类型utils/generation 的静态类
继承自LogitsProcessor


new ForcedBOSTokenLogitsProcessor(bos_token_id)

创建 ForcedBOSTokenLogitsProcessor。

参数类型描述
bos_token_idnumber

要强制执行的序列开始 token 的 ID。


forcedBOSTokenLogitsProcessor._call(input_ids, logits) ⇒ Object

将 BOS token 强制应用于 logits。

类型ForcedBOSTokenLogitsProcessor 的实例方法
返回Object - 具有 BOS token 强制的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.ForcedEOSTokenLogitsProcessor ⇐ LogitsProcessor

一种 logits 处理器,它强制将序列结束 token 概率设为 1。

类型utils/generation 的静态类
继承自LogitsProcessor


new ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)

创建 ForcedEOSTokenLogitsProcessor。

参数类型描述
max_lengthnumber

序列的最大长度。

forced_eos_token_idnumber | Array.<number>

要强制执行的序列结束 token 的 ID。


forcedEOSTokenLogitsProcessor._call(input_ids, logits)

将处理器应用于 input_ids 和 logits。

类型ForcedEOSTokenLogitsProcessor 的实例方法

参数类型描述
input_idsArray.<number>

输入 id。

logitsTensor

logits tensor。


utils/generation.SuppressTokensAtBeginLogitsProcessor ⇐ LogitsProcessor

一种 LogitsProcessor,当 generate 函数开始使用 begin_index tokens 生成时,它会抑制 token 列表。这应确保由 begin_suppress_tokens 定义的 token 在生成开始时不会被采样。

类型utils/generation 的静态类
继承自LogitsProcessor


new SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)

创建 SuppressTokensAtBeginLogitsProcessor。

参数类型描述
begin_suppress_tokensArray.<number>

要抑制的 token 的 ID。

begin_indexnumber

在抑制 token 之前要生成的 token 数量。


suppressTokensAtBeginLogitsProcessor._call(input_ids, logits) ⇒ Object

将 BOS token 强制应用于 logits。

类型SuppressTokensAtBeginLogitsProcessor 的实例方法
返回Object - 具有 BOS token 强制的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.WhisperTimeStampLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个处理向生成的文本添加时间戳的 LogitsProcessor。

类型utils/generation 的静态类
继承自LogitsProcessor


new WhisperTimeStampLogitsProcessor(generate_config)

构造一个新的 WhisperTimeStampLogitsProcessor。

参数类型描述
generate_configObject

传递给 transformer 模型的 generate() 方法的配置对象。

generate_config.eos_token_idnumber

序列结束 (end-of-sequence) 标记的 ID。

generate_config.no_timestamps_token_idnumber

用于指示标记不应具有时间戳的标记 ID。

[generate_config.forced_decoder_ids]Array.<Array<number>>

一个由双元素数组组成的数组,表示强制出现在输出中的解码器 ID。每个数组的第二个元素指示标记是否为时间戳。

[generate_config.max_initial_timestamp_index]number

初始时间戳可以出现的最大索引。


whisperTimeStampLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

修改 logits 以处理时间戳标记。

类型WhisperTimeStampLogitsProcessor 的实例方法
返回值Tensor - 修改后的 logits。

参数类型描述
input_idsArray

标记的输入序列。

logitsTensor

模型输出的 logits。


utils/generation.NoRepeatNGramLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 logits 处理器,禁止重复出现特定大小的 n-gram。

类型utils/generation 的静态类
继承自LogitsProcessor


new NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)

创建一个 NoRepeatNGramLogitsProcessor。

参数类型描述
no_repeat_ngram_sizenumber

no-repeat-ngram 大小。此大小的所有 n-gram 只能出现一次。


noRepeatNGramLogitsProcessor.getNgrams(prevInputIds) ⇒ <code> Map. < string, Array < number > > </code>

从标记 ID 序列生成 n-gram。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回值Map.<string, Array<number>> - 生成的 n-gram 的 Map

参数类型描述
prevInputIdsArray.<number>

先前输入 ID 的列表


noRepeatNGramLogitsProcessor.getGeneratedNgrams(bannedNgrams, prevInputIds) ⇒ <code> Array. < number > </code>

从标记 ID 序列生成 n-gram。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回值Array.<number> - 生成的 n-gram 的 Map

参数类型描述
bannedNgramsMap.<string, Array<number>>

被禁止的 n-gram 的 Map

prevInputIdsArray.<number>

先前输入 ID 的列表


noRepeatNGramLogitsProcessor.calcBannedNgramTokens(prevInputIds) ⇒ <code> Array. < number > </code>

计算被禁止的 n-gram 标记

类型NoRepeatNGramLogitsProcessor 的实例方法
返回值Array.<number> - 生成的 n-gram 的 Map

参数类型描述
prevInputIdsArray.<number>

先前输入 ID 的列表


noRepeatNGramLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将 no-repeat-ngram 处理器应用于 logits。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回值Object - 经过 no-repeat-ngram 处理的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.RepetitionPenaltyLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个惩罚重复输出标记的 logits 处理器。

类型utils/generation 的静态类
继承自LogitsProcessor


new RepetitionPenaltyLogitsProcessor(penalty)

创建一个 RepetitionPenaltyLogitsProcessor。

参数类型描述
penaltynumber

应用于重复标记的惩罚。


repetitionPenaltyLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将重复惩罚应用于 logits。

类型RepetitionPenaltyLogitsProcessor 的实例方法
返回值Object - 经过重复惩罚处理的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.MinLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制执行最小标记数量的 logits 处理器。

类型utils/generation 的静态类
继承自LogitsProcessor


new MinLengthLogitsProcessor(min_length, eos_token_id)

创建一个 MinLengthLogitsProcessor。

参数类型描述
min_lengthnumber

低于此最小长度,eos_token_id 的分数将被设置为负无穷大。

eos_token_idnumber | Array.<number>

序列结束标记的 ID/IDs。


minLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logit 处理器。

类型MinLengthLogitsProcessor 的实例方法
返回值Object - 处理后的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.MinNewTokensLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制执行最小新标记数量的 logits 处理器。

类型utils/generation 的静态类
继承自LogitsProcessor


new MinNewTokensLengthLogitsProcessor(prompt_length_to_skip, min_new_tokens, eos_token_id)

创建一个 MinNewTokensLengthLogitsProcessor。

参数类型描述
prompt_length_to_skipnumber

输入标记长度。

min_new_tokensnumber

低于此最小标记长度,eos_token_id 的分数将被设置为负无穷大。

eos_token_idnumber | Array.<number>

序列结束标记的 ID/IDs。


minNewTokensLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logit 处理器。

类型MinNewTokensLengthLogitsProcessor 的实例方法
返回值Object - 处理后的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.NoBadWordsLogitsProcessor

类型utils/generation 的静态类


new NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)

创建一个 NoBadWordsLogitsProcessor

参数类型描述
bad_words_idsArray.<Array<number>>

不允许生成的标记 ID 列表的列表。

eos_token_idnumber | Array.<number>

序列结束标记的 ID。 可选地,使用列表来设置多个序列结束标记。


noBadWordsLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logit 处理器。

类型NoBadWordsLogitsProcessor 的实例方法
返回值Object - 处理后的 logits。

参数类型描述
input_idsArray

输入 ID。

logitsObject

logits。


utils/generation.Sampler

Sampler 是用于文本生成的所有采样方法的基础类。

类型utils/generation 的静态类


new Sampler(generation_config)

使用指定的生成配置创建一个新的 Sampler 对象。

参数类型描述
generation_configGenerationConfigType

生成配置。


sampler._call(logits, index) ⇒ <code> void </code>

执行 sampler,使用指定的 logits。

类型Sampler 的实例方法

参数类型
logitsTensor
indexnumber

sampler.sample(logits, index)

用于采样 logits 的抽象方法。

类型Sampler 的实例方法
抛出:

  • 错误
参数类型
logitsTensor
indexnumber

sampler.getLogits(logits, index) ⇒ <code> Float32Array </code>

返回指定的 logits 数组,并应用温度系数。

类型Sampler 的实例方法

参数类型
logitsTensor
indexnumber

sampler.randomSelect(probabilities) ⇒ <code> number </code>

根据指定的概率随机选择一个项目。

类型Sampler 的实例方法
返回值number - 所选项目的索引。

参数类型描述
probabilitiesArray

用于选择的概率数组。


Sampler.getSampler(generation_config) ⇒ <code> Sampler </code>

根据指定的选项返回一个 Sampler 对象。

类型Sampler 的静态方法
返回值Sampler - 一个 Sampler 对象。

参数类型描述
generation_configGenerationConfigType

包含 sampler 选项的对象。


utils/generation.GenerationConfig : <code> * </code>

用于保存生成任务配置的类。

类型utils/generation 的静态常量


utils/generation~GenerationConfig

类型utils/generation 的内部类


new GenerationConfig(kwargs)

创建一个新的 GenerationConfig 对象。

参数类型
kwargs(关键字参数)GenerationConfigType

utils/generation~GreedySampler ⇐ <code> Sampler </code>

表示贪婪采样器的类。

类型utils/generation 的内部类
继承自: Sampler


greedySampler.sample(logits, [index]) ⇒ <code> Array </code>

对给定的 logits 张量进行最大概率采样。

类型GreedySampler 的实例方法
返回值Array - 包含单个元组的数组,元组包含最大值的索引和一个无意义的分数(因为这是贪婪搜索)。

参数类型默认
logitsTensor
[index](索引)number-1

utils/generation~MultinomialSampler ⇐ <code> Sampler </code>

表示多项式采样器的类。

类型utils/generation 的内部类
继承自: Sampler


multinomialSampler.sample(logits, index) ⇒ <code> Array </code>

从 logits 中采样。

类型MultinomialSampler 的实例方法

参数类型
logitsTensor
indexnumber

utils/generation~BeamSearchSampler ⇐ <code> Sampler </code>

表示束搜索采样器的类。

类型utils/generation 的内部类
继承自: Sampler


beamSearchSampler.sample(logits, index) ⇒ <code> Array </code>

从 logits 中采样。

类型BeamSearchSampler 的实例方法

参数类型
logitsTensor
indexnumber

utils/generation~GenerationConfigType : <code> Object </code>

默认配置参数。

类型utils/generation 的内部类型定义
属性

名称类型默认描述
[max_length](最大长度)number20

生成的 token 可以具有的最大长度。对应于输入提示的长度 + max_new_tokens。如果也设置了 max_new_tokens,则其效果将被 max_new_tokens 覆盖。

[max_new_tokens](最大新 token 数)number

要生成的最大 token 数,忽略提示中的 token 数。

[min_length](最小长度)number0

要生成的序列的最小长度。对应于输入提示的长度 + min_new_tokens。如果也设置了 min_new_tokens,则其效果将被 min_new_tokens 覆盖。

[min_new_tokens](最小新 token 数)number

要生成的最小 token 数,忽略提示中的 token 数。

[early_stopping](早停)boolean | "never"false

控制基于 beam 的方法(如束搜索)的停止条件。它接受以下值:

  • true,当存在 num_beams 个完整候选项时,生成停止;
  • false,应用启发式方法,当找到更好候选项的可能性非常低时,生成停止;
  • "never",束搜索过程仅在无法找到更好候选项时停止(规范束搜索算法)。
[max_time](最大时间)number

允许计算运行的最长时间(秒)。即使超过分配的时间,生成仍将完成当前 pass。

[do_sample](是否采样)booleanfalse

是否使用采样;否则使用贪婪解码。

[num_beams](束数量)number1

束搜索的束数量。1 表示不进行束搜索。

[num_beam_groups](束组数量)number1

num_beams 分成组的数量,以确保不同束组之间的多样性。有关更多详细信息,请参阅本文

[penalty_alpha](惩罚 alpha)number

这些值平衡了模型的置信度和对比搜索解码中的退化惩罚。

[use_cache](使用缓存)booleantrue

模型是否应使用过去的最后 key/values attention(如果适用于模型)来加速解码。

[temperature](温度)number1.0

用于调节下一个 token 概率的值。

[top_k](Top-K)number50

为 top-k 过滤保留的最高概率词汇 token 的数量。

[top_p](Top-P)number1.0

如果设置为 float < 1,则仅保留概率之和达到 top_p 或更高的最小概率 token 集合以进行生成。

[typical_p](典型 P)number1.0

局部典型性衡量预测下一个目标 token 的条件概率与预测下一个随机 token 的预期条件概率(给定已生成的部分文本)的相似程度。如果设置为 float < 1,则仅保留概率之和达到 typical_p 或更高的最小局部典型 token 集合以进行生成。有关更多详细信息,请参阅本文

[epsilon_cutoff](epsilon 截断)number0.0

如果设置为严格介于 0 和 1 之间的 float,则仅采样条件概率大于 epsilon_cutoff 的 token。在论文中,建议值范围为 3e-4 到 9e-4,具体取决于模型的大小。有关更多详细信息,请参阅截断采样作为语言模型平滑

[eta_cutoff](eta 截断)number0.0

Eta 采样是局部典型性采样和 epsilon 采样的混合体。如果设置为严格介于 0 和 1 之间的 float,则仅当 token 大于 eta_cutoffsqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))) 时才考虑该 token。后一个术语在直观上是预期的下一个 token 概率,按 sqrt(eta_cutoff) 缩放。在论文中,建议值范围为 3e-4 到 2e-3,具体取决于模型的大小。有关更多详细信息,请参阅截断采样作为语言模型平滑

[diversity_penalty](多样性惩罚)number0.0

如果某个束生成的 token 与来自其他组的任何束在特定时间生成的 token 相同,则从该束的分数中减去此值。请注意,仅当启用 group beam search (分组束搜索)时,diversity_penalty 才有效。

[repetition_penalty](重复惩罚)number1.0

重复惩罚的参数。1.0 表示没有惩罚。有关更多详细信息,请参阅本文

[encoder_repetition_penalty](编码器重复惩罚)number1.0

encoder_repetition_penalty 的参数。对不在原始输入中的序列施加指数惩罚。1.0 表示没有惩罚。

[length_penalty](长度惩罚)number1.0

用于基于 beam 的生成的长度的指数惩罚。它作为指数应用于序列长度,而序列长度又用于除以序列的分数。由于分数是序列的对数似然(即负数),因此 length_penalty > 0.0 鼓励更长的序列,而 length_penalty < 0.0 鼓励更短的序列。

[no_repeat_ngram_size](无重复 N 元语法大小)number0

如果设置为 int > 0,则该大小的所有 n-gram 只能出现一次。

[bad_words_ids](禁用词 ID)Array.<Array<number>>

不允许生成的 token ID 列表。为了获得不应出现在生成的文本中的单词的 token ID,请使用 (await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids

[force_words_ids](强制词 ID)Array<Array<number>> | Array<Array<Array<number>>>

必须生成的 token ID 列表。如果给定 number[][],则将其视为必须包含的简单单词列表,与 bad_words_ids 相反。如果给定 number[][][],则会触发析取约束,其中可以允许每个单词的不同形式。

[renormalize_logits](重新归一化 logits)booleanfalse

在应用所有 logits 处理器或变形器(包括自定义的)之后是否重新归一化 logits。强烈建议将此标志设置为 true,因为搜索算法假设分数 logits 已归一化,但某些 logits 处理器或变形器会破坏归一化。

[constraints](约束)Array.<Object>

可以添加到生成的自定义约束,以确保输出将包含 Constraint 对象定义的某些 token 的使用,以最合理的方式。

[forced_bos_token_id](强制 BOS token ID)number

decoder_start_token_id 之后强制作为第一个生成的 token 的 token ID。对于像 mBART 这样的多语言模型很有用,其中第一个生成的 token 需要是目标语言 token。

[forced_eos_token_id](强制 EOS token ID)number | Array.<number>

当达到 max_length 时,强制作为最后一个生成的 token 的 token ID。或者,使用列表设置多个序列结束 token。

[remove_invalid_values](移除无效值)booleanfalse

是否删除模型可能产生的 naninf 输出,以防止生成方法崩溃。请注意,使用 remove_invalid_values 可能会减慢生成速度。

[exponential_decay_length_penalty](指数衰减长度惩罚)Array.<number>

此元组添加指数增长的长度惩罚,在生成一定数量的 token 后。该元组应包含:(start_index, decay_factor),其中 start_index 指示惩罚开始的位置,decay_factor 表示指数衰减的因子。

[suppress_tokens](抑制 token)Array.<number>

在生成时将被抑制的 token 列表。SupressTokens logits 处理器将它们的对数概率设置为 -inf,以便不采样它们。

[begin_suppress_tokens](开始抑制 token)Array.<number>

在生成开始时将被抑制的 token 列表。SupressBeginTokens logits 处理器将它们的对数概率设置为 -inf,以便不采样它们。

[forced_decoder_ids](强制解码器 ID)Array.<Array<number>>

整数对列表,指示从生成索引到 token 索引的映射,这些 token 将在采样之前强制执行。例如,[[1, 123]] 表示第二个生成的 token 将始终是索引为 123 的 token。

[num_return_sequences](返回序列数)number1

批次中每个元素独立计算的返回序列数。

[output_attentions](输出 attentions)booleanfalse

是否返回所有 attention 层的 attention 张量。有关更多详细信息,请参阅返回张量下的 attentions

[output_hidden_states](输出 hidden states)booleanfalse

是否返回所有层的 hidden states。有关更多详细信息,请参阅返回张量下的 hidden_states

[output_scores](输出 scores)booleanfalse

是否返回预测分数。有关更多详细信息,请参阅返回张量下的 scores

[return_dict_in_generate](在 generate 中返回 dict)booleanfalse

是否返回 ModelOutput 而不是普通元组。

[pad_token_id](填充 token ID)number

填充 token 的 ID。

[bos_token_id](BOS token ID)number

序列开始 token 的 ID。

[eos_token_id](EOS token ID)number | Array.<number>

序列结束标记的 ID。 可选地,使用列表来设置多个序列结束标记。

[encoder_no_repeat_ngram_size](编码器无重复 N 元语法大小)number0

如果设置为 int > 0,则在 encoder_input_ids 中出现的所有该大小的 n-gram 不能在 decoder_input_ids 中出现。

[decoder_start_token_id](解码器开始 token ID)number

如果编码器-解码器模型以与 bos 不同的 token 开始解码,则该 token 的 ID。

[generation_kwargs](生成 kwargs)Object{}

额外的生成 kwargs 将转发到模型的 generate 函数。generate 签名中不存在的 Kwargs 将在模型前向传递中使用。


< > 在 GitHub 上更新