文档问答
文档问答,也称为文档视觉问答,是一项涉及对文档图像提出的问题提供答案的任务。支持此任务的模型的输入通常是图像和问题的组合,输出是以自然语言表达的答案。这些模型利用多种模态,包括文本、单词的位置(边界框)以及图像本身。
本指南说明如何
- 在 LayoutLMv2 上对 DocVQA 数据集 进行微调。
- 使用您微调后的模型进行推理。
要查看与此任务兼容的所有架构和检查点,我们建议您查看 任务页面
LayoutLMv2 通过在标记的最终隐藏状态之上添加一个问答头来解决文档问答任务,以预测答案的起始和结束标记的位置。换句话说,问题被视为抽取式问答:在给定上下文中,提取哪些信息可以回答问题。上下文来自 OCR 引擎的输出,这里使用的是 Google 的 Tesseract。
在开始之前,请确保您已安装所有必要的库。LayoutLMv2 依赖于 detectron2、torchvision 和 tesseract。
pip install -q transformers datasets
pip install 'git+https://github.com/facebookresearch/detectron2.git'
pip install torchvision
sudo apt install tesseract-ocr pip install -q pytesseract
安装完所有依赖项后,重新启动您的运行时。
我们鼓励您与社区分享您的模型。登录您的 Hugging Face 帐户将其上传到 🤗 Hub。出现提示时,输入您的令牌以登录
>>> from huggingface_hub import notebook_login
>>> notebook_login()
让我们定义一些全局变量。
>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased"
>>> batch_size = 4
加载数据
在本指南中,我们使用了一个可以在 🤗 Hub 上找到的预处理 DocVQA 的小样本。如果您想使用完整的 DocVQA 数据集,您可以在 DocVQA 主页 上注册并下载。如果您这样做,请查看 如何将文件加载到 🤗 数据集中 以继续本指南。
>>> from datasets import load_dataset
>>> dataset = load_dataset("nielsr/docvqa_1200_examples")
>>> dataset
DatasetDict({
train: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 1000
})
test: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 200
})
})
如您所见,数据集已分成训练集和测试集。查看一个随机示例以熟悉其特征。
>>> dataset["train"].features
以下是各个字段的含义
id
:示例的 IDimage
:包含文档图像的 PIL.Image.Image 对象query
:问题字符串 - 用多种语言提出的自然语言问题answers
:人类注释者提供的正确答案列表words
和bounding_boxes
:OCR 的结果,我们这里不会使用它们answer
:由另一个模型匹配的答案,我们这里不会使用它
让我们只保留英文问题,并删除似乎包含其他模型预测的 answer
特征。我们还将从注释者提供的集合中获取第一个答案。或者,您可以随机抽取它。
>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"])
>>> updated_dataset = updated_dataset.map(
... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"]
... )
请注意,我们在此指南中使用的 LayoutLMv2 检查点已使用 max_position_embeddings = 512
进行训练(您可以在 检查点的 config.json
文件 中找到此信息)。我们可以截断示例,但为了避免答案可能位于大型文档末尾并最终被截断的情况,这里我们将删除嵌入可能最终长于 512 的少数示例。如果您的数据集中大多数文档都很长,您可以实现滑动窗口策略 - 查看 此笔记本 以获取详细信息。
>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)
在这一点上,让我们也从该数据集中删除 OCR 特征。这些是为微调不同模型而进行的 OCR 结果。如果我们想使用它们,它们仍然需要一些处理,因为它们与我们在此指南中使用的模型的输入要求不匹配。相反,我们可以对原始数据使用 LayoutLMv2Processor 来进行 OCR 和标记化。这样,我们将获得与模型预期输入匹配的输入。如果您想手动处理图像,请查看 LayoutLMv2
模型文档 以了解模型期望的输入格式。
>>> updated_dataset = updated_dataset.remove_columns("words")
>>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")
最后,如果我们不查看图像示例,数据探索将不完整。
>>> updated_dataset["train"][11]["image"]
预处理数据
文档问答任务是一个多模态任务,您需要确保来自每个模态的输入都根据模型的期望进行预处理。让我们首先加载 LayoutLMv2Processor,它在内部结合了一个可以处理图像数据的图像处理器和一个可以编码文本数据的标记器。
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
预处理文档图像
首先,让我们借助处理器中的 image_processor
为模型准备文档图像。默认情况下,图像处理器将图像大小调整为 224x224,确保它们具有正确的颜色通道顺序,并使用 tesseract 应用 OCR 以获取单词和归一化边界框。在本教程中,所有这些默认设置正是我们需要的。编写一个函数,将默认图像处理应用于一批图像并返回 OCR 的结果。
>>> image_processor = processor.image_processor
>>> def get_ocr_words_and_boxes(examples):
... images = [image.convert("RGB") for image in examples["image"]]
... encoded_inputs = image_processor(images)
... examples["image"] = encoded_inputs.pixel_values
... examples["words"] = encoded_inputs.words
... examples["boxes"] = encoded_inputs.boxes
... return examples
要以快速的方式将此预处理应用于整个数据集,请使用 map。
>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)
预处理文本数据
在对图像应用 OCR 后,我们需要对数据集的文本部分进行编码以将其准备用于模型。这涉及将我们在上一步中获得的单词和框转换为标记级别的 input_ids
、attention_mask
、token_type_ids
和 bbox
。对于预处理文本,我们将需要处理器中的 tokenizer
。
>>> tokenizer = processor.tokenizer
除了上面提到的预处理之外,我们还需要为模型添加标签。对于 🤗 Transformers 中的 xxxForQuestionAnswering
模型,标签由 start_positions
和 end_positions
组成,指示哪个标记是答案的开头,哪个标记是答案的结尾。
让我们从这里开始。定义一个可以找到较大列表(单词列表)中子列表(答案拆分为单词)的辅助函数。
此函数将以两个列表作为输入,words_list
和 answer_list
。然后它将遍历 words_list
并检查 words_list
中的当前单词 (words_list[i]) 是否等于 answer_list 的第一个单词 (answer_list[0]),以及从当前单词开始且与 answer_list
长度相同的 words_list
的子列表是否等于 answer_list
。如果此条件为真,则表示已找到匹配项,并且函数将记录匹配项、其起始索引 (idx) 及其结束索引 (idx + len(answer_list) - 1)。如果找到多个匹配项,则函数将只返回第一个匹配项。如果未找到匹配项,则函数返回 (None
、0 和 0)。
>>> def subfinder(words_list, answer_list):
... matches = []
... start_indices = []
... end_indices = []
... for idx, i in enumerate(range(len(words_list))):
... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:
... matches.append(answer_list)
... start_indices.append(idx)
... end_indices.append(idx + len(answer_list) - 1)
... if matches:
... return matches[0], start_indices[0], end_indices[0]
... else:
... return None, 0, 0
为了说明此函数如何查找答案的位置,让我们在一个示例上使用它
>>> example = dataset_with_ocr["train"][1]
>>> words = [word.lower() for word in example["words"]]
>>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split())
>>> print("Question: ", example["question"])
>>> print("Words:", words)
>>> print("Answer: ", example["answer"])
>>> print("start_index", word_idx_start)
>>> print("end_index", word_idx_end)
Question: Who is in cc in this letter?
Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', '«short', 'cigarette,', 'tobacco', 'section', '30', 'mm.', '«extremely', 'fast', 'buming', 'cigarette.', '«novel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', '«more', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']
Answer: T.F. Riehl
start_index 17
end_index 18
但是,一旦示例被编码,它们将如下所示
>>> encoding = tokenizer(example["question"], example["words"], example["boxes"])
>>> tokenizer.decode(encoding["input_ids"])
[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...
我们需要在编码后的输入中找到答案的位置。
token_type_ids
告诉我们哪些标记是问题的一部分,哪些标记是文档单词的一部分。tokenizer.cls_token_id
将帮助查找输入开头的特殊标记。word_ids
将帮助将原始words
中找到的答案与完整编码输入中的相同答案匹配,并确定答案在编码输入中的起始/结束位置。
考虑到这一点,让我们创建一个函数来编码数据集中的一批示例
>>> def encode_dataset(examples, max_length=512):
... questions = examples["question"]
... words = examples["words"]
... boxes = examples["boxes"]
... answers = examples["answer"]
... # encode the batch of examples and initialize the start_positions and end_positions
... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
... start_positions = []
... end_positions = []
... # loop through the examples in the batch
... for i in range(len(questions)):
... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)
... # find the position of the answer in example's words
... words_example = [word.lower() for word in words[i]]
... answer = answers[i]
... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
... if match:
... # if match is found, use `token_type_ids` to find where words start in the encoding
... token_type_ids = encoding["token_type_ids"][i]
... token_start_index = 0
... while token_type_ids[token_start_index] != 1:
... token_start_index += 1
... token_end_index = len(encoding["input_ids"][i]) - 1
... while token_type_ids[token_end_index] != 1:
... token_end_index -= 1
... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
... start_position = cls_index
... end_position = cls_index
... # loop over word_ids and increase `token_start_index` until it matches the answer position in words
... # once it matches, save the `token_start_index` as the `start_position` of the answer in the encoding
... for id in word_ids:
... if id == word_idx_start:
... start_position = token_start_index
... else:
... token_start_index += 1
... # similarly loop over `word_ids` starting from the end to find the `end_position` of the answer
... for id in word_ids[::-1]:
... if id == word_idx_end:
... end_position = token_end_index
... else:
... token_end_index -= 1
... start_positions.append(start_position)
... end_positions.append(end_position)
... else:
... start_positions.append(cls_index)
... end_positions.append(cls_index)
... encoding["image"] = examples["image"]
... encoding["start_positions"] = start_positions
... encoding["end_positions"] = end_positions
... return encoding
现在我们有了这个预处理函数,我们可以对整个数据集进行编码
>>> encoded_train_dataset = dataset_with_ocr["train"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
... )
>>> encoded_test_dataset = dataset_with_ocr["test"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
... )
让我们检查编码数据集的特征是什么样的
>>> encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
'start_positions': Value(dtype='int64', id=None),
'end_positions': Value(dtype='int64', id=None)}
评估
文档问答的评估需要大量的后处理。为了避免占用您过多的时间,本指南跳过了评估步骤。在训练过程中,Trainer 仍然会计算评估损失,因此您不会完全不清楚模型的性能。抽取式问答通常使用 F1/精确匹配进行评估。如果您想自己实现它,可以查看 Hugging Face 课程的问答章节 以获取灵感。
训练
恭喜!您已成功完成了本指南中最困难的部分,现在您已准备好训练自己的模型。训练包括以下步骤
- 使用 AutoModelForDocumentQuestionAnswering 加载模型,使用与预处理中相同的检查点。
- 在 TrainingArguments 中定义您的训练超参数。
- 定义一个函数将示例批量组合在一起,这里 DefaultDataCollator 就足够了。
- 将训练参数传递给 Trainer 以及模型、数据集和数据整理器。
- 调用 train() 来微调您的模型。
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
在 TrainingArguments 中使用 output_dir
指定保存模型的位置,并根据需要配置超参数。如果您希望与社区共享您的模型,请将 push_to_hub
设置为 True
(您必须登录 Hugging Face 才能上传您的模型)。在这种情况下,output_dir
也会成为推送模型检查点的仓库名称。
>>> from transformers import TrainingArguments
>>> # REPLACE THIS WITH YOUR REPO ID
>>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... eval_strategy="steps",
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
定义一个简单的 数据整理器来将示例批量组合在一起。
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
最后,将所有内容整合在一起,并调用 train()
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=encoded_train_dataset,
... eval_dataset=encoded_test_dataset,
... tokenizer=processor,
... )
>>> trainer.train()
要将最终模型添加到 🤗 Hub,请创建一个模型卡片并调用 push_to_hub
>>> trainer.create_model_card()
>>> trainer.push_to_hub()
推理
现在您已经微调了一个 LayoutLMv2 模型,并将其上传到 🤗 Hub,您可以将其用于推理。尝试使用微调模型进行推理的最简单方法是在 Pipeline 中使用它。
让我们举个例子
>>> example = dataset["test"][2]
>>> question = example["query"]["en"]
>>> image = example["image"]
>>> print(question)
>>> print(example["answers"])
'Who is ‘presiding’ TRRF GENERAL SESSION (PART 1)?'
['TRRF Vice President', 'lee a. waller']
接下来,使用您的模型实例化一个用于文档问答的管道,并将图像+问题组合传递给它。
>>> from transformers import pipeline
>>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> qa_pipeline(image, question)
[{'score': 0.9949808120727539,
'answer': 'Lee A. Waller',
'start': 55,
'end': 57}]
如果您愿意,也可以手动复制管道的结果
- 获取图像和问题,使用模型的处理器为模型准备它们。
- 将预处理的结果转发到模型。
- 模型返回
start_logits
和end_logits
,它们指示答案的开头和结尾是哪个标记。两者都具有形状 (batch_size, sequence_length)。 - 对
start_logits
和end_logits
的最后一个维度取 argmax 以获取预测的start_idx
和end_idx
。 - 使用分词器解码答案。
>>> import torch
>>> from transformers import AutoProcessor
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> with torch.no_grad():
... encoding = processor(image.convert("RGB"), question, return_tensors="pt")
... outputs = model(**encoding)
... start_logits = outputs.start_logits
... end_logits = outputs.end_logits
... predicted_start_idx = start_logits.argmax(-1).item()
... predicted_end_idx = end_logits.argmax(-1).item()
>>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
'lee a. waller'