通过掩码存储简化加速LLM代码生成

**强制上下文无关文法(CFG)约束的结构化文本生成技术**(Willard & Louf, 2023;Gerganov & et. al., 2024;Lundberg & Ribeiro, 2023;Geng et al., 2024;Beurer-Kellner et al., 2024;Ugare et al., 2024;Dong et al., 2024)在基于LLM的编码工具中特别有用,可以生成语法正确的计算机代码。这些技术保证了完全的符合性,但在推理时引入了计算开销。
**最小化推理开销**对于高效的开发人员体验至关重要,尤其是在实时建议生成的代码时。这具有挑战性,因为编程语言的CFG很复杂,而且CFG约束比正则表达式约束更难施加。
在这篇博客文章中,我提议**加速**_一些_**CFG约束的解码技术**。具体来说,我将:
- 简要描述这些技术的原理;
- 提供输入CFG中可用于减少推理开销的模式示例;
- 描述自动检测这些模式的算法;
- 呈现初步实验结果。
重现实验结果的代码和算法的证明在Python notebook和技术附录中提供。
CFG约束的LLM解码
本篇博客文章重点关注**CFG约束的解码技术**(Beurer-Kellner et al., 2024;Willard & Louf, 2023),这些技术利用:
- **基于自动机的词法分析器**,以确保生成的字符串可以转换为终端序列;
- **增量解析器**,以保证生成的终端序列符合语法。
在本节中,我将以Python语法为例,说明如何使用这两个组件。我只提供一个高层次的概述,不涉及次要的技术细节(忽略的终端、终端优先级、缩进)。我建议您阅读上面引用的论文,以及相关的Python notebook,以获得更全面的介绍。
Python语法在`lark`包中的定义包括:
- 100个**终端**(例如`DEF`、`NAME`、`LPAR`、`RPAR`),每个都由正则表达式描述(例如`DEF`为`def`或`NAME`为`[^\W\d]\w*]`);
- 176个**非终端**(例如`funcdef`、`parameters`、`test`);
- 536条**规则**,例如`funcdef: "def" name "(" parameters ")" "->" test ":" suite`或`name: NAME`。
生成可转换为终端序列的字符串
在使用增量解析器时,确保语法正确的Python代码生成的第一步是构建一个**非确定性有限自动机**(NFA),该自动机**识别可转换为终端序列的字符串**。为此,我们只需将每个终端的正则表达式对应的确定性有限自动机(DFA)与转换和两个额外的节点连接起来,如图1所示的几个终端。
由于LLM操作的是词元而非字符,我们需要将这个基于字符的NFA转换为等效的基于词元的NFA。这通过简单地将每个词元视为一个字符序列并应用相应的转换来完成(Beurer-Kellner et al., 2024;Willard & Louf, 2023)。例如,如果`()`是一个词元,则基于词元的NFA将包含从和到的`()`转换。
在遵循字符转换的同时,我们跟踪遍历的DFA,以用相应的新终端序列注释基于词元的NFA的转换。例如,`()`转换将被注释为`[LPAR, RPAR]`。换句话说,我们创建`nfa_transition`,一个基于词元的NFA的注释转换函数,从到,其中是NFA的状态集,是词元集,是终端集。
...
nfa_transition[q4][1270] = [(q4, ("DEF",)), (q6, ("NAME",))] # 1270 is the token id for `def`
...
nfa_transition[q4][470] = [(q10, ("LPAR", "RPAR"))] # 470 is the token id for `()`
...
# For the Python grammar, `nfa_transitions` is defined with 2,837,801 rows.
检查生成的终端序列是否符合语法
在解码过程的每一步,`nfa_transitions`都会显示哪些词元最终可以导致终端序列。当然,并非所有终端序列都语法有效。因此,我们需要使用一个**增量解析器**,只保留与CFG兼容的终端序列。在本篇博客文章的上下文中,我们可以将这样一个增量解析器视为一个具有内部状态`incremental_parser.state`和两个方法的对象`incremental_parser`:
- `incremental_parser.accepts(new_terminals) -> bool`:该方法接受一个终端序列作为输入,并根据到目前为止消耗的终端和`new_terminals`是否构成有效终端序列的前缀来返回`True`或`False`;
- `incremental_parser.consumes(new_terminals) -> None`:该方法使用`new_terminals`更新内部状态。
增量解析器用于检查添加某个新令牌后将添加的终端是否可接受。我们可以将导致相同附加终端的令牌组合在一起,只检查一次这个终端序列,如果结果为正,则将所有这些令牌标记为可接受,而不是逐个评估这种兼容性。根据(Ugare et al., 2024)和(Beurer-Kellner et al., 2024),我们可以构建一个`mask_store`函数,将映射到。
...
mask_store[q4][("NAME",)] = [0, ... 1, ..., 1, ..., 0]
# ↑ ↑ ↑ ↑
# 0 1,270 18,641 |V|-1
...
mask_store[q4][("LPAR", "RPAR")] = [0, ... 1, ..., 0]
# ↑ ↑ ↑
# 0 470 |V|-1
...
# For the Python grammar, `mask_store` is defined with 112,283 rows.
因此,整体的受约束解码算法可以用以下伪Python代码表示。
def generate(llm, prompt, grammar):
token_ids = tokenize(prompt)
new_token_id = None
# Initialize the states of the lexer and the incremental parser
lexer_states = [(0, [])] # list of (nfa_state, terminals) tuples
incremental_parser = IncrementalParser(grammar)
# Generate one token for each iteration of the while loop
while new_token_id != eos_token_id:
# Initialize an empty mask
mask = [0]*len(tokenizer_vocabulary)
# Filter the potential next tokens with the incremental parser
for nfa_state, terminals in lexer_states:
for new_terminals in mask_store[nfa_state]:
if incremental_parser.accepts(terminals + new_terminals):
mask = element_wise_binary_or(mask, mask[nfa_state][new_terminals])
# Sample a new token with the LLM and the mask
# Cf. algorithm 2 of arxiv.org/abs/2307.09702
new_token_id = llm.sample(token_ids, mask)
token_ids.append(new_token_id)
# Update the states of the lexer and the incremental parser
new_states = [
(new_nfa_state, terminals + new_terminals)
for new_nfa_state, new_terminals in nfa_transition[nfa_state][token_id]
]
length_common_prefix = get_length_common_prefix(
[terminals for new_nfa_state, terminals in new_states[new_token_id]]
)
incremental_parser.consumes(length_common_prefix)
lexer_states = [
(new_nfa_state, terminals[len(length_common_prefix):])
for new_nfa_state, terminals in new_states[new_token_id]
]
return token_ids
用于加速解码的感兴趣模式
约束解码比正则表达式约束慢得多的原因是,在每个解码步骤中,**增量解析器可能为当前NFA状态的每个掩码调用一次**,如上面的双`for`循环所示。下面,我们将看到如何显著减少对增量解析器的调用次数。
考虑图1中的状态。该状态的掩码存储包含1,036个终端序列,对应于`DEF`终端。然而,`DEF`在Python语法的EBNF形式化中只出现一次——`funcdef: "def" name "(" [parameters] ")" ["->" test] ":" suite`——而`name`被定义为`name: NAME | "match" | "case"`。因此,只有`NAME`(或`"match"`或`"case"`)终端可以跟在`DEF`之后。结果,`DEF`后面跟着`NAME`(或`"match"`或`"case"`)在`DEF`被接受时总是被接受,而`DEF`后面跟着`ASYNC`、`CLASS`或任何其他终端则永远不会被接受,无论当前的解析器状态如何。这意味着我们可以简化掩码存储,以避免对增量解析器进行不必要的调用。
...
mask_store[q4][()] = m1
mask_store[q4][("NAME",)] = m2
mask_store[q4][("MATCH",)] = m3
mask_store[q4][("CASE",)] = m4
mask_store[q4][("ASYNC",)] = m5
mask_store[q4][("CLASS",)] = m6
...
... 应该简化为
...
mask_store[q4][()] = element_wise_binary_or(m1, m2, m3, m4)
...
现在我们来看Python语法的另一个有趣的模式。`MINUS`和`PLUS`终端只出现在两条规则中:
!_unary_op: "+"|"-"|"~"
!_add_op: "+"|"-"
鉴于`MINUS`和`PLUS`在这些规则中可以互换,用一个替换另一个不会影响字符串的语法正确性。这是简化掩码存储的另一个机会,通过合并`PLUS`和`MINUS`的条目。例如:
...
mask_store[q6][("MINUS",)] = m1
mask_store[q6][("MINUS", "MINUS")] = m2
mask_store[q6][("PLUS",)] = m3
mask_store[q6][("PLUS", "PLUS")] = m4
...
... 应该简化为
...
mask_store[q6][("MINUS",)] = element_wise_binary_or(m1, m3)
mask_store[q6][("MINUS", "MINUS")] = element_wise_binary_or(m2, m4)
...
如果我们能访问以下函数,我们就可以系统地识别相似的模式:
...其中、和分别表示与语法对应的上下文无关语言、的前缀集和的终端集。
掩码存储可以根据这三个函数进行简化:
- 如果,我们可以将`mask_store(q, ())`替换为`termwise_binary_or(mask_store(q, ()), mask_store(q, S))`,并为对应于终端的所有NFA状态`q`删除`mask_store(q, S)`;
- 如果,我们可以为对应于终端的所有NFA状态`q`删除`mask_store(q, S)`;
- 如果,我们可以将`mask_store(q, S1)`替换为`termwise_binary_or(mask_store(q, S1), mask_store(q, S2))`,并为对应于终端的所有NFA状态`q`删除`mask_store(q, S2)`。
现在我们来研究是否真的有可能计算、和。
始终非法的后续
的定义意味着:
鉴于是上的正则语言,计算等价于确定上下文无关语言和正则语言的交集,并测试这个新的上下文无关语言是否为空。由于这些操作都有标准算法,这提供了一种直接获取的方法。
def is_never_legal(current_terminal, new_terminals, grammar):
"""
Return True if `new_terminals` can never follow `current_terminal` given the grammar
"""
# We create a regular expression defined over the set of terminals, that recognizes
# any sequence of terminals including `current_terminal` immediately followed by
# `new_terminals`.
regex = f".*{current_terminal}{''.join(new_terminals)}.*"
# The intersection of a context-free grammar and a regular language is a
# context-free grammar which can be efficiently computed with the Bar-Hillel
# construction. We can then test its emptiness with a standard CFG algorithm.
return is_empty(intersection(grammar, regex))
始终合法的后续
相反,是不可计算的,这已在本博客文章的技术附录中得到证明。因此,一个通用的算法是无法实现的,但我将在下面概述一种方法,以获取该函数的一些值。为简洁起见,我仅限于一个非常简单的示例,并着重传达主要直觉,而不对该方法进行形式化或证明其正确性。如果您对这些细节感兴趣,我邀请您阅读技术附录。
我们以对应于规则的上下文无关文法为例。相关的上下文无关语言显然是。如果我们想确定的一些值,了解哪些符号可以跟在另一个符号后面是很有用的。我们可以用有向图来描述这些关系,如图2所示。
然而,这样的有向图通常不足以得出关于值的结论。为了进一步深入,我们将这种表示扩展为不仅跟踪符号之间的后续关系,还跟踪这些关系产生的规则。更准确地说,我们使用图3所示的下推自动机,其中和作为初始状态和堆栈符号,作为唯一的接受状态。在这个下推自动机中,状态表示正在生成的终端或非终端符号,而堆栈符号表示G规则中的特定位置,并充当生成符号后的“返回地址”。
这个下推自动机对于我们的目的具有有趣的特性:
- 下推自动机接受的单词正是语法生成的单词;
- 一个被接受单词的有效路径对应于该单词语法树的深度优先遍历(参见图4);
- 一系列有效步骤总是可以完成,使得生成的路径对应于一个被接受的单词。
如果我们想确定的值,其中,**第一步**是从状态开始,堆栈为空,并列出通过下推自动机读取的所有路径。为此,我们遵循下推自动机的步进关系,但如果堆栈为空且需要弹出堆栈符号,我们仍然继续转换并记录缺失的堆栈符号。这就像我们在堆栈为空时拥有一条**无限信用额度来借用堆栈符号**。我们还在搜索期间定义了最大堆栈大小,以保证终止。
**第二步**是将所有这些路径表示为一个非确定性有限自动机(NFA),如图5所示:节点对应于下推自动机状态和相关堆栈的组合,起始节点是,接受节点是带任何堆栈的,转换是或沿路径识别出的缺失堆栈符号。
**第三步**,也是最后一步,是识别NFA的可达节点,即满足以下任一属性的节点:
- 该节点是一个最终节点,但不是初始节点;
- 从该节点到可达节点存在转换;
- 对于从该节点下推自动机状态可弹出的每个堆栈符号,从该节点到可达节点存在一个带有该堆栈符号的转换。
对于 的一个充分条件是 是 -coaccessible。图 5 显示了 、、 和 获得的 NFA。由于 是唯一可以从下推自动机中的 弹出的堆栈符号,我们可以得出结论:。对于 和 则不是这种情况,而且很容易证明 。
这种方法的直觉是,NFA 捕获了从 到 所需的堆栈符号。沿着这样的轨迹,上述定义的条件确保我们总能朝着 更进一步,原因可能是我们可以通过 转换向前移动,或者堆栈顶部的每个可能的堆栈符号都有一个相关的转换。
联合合法延续
与 一样, 也是不可计算的,因为 的值可以直接从 的值中导出。实际上,对于所有 :
但是,有一些简单的方法可以获取 的某些值。首先,我们可以使用 的已知值,例如 ,得出结论:
例如,使用上一节中描述的方法,我们可以证明对于 Python 语法,。因此:
此外,我们可以利用某些终结符在语法规则中可以互换的事实。例如,Python 语法中包含 和 终结符的唯一规则是
!_unary_op: "+"|"-"|"~"
!_add_op: "+"|"-"
这意味着,在不影响终结符序列的语法正确性的情况下,总是可以将一个 终结符替换为 终结符(反之亦然)。如果我们能确定终结符 和 是可互换的,我们可以得出结论:
...其中如果 ,则 且 。
对于 Python 语法,共有七组可互换的终结符,总计包含 35 个终结符。
1: "/=", "%=", "^=", "**=", "*=", ">>=", "|=", "-=", "+=", "//=", "<<=", "@=", "&="
2: FALSE, NONE, TRUE
3: PLUS, MINUS
4: "<<", ">>"
5: "//", PERCENT
6: "!=", LESSTHAN, ">=", "==", "<=", MORETHAN, "<>"
7: HEX_NUMBER, FLOAT_NUMBER, BIN_NUMBER, OCT_NUMBER, DEC_NUMBER, IMAG_NUMBER
实验结果
随附的 Python notebook 评估了掩码存储优化的有效性。由于某些操作需要计算时间,特别是 的计算(因为它需要考虑整个语法,而使用 仅利用下推自动机中的局部关系),我按以下方式进行:
- 步骤 1:识别相互可互换的终端集,以合并这些终端对应的掩码存储条目;
- 步骤 2:使用上面提到的有向图(参见图 2)来识别哪些终端可能跟在另一个终端之后,并删除不符合这些约束的掩码存储条目;
- 步骤 3:识别终端 ,使得 或 ,并在此基础上合并或删除掩码存储条目;
- 步骤 4:计算 (如果可能)和 的值,用于掩码存储的剩余条目,并在此基础上删除和合并这些条目。
图 6 显示,这些步骤结合起来,使掩码存储的大小减少了十倍。
此外,我进行了三项实验,以确认精简掩码存储会产生有效结果
- 实验 1:我对来自四个 GitHub 仓库的系列 Python 文件(总计超过一百万个令牌)进行了分词,并对每个令牌,我使用精简后的掩码存储计算了掩码,并检查了下一个令牌是否确实被掩码允许;
- 实验 2:我检查了使用精简掩码存储生成的 400 个字符串(总计超过一百万个字符)是否都是语法正确的 Python 代码;
- 实验 3:对于从与之前相同的 Python 文件中提取的超过 100,000 个令牌,我分别使用原始掩码存储和精简掩码存储计算了掩码,并检查了所得掩码是否系统地相同。
这三个实验都取得了成功,这表明,正如预期,计算出的掩码是合适的(即,既不过度限制,也不过度宽松),并且精简掩码存储不会改变在解码时计算的掩码。
结论与未来潜在工作
本文介绍了一种新颖的方法,用于简化在利用基于自动机的词法分析器和增量解析器的 CFG 约束解码技术中的掩码存储。该方法显著降低了推理开销,代价是针对给定语法和分词器进行一次适度的额外计算。
该方法的潜在改进方向如下:
- 我们可以修改 、 和 的定义,使其第一个参数是终结符序列而不是单个终结符。这种调整可以增加这些函数返回 的频率,从而为简化掩码存储创造更多机会。这可能还会增加预处理步骤的计算时间,但可以通过限制输入终结符序列的大小来控制。对于 Python 语法,仅将第一个参数从单个终结符更改为一对终结符可能会非常有帮助。例如,
NAME
终结符可以在各种上下文中使用,但如果我们知道它前面是DEF
终结符或IMPORT
终结符,则会大大缩小下一个终结符的选项; - 用于获取 的 方法等同于检查是否有可能在下推自动机中从状态 到状态 的路径,无论初始堆栈是什么。我们可以使用类似的方法来尝试计算 :我们需要确定从 到 或从 到 的路径创建的 NFA 是否等效;
- 我们可以尝试同时强制执行 CFG 约束和*正确的*分词,如上一篇博客文章中针对正则表达式约束所述。这可以带来与正则表达式约束相同的好处:更快的解码和更少的语言模型分布失真风险;
- 在实际应用中,我们可以考虑修改输入语法的定义,以在语法约束之上强制执行语义或风格约束。这既可以加速解码,又可以生成更符合我们需求的代码。例如,对于 Python 语法
- 我们可以将 `NAME` 终端替换为多个终端(`FUNCTION_NAME`、`CLASS_NAME`、`VARIABLE_NAME`...),以便可以应用命名约定(例如,函数名小写,类名 CamelCase...)。此外,更少的歧义将导致更简单的掩码存储;
- 我们可以指定一个可导入模块的允许列表,这对于控制代码的依赖性很有用。这也可以增加只允许一个令牌的可能性,在这种情况下无需调用 LLM;
- 我们可以强制为每个函数(或类或方法)添加文档字符串。这有助于开发人员更好地理解生成的代码,但也可以加速解码:如果我们在函数的第一行末尾排除注释,我们知道右括号后应该跟一个冒号、一个换行符、一个缩进和一个长字符串。然后,我们可以添加相应的令牌而无需调用 LLM;
- 我们可以尝试根据已生成的代码动态调整约束。例如,在 `from collections import` 之后,我们可能只允许 `collections` 模块中包含的类,或者在 `object.` 之后,我们可能只允许 `object` 的属性或方法的名称。这与本文讨论的情况有显著不同,因为它对于控制预处理时间至关重要。
- 在这篇博客文章中,我计算了掩码存储在精简前后的条目数量,并检查了计算出的掩码是否正确,但我没有测量实际的推理开销。后续工作将严格比较推理开销与现有软件包的开销。
感谢下面提到的论文作者以及本项目中使用的各种软件包(其中包括 、、 和 )的贡献者。我还要特别感谢 Martin Berglund,他得知我的研究后立即提出了潜在的不可判定性问题。
如果您想引用这篇博客文章,欢迎使用以下 BibTeX 条目
@misc{Tran-Thien_2025,
title={Accelerating LLM Code Generation Through Mask Store Streamlining},
url={https://vivien000.github.io/blog/journal/grammar-llm-decoding.html},
journal={Unsupervised Thoughts (blog)},
author={Tran-Thien, Vivien},
year={2025}
}
参考文献
- Willard, B. T., & Louf, R. (2023). 大型语言模型的高效引导生成。
- Gerganov, G., & et. al. (2024). llama.cpp: Meta 的 LLaMA 模型(及其他)的纯 C/C++ 推理。 https://github.com/ggerganov/llama.cpp
- Lundberg, S., & Ribeiro, M. T. C. et al. (2023). Guidanceai/guidance: 用于控制大型语言模型的引导语言。 https://github.com/guidance-ai/guidance
- Geng, S., Josifoski, M., Peyrard, M., & West, R. (2024). 无需微调的结构化 NLP 任务的语法约束解码。 https://arxiv.org/abs/2305.13971
- Beurer-Kellner, L., Fischer, M., & Vechev, M. (2024). 正确引导 LLM:快速、非侵入式约束生成。 https://arxiv.org/abs/2403.06988
- Ugare, S., Suresh, T., Kang, H., Misailovic, S., & Singh, G. (2024). SynCode: 带语法增强的 LLM 生成。 https://arxiv.org/abs/2403.01632
- Dong, Y., Ruan, C. F., Cai, Y., Lai, R., Xu, Z., Zhao, Y., & Chen, T. (2024). XGrammar:大型语言模型的灵活高效结构化生成引擎。https://arxiv.org/abs/2411.15100
这篇博文最初发布在我的个人博客上。