使用蛋白质语言模型和线性求和分配预测蛋白质-蛋白质相互作用

社区文章 发布于2023年10月15日

TLDR: 这篇博客文章是关于使用ESM-2(一种蛋白质语言模型)通过掩码语言建模损失来评估蛋白质对,以预测有高结合可能性的蛋白质对。

image/png

在论文使用掩码语言建模配对相互作用的蛋白质序列中,作者提出了一种使用MSA Transformer或ESM-2这两种蛋白质语言模型之一来预测一对蛋白质相互结合可能性的方法。在这篇文章中,我们将重点关注ESM-2,因为与需要多序列比对(MSA)的MSA Transformer相比,ESM-2方法更快、更容易使用。该方法非常简单。我们取一个我们想要测试相互作用的蛋白质列表,然后将它们配对连接起来。一旦创建了蛋白质对,我们使用ESM-2的掩码语言建模功能,随机掩码残基,然后计算MLM损失。我们对每对蛋白质进行多次迭代的平均计算,获得一个分数,该分数表示两种蛋白质相互结合的可能性。然后,我们计算一个这样的分数矩阵。使用这个矩阵,我们能够解决相关的线性分配问题,以计算最佳结合伙伴。

引言

预测蛋白质-蛋白质相互作用是分子生物学中的一项关键任务。在这里,我们将使用Meta AI的ESM-2模型计算蛋白质对的掩码语言模型(MLM)损失,旨在找到损失最低的对。其原理是,实际相互作用的蛋白质将产生比不相互作用的蛋白质更低的MLM损失。

1. 导入必要的库

import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
  • numpy:用于数值运算,特别是处理矩阵。
  • linear_sum_assignment:这是SciPy库中的一个优化算法,用于解决线性求和分配问题。我们将用它来根据MLM损失找到蛋白质的最佳配对。
  • transformers:这是Hugging Face的库,提供了用于各种NLP任务的预训练模型。我们用它来加载ESM-2模型及其分词器。
  • torch:PyTorch库,transformers库基于它构建。

2. 初始化模型和分词器

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
  • 分词器:用于将蛋白质序列转换为适合模型处理的格式。
  • 模型:ESM-2模型,专门为蛋白质序列构建。

3. 设置GPU

使用GPU可以大大加快计算速度。以下代码检查GPU是否可用,并将模型设置为在GPU上运行。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. 定义蛋白质序列

在本教程中,我们选择了一组蛋白质序列进行分析。

all_proteins = [
    "MEESQSELNIDPPLSQETFSELWNLLPENNVLSSELCPAVDELLLPESVVNWLDEDSDDAPRMPATSAPTAPGPAPSWPLSSSVPSPKTYPGTYGFRLGFLHSGTAKSVTWTYSPLLNKLFCQLAKTCPVQLWVSSPPPPNTCVRAMAIYKKSEFVTEVVRRCPHHERCSDSSDGLAPPQHLIRVEGNLRAKYLDDRNTFRHSVVVPYEPPEVGSDYTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNVLGRNSFEVRVCACPGRDRRTEEENFHKKGEPCPEPPPGSTKRALPPSTSSSPPQKKKPLDGEYFTLQIRGRERYEMFRNLNEALELKDAQSGKEPGGSRAHSSHLKAKKGQSTSRHKKLMFKREGLDSD",
    "MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDEKQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSENRCHLEGGSDQKDLVQELQEEKPSSSHLVSRPSTSSRRRAISETEENSDELSGERQRKRHKSDSISLSFDESLALCVIREICCERSSSSESTGTPSNPDLDAGVSEHSGDWLDQDSVSDQFSVEFEVESLDSEDYSLSEEGQELSDEDDEVYQVTVYQAGESDTDSFEEDPEISLADYWKCTSCNEMNPPLPSHCNRCWALRENWLPEDKGKDKGEISEKAKLENSTQAEEGFDVPDCKKTIVNDSRESCVEENDDKITQASQSQESEDYSQPSTSSSIIYSSQEDVKEFEREETQDKEESVESSLPLNAIEPCVICQGRPKNGCIVHGKTGHLMACFTCAKKLKKRNKPCPVCRQPIQMIVLTYFP",
    "MNRGVPFRHLLLVLQLALLPAATQGKKVVLGKKGDTVELTCTASQKKSIQFHWKNSNQIKILGNQGSFLTKGPSKLNDRADSRRSLWDQGNFPLIIKNLKIEDSDTYICEVEDQKEEVQLLVFGLTANSDTHLLQGQSLTLTLESPPGSSPSVQCRSPRGKNIQGGKTLSVSQLELQDSGTWTCTVLQNQKKVEFKIDIVVLAFQKASSIVYKKEGEQVEFSFPLAFTVEKLTGSGELWWQAERASSSKSWITFDLKNKEVSVKRVTQDPKLQMGKKLPLHLTLPQALPQYAGSGNLTLALEAKTGKLHQEVNLVVMRATQLQKNLTCEVWGPTSPKLMLSLKLENKEAKVSKREKAVWVLNPEAGMWQCLLSDSGQVLLESNIKVLPTWSTPVQPMALIVLGGVAGLLLFIGLGIFFCVRCRHRRRQAERMSQIKRLLSEKKTCQCPHRFQKTCSPI",
    "MRVKEKYQHLWRWGWKWGTMLLGILMICSATEKLWVTVYYGVPVWKEATTTLFCASDAKAYDTEVHNVWATHACVPTDPNPQEVVLVNVTENFNMWKNDMVEQMHEDIISLWDQSLKPCVKLTPLCVSLKCTDLGNATNTNSSNTNSSSGEMMMEKGEIKNCSFNISTSIRGKVQKEYAFFYKLDIIPIDNDTTSYTLTSCNTSVITQACPKVSFEPIPIHYCAPAGFAILKCNNKTFNGTGPCTNVSTVQCTHGIRPVVSTQLLLNGSLAEEEVVIRSANFTDNAKTIIVQLNQSVEINCTRPNNNTRKSIRIQRGPGRAFVTIGKIGNMRQAHCNISRAKWNATLKQIASKLREQFGNNKTIIFKQSSGGDPEIVTHSFNCGGEFFYCNSTQLFNSTWFNSTWSTEGSNNTEGSDTITLPCRIKQFINMWQEVGKAMYAPPISGQIRCSSNITGLLLTRDGGNNNNGSEIFRPGGGDMRDNWRSELYKYKVVKIEPLGVAPTKAKRRVVQREKRAVGIGALFLGFLGAAGSTMGARSMTLTVQARQLLSGIVQQQNNLLRAIEAQQHLLQLTVWGIKQLQARILAVERYLKDQQLLGIWGCSGKLICTTAVPWNASWSNKSLEQIWNNMTWMEWDREINNYTSLIHSLIEESQNQQEKNEQELLELDKWASLWNWFNITNWLWYIKIFIMIVGGLVGLRIVFAVLSIVNRVRQGYSPLSFQTHLPTPRGPDRPEGIEEEGGERDRDRSIRLVNGSLALIWDDLRSLCLFSYHRLRDLLLIVTRIVELLGRRGWEALKYWWNLLQYWSQELKNSAVSLLNATAIAVAEGTDRVIEVVQGACRAIRHIPRRIRQGLERILL",
    "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", 
    "MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSFPKLIMITDYLLLFRVYGLESLKDLFPNLTVIRGSRLFFNYALVIFEMVHLKELGLYNLMNITRGSVRIEKNNELCYLATIDWSRILDSVEDNYIVLNKDDNEECGDICPGTAKGKTNCPATVINGQFVERCWTHSHCQKVCPTICKSHGCTAEGLCCHSECLGNCSQPDDPTKCVACRNFYLDGRCVETCPPPYYHFQDWRCVNFSFCQDLHHKCKNSRRQGCHQYVIHNNKCIPECPSGYTMNSSNLLCTPCLGPCPKVCHLLEGEKTIDSVTSAQELRGCTVINGSLIINIRGGNNLAAELEANLGLIEEISGYLKIRRSYALVSLSFFRKLRLIRGETLEIGNYSFYALDNQNLRQLWDWSKHNLTITQGKLFFHYNPKLCLSEIHKMEEVSGTKGRQERNDIALKTNGDQASCENELLKFSYIRTSFDKILLRWEPYWPPDFRDLLGFMLFYKEAPYQNVTEFDGQDACGSNSWTVVDIDPPLRSNDPKSQNHPGWLMRGLKPWTQYAIFVKTLVTFSDERRTYGAKSDIIYVQTDATNPSVPLDPISVSNSSSQIILKWKPPSDPNGNITHYLVFWERQAEDSELFELDYCLKGLKLPSRTWSPPFESEDSQKHNQSEYEDSAGECCSCPKTDSQILKELEESSFRKTFEDYLHNVVFVPRKTSSGTGAEDPRPSRKRRSLGDVGNVTVAVPTVAAFPNTSSTSVPTSPEEHRPFEKVVNKESLVISGLRHFTGYRIELQACNQDTPEERCSVAAYVSARTMPEAKADDIVGPVTHEIFENNVVHLMWQEPKEPNGLIVLYEVSYRRYGDEELHLCVSRKHFALERGCRLRGLSPGNYSVRIRATSLAGNGSWTEPTYFYVTDYLDVPSNIAKIIIGPLIFVFLFSVVIGSIYLFLRKRQPDGPLGPLYASSNPEYLSASDVFPCSVYVPDEWEVSR" 
]  # A list of protein sequences

5. 定义MLM损失函数

我们的目标是找到产生最低MLM损失的蛋白质对。我们将定义一个函数来计算一批蛋白质对的MLM损失。

BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15

# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
    avg_losses = []
    for _ in range(NUM_MASKS):
        # Tokenize the concatenated protein pairs
        inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)
        
        # Move input tensors to GPU if available
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the mask token ID
        mask_token_id = tokenizer.mask_token_id
        
        # Clone input IDs for labels
        labels = inputs["input_ids"].clone()
        
        # Randomly mask 15% of the residues for each sequence in the batch
        for idx in range(inputs["input_ids"].shape[0]):
            mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
            inputs["input_ids"][idx, mask_indices] = mask_token_id
            labels[idx, [i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100
        
        # Compute the MLM loss
        outputs = model(**inputs, labels=labels)
        avg_losses.append(outputs.loss.item())
    
    # Return the average loss for the batch
    return sum(avg_losses) / NUM_MASKS

在函数内部

  • 我们对每个蛋白质对进行分词。
  • 每条序列中15%的残基被随机掩码。这可以根据您的偏好进行调整,并且应该测试不同的百分比。
  • 我们使用ESM-2模型计算MLM损失。
  • 此过程对每对执行NUM_MASKS次,并计算平均损失。

这里,需要注意的是,使用更大的模型和更长的上下文窗口可能会改善结果,但也会需要更多的计算资源。目前,这可以在免费的Colab实例中运行。如果您想使用更大的模型和更长的上下文窗口,您可以考虑使用其他ESM-2模型,例如esm2_t36_3B_UR50D。您还应该尝试调整上述代码中的max_length。上面的上下文窗口对于我们这里选择的长蛋白质来说并不真正足够,使用更大的模型和更长的上下文窗口,或者使用更小的蛋白质几乎肯定会产生更好的结果,但这应该能让您开始。

6. 构建损失矩阵

我们需要计算每种蛋白质配对的MLM损失。

# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))

for i in range(len(all_proteins)):
    for j in range(i+1, len(all_proteins), BATCH_SIZE):  # to avoid self-pairing and use batches
        pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
        batch_loss = compute_mlm_loss_batch(pairs)
        for k in range(len(pairs)):
            loss_matrix[i, j+k] = batch_loss
            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric

# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)

对于每一对,计算MLM损失并存储在`loss_matrix`中。

7. 寻找最佳配对

最后,我们使用linear_sum_assignment函数来找到蛋白质的最佳配对。

# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))

print(optimal_pairs)

这将打印以下内容

[(0, 1), (1, 0), (2, 5), (3, 4), (4, 3), (5, 2)]

此函数尝试找到最小化总MLM损失的最佳分配。


下一步

现在,您可以考虑选择一个感兴趣的蛋白质,然后按照本教程计算该蛋白质的结合位点。一旦您有了结合位点,您可以前往RFDiffusion笔记本,为您的蛋白质设计几个结合伙伴。一旦您设计了几个结合伙伴,您就可以使用此代码来测试哪些结合伙伴对您的蛋白质具有最高的结合亲和力,方法是计算如上所述的您感兴趣的蛋白质与各种潜在结合伙伴串联后的MLM损失。请记住,您的蛋白质越长,您需要的上下文窗口就越长。以下是使用此方法对固定蛋白质的结合伙伴进行排名的一个示例。

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM"
potential_binders = [
    "MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL", 
    "MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ", 
    "EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN"
]  # Add potential binding sequences here

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0
    
    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
        
        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")

这将打印以下内容

Ranking of Potential Binders:
1. MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL
- MLM Loss: 8.279285748799643
2. MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ
- MLM Loss: 12.046218872070312
3. EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN
- MLM Loss: 12.776518821716309

我们还可以使用文献中用于复合物的一种技巧,即用大约20个G氨基酸组成的长字符串来代替冒号(:)来分隔两种蛋白质。这可能会导致一些不同的预测,值得测试。我们还可以决定优先掩盖由我们的结合位点预测模型(如上所述)预测的结合位点,因为蛋白质语言模型已知会更多地关注结合残基和其他特殊残基。该方法的另一个变体可能是只对表示潜在蛋白质复合物的每个串联序列对中的其中一个序列应用掩码。所有这些都可能提高该方法的性能。尝试这些变体,看看哪种方法效果最好!

构建PPI网络

我们还可以基于这种方法构建蛋白质-蛋白质相互作用(PPI)网络。我们只需根据蛋白质对的MLM损失计算阈值来创建图。如果损失低于某个阈值,则图中这两蛋白质之间存在一条边。这提供了一种非常快速的方法来近似蛋白质相互作用组并找到候选相互作用。我们可以这样操作:

import networkx as nx
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import plotly.graph_objects as go
from ipywidgets import interact
from ipywidgets import widgets

# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"  # You can change this to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Send the model to the device (GPU or CPU)
model.to(device)

# Ensure the model is in evaluation mode
model.eval()

# Define Protein Sequences (Replace with your list)
all_proteins = [
    "MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
    "MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV", 
    "MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ", 
    
    "MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL", 
    "MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS", 
    "MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV", 
    
    "MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN", 
    "MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE", 
    "MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND", 
    
    "MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV", 
    "MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH", 
    "MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD", 
    
    "MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS", 
    "MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD", 
    "MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
]

def compute_average_mlm_loss(protein1, protein2, iterations=10):
    total_loss = 0.0
    connector = "G" * 25  # Connector sequence of G's
    for _ in range(iterations):
        concatenated_sequence = protein1 + connector + protein2
        inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
        
        mask_prob = 0.55
        mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
        
        # Locate the positions of the connector 'G's and set their mask indices to False
        connector_indices = tokenizer.encode(connector, add_special_tokens=False)
        connector_length = len(connector_indices)
        start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
        end_connector = start_connector + connector_length
        
        # Avoid masking the connector 'G's
        mask_indices[0, start_connector:end_connector] = False
        
        # Apply the mask to the input IDs
        inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Send inputs to the device

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
        
        loss = outputs.loss
        total_loss += loss.item()
    
    return total_loss / iterations

# Compute all average losses to determine the maximum threshold for the slider
all_losses = []
for i, protein1 in enumerate(all_proteins):
    for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
        avg_loss = compute_average_mlm_loss(protein1, protein2)
        all_losses.append(avg_loss)

# Set the maximum threshold to the maximum loss computed
max_threshold = max(all_losses)
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")

def plot_graph(threshold):
    G = nx.Graph()

    # Add all protein nodes to the graph
    for i, protein in enumerate(all_proteins):
        G.add_node(f"protein {i+1}")

    # Loop through all pairs of proteins and calculate average MLM loss
    loss_idx = 0  # Index to keep track of the position in the all_losses list
    for i, protein1 in enumerate(all_proteins):
        for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
            avg_loss = all_losses[loss_idx]
            loss_idx += 1
            
            # Add an edge if the loss is below the threshold
            if avg_loss < threshold:
                G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))

    # 3D Network Plot
    # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
    k_value = 2  # Lower value will bring nodes closer together
    pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)

    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])
    
    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
    
    node_x = []
    node_y = []
    node_z = []
    node_text = []
    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(node)
    
    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
    
    layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))

    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    fig.show()

# Create an interactive slider for the threshold value with a default of 8.50
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))

这将打印一个表示PPI网络的交互式3D图。您可以调整MLM损失值的阈值滑块,使相互作用的要求更严格或更宽松。

image/png

例如,尝试将滑块设置为 8.208.30,看看这种MLM损失阈值选择会产生什么样的预测互作组。您还应该调整掩码令牌的数量,看看这对图有什么影响。一般来说,掩码的残基越多,预测的PPI图中出现连接所需的阈值就越高。请注意,此代码可能需要一些时间才能运行,因为我们正在计算列表中所有蛋白质对的损失,但总的来说,该方法是一种非常快速的零样本预测PPI网络的方法。下一步,您还可以考虑训练模型来预测已知相互作用对的掩码残基,以便进一步微调此任务。另一个有趣且重要的问题是蛋白质长度如何影响此计算。该方法对长度的巨大变化是否具有鲁棒性,或者蛋白质是否需要相似长度才能使该方法生效?

结论

现在,您应该能够使用ESM-2模型通过比较不同蛋白质配对的MLM损失来预测潜在的蛋白质-蛋白质相互作用。这种方法提供了一种利用深度学习技术推断相互作用的新颖方式。请记住,这种方法提供了一种启发式方法,应与实验验证相结合以获得结论性结果。下一步,您可以尝试实现PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling中的想法,该模型使用掩码语言建模微调ESM-2以生成结合伙伴,或者您可以尝试以类似方式在结合伙伴的串联对上微调ESM-2,其中每个结合伙伴中的一部分令牌被掩码,以查看性能是否有所提高。

社区

注册登录 发表评论