Transformers.js 文档

utils/generation

您正在查看的是需要从源码安装。如果您想进行常规 npm 安装,请查看最新的稳定版本 (v3.0.0)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

utils/generation

用于生成任务的类、函数和实用工具。

待办

  • 描述如何创建自定义 GenerationConfig

utils/generation.LogitsProcessorList ⇐ <code> Callable </code>

一个表示 logits 处理器列表的类。logits 处理器是修改语言模型输出 logits 的函数。该类提供了添加新处理器并将所有处理器应用于一批 logits 的方法。

类型utils/generation 的静态类
继承Callable


new LogitsProcessorList()

构造 LogitsProcessorList 的新实例。


logitsProcessorList.push(item)

向列表添加新的 logits 处理器。

类型LogitsProcessorList 的实例方法

参数量类型描述
itemLogitsProcessor

要添加的 logits 处理器函数。


logitsProcessorList.extend(items)

向列表添加多个 logits 处理器。

类型LogitsProcessorList 的实例方法

参数量类型描述
itemsArray.<LogitsProcessor>

要添加的 logits 处理器函数。


logitsProcessorList._call(input_ids, batchedLogits)

将列表中所有的 logits 处理器应用于一批 logits,并就地修改它们。

类型LogitsProcessorList 的实例方法

参数量类型描述
input_idsArray.<number>

语言模型的输入 ID。

batchedLogitsArray.<Array<number>>

一个二维 logits 数组,其中每行对应于批处理中的单个输入序列。


utils/generation.LogitsProcessor ⇐ <code> Callable </code>

logits 处理器的基类。

类型utils/generation 的静态类
继承Callable


logitsProcessor._call(input_ids, logits)

将处理器应用于输入 logits。

类型LogitsProcessor 的实例抽象方法
抛出:

  • Error 如果子类未实现 `_call`,则抛出错误。
参数量类型描述
input_ids数组

输入 ID。

logits张量

要处理的 logits。


utils/generation.ForceTokensLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制解码器生成特定 token 的 logits 处理器。

类型utils/generation 的静态类
继承LogitsProcessor


new ForceTokensLogitsProcessor(forced_decoder_ids)

构造 ForceTokensLogitsProcessor 的新实例。

参数量类型描述
forced_decoder_ids数组

要强制生成的 token ID。


forceTokensLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

将处理器应用于输入 logits。

类型ForceTokensLogitsProcessor 的实例方法
返回Tensor - 处理后的 logits。

参数量类型描述
input_ids数组

输入 ID。

logits张量

要处理的 logits。


utils/generation.ForcedBOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 LogitsProcessor,它强制在生成的序列开头添加 BOS token。

类型utils/generation 的静态类
继承LogitsProcessor


new ForcedBOSTokenLogitsProcessor(bos_token_id)

创建一个 ForcedBOSTokenLogitsProcessor。

参数量类型描述
bos_token_id数字

要强制使用的序列开始 token 的 ID。


forcedBOSTokenLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将 BOS token 强制应用于 logits。

类型ForcedBOSTokenLogitsProcessor 的实例方法
返回Object - 强制使用 BOS token 的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.ForcedEOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 logits 处理器,它将序列结束 token 的概率强制设置为 1。

类型utils/generation 的静态类
继承LogitsProcessor


new ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)

创建一个 ForcedEOSTokenLogitsProcessor。

参数量类型描述
max_length数字

序列的最大长度。

forced_eos_token_idnumber | Array<number>

要强制使用的序列结束 token 的 ID。


forcedEOSTokenLogitsProcessor._call(input_ids, logits)

将处理器应用于 input_ids 和 logits。

类型ForcedEOSTokenLogitsProcessor 的实例方法

参数量类型描述
input_idsArray.<number>

输入 ID。

logits张量

logits 张量。


utils/generation.SuppressTokensAtBeginLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 LogitsProcessor,它在 generate 函数开始使用 begin_index 个 token 生成时立即抑制 token 列表。这应确保在生成开始时不会采样由 begin_suppress_tokens 定义的 token。

类型utils/generation 的静态类
继承LogitsProcessor


new SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)

创建一个 SuppressTokensAtBeginLogitsProcessor。

参数量类型描述
begin_suppress_tokensArray.<number>

要抑制的 token ID。

begin_index数字

在抑制 token 之前要生成的 token 数量。


suppressTokensAtBeginLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将 BOS token 强制应用于 logits。

类型SuppressTokensAtBeginLogitsProcessor 的实例方法
返回Object - 强制使用 BOS token 的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.WhisperTimeStampLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 LogitsProcessor,用于处理向生成的文本添加时间戳。

类型utils/generation 的静态类
继承LogitsProcessor


new WhisperTimeStampLogitsProcessor(generate_config)

构造一个新的 WhisperTimeStampLogitsProcessor。

参数量类型描述
generate_configObject

传递给 transformer 模型 generate() 方法的配置对象。

generate_config.eos_token_id数字

序列结束 token 的 ID。

generate_config.no_timestamps_token_id数字

用于指示 token 不应包含时间戳的 token ID。

[generate_config.forced_decoder_ids]Array.<Array<number>>

一个由两个元素组成的数组,表示强制出现在输出中的解码器 ID。每个数组的第二个元素指示该 token 是否为时间戳。

[generate_config.max_initial_timestamp_index]数字

初始时间戳可以出现的最大索引。


whisperTimeStampLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

修改 logits 以处理时间戳 token。

类型WhisperTimeStampLogitsProcessor 的实例方法
返回Tensor - 修改后的 logits。

参数量类型描述
input_ids数组

输入 token 序列。

logits张量

模型输出的 logits。


utils/generation.NoRepeatNGramLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个 logits 处理器,它不允许重复一定大小的 n-gram。

类型utils/generation 的静态类
继承LogitsProcessor


new NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)

创建一个 NoRepeatNGramLogitsProcessor。

参数量类型描述
no_repeat_ngram_size数字

不重复 n-gram 的大小。所有此大小的 n-gram 只能出现一次。


noRepeatNGramLogitsProcessor.getNgrams(prevInputIds) ⇒ <code> Map. < string, Array < number > > </code>

从 token ID 序列生成 n-gram。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回Map.<string, Array<number>> - 生成的 n-gram 的映射

参数量类型描述
prevInputIdsArray.<number>

上一个输入 ID 列表


noRepeatNGramLogitsProcessor.getGeneratedNgrams(bannedNgrams, prevInputIds) ⇒ <code> Array. < number > </code>

从 token ID 序列生成 n-gram。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回Array.<number> - 生成的 n-gram 的映射

参数量类型描述
bannedNgramsMap.<string, Array<number>>

禁用 n-gram 的映射

prevInputIdsArray.<number>

上一个输入 ID 列表


noRepeatNGramLogitsProcessor.calcBannedNgramTokens(prevInputIds) ⇒ <code> Array. < number > </code>

计算禁止的 n-gram token

类型NoRepeatNGramLogitsProcessor 的实例方法
返回Array.<number> - 生成的 n-gram 的映射

参数量类型描述
prevInputIdsArray.<number>

上一个输入 ID 列表


noRepeatNGramLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将不重复 n-gram 处理器应用于 logits。

类型NoRepeatNGramLogitsProcessor 的实例方法
返回Object - 经过不重复 n-gram 处理的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.RepetitionPenaltyLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个对重复输出 token 进行惩罚的 logits 处理器。

类型utils/generation 的静态类
继承LogitsProcessor


new RepetitionPenaltyLogitsProcessor(penalty)

创建一个 RepetitionPenaltyLogitsProcessor。

参数量类型描述
penalty数字

对重复 token 应用的惩罚。


repetitionPenaltyLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

将重复惩罚应用于 logits。

类型RepetitionPenaltyLogitsProcessor 的实例方法
返回Object - 经过重复惩罚处理的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.MinLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制最小 token 数量的 logits 处理器。

类型utils/generation 的静态类
继承LogitsProcessor


new MinLengthLogitsProcessor(min_length, eos_token_id)

创建一个 MinLengthLogitsProcessor。

参数量类型描述
min_length数字

当长度低于此值时,eos_token_id 的分数将设置为负无穷大。

eos_token_idnumber | Array<number>

序列结束 token 的 ID。


minLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logits 处理器。

类型MinLengthLogitsProcessor 的实例方法
返回Object - 处理后的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.MinNewTokensLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制最小新 token 数量的 logits 处理器。

类型utils/generation 的静态类
继承LogitsProcessor


new MinNewTokensLengthLogitsProcessor(prompt_length_to_skip, min_new_tokens, eos_token_id)

创建一个 MinNewTokensLengthLogitsProcessor。

参数量类型描述
prompt_length_to_skip数字

输入 token 长度。

min_new_tokens数字

当新 token 长度低于此值时,eos_token_id 的分数将设置为负无穷大。

eos_token_idnumber | Array<number>

序列结束 token 的 ID。


minNewTokensLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logits 处理器。

类型MinNewTokensLengthLogitsProcessor 的实例方法
返回Object - 处理后的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.NoBadWordsLogitsProcessor

类型utils/generation 的静态类


new NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)

创建一个 NoBadWordsLogitsProcessor

参数量类型描述
bad_words_idsArray.<Array<number>>

不允许生成的 token ID 列表的列表。

eos_token_idnumber | Array<number>

“序列结束”标记的 ID。可选地,使用列表来设置多个“序列结束”标记。


noBadWordsLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用 logits 处理器。

类型NoBadWordsLogitsProcessor 的实例方法
返回Object - 处理后的 logits。

参数量类型描述
input_ids数组

输入 ID。

logitsObject

logits。


utils/generation.Sampler

Sampler 是所有用于文本生成的采样方法的基类。

类型utils/generation 的静态类


new Sampler(generation_config)

创建具有指定生成配置的新 Sampler 对象。

参数量类型描述
generation_configGenerationConfigType

生成配置。


sampler._call(logits, index) ⇒ <code> void </code>

执行采样器,使用指定的 logits。

类型Sampler 的实例方法

参数量类型
logits张量
索引数字

sampler.sample(logits, index)

用于采样 logits 的抽象方法。

类型Sampler 的实例方法
抛出:

  • 错误
参数量类型
logits张量
索引数字

sampler.getLogits(logits, index) ⇒ <code> Float32Array </code>

将指定的 logits 作为数组返回,并应用了温度。

类型Sampler 的实例方法

参数量类型
logits张量
索引数字

sampler.randomSelect(probabilities) ⇒ <code> number </code>

根据指定的概率随机选择一个项目。

类型Sampler 的实例方法
返回number - 所选项目的索引。

参数量类型描述
probabilities数组

用于选择的概率数组。


Sampler.getSampler(generation_config) ⇒ <code> Sampler </code>

根据指定选项返回一个 Sampler 对象。

类型Sampler 的静态方法Sampler
返回Sampler - 一个 Sampler 对象。

参数量类型描述
generation_configGenerationConfigType

包含采样器选项的对象。


utils/generation.GenerationConfig : <code> * </code>

保存生成任务配置的类。

类型utils/generation 的静态常量utils/generation


utils/generation~GenerationConfig

类型utils/generation 的内部类utils/generation


new GenerationConfig(kwargs)

创建一个新的 GenerationConfig 对象。

参数量类型
kwargsGenerationConfigType

utils/generation~GreedySampler ⇐ <code> Sampler </code>

表示贪婪采样器的类。

类型utils/generation 的内部类utils/generation
继承Sampler


greedySampler.sample(logits, [index]) ⇒ <code> Array </code>

对给定 logits 张量的最大概率进行采样。

类型GreedySampler 的实例方法GreedySampler
返回Array - 包含一个元组的数组,该元组包含最大值的索引和无意义的分数(因为这是贪婪搜索)。

参数量类型默认
logits张量
[index]数字-1

utils/generation~MultinomialSampler ⇐ <code> Sampler </code>

表示多项式采样器的类。

类型utils/generation 的内部类utils/generation
继承Sampler


multinomialSampler.sample(logits, index) ⇒ <code> Array </code>

从 logits 中采样。

类型MultinomialSampler 的实例方法MultinomialSampler

参数量类型
logits张量
索引数字

utils/generation~BeamSearchSampler ⇐ <code> Sampler </code>

表示 BeamSearchSampler 的类。

类型utils/generation 的内部类utils/generation
继承Sampler


beamSearchSampler.sample(logits, index) ⇒ <code> Array </code>

从 logits 中采样。

类型BeamSearchSampler 的实例方法BeamSearchSampler

参数量类型
logits张量
索引数字

utils/generation~GenerationConfigType : <code> Object </code>

默认配置参数。

类型utils/generation 的内部类型定义utils/generation
属性

名称类型默认描述
[max_length]数字20

生成的令牌可以具有的最大长度。对应于输入提示的长度 + max_new_tokens。如果也设置了 max_new_tokens,其效果将被覆盖。

[max_new_tokens]数字

要生成的最大令牌数,忽略提示中的令牌数。

[min_length]数字0

要生成的序列的最小长度。对应于输入提示的长度 + min_new_tokens。如果也设置了 min_new_tokens,其效果将被覆盖。

[min_new_tokens]数字

要生成的最小令牌数,忽略提示中的令牌数。

[early_stopping]boolean | "never"false

控制基于束的方法(如束搜索)的停止条件。它接受以下值

  • true,表示一旦有 num_beams 个完整候选,生成即停止;
  • false,表示应用启发式方法,当不太可能找到更好的候选时,生成停止;
  • "never",表示束搜索过程只有在无法找到更好候选时才停止(经典束搜索算法)。
[max_time]数字

允许计算运行的最大时间(秒)。生成将在分配时间过后完成当前轮次。

[do_sample]booleanfalse

是否使用采样;否则使用贪婪解码。

[num_beams]数字1

束搜索的束数。1 表示没有束搜索。

[num_beam_groups]数字1

num_beams 分成组的数量,以确保不同束组之间的多样性。有关更多详细信息,请参阅本文

[penalty_alpha]数字

这些值平衡了模型置信度和对比搜索解码中的退化惩罚。

[use_cache]booleantrue

模型是否应使用过去的最后键/值注意力(如果适用于模型)以加快解码速度。

[temperature]数字1.0

用于调节下一个 token 概率的值。

[top_k]数字50

保留用于 top-k 过滤的最高概率词汇 token 数量。

[top_p]数字1.0

如果设置为小于 1 的浮点数,则仅保留概率总和达到 top_p 或更高的最小概率令牌集进行生成。

[typical_p]数字1.0

局部典型性衡量预测下一个目标令牌的条件概率与预测下一个随机令牌的预期条件概率的相似程度,给定已生成的文本片段。如果设置为小于 1 的浮点数,则保留局部典型性最高且概率总和达到 typical_p 或更高的最小令牌集进行生成。有关更多详细信息,请参阅本文

[epsilon_cutoff]数字0.0

如果设置为严格介于 0 和 1 之间的浮点数,则只采样条件概率大于 epsilon_cutoff 的令牌。在论文中,建议的值范围从 3e-4 到 9e-4,具体取决于模型的规模。有关更多详细信息,请参阅《截断采样作为语言模型去平滑》

[eta_cutoff]数字0.0

Eta 采样是局部典型性采样和 epsilon 采样的混合。如果设置为严格介于 0 和 1 之间的浮点数,只有当令牌大于 eta_cutoffsqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))) 时才考虑该令牌。后者直观地是预期的下一个令牌概率,由 sqrt(eta_cutoff) 进行缩放。在论文中,建议的值范围从 3e-4 到 2e-3,具体取决于模型的规模。有关更多详细信息,请参阅《截断采样作为语言模型去平滑》

[diversity_penalty]数字0.0

如果一个束在某个特定时间生成的令牌与来自其他组的任何束相同,则从该束的分数中减去此值。请注意,diversity_penalty 仅在启用 group beam search 时才有效。

[repetition_penalty]数字1.0

重复惩罚的参数。1.0 表示没有惩罚。有关更多详细信息,请参阅本文

[encoder_repetition_penalty]数字1.0

encoder_repetition_penalty 的参数。对原始输入中不存在的序列施加指数惩罚。1.0 表示没有惩罚。

[length_penalty]数字1.0

对基于束的生成中使用的长度施加指数惩罚。它作为序列长度的指数应用,然后用于除以序列的分数。由于分数是序列的对数似然(即负数),length_penalty > 0.0 促进较长的序列,而 length_penalty < 0.0 鼓励较短的序列。

[no_repeat_ngram_size]数字0

如果设置为大于 0 的整数,则该大小的所有 ngrams 只能出现一次。

[bad_words_ids]Array.<Array<number>>

不允许生成的令牌 ID 列表。要获取不应出现在生成文本中的单词的令牌 ID,请使用 (await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids

[force_words_ids]Array<Array<number>> | Array<Array<Array<number>>>

必须生成的令牌 ID 列表。如果给定 number[][],则将其视为必须包含的简单单词列表,与 bad_words_ids 相反。如果给定 number[][][],则触发不相容约束,允许每种单词的不同形式。

[renormalize_logits]booleanfalse

在应用所有 logits 处理器或扭曲器(包括自定义的)之后是否重新归一化 logits。强烈建议将此标志设置为 true,因为搜索算法假定分数 logits 已归一化,但某些 logit 处理器或扭曲器会破坏归一化。

[constraints]Array.<Object>

可以添加到生成中的自定义约束,以确保输出将以最合理的方式包含由 Constraint 对象定义的某些令牌的使用。

[forced_bos_token_id]数字

decoder_start_token_id 之后强制作为第一个生成的令牌的令牌 ID。对于像 mBART 这样的多语言模型很有用,其中第一个生成的令牌需要是目标语言令牌。

[forced_eos_token_id]number | Array<number>

当达到 max_length 时强制作为最后一个生成的令牌的令牌 ID。可选地,使用列表设置多个*序列结束*令牌。

[remove_invalid_values]booleanfalse

是否删除模型可能产生的*NaN*和*inf*输出,以防止生成方法崩溃。请注意,使用 remove_invalid_values 可能会减慢生成速度。

[exponential_decay_length_penalty]Array.<number>

此元组在生成一定数量的令牌后添加一个指数增长的长度惩罚。该元组应包含:(start_index, decay_factor),其中 start_index 表示惩罚开始的位置,decay_factor 表示指数衰减的因子。

[suppress_tokens]Array.<number>

在生成时将被抑制的令牌列表。SupressTokens logit 处理器将它们的对数概率设置为 -inf,以便它们不被采样。

[begin_suppress_tokens]Array.<number>

在生成开始时将被抑制的令牌列表。SupressBeginTokens logit 处理器将它们的对数概率设置为 -inf,以便它们不被采样。

[forced_decoder_ids]Array.<Array<number>>

整数对的列表,表示在采样之前将被强制的生成索引到令牌索引的映射。例如,[[1, 123]] 表示第二个生成的令牌将始终是索引为 123 的令牌。

[num_return_sequences]数字1

批处理中每个元素独立计算的返回序列数。

[output_attentions]booleanfalse

是否返回所有注意力层的注意力张量。有关更多详细信息,请参阅返回张量下的 attentions

[output_hidden_states]booleanfalse

是否返回所有层的隐藏状态。有关更多详细信息,请参阅返回张量下的 hidden_states

[output_scores]booleanfalse

是否返回预测分数。有关更多详细信息,请参阅返回张量下的 scores

[return_dict_in_generate]booleanfalse

是否返回 ModelOutput 而不是普通元组。

[pad_token_id]数字

*填充*令牌的 ID。

[bos_token_id]数字

*序列开始*令牌的 ID。

[eos_token_id]number | Array<number>

“序列结束”标记的 ID。可选地,使用列表来设置多个“序列结束”标记。

[encoder_no_repeat_ngram_size]数字0

如果设置为大于 0 的整数,则 encoder_input_ids 中出现的所有相同大小的 ngrams 不能在 decoder_input_ids 中出现。

[decoder_start_token_id]数字

如果编码器-解码器模型使用与*bos*不同的令牌开始解码,则该令牌的 ID。

[generation_kwargs]Object{}

额外的生成 kwargs 将转发到模型的 generate 函数。不在 generate 签名中的 kwargs 将在模型的前向传播中使用。


< > 在 GitHub 上更新