Arc 虚拟细胞挑战赛:入门指南
Arc 研究所 (Arc Institute) 最近推出了虚拟细胞挑战赛。参赛者需要训练一个模型,该模型能够预测在(部分)未见过的细胞类型中沉默某个基因的效果,他们将这项任务称为情境泛化。对于几乎没有生物学背景的机器学习工程师来说,其中的术语和所需的背景知识可能相当令人生畏。为了鼓励参与,我们以更适合其他学科工程师理解的形式重新阐述了这项挑战。
目标
训练一个模型,以预测使用 CRISPR 技术沉默一个基因对细胞产生的影响。
在原子世界中进行实验成本高昂、耗时费力且容易出错。如果我们能够在不接触培养皿的情况下测试数千种候选药物,那会怎么样?这就是虚拟细胞挑战赛的目标——建立一个模型(很可能是一个神经网络),能够精确模拟当我们改变某个参数时细胞会发生什么。鉴于缩短反馈循环通常是加速进展的最佳方式,一个能够准确完成这项任务的模型将产生重大影响。
要训练这个神经网络,我们需要数据。针对这项挑战,Arc 整理了一个包含约 30 万个单细胞 RNA 测序图谱的数据集。在继续之前,重温中心法则可能会有所帮助。本文将以此为基础,提供你参加挑战赛所需的最少生物学知识。
训练数据
训练集由一个稀疏矩阵和一些相关的元数据组成。具体来说,我们有 22 万个细胞,每个细胞都有一个转录组。这个转录组是一个稀疏行向量,其中每个条目是相应基因(我们的列)编码的 RNA 分子(转录本)的原始计数。在这 22 万个细胞中,约有 3.8 万个是未扰动的,意味着没有基因被 CRISPR 技术沉默。这些对照细胞至关重要,我们很快就会看到原因。
为了更具体地理解数据集,让我们选择一个基因 TMSB4X(数据集中被沉默最频繁的基因),并比较一个对照细胞和一个受扰动细胞中检测到的 RNA 分子数量。
我们可以看到,与对照细胞相比,TMSB4X 基因被沉默的细胞的转录本数量大大减少了。
为挑战赛建模
敏锐的读者可能会想,为什么不直接测量基因沉默前后的 RNA 分子数量——为什么我们还需要对照细胞呢?不幸的是,读取转录组会破坏细胞,这个问题让人想起了观察者效应。
这种无法在扰动前后测量细胞状态的情况带来了许多问题,因为我们被迫使用一群基线(也称为对照、未扰动)细胞作为参考点。即使在扰动之前,对照细胞和受扰动细胞也并非完全同质。这意味着我们现在必须从异质性引起的噪声中分离出我们的真实信号,即扰动的影响。
更正式地,我们可以将被扰动细胞中观察到的基因表达建模为
其中
- : 在受到扰动 的细胞中观察到的基因表达测量值。
- : 未受扰动的基线细胞群体的分布。
- : 扰动 对该群体造成的真实效应。
- : 基线群体的生物学异质性。
- : 特定于实验的技术噪声,假定其独立于未受扰动的细胞状态和 。
STATE:Arc 提供的基线模型
在虚拟细胞挑战赛之前,Arc 发布了 STATE,这是他们自己尝试用一对基于 Transformer 的模型来解决这个挑战。这为参赛者提供了一个强大的起点,因此我们将详细探讨它。
STATE 由两个模型组成:状态转换模型 (State Transition Model, ST) 和状态嵌入模型 (State Embedding Model, SE)。SE 旨在为细胞生成丰富的语义嵌入,以提高跨细胞类型的泛化能力。ST 是“细胞模拟器”,它接收对照细胞的转录组或由 SE 生成的细胞嵌入,以及一个表示目标扰动的独热编码向量,然后输出受扰动后的转录组。
状态转换模型 (ST)
状态转换模型是一个相对简单的 Transformer,它使用 Llama 骨干网络,处理以下输入:
- 一组协变量匹配的基线细胞的转录组(或 SE 嵌入)。
- 一组表示每个细胞基因扰动的独热向量。
使用一组协变量匹配的对照细胞与配对的目标细胞,应有助于模型识别出我们预期扰动的实际效果。对照组张量和扰动张量都通过独立的编码器(即带有 GELU 激活函数的 4 层 MLP)进行处理。如果直接在基因表达空间中工作(即生成完整的转录组),它们会将输出通过一个学习到的解码器。
ST 使用最大均值差异 (Maximum Mean Discrepancy) 进行训练。简而言之,模型学习最小化两个概率分布之间的差异。
状态嵌入模型 (SE)
状态嵌入模型是一个类似 BERT 的自编码器。要更深入地理解这一点,我们首先需要补充一些生物学基础知识。
一点生物学知识补充
一个基因由外显子(蛋白质编码区)和内含子(非蛋白质编码区)组成。DNA 首先被转录成前体 mRNA,如上图所示。然后细胞进行选择性剪接。这基本上是“挑选外显子”,并剪掉所有内含子。你可以把基因想象成一份制作桌子的宜家说明书。通过省略一些部件,人们也可以造出一张三条腿的桌子,或者稍加努力还能造出一个奇怪的书架。这些不同的物体就类似于蛋白质异构体,即由同一个基因编码的不同蛋白质。
回到模型
有了这个基本理解,我们就可以继续讨论 SE 模型的工作原理了。记住,我们 SE 的核心目标是创建有意义的细胞嵌入。要做到这一点,我们必须首先创建有意义的基因嵌入。
为了生成单个基因嵌入,我们首先获取该基因编码的所有不同蛋白质异构体的氨基酸序列(例如,TMSB4X 的序列为 ...)。然后,我们将这些序列输入到 ESM2,这是一个来自 FAIR 的 150 亿参数的蛋白质语言模型。ESM 为每个氨基酸生成一个嵌入,我们将它们进行均值池化,以获得一个“转录本”(即蛋白质异构体)的嵌入。
现在我们有了所有这些蛋白质异构体的嵌入,我们只需对它们进行均值池化,就能得到基因嵌入。接下来,我们使用一个学习到的编码器将这些基因嵌入投影到我们的模型维度,如下所示:
我们现在已经得到了一个基因嵌入,但我们真正想要的是一个细胞嵌入。为此,Arc 将每个细胞表示为按对数倍数表达水平排名的前 2048 个基因。
然后,我们从我们的 2048 个基因嵌入中构建一个“细胞句子”,如下所示:
我们在句子中添加一个 标记和 标记。 标记最终被用作我们的“细胞嵌入”(非常像 BERT),而 标记用于“解耦特定于数据集的影响”。虽然基因是按对数倍数表达水平排序的,但 Arc 通过一种类似于位置嵌入的方式将转录组信息融入,从而进一步强调每个基因的表达强度。通过一个奇特的“软分箱”算法和两个 MLP,他们创建了一些“表达编码”,然后将这些编码添加到每个基因嵌入中。这应该会根据基因在转录组中的表达强度来调节每个基因嵌入的量级。
为了训练模型,他们对每个细胞遮蔽 1280 个基因,模型的任务是预测它们。这 1280 个基因的选择标准是使其具有广泛的表达强度范围。对于喜欢图形化解释的读者,下图展示了细胞句子的构建过程。
评估
了解你的提交将如何被评估是成功的关键。Arc 选择的 3 个评估指标是扰动判别、差异表达和平均绝对误差。鉴于平均绝对误差很简单,且顾名思义,我们将在分析中省略它。
扰动判别
扰动判别旨在评估你的模型在揭示不同扰动之间的相对差异方面的表现。为此,我们计算测试集中所有测得的受扰动转录组(我们试图预测的真实值,)和所有其他受扰动转录组()与我们预测的转录组 之间的曼哈顿距离。然后,我们按照以下方式对真实值相对于所有转录组的位置进行排名:
之后,我们用转录组的总数进行归一化:
其中 将是一个完美的匹配。你的预测的总体得分是所有 $$\text{PDisc}_t$$ 的平均值。然后将其归一化为:
我们乘以 2 是因为对于随机预测,大约一半的结果会更近,一半会更远。
差异表达
差异表达旨在评估你正确识别出真正受影响的基因中有多大比例被显著影响。首先,对每个基因使用带并列校正的 Wilcoxon 秩和检验来计算一个 值。我们对我们预测的扰动分布和真实扰动分布都进行此操作。
接下来,我们应用 Benjamini-Hochberg 程序,基本上是一些统计方法来调整 值,因为对于 个基因和一个 值阈值为 的情况下,你预计会有 个假阳性。我们将我们预测的差异表达基因集表示为 ,真实差异表达基因集表示为 。
如果我们的集合大小小于真实集合大小,则取集合的交集,然后除以真实差异表达基因的数量,如下所示:
如果我们的预测集大小大于真实集大小,我们选择预测中差异表达最显著的子集(我们“最确信”的预测,记作 ),取其与真实集的交集,然后除以真实集的数量。
对所有预测的扰动重复此操作,并取平均值以获得最终分数。
结论
如果这项挑战激起了您的兴趣,该如何开始呢?幸运的是,Arc 提供了一个 Colab 笔记,其中详细介绍了训练其 STATE 模型的全过程。此外,STATE 很快将登陆 transformers
库,因此,使用他们的预训练模型将会非常简单,只需
import torch
from transformers import StateEmbeddingModel
model_name = "arcinstitute/SE-600M"
model = StateEmbeddingModel.from_pretrained(model_name)
input_ids = torch.randn((1, 1, 5120), dtype=torch.float32)
mask = torch.ones((1, 1, 5120), dtype=torch.bool)
mask[:, :, 2560:] = False
outputs = model(input_ids, mask)
祝所有参赛者好运!
本文最初发布于此。