Transformers 文档

MatCha

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

MatCha

PyTorch

概述

MatCha 在论文 MatCha: Enhancing Visual Language Pretraining with Math Reasoning and Chart Derendering 中提出,作者为 Fangyu Liu, Francesco Piccinno, Syrine Krichene, Chenxi Pang, Kenton Lee, Mandar Joshi, Yasemin Altun, Nigel Collier, Julian Martin Eisenschlos。

该论文的摘要如下:

视觉语言数据(如图表和信息图)在人类世界中无处不在。然而,最先进的视觉语言模型在这些数据上的表现不佳。我们提出了 MatCha(数学推理和图表去渲染预训练),以增强视觉语言模型在联合建模图表和语言数据方面的能力。具体来说,我们提出了几个预训练任务,涵盖图表解构和数值推理,这是视觉语言建模的关键能力。我们从最近提出的图像到文本视觉语言模型 Pix2Struct 开始执行 MatCha 预训练。在 PlotQA 和 ChartQA 等标准基准测试中,MatCha 模型优于最先进的方法,提升幅度高达近 20%。我们还研究了 MatCha 预训练如何迁移到屏幕截图、教科书图表和文档图形等领域,并观察到总体改进,验证了 MatCha 预训练在更广泛的视觉语言任务中的实用性。

模型描述

MatCha 是一个使用 Pix2Struct 架构训练的模型。您可以在 Pix2Struct 文档中找到有关 Pix2Struct 的更多信息。MatCha 是 Pix2Struct 架构的视觉问答子集。它在图像上渲染输入问题并预测答案。

使用方法

目前有 6 个 MatCha 的检查点可用:

  • google/matcha:基础 MatCha 模型,用于在下游任务上微调 MatCha
  • google/matcha-chartqa:在 ChartQA 数据集上微调的 MatCha 模型。它可用于回答有关图表的问题。
  • google/matcha-plotqa-v1:在 PlotQA 数据集上微调的 MatCha 模型。它可用于回答有关绘图的问题。
  • google/matcha-plotqa-v2:在 PlotQA 数据集上微调的 MatCha 模型。它可用于回答有关绘图的问题。
  • google/matcha-chart2text-statista:在 Statista 数据集上微调的 MatCha 模型。
  • google/matcha-chart2text-pew:在 Pew 数据集上微调的 MatCha 模型。

chart2text-pewchart2text-statista 上微调的模型更适合用于摘要,而在 plotqachartqa 上微调的模型更适合用于问答。

您可以按如下方式使用这些模型(以 ChatQA 数据集为例):

from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import requests
from PIL import Image

model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(0)
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")
url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, text="Is the sum of all 4 places greater than Laos?", return_tensors="pt").to(0)
predictions = model.generate(**inputs, max_new_tokens=512)
print(processor.decode(predictions[0], skip_special_tokens=True))

微调

要微调 MatCha,请参考 pix2struct 微调 notebook。对于 Pix2Struct 模型,我们发现使用 Adafactor 和余弦学习率调度器微调模型可以更快地收敛。

from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup

optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=40000)

MatCha 是一个使用 Pix2Struct 架构训练的模型。您可以在 Pix2Struct 文档中找到有关 Pix2Struct 的更多信息。

< > 在 GitHub 上更新