Transformers 文档
文档问答
并获得增强的文档体验
开始使用
文档问答
文档问答,也称为文档视觉问答,是一项涉及回答关于文档图像的问题的任务。支持此任务的模型的输入通常是图像和问题的组合,输出是自然语言表达的答案。这些模型利用多种模态,包括文本、单词位置(边界框)和图像本身。
本指南演示了如何
- 在 DocVQA 数据集 上微调 LayoutLMv2。
- 使用您的微调模型进行推理。
要查看与此任务兼容的所有架构和检查点,我们建议查看任务页面
LayoutLMv2 通过在令牌的最终隐藏状态之上添加一个问答头来解决文档问答任务,以预测答案的起始和结束令牌的位置。换句话说,该问题被视为抽取式问答:给定上下文,抽取哪段信息回答了问题。上下文来自 OCR 引擎的输出,这里是 Google 的 Tesseract。
开始之前,请确保已安装所有必需的库。LayoutLMv2 依赖于 detectron2、torchvision 和 tesseract。
pip install -q transformers datasets
pip install 'git+https://github.com/facebookresearch/detectron2.git'
pip install torchvision
sudo apt install tesseract-ocr pip install -q pytesseract
安装所有依赖项后,重启您的运行时。
我们鼓励您与社区分享您的模型。登录您的 Hugging Face 帐户以将其上传到 🤗 Hub。出现提示时,输入您的令牌登录
>>> from huggingface_hub import notebook_login
>>> notebook_login()
让我们定义一些全局变量。
>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased"
>>> batch_size = 4
加载数据
在本指南中,我们使用了一个经过预处理的 DocVQA 小样本,您可以在 🤗 Hub 上找到它。如果您想使用完整的 DocVQA 数据集,您可以在 DocVQA 主页上注册并下载。如果这样做,要继续本指南,请查看 如何将文件加载到 🤗 数据集中。
>>> from datasets import load_dataset
>>> dataset = load_dataset("nielsr/docvqa_1200_examples")
>>> dataset
DatasetDict({
train: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 1000
})
test: Dataset({
features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],
num_rows: 200
})
})
如您所见,数据集已分为训练集和测试集。查看一个随机示例以熟悉其特征。
>>> dataset["train"].features
以下是各个字段的含义
id
: 示例的 IDimage
: 包含文档图像的 PIL.Image.Image 对象query
: 问题字符串 - 自然语言提问,支持多种语言answers
: 人工标注者提供的一系列正确答案words
和bounding_boxes
: OCR 结果,我们在此不使用answer
: 由另一个模型匹配的答案,我们在此不使用
我们只保留英语问题,并删除包含另一个模型预测的 `answer` 特征。我们还将从标注者提供的一组答案中取第一个。或者,您可以随机采样。
>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"])
>>> updated_dataset = updated_dataset.map(
... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"]
... )
请注意,本指南中使用的 LayoutLMv2 检查点已使用 `max_position_embeddings = 512` 进行训练(您可以在检查点的 `config.json` 文件中找到此信息)。我们可以截断示例,但为了避免答案可能位于大文档末尾并最终被截断的情况,这里我们将删除少数嵌入长度可能超过 512 的示例。如果您的数据集中大多数文档都很长,您可以实施滑动窗口策略 - 有关详细信息,请查看此 Notebook。
>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)
此时,我们还将从该数据集中删除 OCR 功能。这些功能是为微调不同模型而进行的 OCR 结果。如果我们要使用它们,它们仍然需要一些处理,因为它们与本指南中使用的模型的输入要求不匹配。相反,我们可以对原始数据使用 LayoutLMv2Processor 进行 OCR 和分词。这样,我们将获得与模型预期输入匹配的输入。如果您想手动处理图像,请查看 `LayoutLMv2` 模型文档,了解模型期望的输入格式。
>>> updated_dataset = updated_dataset.remove_columns("words")
>>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")
最后,如果不查看图像示例,数据探索就不完整。
>>> updated_dataset["train"][11]["image"]

预处理数据
文档问答任务是一个多模态任务,您需要确保每个模态的输入都根据模型的预期进行预处理。让我们首先加载 LayoutLMv2Processor,它内部结合了一个可以处理图像数据的图像处理器和一个可以编码文本数据的分词器。
>>> from transformers import AutoProcessor
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
预处理文档图像
首先,让我们借助处理器中的 `image_processor` 为模型准备文档图像。默认情况下,图像处理器会将图像大小调整为 224x224,确保它们具有正确的颜色通道顺序,使用 Tesseract 应用 OCR 以获取单词和标准化边界框。在本教程中,所有这些默认设置都正是我们需要的。编写一个函数,将默认图像处理应用于一批图像并返回 OCR 结果。
>>> image_processor = processor.image_processor
>>> def get_ocr_words_and_boxes(examples):
... images = [image.convert("RGB") for image in examples["image"]]
... encoded_inputs = image_processor(images)
... examples["image"] = encoded_inputs.pixel_values
... examples["words"] = encoded_inputs.words
... examples["boxes"] = encoded_inputs.boxes
... return examples
要快速将此预处理应用于整个数据集,请使用 map。
>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)
预处理文本数据
对图像应用 OCR 后,我们需要对数据集的文本部分进行编码,以准备它用于模型。这涉及将上一步中获得的单词和框转换为令牌级别的 `input_ids`、`attention_mask`、`token_type_ids` 和 `bbox`。对于文本预处理,我们需要处理器中的 `tokenizer`。
>>> tokenizer = processor.tokenizer
除了上面提到的预处理之外,我们还需要为模型添加标签。对于 🤗 Transformers 中的 `xxxForQuestionAnswering` 模型,标签由 `start_positions` 和 `end_positions` 组成,指示哪个令牌是答案的开始,哪个令牌是答案的结束。
让我们从这里开始。定义一个辅助函数,该函数可以在较大的列表(单词列表)中查找子列表(答案分成单词)。
该函数将接收两个列表作为输入,`words_list` 和 `answer_list`。然后,它将遍历 `words_list` 并检查 `words_list` 中的当前单词 (words_list[i]) 是否等于 `answer_list` 的第一个单词 (answer_list[0]),以及从当前单词开始的 `words_list` 子列表(与 `answer_list` 长度相同)是否等于 `answer_list`。如果此条件为真,则表示已找到匹配项,函数将记录该匹配项、其起始索引 (idx) 和其结束索引 (idx + len(answer_list) - 1)。如果找到多个匹配项,函数将仅返回第一个。如果未找到匹配项,函数将返回 (`None`, 0, 和 0)。
>>> def subfinder(words_list, answer_list):
... matches = []
... start_indices = []
... end_indices = []
... for idx, i in enumerate(range(len(words_list))):
... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:
... matches.append(answer_list)
... start_indices.append(idx)
... end_indices.append(idx + len(answer_list) - 1)
... if matches:
... return matches[0], start_indices[0], end_indices[0]
... else:
... return None, 0, 0
为了说明此函数如何找到答案的位置,我们以一个示例为例
>>> example = dataset_with_ocr["train"][1]
>>> words = [word.lower() for word in example["words"]]
>>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split())
>>> print("Question: ", example["question"])
>>> print("Words:", words)
>>> print("Answer: ", example["answer"])
>>> print("start_index", word_idx_start)
>>> print("end_index", word_idx_end)
Question: Who is in cc in this letter?
Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', '«short', 'cigarette,', 'tobacco', 'section', '30', 'mm.', '«extremely', 'fast', 'buming', 'cigarette.', '«novel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', '«more', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']
Answer: T.F. Riehl
start_index 17
end_index 18
然而,一旦示例被编码,它们将看起来像这样
>>> encoding = tokenizer(example["question"], example["words"], example["boxes"])
>>> tokenizer.decode(encoding["input_ids"])
[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...
我们需要在编码后的输入中找到答案的位置。
- `token_type_ids` 告诉我们哪些 token 属于问题,哪些属于文档中的单词。
- `tokenizer.cls_token_id` 将有助于找到输入开头处的特殊 token。
- `word_ids` 将有助于将原始 `words` 中找到的答案与完整编码输入中的相同答案进行匹配,并确定答案在编码输入中的起始/结束位置。
考虑到这一点,让我们创建一个函数来编码数据集中的一批示例
>>> def encode_dataset(examples, max_length=512):
... questions = examples["question"]
... words = examples["words"]
... boxes = examples["boxes"]
... answers = examples["answer"]
... # encode the batch of examples and initialize the start_positions and end_positions
... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True)
... start_positions = []
... end_positions = []
... # loop through the examples in the batch
... for i in range(len(questions)):
... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id)
... # find the position of the answer in example's words
... words_example = [word.lower() for word in words[i]]
... answer = answers[i]
... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())
... if match:
... # if match is found, use `token_type_ids` to find where words start in the encoding
... token_type_ids = encoding["token_type_ids"][i]
... token_start_index = 0
... while token_type_ids[token_start_index] != 1:
... token_start_index += 1
... token_end_index = len(encoding["input_ids"][i]) - 1
... while token_type_ids[token_end_index] != 1:
... token_end_index -= 1
... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]
... start_position = cls_index
... end_position = cls_index
... # loop over word_ids and increase `token_start_index` until it matches the answer position in words
... # once it matches, save the `token_start_index` as the `start_position` of the answer in the encoding
... for id in word_ids:
... if id == word_idx_start:
... start_position = token_start_index
... else:
... token_start_index += 1
... # similarly loop over `word_ids` starting from the end to find the `end_position` of the answer
... for id in word_ids[::-1]:
... if id == word_idx_end:
... end_position = token_end_index
... else:
... token_end_index -= 1
... start_positions.append(start_position)
... end_positions.append(end_position)
... else:
... start_positions.append(cls_index)
... end_positions.append(cls_index)
... encoding["image"] = examples["image"]
... encoding["start_positions"] = start_positions
... encoding["end_positions"] = end_positions
... return encoding
现在我们有了这个预处理函数,我们可以对整个数据集进行编码
>>> encoded_train_dataset = dataset_with_ocr["train"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names
... )
>>> encoded_test_dataset = dataset_with_ocr["test"].map(
... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names
... )
让我们看看编码后的数据集特征是什么样的
>>> encoded_train_dataset.features
{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
'start_positions': Value(dtype='int64', id=None),
'end_positions': Value(dtype='int64', id=None)}
评估
文档问答的评估需要大量的后处理。为了避免占用您太多时间,本指南跳过了评估步骤。 Trainer 仍然会在训练期间计算评估损失,因此您不会完全不知道模型的性能。抽取式问答通常使用 F1/精确匹配进行评估。如果您想自己实现,请查看 Hugging Face 课程的 问答章节 以获取灵感。
训练
恭喜您!您已经成功完成了本指南中最艰难的部分,现在可以训练自己的模型了。训练包括以下步骤
- 使用与预处理中相同的检查点,通过 AutoModelForDocumentQuestionAnswering 加载模型。
- 在 TrainingArguments 中定义您的训练超参数。
- 定义一个函数来批量处理示例,这里 DefaultDataCollator 就足够了
- 将训练参数与模型、数据集和数据整理器一起传递给 Trainer。
- 调用 train() 来微调您的模型。
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
在 TrainingArguments 中,使用 `output_dir` 指定保存模型的位置,并根据需要配置超参数。如果您希望与社区共享模型,请将 `push_to_hub` 设置为 `True`(您必须登录 Hugging Face 才能上传模型)。在这种情况下,`output_dir` 也将是您的模型检查点将被推送到的仓库名称。
>>> from transformers import TrainingArguments
>>> # REPLACE THIS WITH YOUR REPO ID
>>> repo_id = "MariaK/layoutlmv2-base-uncased_finetuned_docvqa"
>>> training_args = TrainingArguments(
... output_dir=repo_id,
... per_device_train_batch_size=4,
... num_train_epochs=20,
... save_steps=200,
... logging_steps=50,
... eval_strategy="steps",
... learning_rate=5e-5,
... save_total_limit=2,
... remove_unused_columns=False,
... push_to_hub=True,
... )
定义一个简单的数据整理器以将示例批处理在一起。
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
最后,将所有内容整合在一起,并调用 train()
>>> from transformers import Trainer
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=encoded_train_dataset,
... eval_dataset=encoded_test_dataset,
... processing_class=processor,
... )
>>> trainer.train()
要将最终模型添加到 🤗 Hub,请创建一个模型卡并调用 `push_to_hub`
>>> trainer.create_model_card()
>>> trainer.push_to_hub()
推理
现在您已经微调了 LayoutLMv2 模型并将其上传到 🤗 Hub,您可以将其用于推理。尝试微调模型进行推理的最简单方法是在 Pipeline 中使用它。
举个例子
>>> example = dataset["test"][2]
>>> question = example["query"]["en"]
>>> image = example["image"]
>>> print(question)
>>> print(example["answers"])
'Who is ‘presiding’ TRRF GENERAL SESSION (PART 1)?'
['TRRF Vice President', 'lee a. waller']
接下来,使用您的模型实例化一个用于文档问答的管道,并将图像 + 问题组合传递给它。
>>> from transformers import pipeline
>>> qa_pipeline = pipeline("document-question-answering", model="MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> qa_pipeline(image, question)
[{'score': 0.9949808120727539,
'answer': 'Lee A. Waller',
'start': 55,
'end': 57}]
您也可以手动重现管道的结果,如果您愿意的话
- 选取图像和问题,使用模型中的处理器准备它们,以便模型使用。
- 通过模型转发预处理的结果。
- 模型返回 `start_logits` 和 `end_logits`,它们指示答案的起始 token 和结束 token。两者形状均为 (batch_size, sequence_length)。
- 对 `start_logits` 和 `end_logits` 的最后一个维度进行 argmax,以获得预测的 `start_idx` 和 `end_idx`。
- 用分词器解码答案。
>>> import torch
>>> from transformers import AutoProcessor
>>> from transformers import AutoModelForDocumentQuestionAnswering
>>> processor = AutoProcessor.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("MariaK/layoutlmv2-base-uncased_finetuned_docvqa")
>>> with torch.no_grad():
... encoding = processor(image.convert("RGB"), question, return_tensors="pt")
... outputs = model(**encoding)
... start_logits = outputs.start_logits
... end_logits = outputs.end_logits
... predicted_start_idx = start_logits.argmax(-1).item()
... predicted_end_idx = end_logits.argmax(-1).item()
>>> processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
'lee a. waller'