Transformers.js文档

utils/generation

Hugging Face's logo
加入Hugging Face社区

并获得增强文档体验访问权限

开始使用

utils/generation

生成相关的类、函数和实用工具。

待办事项

  • 描述如何创建自定义的 GenerationConfig

utils/generation.LogitsProcessorList ⇐ <code> 可调用 </code>

表示一组 logits 处理器的类。logits 处理器是修改语言模型 logits 输出的函数。此类提供了添加新处理器以及将所有处理器应用于 logits 批处理的方法。

类型:utils/generation 模块的静态类
扩展Callable


new LogitsProcessorList()

构建一个新的 LogitsProcessorList 实例。


logitsProcessorList.push(item)

向列表添加新的 logits 处理器。

类型LogitsProcessorList 实例方法

参数类型描述
itemLogitsProcessor

要添加的 logits 处理器函数


logitsProcessorList.extend(items)

向列表添加多个 logits 处理器。

类型LogitsProcessorList 实例方法

参数类型描述
itemsArray.

Logits处理器用于添加功能。


logitsProcessorList._call(input_ids, batchedLogits)

将列表中的所有logits处理器应用到一批logits上,就地修改。

类型LogitsProcessorList 实例方法

参数类型描述
input_idsArray.

语言模型的输入ID。

batchedLogitsArray.>

一个2维的logits数组,其中每一行对应于批次中单一输入序列。


utils/generation.LogitsProcessor ⇐ Callable

处理logits的基类。

类型:utils/generation 模块的静态类
扩展Callable


logitsProcessor._call(input_ids, logits)

将处理器应用到输入的logits上。

类型LogitsProcessor 的实例抽象方法
抛出:

  • Error 如果子类中没有实现 `_call`,将抛出错误。
参数类型描述
input_idsArray

输入ID。

logitsTensor

要处理的logits。


utils/generation.ForceTokensLogitsProcessor ⇐ <code> LogitsProcessor </code>

强制解码器生成特定令牌的令牌处理器。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new ForceTokensLogitsProcessor(forced_decoder_ids)

构建一个新的 ForceTokensLogitsProcessor 实例。

参数类型描述
forced_decoder_idsArray

应强制生成的令牌的 ID。


forceTokensLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

将处理器应用到输入的logits上。

类型: ForceTokensLogitsProcessor 的实例方法
返回值: Tensor - 处理后的 logits。

参数类型描述
input_idsArray

输入ID。

logitsTensor

要处理的logits。


utils/generation.ForcedBOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一种LogitsProcessor,它强制在生成的序列开头添加一个BOS标记。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new ForcedBOSTokenLogitsProcessor(bos_token_id)

创建一个ForcedBOSTokenLogitsProcessor。

参数类型描述
bos_token_id数字

要强制添加的序列开头标记的ID。


forcedBOSTokenLogitsProcessor._call(input_ids, logits) ⇒ <code> 对象 </code>

将BOS标记强制应用到logits上。

类型: ForcedBOSTokenLogitsProcessor 的实例方法
返回: 对象 - 应用了BOS标记强制的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.ForcedEOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一种logits处理器,将序列结束标记的概率强制设置为1。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)

创建一个ForcedEOSTokenLogitsProcessor。

参数类型描述
max_length数字

序列的最大长度。

forced_eos_token_idnumber | Array<number>

强制结束序列标记的ID。


forcedEOSTokenLogitsProcessor._call(input_ids, logits)

将处理程序应用于input_ids和logits。

类型ForcedEOSTokenLogitsProcessor的实例方法

参数类型描述
input_idsArray.

输入ID。

logitsTensor

logits张量。


utils/generation.SuppressTokensAtBeginLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个LogitsProcessor,在开始生成时使用begin_index标记抑制一系列标记。这应该确保不会在生成的开始采样由begin_suppress_tokens定义的标记。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)

创建一个SuppressTokensAtBeginLogitsProcessor。

参数类型描述
begin_suppress_tokensArray.

要抑制的标记的ID。

begin_index数字

在抑制标记之前生成的标记数量。


suppressTokensAtBeginLogitsProcessor._call(input_ids, logits) ⇒ <code> 对象 </code>

将BOS标记强制应用到logits上。

类型: SuppressTokensAtBeginLogitsProcessor 的实例方法
返回: 对象 - 应用了BOS标记强制的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.WhisperTimeStampLogitsProcessor ⇐ <code> LogitsProcessor </code>

处理向生成的文本添加时间戳的 LogitsProcessor

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new WhisperTimeStampLogitsProcessor(generate_config)

构建一个新的 WhisperTimeStampLogitsProcessor

参数类型描述
generate_config对象

传递给变换器模型 generate() 方法的配置对象

generate_config.eos_token_id数字

序列结束标记的 ID

generate_config.no_timestamps_token_id数字

用于表示标记不应有时间戳的标记 ID

[generate_config.forced_decoder_ids]Array.>

表示必须出现在输出中的解码器 ID 的二维数组。每个数组中的第二个元素指示是否为时间戳。

[generate_config.max_initial_timestamp_index]数字

初始时间戳可以出现的最大索引。


whisperTimestampLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

修改logits以处理时间戳标记。

类型: WhisperTimeStampLogitsProcessor 的实例方法
返回: Tensor - 经过修改的logits。

参数类型描述
input_idsArray

标记输入序列。

logitsTensor

模型的输出logits。


utils/generation/NoRepeatNGramLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个不允许某些规模的n-gram重复的logits处理器。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)

创建一个NoRepeatNGramLogitsProcessor对象。

参数类型描述
no_repeat_ngram_size数字

不重复n-gram的大小。此大小范围内的所有n-gram只能出现一次。


noRepeatNGramLogitsProcessor.getNgrams(prevInputIds) ⇒ <code> Map. < string, Array < number > > </code>

从序列的标记ID生成n-gram。

类型: NoRepeatNGramLogitsProcessor类的实例方法
返回值: Map.<string, Array<number>> - 生成的n-gram的映射

参数类型描述
prevInputIdsArray.

先前输入ID的列表


noRepeatNGramLogitsProcessor.getGeneratedNgrams(bannedNgrams, prevInputIds) ⇒ <code> Array. < number > </code>

从序列的标记ID生成n-gram。

类型: NoRepeatNGramLogitsProcessor类的实例方法
返回值: Array.<number> - 生成n-gram的映射

参数类型描述
bannedNgramsMap.<string, Array<number>>

禁止的n-gram的映射

prevInputIdsArray.

先前输入ID的列表


noRepeatNGramLogitsProcessor.calcBannedNgramTokens(prevInputIds) ⇒ <code> Array. < number > </code>

计算禁止的n-gram标记

类型: NoRepeatNGramLogitsProcessor类的实例方法
返回值: Array.<number> - 生成n-gram的映射

参数类型描述
prevInputIdsArray.

先前输入ID的列表


noRepeatNGramLogitsProcessor._call(input_ids, logits) ⇒ <code> 对象 </code>

将对logits应用无重复ngram处理器。

类型: NoRepeatNGramLogitsProcessor类的实例方法
返回值: 对象 - 经过无重复ngram处理的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.RepetitionPenaltyLogitsProcessor ⇐ <code> LogitsProcessor </code>

一种惩罚重复输出标记的logits处理器。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new RepetitionPenaltyLogitsProcessor(penalty)

创建一个RepetitionPenaltyLogitsProcessor。

参数类型描述
penalty数字

应用于重复标记的惩罚。


repetitionPenaltyLogitsProcessor._call(input_ids, logits) ⇒ <code> 对象 </code>

将对logits应用重复惩罚。

类型: RepetitionPenaltyLogitsProcessor 的实例方法
返回: Object - 经过重复惩罚处理后得到的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.MinLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制执行最小token数量的logits处理器。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new MinLengthLogitsProcessor(min_length, eos_token_id)

创建MinLengthLogitsProcessor实例。

参数类型描述
min_length数字

设置eos_token_id的分数为负无穷大的最小长度。

eos_token_idnumber | Array<number>

结束序列token的ID/ID们。


minLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用logit处理器。

类型: MinLengthLogitsProcessor的实例方法
返回: Object - 处理后的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.MinNewTokensLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一个强制执行最小新标记数的目标 logits 处理器。

类型:utils/generation 模块的静态类
扩展: LogitsProcessor


new MinNewTokensLengthLogitsProcessor(prompt_length_to_skip, min_new_tokens, eos_token_id)

创建 MinNewTokensLengthLogitsProcessor。

参数类型描述
prompt_length_to_skip数字

输入标记长度。

min_new_tokens数字

设置 eos_token_id 的得分为负无穷大的最小新 标记 长度。

eos_token_idnumber | Array<number>

结束序列token的ID/ID们。


minNewTokensLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用logit处理器。

类型: MinNewTokensLengthLogitsProcessor 的实例方法
返回: Object - 处理后的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.NoBadWordsLogitsProcessor

类型:utils/generation 模块的静态类


new NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)

创建 NoBadWordsLogitsProcessor 对象。

参数类型描述
bad_words_idsArray.>

不允许生成的 token id 的列表。

eos_token_idnumber | Array<number>

序列结束标记 token 的 id。可选地,可以使用列表来设置多个序列结束标记 token。


noBadWordsLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

应用logit处理器。

类型: NoBadWordsLogitsProcessor 的实例方法
返回: Object - 处理后的logits。

参数类型描述
input_idsArray

输入ID。

logits对象

logits。


utils/generation.Sampler

Sampler 是所有用于文本生成的采样方法的基类。

类型:utils/generation 模块的静态类


new Sampler(generation_config)

使用指定的生成配置创建一个新的 Sampler 对象。

参数类型描述
generation_config生成配置类型

生成配置。


sampler._call(logits, index) ⇒ <code> void </code>

执行采样器,使用指定的logits。

类型:类 Sampler 的实例方法

参数类型
logitsTensor
index数字

sampler.sample(logits, index)

采样logits的抽象方法。

类型:类 Sampler 的实例方法
抛出:

  • 错误
参数类型
logitsTensor
index数字

sampler.getLogits(logits, index) ⇒ <code> Float32Array </code>

返回应用了温度的指定logits数组。

类型:类 Sampler 的实例方法

参数类型
logitsTensor
index数字

sampler.randomSelect(probabilities) ⇒ <code> number </code>

根据指定的概率随机选择一个项目。

类型:类 Sampler 的实例方法
返回: number - 被选中项目的索引。

参数类型描述
probabilitiesArray

用于选择的概率数组。


Sampler.getSampler(generation_config) ⇒ <code> Sampler </code>

根据指定的选项返回一个Sampler对象。

类型Sampler 的静态方法
返回值Sampler - 一个Sampler对象。

参数类型描述
generation_config生成配置类型

包含采样器选项的对象。


utils/generation.GenerationConfig : <code> * </code>

包含生成任务的配置的类。

类型utils/generation 的静态常量


utils/generation~GenerationConfig

类型utils/generation 的内部类


new GenerationConfig(kwargs)

创建一个新的GenerationConfig对象。

参数类型
kwargs生成配置类型

utils/generation~GreedySampler ⇐ <code> 样本器 </code>

表示贪婪样本器的类。

类型utils/generation 的内部类
扩展: Sample


greedySampler.sample(logits, [index]) ⇒ <code> 数组 </code>

采样给定logits张量的最大概率。

类型: GreedySampler 的实例方法
返回值: 数组 - 包含最大值的索引和无意义的分数(因为这是一个贪婪搜索)的唯一元组所需的数组。

参数类型默认值
logitsTensor
[index]数字-1

utils/generation~MultinomialSampler ⇐ <code> 样本器 </code>

表示多项式样本器的类。

类型utils/generation 的内部类
扩展: Sample


multinomialSampler.sample(logits, index) ⇒ <code> 数组 </code>

从logits中抽取样本。

类型:类 MultinomialSampler 的实例方法

参数类型
logitsTensor
index数字

utils/generation~BeamSearchSampler ⇐ <code> 样本器 </code>

表示BeamSearchSampler的类。

类型utils/generation 的内部类
扩展: Sample


beamSearchSampler.sample(logits, index) ⇒ <code> 数组 </code>

从logits中抽取样本。

类型:类 BeamSearchSampler 的实例方法

参数类型
logitsTensor
index数字

utils/generation~GenerationConfigType : <code> 对象 </code>

默认配置参数。

类型: utils/generation 的内部typedef
属性

名称类型默认值描述
[最大长度]数字20

生成令牌的最大长度。对应于输入提示的长度 + max_new_tokens。如果也设置了,则其效果会被 max_new_tokens 覆盖。

[最大新令牌数]数字

要生成的最大令牌数量,忽略提示中的令牌数量。

[最小长度]数字0

要生成的序列的最小长度。对应于输入提示的长度 + min_new_tokens。如果也设置了,则其效果会被 min_new_tokens 覆盖。

[最小新令牌数]数字

要生成的最小令牌数量,忽略提示中的令牌数量。

[提前停止]布尔值 | "never"false

控制基于束方法(如束搜索)的停止条件。它接受以下值

  • true,表示一旦有 num_beams 个完整的候选者,生成将停止;
  • false,其中应用了一种启发式算法,当很难找到更好的候选者时,生成将停止;
  • "never",其中束搜索过程仅在找不到更好的候选者时停止(规范束搜索算法)。
[最大时间]数字

允许计算运行的最大时间(以秒为单位)。即使在分配的时间已通过后,生成还将完成当前的遍历。

[do_sample]布尔值false

是否使用采样;否则使用贪婪解码。

[束数]数字1

束搜索的束数。1 表示没有束搜索。

[束组数]数字1

num_beams 分成多少组,以确保不同组的束之间的多样性。有关更多详细信息,请参阅 该论文

[惩罚_alpha]数字

在对比搜索解码中,这些值的平衡可以平衡模型信心和退化惩罚。

[使用缓存]布尔值true

模型是否应该使用过去最后的关键/值注意力(如果适用于模型)来加速解码。

[温度]数字1.0

用来模调整个概率的值。

[top_k]数字50

用于 top-k 过滤的最高概率词汇令牌数量。

[top_p]数字1.0

如果设置为小于 1 的浮点数,则只保留生成中最可能的最小集合的令牌,其概率之和为 top_p 或更高。

[典型_p]数字1.0

局部典型性衡量预测下一目标的条件概率与给已生成的部分文本期望的条件概率预测的随机令牌的相似程度。如果设置为小于 1 的浮点数,则保留最小的最典型令牌集,其概率之和为 typical_p 或更高。有关更多详细信息,请参阅 该论文

[epsilon_cutoff]数字0.0

如果严格设置为介于0和1之间的浮点数,则只会采样条件概率大于epsilon_cutoff的标记。在论文中,建议的值从3e-4到9e-4不等,具体取决于模型的大小。更多详情请参阅“截断采样作为语言模型反平滑”

[eta_cutoff]数字0.0

Eta采样是局部典型采样和epsilon采样的混合体。如果严格设置为介于0和1之间的浮点数,则只有当标记大于eta_cutoffsqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))时,才会考虑标记。后一个术语直观上是下一个标记的期望概率,经过sqrt(eta_cutoff)缩放。在论文中,建议的值从3e-4到2e-3不等,具体取决于模型的大小。更多详情请参阅“截断采样作为语言模型反平滑”

[diversity_penalty]数字0.0

如果生成与特定时间点上其他组中任何beam的标记相同的标记,则从beam的得分中减去此值。请注意,只有当启用group beam search时,diversity_penalty才有效。

[repetition_penalty]数字1.0

重复惩罚的参数。1.0表示没有惩罚。更多详情请参阅这篇论文

[encoder_repetition_penalty]数字1.0

encoder_repetition_penalty的参数。对原始输入中不存在的序列的指数惩罚。1.0表示没有惩罚。

[length_penalty]数字1.0

使用基于beam的生成所使用的长度的指数惩罚。它作为序列长度的指数,然后用来除以序列的得分。因为得分是序列的对数似然(即负的),所以当length_penalty>0.0会促进较长序列,而当length_penalty<0.0会鼓励较短的序列。

[no_repeat_ngram_size]数字0

如果设置为大于0的整数值,则该大小的所有ngram只能出现一次。

[bad_words_ids]Array.>

不允许生成的标记的标记id列表。为了获取不应出现在生成的文本中的单词的标记id,请使用(await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids

[force_words_ids]Array<Array<number>> | Array<Array<Array<number>>>

必须生成的标记id的列表。如果提供一个number[][],则它被视为必须包含的简单单词列表,与bad_words_ids相反。如果提供一个number[][][],这会触发一个析取约束,可以允许每个单词的不同形式。

[renormalize_logits]布尔值false

是否在应用所有logits处理器或wrappers(包括自定义的)后重新规范化logits。强烈建议将此标志设置为true,因为搜索算法假设得分logits已经归一化,但某些logits处理器或wrappers可能会破坏归一化。

[constraints]Array.<Object>

可以添加到生成中的自定义约束,以确保输出将以尽可能合理的方式包含特定标记的使用,如由Constraint对象定义所述。

[forced_bos_token_id]数字

强制作为在decoder_start_token_id之后首次生成的标记的标记id。对于mBART等多语言模型非常有用,其中首先生成的标记需要是目标语言标记。

[forced_eos_token_id]number | Array<number>

在达到max_length时强制作为最后一个生成的标记的标记id。可选地,可以使用列表来设置多个end-of-sequence标记。

[remove_invalid_values]布尔值false

是否删除模型可能产生的naninf输出,以防止生成方法崩溃。请注意,使用remove_invalid_values可能会减慢生成速度。

[exponential_decay_length_penalty]Array.

这个元组在生成了一定数量的标记后添加一个指数级增加的长度惩罚。该元组应包含:(start_index, decay_factor),其中start_index表示惩罚开始位置,decay_factor表示指数衰减的因素。

[suppress_tokens]Array.

要在生成过程中抑制的标记列表。SupressTokens对数概率处理器会将它们的对数概率设置为-inf,以便它们不会被采样。

[begin_suppress_tokens]Array.

要在生成开始时抑制的标记列表。SupressBeginTokens对数概率处理器会将它们的对数概率设置为-inf,以便它们不会被采样。

[forced_decoder_ids]Array.>

一个整数对的列表,指示在采样之前将从生成索引到标记索引的映射。例如,[[1, 123]]表示第二个生成的标记总是标记索引为123的标记。

[num_return_sequences]数字1

为批处理中的每个元素返回独立计算出的序列数。

[output_attentions]布尔值false

是否返回所有注意力的张量。有关更多详细信息,请参阅返回的张量中的attentions

[output_hidden_states]布尔值false

是否返回所有层的隐藏状态。有关更多详细信息,请参阅返回的张量中的hidden_states

[output_scores]布尔值false

是否返回预测分数。有关更多详细信息,请参阅返回的张量中的scores

[return_dict_in_generate]布尔值false

是否返回一个ModelOutput而不是一个普通的元组。

[pad_token_id]数字

填充标记的ID。

[bos_token_id]数字

序列开始标记的ID。

[eos_token_id]number | Array<number>

序列结束标记 token 的 id。可选地,可以使用列表来设置多个序列结束标记 token。

[encoder_no_repeat_ngram_size]数字0

如果设置为大于0的整数,则在encoder_input_ids中出现的该大小ngram不能出现在decoder_input_ids中。

[decoder_start_token_id]数字

如果编码器-解码器模型以不同的标记开始解码而不是bos,则该标记的ID。

[generation_kwargs]对象{}

额外的生成kwargs将被转发到模型的generate函数。不在generate签名的kwargs将在模型前向传递时使用。


< > 更新GitHub上的信息