使用 ESM-2 预测突变对蛋白质功能的影响

社区文章 发布于 2023 年 12 月 13 日

在文章语言模型实现对蛋白质功能突变效应的零样本预测中,作者引入了几种评分函数来确定突变对蛋白质序列的影响。在这里,我们将使用 Hugging Face 的 ESM-2 蛋白质语言模型(来自 Transformers 库)重新实现这些评分方法。我们还将讨论如何使用和解释它们。

image/png

引言

蛋白质序列的突变可能相当复杂,它们对蛋白质的影响范围从对功能有害,到中性且无关紧要,再到导致功能改善。研究表明,即使是单点突变或少量突变也可能导致剧烈的构象变化,从而导致“折叠转换”和折叠蛋白质三维结构的变化。判断突变的影响是困难的,但像 ESM-2 系列模型这样的蛋白质语言模型可以提供大量关于突变对蛋白质折叠和功能影响的信息。

特别是,在语言模型实现对蛋白质功能突变效应的零样本预测中,作者引入了几种评分函数,其得分与功能效应高度相关。其中第一个函数是掩码边际评分函数

iMlogp(xi=ximtxM)logp(xi=xiwtxM) \sum_{i \in M} \log p(x_i = x_i^{mt}|x_{-M}) - \log p(x_i = x_i^{wt}|x_{-M})

其中 MM 是发生突变的掩码残基,ximtx_i^{mt} 是位置 ii 处的突变型残基,xiwtx_i^{wt} 是位置 ii 处的野生型残基。该函数被证明表现最佳。

对数似然比和点突变

我们还可以使用每个单点突变的对数似然比(LLR)来理解突变的影响,并将结果以热图形式表示,显示对蛋白质功能有利或有害的突变热点。HuggingFace 空间ESM Variants中对此进行了示例,其中计算了人类蛋白质所有点突变的 LLR。对于一般蛋白质,您可以尝试 HuggingFace 空间Variant Effects LLR

from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display

def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
            
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")
    plt.show()

def interactive_heatmap(protein_sequence):
    # Define interactive widgets
    start_slider = widgets.IntSlider(value=1, min=1, max=len(protein_sequence), step=1, description='Start:')
    end_slider = widgets.IntSlider(value=len(protein_sequence), min=1, max=len(protein_sequence), step=1, description='End:')

    ui = widgets.HBox([start_slider, end_slider])

    def update_heatmap(start, end):
        if start <= end:
            generate_heatmap(protein_sequence, start, end)

    out = widgets.interactive_output(update_heatmap, {'start': start_slider, 'end': end_slider})

    # Display the interactive widgets
    display(ui, out)

下面我们来看看如何使用它

# Example usage:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
interactive_heatmap(protein_sequence)

这应该返回一些可调节左右端点的内容,以防您想放大到蛋白质序列中的特定位置范围。下面,我们看到位置 40-70 的范围

image/png

请注意深蓝色区域,LLR 值在此处为负,表明突变可能对功能有害;以及浅黄色区域,LLR 值为正,表明突变可能对蛋白质功能有益。此外,请注意垂直排列的深色条带,表示可能在进化上保守的区域,以及垂直排列的较亮条带,表示蛋白质中可能比野生型序列更优的区域。另请注意,对于蛋白质的某些区域,存在可能对整个蛋白质区域的功能有害的氨基酸突变,由沿蛋白质大部分区域水平排列的深色条带表示。同样,我们看到水平排列的较亮黄色条带,表示几乎任何突变到该氨基酸的残基都将优于野生型。一旦我们应用了其中一个突变,我们将获得突变蛋白质的不同热图。例如,将残基 57 处的 D 氨基酸突变为 L,会改变热图。现在,可视化残基 40-70,我们看到以下内容

image/png

下面,我们看到论文中的一张图,展示了 LLR 热图如何提示蛋白质功能上有益或有害的突变。这里,红色代表较低的 LLR 值,蓝色代表较高的 LLR 值,因此该图像与上面蓝色代表较高 LLR 值而黄色代表较低 LLR 值的情况有些反转。

image/png

深度突变扫描评分

我们可以修改ESM Github 仓库中(参见 `predict.py` 文件)用于零样本变异效应评分的脚本。在该脚本中,使用了 ESM-1v,但我们将使用较新的 ESM-2 系列模型。下面,我们提供了一个脚本,用于使用三种不同的评分方法对突变效应进行评分:伪困惑度 (PPPL)、野生型边际 (wt-marginal) 和掩码边际。该脚本需要一个 CSV 文件,其中包含指示要应用于野生型蛋白质序列的突变列。一旦用户选择了模型和评分方法,该脚本将创建一个 `output.csv` 文件,其中包含模型预测的评分。

import argparse
import pathlib
import string
import torch
from esm import Alphabet, pretrained, MSATransformer
import pandas as pd
from tqdm import tqdm
from Bio import SeqIO
import itertools

def remove_insertions(sequence: str) -> str:
    deletekeys = dict.fromkeys(string.ascii_lowercase)
    deletekeys["."] = None
    deletekeys["*"] = None
    translation = str.maketrans(deletekeys)
    return sequence.translate(translation)

def create_parser():
    parser = argparse.ArgumentParser(description="Label a deep mutational scan with predictions from an ensemble of ESM-1v models.")
    parser.add_argument("--model-location", type=str, help="PyTorch model file OR name of pretrained model to download", nargs="+")
    parser.add_argument("--sequence", type=str, help="Base sequence to which mutations were applied")
    parser.add_argument("--dms-input", type=pathlib.Path, help="CSV file containing the deep mutational scan")
    parser.add_argument("--mutation-col", type=str, default="mutant", help="column in the deep mutational scan labeling the mutation as 'AiB'")
    parser.add_argument("--dms-output", type=pathlib.Path, help="Output file containing the deep mutational scan along with predictions")
    parser.add_argument("--offset-idx", type=int, default=0, help="Offset of the mutation positions in `--mutation-col`")
    parser.add_argument("--scoring-strategy", type=str, default="wt-marginals", choices=["wt-marginals", "pseudo-ppl", "masked-marginals"], help="")
    parser.add_argument("--msa-path", type=pathlib.Path, help="path to MSA in a3m format (required for MSA Transformer)")
    parser.add_argument("--msa-samples", type=int, default=400, help="number of sequences to select from the start of the MSA")
    parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
    return parser

def label_row(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"
    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"
    sequence = sequence[:idx] + mt + sequence[(idx + 1):]
    data = [("protein1", sequence)]
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())
    return sum(log_probs)

def main(args):
    df = pd.read_csv(args.dms_input)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.nogpu else "cpu")

    for model_location in args.model_location:
        model, alphabet = pretrained.load_model_and_alphabet(model_location)
        model = model.to(device)
        model.eval()
        batch_converter = alphabet.get_batch_converter()

        if isinstance(model, MSATransformer):
            data = [read_msa(args.msa_path, args.msa_samples)]
            assert args.scoring_strategy == "masked-marginals", "MSA Transformer only supports masked marginal strategy"
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)

            all_token_probs = []
            for i in tqdm(range(batch_tokens.size(2))):
                batch_tokens_masked = batch_tokens.clone()
                batch_tokens_masked[0, 0, i] = alphabet.mask_idx
                with torch.no_grad():
                    token_probs = torch.log_softmax(model(batch_tokens_masked)["logits"], dim=-1)
                all_token_probs.append(token_probs[:, 0, i])
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
            df[model_location] = df.apply(
                lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                axis=1,
            )
        else:
            data = [("protein1", args.sequence)]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)

            if args.scoring_strategy == "wt-marginals":
                with torch.no_grad():
                    token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)
                df[model_location] = df.apply(
                    lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                    axis=1,
                )
            elif args.scoring_strategy == "masked-marginals":
                all_token_probs = []
                for i in tqdm(range(batch_tokens.size(1))):
                    batch_tokens_masked = batch_tokens.clone()
                    batch_tokens_masked[0, i] = alphabet.mask_idx
                    with torch.no_grad():
                        token_probs = torch.log_softmax(model(batch_tokens_masked)["logits"], dim=-1)
                    all_token_probs.append(token_probs[:, i])
                token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
                df[model_location] = df.apply(
                    lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                    axis=1,
                )
            elif args.scoring_strategy == "pseudo-ppl":
                tqdm.pandas()
                df[model_location] = df.progress_apply(
                    lambda row: compute_pppl(row[args.mutation_col], args.sequence, model, alphabet, args.offset_idx),
                    axis=1,
                )

    df.to_csv(args.dms_output)

if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    main(args)

示例用法

在您的终端中,您可以按如下方式运行脚本

python scoring_esm2.py \  
    --model-location esm2_t12_35M_UR50D \
    --sequence "MKTIIALSYIFCLVFA" \
    --dms-input "mutations.csv" \
    --mutation-col "mutant" \
    --dms-output "output_2.csv" \
    --offset-idx 0 \
    --scoring-strategy "masked-marginals" \
    --nogpu

调整模型、序列、突变文件和评分策略以满足您的需求,但请记住,掩码边际评分策略被证明表现最佳。这应该会创建一个名为 `output.csv` 的文件,如下所示

,mutant,esm2_t12_35M_UR50D
0,T2B,-10.990091323852539
1,I3A,-0.5448870658874512
2,A5M,-0.8617167472839355

我们可以看到其中一些突变个体上比其他突变更有害,其中第三个残基处的 TBT \to B 突变(或第二个,因为我们这里从 `0` 开始索引)的分数远低于其他两个突变,后者更接近于零,表明它们更中性。使用 LLR 值,我们还可以选择可能提供改进功能的点突变。当使用野生型边际评分策略时,我们得到以下结果

,mutant,esm2_t12_35M_UR50D
0,T2B,-13.739639282226562
1,I3A,-3.976250171661377
2,A5M,-4.413556098937988

正如我们所看到的,wt-marginal 和掩码边际评分策略都表明突变对蛋白质功能有害,但第一个突变比第二个和第三个突变更有害。在上面的蛋白质上运行脚本

MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE

对于突变 `D56L`(对于上一节中关于 LLR 热图提到的蛋白质序列),我们得到以下结果

,mutant,facebook/esm2_t12_35M_UR50D
0,D56L,1.3842720985412598

这证实了 LLR 预测,并表明该突变可能对蛋白质的功能有益,并且突变体可能比野生型更具适应性。这种评分可以用来确定蛋白质进化的方向性,提供一种进化向量场或流描述。例如,Evolocity 就采用了这种方法,其中发现了蛋白质进化的模式。Evolocity 在蛋白质语言模型的进化速度预测多种蛋白质的进化动力学中被引入,并使用了略旧的 ESM-1b 蛋白质语言模型,但这些方法可以适应 ESM-2 或其他蛋白质语言模型。

"关键的概念进步在于,通过学习局部进化的基本规则,我们可以构建一个全局进化“矢量场”,我们展示它可以(1)预测观察到的进化轨迹的根源(或潜在的多个根源),(2)按进化时间对蛋白质序列进行排序,以及(3)识别驱动这些轨迹的突变策略。"

如果使用掩码边际分数而不是伪似然分数会提供显著不同或改进的结果,这将很有趣。我们还应该注意到,脚本中上述 PPPL 的实现与掩码语言模型评分中第 2.3 节建议的定义并不完全匹配。在这里,作者将 PPPL 定义为

PPPL(T)=exp(1NtTPLL(t)) PPPL(T) = \exp\left(- \frac{1}{N} \sum_{t \in T} PLL(t)\right)

其中 PLL(t)PLL(t) 表示令牌 tTt \in T 的伪对数似然。因此,我们通常会看到类似以下的代码作为 PPPL 的实现

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

def calculate_pppl(model, tokenizer, sequence):
    token_ids = tokenizer.encode(sequence, return_tensors='pt')
    input_length = token_ids.size(1)
    log_likelihood = 0.0

    for i in range(input_length):
        # Create a copy of the token IDs
        masked_token_ids = token_ids.clone()
        # Mask a token that we will try to predict back
        masked_token_ids[0, i] = tokenizer.mask_token_id

        with torch.no_grad():
            output = model(masked_token_ids)
            logit_prob = torch.nn.functional.log_softmax(output.logits, dim=-1)
        
        log_likelihood += logit_prob[0, i, token_ids[0, i]]

    # Calculate the average log likelihood per token
    avg_log_likelihood = log_likelihood / input_length

    # Compute and return the pseudo-perplexity
    pppl = torch.exp(-avg_log_likelihood)
    return pppl.item()

# Load the model and tokenizer
model_name = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Protein sequence
protein_sequence = "MKTIIALSYIFCLVFA"

# Calculate PPPL
pppl = calculate_pppl(model, tokenizer, protein_sequence)
print(f"Pseudo-Perplexity of the sequence: {pppl}")

它将返回整个序列的单个值

Pseudo-Perplexity of the sequence: 9.073078155517578

结论

在本文中,我们探讨了 ESM-2 模型的能力,特别是在预测突变对蛋白质功能影响方面的能力。利用生物信息学中的高级语言模型,我们实现了多种评分方法,包括掩码边际评分函数、伪困惑度 (PPPL) 和野生型边际 (wt-marginal) 评分。这些方法有助于更深入地理解突变如何影响蛋白质结构和功能,为蛋白质工程和疾病分析研究提供宝贵见解。

其中,掩码边际评分函数因其与功能效应的显著相关性而尤其突出。通过计算野生型和突变型残基之间的对数概率差,该函数提供了一种量化突变影响的方法。单个突变得分的总和给出了它们集体效应的总体估计,提供了一种方便的方法来同时评估多个突变。

我们的脚本便于这些评分方法的轻松集成和应用。通过输入蛋白质序列和所需的突变列表,用户可以快速获得指示潜在功能变化的评分。该脚本在选择不同 ESM 模型和评分策略方面的灵活性允许根据特定的研究需求进行定制分析。此外,使用 ipywidgets 实现的交互式热图可视化提供了突变景观的直观图形表示。通过突出显示潜在功能意义的区域,研究人员可以快速识别需要进一步调查的关键区域。

总而言之,本文讨论的工具和方法为评估蛋白质突变的影响提供了一种强大的方法。它们融合了计算效率和生物学洞察力,这在快速发展的生物信息学领域是无价的。随着这些模型不断改进,我们可以期待更准确、更深入的预测,进一步增进我们对蛋白质动力学和疾病机制的理解。

社区

注册登录 以评论