ESM-2 的 QLoRA 和翻译后修饰位点预测

社区文章 发布于 2023 年 11 月 11 日

在这篇文章中,我们将向您展示如何使用来自 UniProt 的翻译后修饰位点数据来训练您自己的 ESM-2 QLoRA 模型,该任务被视为一个二元标记分类任务。我们将首先介绍如何从 UniProt 收集数据并根据 UniProt 家族创建训练/测试数据集分割。这将有助于避免由于标准随机训练/测试分割中可能出现的序列相似性导致的过拟合。一旦我们创建了训练和测试数据集,我们将向您展示如何对蛋白质语言模型 ESM-2 进行 QLoRA 微调,以预测蛋白质序列中可能发生翻译后修饰的位置。

image/png

什么是 ESM-2?

ESM-2(进化尺度建模)等蛋白质语言模型代表了计算生物学的一项重大进步。ESM-2 是一种深度学习模型,旨在理解蛋白质的“语言”,即控制蛋白质中氨基酸序列的结构和功能的模式和规则,这有点类似于 ChatGPT 理解人类语言的方式,但其目标是掩码语言建模而不是自回归(因果语言模型)目标,这更适合某些与蛋白质相关的任务。该模型可以通过将问题视为二元标记分类任务来微调,以预测一般的翻译后修饰位点,其中蛋白质序列中的每个氨基酸都被视为一个标记。

微调过程涉及在已知 PTM 位点的数据集上训练模型,使模型能够学习与这些修饰相关的上下文模式。通过这样做,ESM-2 可以预测新颖的、未见的蛋白质序列中的每个氨基酸(标记)是否可能发生特定修饰。这种二元分类对于识别蛋白质中潜在的 PTM 位点至关重要,这有助于更详细地理解蛋白质功能和调控。

翻译后修饰 (PTM) 简介

蛋白质的翻译后修饰 (PTM) 是细胞生物学的一个关键方面,它显著影响蛋白质的功能和调控。PTM 是指蛋白质合成后的化学修饰。这些修饰通常发生在蛋白质在核糖体合成(蛋白质以氨基酸线性链的形式生成)之后。PTM 的最常见形式包括磷酸化、糖基化、泛素化、亚硝化、甲基化、乙酰化、脂化和蛋白水解切割。

PTM 的重要性在于它们能够使蛋白质功能多样化,超越基因序列本身的规定。它们在调节蛋白质活性、稳定性、定位以及与其他细胞分子的相互作用中发挥着至关重要的作用。PTM 可以改变蛋白质的物理和化学性质,从而影响它们的折叠、构象、分布以及与其他蛋白质和 DNA 的相互作用。这对于无数的细胞过程至关重要,包括信号转导、细胞周期控制、代谢途径和免疫反应。

PTM 用于各种生物学和医学应用。在药物发现和开发中,了解 PTM 可以导致新药靶点和治疗策略的识别。此外,异常 PTM 通常与癌症、神经退行性疾病和代谢疾病等疾病相关,这使得它们成为诊断和治疗的潜在生物标志物。

数据整理和预处理

首先,前往 UniProt,并在搜索栏中选择“高级”。接下来,当出现高级搜索选项时,选择“PTM/处理”,然后选择“修饰残基”。在搜索字段中输入*(删除所有额外的搜索字段后),然后选择“搜索”。完成此操作后,您将获得一个包含修饰氨基酸残基的蛋白质列表。您可以通过在表格视图中选择“自定义列”来定制表格布局以反映此内容。您应该自定义列,使其仅包含蛋白质序列、蛋白质家族和修饰残基。接下来,下载此数据,确保仅包含蛋白质序列、“蛋白质家族”和修饰残基。请务必包含蛋白质家族,因为这将用于创建训练/测试数据集分割。将此文件下载为包含这些列的 TSV 后,您可以运行以下数据预处理步骤来创建训练/测试数据集分割。

import pandas as pd

# Load the TSV file
file_path = 'PTM/uniprotkb_family_AND_ft_mod_res_AND_pro_2023_10_07.tsv'
data = pd.read_csv(file_path, sep='\t')

# Display the first few rows of the data
data.head()

这应该会打印出类似以下内容:

image/png

import re

def get_ptm_sites(row):
    # Extract the positions of modified residues from the 'Modified residue' column
    modified_positions = [int(i) for i in re.findall(r'MOD_RES (\d+)', row['Modified residue'])]
    
    # Create a list of zeros of length equal to the protein sequence
    ptm_sites = [0] * len(row['Sequence'])
    
    # Replace the zeros with ones at the positions of modified residues
    for position in modified_positions:
        # Subtracting 1 because positions are 1-indexed, but lists are 0-indexed
        ptm_sites[position - 1] = 1
    
    return ptm_sites

# Apply the function to each row in the DataFrame
data['PTM sites'] = data.apply(get_ptm_sites, axis=1)

# Display the first few rows of the updated DataFrame
data.head()

此下一个单元格将把较长的蛋白质序列及其标签分割成长度为 512 或更短的非重叠块,以适应较小的 ESM-2 模型的 1024 上下文窗口。如果您愿意,可以将其调整为更长的长度。大多数蛋白质序列平均约 350 个残基,因此更长的上下文窗口通常不必要,尽管我们观察到 1000 个上下文窗口的性能更好。请记住,这会影响训练时间和批次大小。

# Function to split sequences and PTM sites into chunks
def split_into_chunks(row):
    sequence = row['Sequence']
    ptm_sites = row['PTM sites']
    chunk_size = 512
    
    # Calculate the number of chunks
    num_chunks = (len(sequence) + chunk_size - 1) // chunk_size
    
    # Split sequences and PTM sites into chunks
    sequence_chunks = [sequence[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
    ptm_sites_chunks = [ptm_sites[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
    
    # Create new rows for each chunk
    rows = []
    for i in range(num_chunks):
        new_row = row.copy()
        new_row['Sequence'] = sequence_chunks[i]
        new_row['PTM sites'] = ptm_sites_chunks[i]
        rows.append(new_row)
    
    return rows

# Create a new DataFrame to store the chunks
chunks_data = []

# Iterate through each row of the original DataFrame and split into chunks
for _, row in data.iterrows():
    chunks_data.extend(split_into_chunks(row))

# Convert the list of chunks into a DataFrame
chunks_df = pd.DataFrame(chunks_data)

# Reset the index of the DataFrame
chunks_df.reset_index(drop=True, inplace=True)

# Display the first few rows of the new DataFrame
chunks_df.head()

接下来,我们根据 UniProt 家族创建训练/测试分割。

from tqdm import tqdm
import numpy as np

# Function to split data into train and test based on families
def split_data(df):
    # Get a unique list of protein families
    unique_families = df['Protein families'].unique().tolist()
    np.random.shuffle(unique_families)  # Shuffle the list to randomize the order of families
    
    test_data = []
    test_families = []
    total_entries = len(df)
    total_families = len(unique_families)
    
    # Set up tqdm progress bar
    with tqdm(total=total_families) as pbar:
        for family in unique_families:
            # Separate out all proteins in the current family into the test data
            family_data = df[df['Protein families'] == family]
            test_data.append(family_data)
            
            # Update the list of test families
            test_families.append(family)
            
            # Remove the current family data from the original DataFrame
            df = df[df['Protein families'] != family]
            
            # Calculate the percentage of test data and the percentage of families in the test data
            percent_test_data = sum(len(data) for data in test_data) / total_entries * 100
            percent_test_families = len(test_families) / total_families * 100
            
            # Update tqdm progress bar with readout of percentages
            pbar.set_description(f'% Test Data: {percent_test_data:.2f}% | % Test Families: {percent_test_families:.2f}%')
            pbar.update(1)
            
            # Check if the 20% threshold for test data is crossed
            if percent_test_data >= 20:
                break
    
    # Concatenate the list of test data DataFrames into a single DataFrame
    test_df = pd.concat(test_data, ignore_index=True)
    
    return df, test_df  # Return the remaining data and the test data

# Split the data into train and test based on families
train_df, test_df = split_data(chunks_df)

如果您想在保持训练/测试分割的同时减小数据集的大小,可以将下面的百分比调整到小于 100%

import pandas as pd

# Assuming train_df and test_df are your dataframes
fraction = 1.00  # 100.0%

# Randomly select 100% of the data
reduced_train_df = train_df.sample(frac=fraction, random_state=42)
reduced_test_df = test_df.sample(frac=fraction, random_state=42)
import pickle 

# Extract sequences and PTM site labels from the reduced train and test DataFrames
train_sequences_reduced = reduced_train_df['Sequence'].tolist()
train_labels_reduced = reduced_train_df['PTM sites'].tolist()
test_sequences_reduced = reduced_test_df['Sequence'].tolist()
test_labels_reduced = reduced_test_df['PTM sites'].tolist()

# Save the lists to the specified pickle files
pickle_file_path = "2100K_ptm_data_512/"

with open(pickle_file_path + "train_sequences_chunked_by_family.pkl", "wb") as f:
    pickle.dump(train_sequences_reduced, f)

with open(pickle_file_path + "test_sequences_chunked_by_family.pkl", "wb") as f:
    pickle.dump(test_sequences_reduced, f)

with open(pickle_file_path + "train_labels_chunked_by_family.pkl", "wb") as f:
    pickle.dump(train_labels_reduced, f)

with open(pickle_file_path + "test_labels_chunked_by_family.pkl", "wb") as f:
    pickle.dump(test_labels_reduced, f)

# Return the paths to the saved pickle files
saved_files = [
    pickle_file_path + "train_sequences_chunked_by_family.pkl",
    pickle_file_path + "test_sequences_chunked_by_family.pkl",
    pickle_file_path + "train_labels_chunked_by_family.pkl",
    pickle_file_path + "test_labels_chunked_by_family.pkl"
]
saved_files

训练 QLoRA

导入库和模块

第一个单元格导入必要的库和模块

  • oswandb 用于环境和实验跟踪。
  • numpytorch 用于数值和张量运算。
  • transformersdatasetsaccelerate 中的各种模块,用于处理标记分类和模型加速。
  • peft 用于 PEFT(参数高效微调)配置。
  • pickle 用于加载数据集。
import os
import wandb
import numpy as np
import torch
import torch.nn as nn
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from datasets import Dataset
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import pickle

初始化加速器和 Weights & Biases

第二个单元格设置 Accelerator,用于在可用硬件上高效训练,并初始化 Weights & Biases (W&B) 平台以进行实验跟踪。

# Initialize accelerator and Weights & Biases
accelerator = Accelerator()
os.environ["WANDB_NOTEBOOK_NAME"] = 'qlora_ptm_v2.py'
wandb.init(project='ptm_site_prediction')

辅助函数和数据准备

第三个单元格定义了几个辅助函数

  • print_trainable_parameters:显示可训练参数的数量。
  • save_config_to_txt:将模型配置保存为文本文件。
  • truncate_labels:截断长度超过最大长度的序列的标签。
  • compute_metrics:计算评估指标,如准确率、精确率、召回率、F1 分数、AUC 和 MCC。
  • compute_loss:考虑类别权重的自定义损失计算。
# Helper Functions and Data Preparation
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def save_config_to_txt(config, filename):
    """Save the configuration dictionary to a text file."""
    with open(filename, 'w') as f:
        for key, value in config.items():
            f.write(f"{key}: {value}\n")

def truncate_labels(labels, max_length):
    return [label[:max_length] for label in labels]

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    mcc = matthews_corrcoef(labels, predictions)
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

def compute_loss(model, logits, inputs):
    # logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss(weight=class_weights)
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

加载数据

第四个单元格从 pickle 文件加载训练和测试数据集,确保数据已准备好进行处理和模型训练。

# Load data from pickle files
with open("2100K_ptm_data/train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
    
with open("2100K_ptm_data/test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)

with open("2100K_ptm_data/train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)

with open("2100K_ptm_data/test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

分词

第五个单元格涉及使用 ESM-2 模型的 AutoTokenizer 对蛋白质序列进行分词。此过程将序列转换为适合模型的格式,同时考虑填充、截断和最大序列长度等方面。

# Tokenization
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")

# Set max_sequence_length to the tokenizer's max input length
max_sequence_length = 1024

train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)

# Directly truncate the entire list of labels
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

创建数据集

第六个单元格创建用于训练和测试的 Dataset 对象,其中包含分词数据和相应的标签。

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

计算类别权重

第七个单元格计算类别权重以解决类别不平衡问题,这对于二元分类任务中平衡的训练过程至关重要。由于 PTM 位点比非 PTM 位点少得多,我们需要这样做以确保模型不会只学习预测多数类,从而获得高准确率并止步于此。

# Compute Class Weights
classes = [0, 1]  
flat_train_labels = [label for sublist in train_labels for label in sublist]
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)

定义自定义训练器类

第八个单元格引入了一个自定义 Trainer 类,用于在模型训练期间合并加权损失函数。

# Define Custom Trainer Class
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        logits = outputs.logits
        loss = compute_loss(model, logits, inputs)
        return (loss, outputs) if return_outputs else loss

配置量化设置

第九个单元格设置模型的量化设置,这有助于减小模型大小并提高推理效率。

# Configure the quantization settings
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

无扫描的训练函数

第十个单元格定义了主要的训练函数

  • 设置模型配置并将其记录到 W&B 中。
  • 初始化用于标记分类的 ESM-2 模型,并带有特定标签并应用量化。
  • 准备模型进行 PEFT 和 4 位量化训练。
  • 配置训练参数,如学习率、批次大小、时期等。
  • 初始化自定义的 WeightedTrainer
  • 执行训练过程并保存模型。
def train_function_no_sweeps(train_dataset, test_dataset):
    
    # Directly set the config
    config = {
        "lora_alpha": 1, 
        "lora_dropout": 0.5,
        "lr": 3.701568055793089e-04,
        "lr_scheduler_type": "cosine",
        "max_grad_norm": 0.5,
        "num_train_epochs": 1,
        "per_device_train_batch_size": 36,
        "r": 2,
        "weight_decay": 0.3,
        # Add other hyperparameters as needed
    }

    # Log the config to W&B
    wandb.config.update(config)

    # Save the config to a text file
    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    config_filename = f"esm2_t30_150M_qlora_ptm_config_{timestamp}.txt"
    save_config_to_txt(config, config_filename)
    
        
    model_checkpoint = "facebook/esm2_t30_150M_UR50D"  
    
    # Define labels and model
    id2label = {0: "No ptm site", 1: "ptm site"}
    label2id = {v: k for k, v in id2label.items()}
    
    model = AutoModelForTokenClassification.from_pretrained(
        model_checkpoint,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        quantization_config=bnb_config  # Apply quantization here
    )

    # Prepare the model for 4-bit quantization training
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)
    
    # Convert the model into a PeftModel
    peft_config = LoraConfig(
        task_type=TaskType.TOKEN_CLS,
        inference_mode=False,
        r=config["r"],
        lora_alpha=config["lora_alpha"],
        target_modules=[
            "query",
            "key",
            "value",
            "EsmSelfOutput.dense",
            "EsmIntermediate.dense",
            "EsmOutput.dense",
            "EsmContactPredictionHead.regression",
            "classifier"
        ],
        lora_dropout=config["lora_dropout"],
        bias="none",  # or "all" or "lora_only"
        # modules_to_save=["classifier"]
    )
    model = get_peft_model(model, peft_config)
    print_trainable_parameters(model) # added this in

    # Use the accelerator
    model = accelerator.prepare(model)
    train_dataset = accelerator.prepare(train_dataset)
    test_dataset = accelerator.prepare(test_dataset)

    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Training setup
    training_args = TrainingArguments(
        output_dir=f"esm2_t30_150M_qlora_ptm_sites_{timestamp}",
        learning_rate=config["lr"],
        lr_scheduler_type=config["lr_scheduler_type"],
        gradient_accumulation_steps=1, # changed from 1 to 4
        # warmup_steps=2, # added this in 
        max_grad_norm=config["max_grad_norm"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        per_device_eval_batch_size=config["per_device_train_batch_size"],
        num_train_epochs=config["num_train_epochs"],
        weight_decay=config["weight_decay"],
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=False,
        logging_dir=None,
        logging_first_step=False,
        logging_steps=200,
        save_total_limit=3,
        no_cuda=False,
        seed=8893,
        fp16=True,
        report_to='wandb', 
        optim="paged_adamw_8bit" # added this in 

    )
    
    # Initialize Trainer
    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
        compute_metrics=compute_metrics
    )

    # Train and Save Model
    trainer.train()
    save_path = os.path.join("qlora_ptm_sites", f"best_model_esm2_t30_150M_qlora_{timestamp}")
    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)

主执行

最后一个单元格是训练脚本的入口点,它使用准备好的数据集调用训练函数。

# Call the training function
if __name__ == "__main__":
    train_function_no_sweeps(train_dataset, test_dataset)

结论

本笔记本展示了一种复杂的方法,利用最先进的蛋白质语言模型在生物化学中预测蛋白质序列的翻译后修饰位点。通过微调蛋白质语言模型 ESM-2,本笔记本将使您能够将深度学习集成到蛋白质生物信息学中,为理解蛋白质功能和相互作用的更高级研究铺平道路。一旦您训练好用于预测翻译后修饰的新 ESM-2 模型,请务必将其上传到 Hugging Face 并分享!

要试用此模型的一个版本,请访问 neurosnap,查看他们提供的各种模型,或者访问 Hugging Face 集合 ESM-PTM。您还可以阅读最近发布的应用 LoRA 到蛋白质语言模型的研究 通过参数高效微调实现蛋白质语言模型的民主化,以及 探索蛋白质语言模型的训练后量化

社区

注册登录 发表评论