ESMBind (ESMB) 集成模型

社区文章 发布于 2023 年 9 月 22 日

内容提要: 在这篇文章中,我们将讨论如何使用 ESMBind (ESMB) 模型构建一个基本的集成模型。我们将采用“硬”投票和“软”投票两种策略。我们将向您展示如何在一个预处理过的、已分割为训练/测试集的蛋白质序列数据集上计算训练和测试指标。请注意,以下内容纯粹用于演示目的。这些模型未经充分测试,并且似乎存在过拟合(请参见下文的精确度、F1 分数和 MCC)。

image/png

介绍

请注意,由于使用集成模型带来的内存限制,您可能需要在本地或 Google Colab Pro 实例上运行此代码示例。您也可以尝试使用 P100 GPU 的 Kaggle notebook。另一种选择是,使用我们之前的文章,以 esm2_t6_8M_UR50D 为基础模型,训练两个或更多个较小的 ESMB 模型。 回想一下,在这篇文章中,我们展示了如何使用低秩适应 (LoRA) 来微调一个结合位点预测器。我们会在这里回顾一些信息,但除非您已经熟悉 LoRA 和集成模型的基础知识,否则在继续之前最好先阅读那篇文章。此外,请注意本文纯粹用于演示目的。为了获得更好的集成模型,您应该使用上一篇文章中给出的示例,并采用不同的超参数来训练您自己的模型。

ESMBind(或 ESMB)是一系列微调模型的集合,它们在基础模型 ESM-2 之上使用低秩适应 (LoRA) 进行微调,旨在仅基于单个蛋白质的序列来预测其结合位点。它不需要多序列比对 (MSA) 或任何关于蛋白质 3D 折叠或主链结构的结构信息。这使得 ESMB 模型易于访问、使用简单,并且应用和理解所需的领域知识较少,从而更具可解释性。然而,这可能会以牺牲性能为代价。

请记住,我们在上面链接的文章中展示了如何对蛋白质语言模型 (pLM) ESM-2 使用低秩适应 (LoRA)。LoRA 是一种技术,已被证明可以显著改善 pLM esm2_t12_35M_UR50D 的过拟合问题(另请参阅 Hugging Face 上的 ESM)。这也使我们能够以参数高效的方式微调更大的模型。下面,我们将为您提供代码,既可以在用于集成中单个模型训练/测试分割的预处理数据集上获取训练/测试指标,也可以在您自己的蛋白质序列上运行推理。

训练/测试数据集

在开始之前,请下载以下 pickle 文件,然后在代码中调整下面的路径以匹配您的本地文件路径。

在一个大型预处理数据集上获取训练/测试指标

注意,这段代码可以在 Google Colab 或 Kaggle 实例中运行。但是,如果使用 Colab 提供的免费 GPU,代码的第一部分需要几个小时才能运行完毕。用于在单个蛋白质序列或少量蛋白质上测试集成模型的推理部分代码,运行时间仅需几秒钟。因此,如果您只想在少数蛋白质序列上测试模型,可以跳到最后一节“推理”。

步骤 0:安装和导入

!pip install transformers -q 
!pip install datasets -q 
!pip install accelerate -q 
!pip install scipy -q
!pip install scikit-learn -q
!pip install peft -q 
import os
import pickle
import numpy as np
from scipy import stats
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import AutoModelForTokenClassification, Trainer, AutoTokenizer, DataCollatorForTokenClassification
from datasets import Dataset, concatenate_datasets
from accelerate import Accelerator
from peft import PeftModel
import gc

步骤 1:加载数据

在这一步,您将从 pickle 文件中加载训练和测试数据集的序列及标签。这些数据集将分别用于训练和评估您的模型。

# Step 1: Load train/test data and labels from pickle files
with open("/content/drive/MyDrive/train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("/content/drive/MyDrive/test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("/content/drive/MyDrive/train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("/content/drive/MyDrive/test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

步骤 2:批量分词和数据集创建

在这一步,使用预训练的分词器对序列进行分词。分词是将输入文本转换为标记(即整数值)的过程。然后,分词后的序列和标签被用来创建数据集。

# Step 2: Define the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
max_sequence_length = tokenizer.model_max_length

步骤 3:分批计算指标以节省内存

# Step 3: Define a `compute_metrics_for_batch` function.
def compute_metrics_for_batch(sequences_batch, labels_batch, models, voting='hard'):
    # Tokenize batch
    batch_tokenized = tokenizer(sequences_batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
    
    batch_dataset = Dataset.from_dict({k: v for k, v in batch_tokenized.items()})
    batch_dataset = batch_dataset.add_column("labels", labels_batch[:len(batch_dataset)])
    
    # Convert labels to numpy array of shape (1000, 1002)
    labels_array = np.array([np.pad(label, (0, 1002 - len(label)), constant_values=-100) for label in batch_dataset["labels"]])
    
    # Initialize a trainer for each model
    data_collator = DataCollatorForTokenClassification(tokenizer)
    trainers = [Trainer(model=model, data_collator=data_collator) for model in models]
    
    # Get the predictions from each model
    all_predictions = [trainer.predict(test_dataset=batch_dataset)[0] for trainer in trainers]
    
    if voting == 'hard':
        # Hard voting
        hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
        ensemble_predictions = stats.mode(hard_predictions, axis=0)[0][0]
    elif voting == 'soft':
        # Soft voting
        avg_predictions = np.mean(all_predictions, axis=0)
        ensemble_predictions = np.argmax(avg_predictions, axis=2)
    else:
        raise ValueError("Voting must be either 'hard' or 'soft'")
    
    print("Shape of ensemble_predictions:", ensemble_predictions.shape)  # Debug print
    
    # Use broadcasting to create 2D mask
    mask_2d = labels_array != -100
    
    # Filter true labels and predictions using the mask
    true_labels_list = [label[mask_2d[idx]] for idx, label in enumerate(labels_array)]
    true_labels = np.concatenate(true_labels_list)
    flat_predictions_list = [ensemble_predictions[idx][mask_2d[idx]] for idx in range(ensemble_predictions.shape[0])]
    flat_predictions = np.concatenate(flat_predictions_list).tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}

步骤 4:定义一个函数以分批进行评估

#Step 4: Evaluate in Batches
def evaluate_in_batches(sequences, labels, models, voting='hard', batch_size=1000):
    num_batches = len(sequences) // batch_size + int(len(sequences) % batch_size != 0)
    metrics_list = []
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        batch_metrics = compute_metrics_for_batch(sequences[start_idx:end_idx], labels[start_idx:end_idx], models, voting)
        
        # Print metrics for the first five batches
        if i < 5:
            print(f"Batch {i+1}/{num_batches} metrics: {batch_metrics}")
        
        metrics_list.append(batch_metrics)
    
    # Average metrics over all batches
    avg_metrics = {key: np.mean([metrics[key] for metrics in metrics_list]) for key in metrics_list[0]}
    return avg_metrics

步骤 5:定义集成模型

# Load pre-trained base model and fine-tuned LoRA models
accelerator = Accelerator()
base_model_path = "facebook/esm2_t12_35M_UR50D"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
lora_model_paths = [
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
    # Add more models or swap out for your own models
]
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]
models = [accelerator.prepare(model) for model in models]

步骤 6:集成投票和指标计算

# Step 5: Compute and print the metrics
train_metrics_hard = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='hard')
test_metrics_hard = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='hard')
train_metrics_soft = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='soft')
test_metrics_soft = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='soft')

train_metrics_hard, test_metrics_hard, train_metrics_soft, test_metrics_soft

然后,这将打印出类似以下内容:

train - Batch 1/451 metrics: {'accuracy': 0.9907783025067246, 'precision': 0.7792440817271516, 'recall': 0.9714265098491954, 'f1': 0.8647867420349434, 'auc': 0.9814053346312887, 'mcc': 0.8656769123429833}

train - Batch 2/451 metrics: {'accuracy': 0.9906862419735746, 'precision': 0.7686626071267478, 'recall': 0.9822046109510086, 'f1': 0.8624114372469636, 'auc': 0.9865753167670478, 'mcc': 0.8645747724704963}

train - Batch 3/451 metrics: {'accuracy': 0.9907034630406232, 'precision': 0.7662082514734774, 'recall': 0.9884141926140478, 'f1': 0.8632411067193676, 'auc': 0.9895938451445732, 'mcc': 0.8659743174909746}

train - Batch 4/451 metrics: {'accuracy': 0.991028787153535, 'precision': 0.7751964275620372, 'recall': 0.9881115354132142, 'f1': 0.8687994931897371, 'auc': 0.9896153675458282, 'mcc': 0.871052392709521}

train - Batch 5/451 metrics: {'accuracy': 0.9901174908557153, 'precision': 0.7585922916437905, 'recall': 0.9865762227775794, 'f1': 0.8576926658183058, 'auc': 0.988401969496207, 'mcc': 0.8605718730416185}

之后,将需要漫长等待训练批次完成,然后会打印出前五个测试批次的指标,这些指标将与训练指标相似。

test - Batch 1/114 metrics: {'accuracy': 0.9410464672512716, 'precision': 0.37514282087088996, 'recall': 0.8439481350317016, 'f1': 0.5194051887787388, 'auc': 0.8944018149939027, 'mcc': 0.5392923907809524}

test - Batch 2/114 metrics: {'accuracy': 0.938214353140821, 'precision': 0.361414131305044, 'recall': 0.8304587788892721, 'f1': 0.5036435270736724, 'auc': 0.886450001724052, 'mcc': 0.5233747173742583}

test - Batch 3/114 metrics: {'accuracy': 0.9411384591024733, 'precision': 0.3683750578316969, 'recall': 0.8300225864365552, 'f1': 0.5102807398572268, 'auc': 0.8877119446522322, 'mcc': 0.5294666106367614}

test - Batch 4/114 metrics: {'accuracy': 0.9403683315585174, 'precision': 0.369614054572532, 'recall': 0.8394290300389818, 'f1': 0.5132402166102942, 'auc': 0.8918623875782199, 'mcc': 0.5334084101768152}

test - Batch 5/114 metrics: {'accuracy': 0.9400765476285562, 'precision': 0.37219051467245823, 'recall': 0.8356296422294041, 'f1': 0.514999563204333, 'auc': 0.8899200984461443, 'mcc': 0.5337721026971387}

对于软投票策略,这个过程也将重复。在为软投票策略的每个训练和测试批次再次漫长等待后,您应该会得到所有批次的训练和测试指标的平均值。

推理

最后,我们可以像下面的代码一样,对感兴趣的蛋白质进行推理。这部分代码可以独立于本文中的其他代码运行,并且应该只需要几秒钟。

from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
from datasets import Dataset
from peft import PeftModel
import numpy as np
from scipy import stats

# ESM-2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Paths to the saved LoRA models
lora_model_paths = [
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
    # add paths to other models
]

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the models
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]

# Define the new protein sequence
new_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"

# Step 1 and 2: Tokenization and Dataset creation
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
tokenized_inputs = tokenizer(new_sequence, return_tensors="pt", truncation=True, padding=True, is_split_into_words=False)
new_dataset = Dataset.from_dict({k: v for k, v in tokenized_inputs.items()})

# Step 3: Create trainer objects for each model in the ensemble
data_collator = DataCollatorForTokenClassification(tokenizer)
trainers = [Trainer(model=model, data_collator=data_collator) for model in models]

# Step 4: Getting predictions from each model and applying voting strategies
all_predictions = [trainer.predict(test_dataset=new_dataset)[0] for trainer in trainers]

# Hard voting
hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
ensemble_predictions_hard = stats.mode(hard_predictions, axis=0)[0][0]

# Soft voting
avg_predictions = np.mean(all_predictions, axis=0)
ensemble_predictions_soft = np.argmax(avg_predictions, axis=2)

# Print the final predictions obtained using hard and soft voting
print("Hard voting predictions:", ensemble_predictions_hard)
print("Soft voting predictions:", ensemble_predictions_soft)

这将打印出类似以下内容:

Hard voting predictions: [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 1 1 0 0 1 1 1 1 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0]
Soft voting predictions: [[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 1 1 0 0 1 1 1 1 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0]]

在这里,1 表示集成模型预测的结合位点,0 表示集成模型预测的非结合位点。接下来,为了得到更适合为您的蛋白质设计结合配体的信息,请运行以下代码:

# Convert token IDs back to amino acid residues
residues = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0])

# Print the amino acid residues and their positions for binding sites using hard voting
binding_sites_hard = [(idx, residue) for idx, (label, residue) in enumerate(zip(ensemble_predictions_hard[0], residues)) if label == 1]
print("Binding sites (Hard voting):")
for position, residue in binding_sites_hard:
    print(f"{residue}{position}")

# Print the amino acid residues and their positions for binding sites using soft voting
binding_sites_soft = [(idx, residue) for idx, (label, residue) in enumerate(zip(ensemble_predictions_soft[0], residues)) if label == 1]
print("\nBinding sites (Soft voting):")
for position, residue in binding_sites_soft:
    print(f"{residue}{position}")

这将打印出类似以下内容:

Binding sites (Hard voting):
P8
N9
H10
I12
Y13
I14
N15
N16
L17
N18
E19
K20
K22
F34
G38
L41
L44
V45
S46
R47
S48
L49
K50
M51
R52
G53
Q54
A55
F59
Q73
G74
Y78
D79
K80
P81
M82
I84
Q85
Y86
A87
K88
T89
D90

Binding sites (Soft voting):
P8
N9
H10
I12
Y13
I14
N15
N16
L17
N18
E19
K20
K22
F34
G38
L41
L44
V45
S46
R47
S48
L49
K50
M51
R52
G53
Q54
A55
F59
Q73
G74
Y78
D79
K80
P81
M82
I84
Q85
Y86
A87
K88
T89
D90

使用 RFDiffusion 为您的蛋白质设计结合物

RFDiffusion 是一个生成 3D 蛋白质结构的扩散模型。这在概念上类似于 Stable Diffusion 和 Dall-E 等扩散模型,但它是针对蛋白质的。它的架构与 Stable Diffusion 不同(使用 RosettaFold 作为骨干模型,而不是 Stable Diffusion 中使用的 UNet)。

一旦您获得了结合位点的预测,您应该前往 RFDiffusion Notebook,并使用模型预测的结合位点的某个子集作为结合物的“热点”来为您的蛋白质设计一个结合物。您首先需要一个蛋白质的 PDB 文件。要获取一个,请前往 ESM Metagenomic Atlas 网站上的 ESMFold 工具。选择“Fold Sequence”,然后粘贴您的蛋白质序列以进行折叠,并按回车。一旦您的蛋白质被折叠,您应该会得到一个 3D 结构

image/png

现在您可以下载您的 PDB 文件了。下载后,将其上传到 RFDiffusion Google Colab notebook,并在 RFDiffusion notebook 中使用您上传的 PDB 文件的路径来为您的蛋白质设计一个结合物。使用以下设置

%%time
#@title run **RFdiffusion** to generate a backbone
name = "test" #@param {type:"string"}
contigs = "100" #@param {type:"string"}
pdb = "/content/unnamed.pdb" #@param {type:"string"}
iterations = 50 #@param ["25", "50", "100", "150", "200"] {type:"raw"}
hotspot = "A41,A44,A45,A46" #@param {type:"string"}
num_designs = 1 #@param ["1", "2", "4", "8", "16", "32"] {type:"raw"}
visual = "interactive" #@param ["none", "image", "interactive"]
#@markdown ---
#@markdown **symmetry** settings
#@markdown ---
symmetry = "cyclic" #@param ["none", "auto", "cyclic", "dihedral"]
order = 3 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"] {type:"raw"}
chains = "" #@param {type:"string"}
add_potential = True #@param {type:"boolean"}
#@markdown - `symmetry='auto'` enables automatic symmetry dectection with [AnAnaS](https://team.inria.fr/nano-d/software/ananas/).
#@markdown - `chains="A,B"` filter PDB input to these chains (may help auto-symm detector)
#@markdown - `add_potential` to discourage clashes between chains

# determine where to save
path = name
while os.path.exists(f"outputs/{path}_0.pdb"):
  path = name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))

flags = {"contigs":contigs,
         "pdb":pdb,
         "order":order,
         "iterations":iterations,
         "symmetry":symmetry,
         "hotspot":hotspot,
         "path":path,
         "chains":chains,
         "add_potential":add_potential,
         "num_designs":num_designs,
         "visual":visual}

for k,v in flags.items():
  if isinstance(v,str):
    flags[k] = v.replace("'","").replace('"','')

contigs, copies = run_diffusion(**flags)

您将得到一个像下面这样的环状蛋白质

image/png

您可以运行 RFDiffusion Colab notebook 的其余部分,以获得一个能够折叠成您生成的结构的序列并进行验证。就是这样!您已成功设计出一个被预测能够与您感兴趣的蛋白质沿着“热点”结合的蛋白质,也就是沿着由 ESMBind 模型或模型集成预测的结合位点子集给出的感兴趣位点。请务必阅读 RFDiffusion Github 上链接的 RFDiffusion 论文,并通过给他们的 Github 点赞来向 RFDiffusion 的开发者们表达支持。他们构建了一个了不起的蛋白质扩散模型!您还可以在 Neurosnap 上与更多蛋白质相关模型(包括 RFDiffusion)互动。

社区

注册登录 发表评论