MatCha
概述
MatCha 由 Fangyu Liu、Francesco Piccinno、Syrine Krichene、Chenxi Pang、Kenton Lee、Mandar Joshi、Yasemin Altun、Nigel Collier 和 Julian Martin Eisenschlos 在论文 MatCha: Enhancing Visual Language Pretraining with Math Reasoning and Chart Derendering 中提出。
论文的摘要如下:
人类世界中普遍存在着诸如图表、图和信息图表之类的视觉语言数据。然而,最先进的视觉语言模型在这些数据上表现不佳。我们提出了 MatCha(数学推理和图表去渲染预训练),以增强视觉语言模型在联合建模图表/图和语言数据方面的能力。具体来说,我们提出了一些预训练任务,涵盖了图表分解和数值推理,这是视觉语言建模的关键能力。我们从 Pix2Struct 开始进行 MatCha 预训练,Pix2Struct 是最近提出的一个图像到文本的视觉语言模型。在 PlotQA 和 ChartQA 等标准基准上,MatCha 模型的性能优于最先进的方法,提高幅度高达 20%。我们还检查了 MatCha 预训练如何迁移到诸如截图、教科书图表和文档图片等领域,并观察到整体改进,验证了 MatCha 预训练在更广泛的视觉语言任务中的实用性。
模型描述
MatCha 是一个使用 `Pix2Struct` 架构训练的模型。您可以在 Pix2Struct 文档 中找到有关 `Pix2Struct` 的更多信息。MatCha 是 `Pix2Struct` 架构的视觉问答子集。它会在图像上渲染输入问题并预测答案。
使用
目前 MatCha 可用 6 个检查点
google/matcha
:基础 MatCha 模型,用于在下游任务中微调 MatChagoogle/matcha-chartqa
:在 ChartQA 数据集上微调的 MatCha 模型。它可用于回答有关图表的问题。google/matcha-plotqa-v1
:在 PlotQA 数据集上微调的 MatCha 模型。它可用于回答有关图的问题。google/matcha-plotqa-v2
:在 PlotQA 数据集上微调的 MatCha 模型。它可用于回答有关图的问题。google/matcha-chart2text-statista
:在 Statista 数据集上微调的 MatCha 模型。google/matcha-chart2text-pew
:在 Pew 数据集上微调的 MatCha 模型。
在 `chart2text-pew` 和 `chart2text-statista` 上微调的模型更适合进行摘要,而在 `plotqa` 和 `chartqa` 上微调的模型更适合进行问答。
您可以按照以下方式使用这些模型(在 ChatQA 数据集上的示例):
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import requests
from PIL import Image
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(0)
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")
url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/20294671002019.png"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, text="Is the sum of all 4 places greater than Laos?", return_tensors="pt").to(0)
predictions = model.generate(**inputs, max_new_tokens=512)
print(processor.decode(predictions[0], skip_special_tokens=True))
微调
要微调 MatCha,请参考 pix2struct 微调笔记本。对于 `Pix2Struct` 模型,我们发现使用 Adafactor 和余弦学习率调度程序微调模型会导致更快的收敛。
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=0.01, weight_decay=1e-05)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=40000)
MatCha 是一个使用 `Pix2Struct` 架构训练的模型。您可以在 Pix2Struct 文档 中找到有关 `Pix2Struct` 的更多信息。