Wang Peng、Cheng Da 和 Cong Yao 在《对场景文本识别的多粒度预测》一文中提出了 MGP-STR 模型,该文链接为 https://arxiv.org/abs/2209.03592。MGP-STR 是一个概念上 简单 但 强大 的视觉场景文本识别(STR)模型,它建立在 视觉 Transformer (ViT) 之上。为了整合语言知识,提出了一种多粒度预测(MGP)策略,以隐式方式将来自语言模态的信息注入模型。
该论文的摘要如下
场景文本识别(STR)一直是计算机视觉中的一个活跃的研究领域。为了解决这个具有挑战性的问题,已经提出了许多创新的方法,将语言知识整合到 STR 模型中已经成为一个显著的趋势。在这项工作中,我们首先从近期视觉 Transformer (ViT) 的进展中汲取灵感,构建了一个概念上简单但强大的视觉 STR 模型,该模型建立在 ViT 之上,并且在场景文本识别中优于以前的最佳模型,包括纯视觉模型和语言增强方法。为了整合语言知识,我们进一步提出了一种多粒度预测策略,以隐式方式将来自语言模态的信息注入模型,即在输出空间中引入了 NLP 中广泛使用的子词表示(BPE 和 WordPiece),除了传统的字符级表示外,而不采用独立的语言模型(LM)。结果算法(称为 MGP-STR)能够将 STR 性能推向更高的水平。具体来说,它在标准基准测试上实现了 93.35% 的平均识别准确率。
MGP-STR 架构图。源自 原始论文。MGP-STR 在两个合成数据集上进行了训练,即 MJSynth (MJ) 和 SynthText (ST),而没有在其他数据集上进行微调。在六个标准拉丁场景文本基准测试中实现了最先进的成果,包括 3 个常规文本数据集(IC13、SVT、IIIT)和 3 个不规则数据集(IC15、SVTP、CUTE)。此模型由 yuekun 贡献。原始代码可以在 此处 找到。
推理示例
MgpstrModel 接受图像作为输入,并生成三种类型的预测,代表不同粒度的文本信息。这三种预测类型融合后给出最终预测结果。
ViTImageProcessor 类负责预处理输入图像,MgpstrTokenizer 将生成的字符标记解码为目标字符串。每个标记解码为原来的字符串,MgpstrProcessor 将 ViTImageProcessor 和 MgpstrTokenizer 包装到单个实例中,以同时提取输入特征和解码预测的标记 id。
- 逐步光学字符识别(OCR)
>>> from transformers import MgpstrProcessor, MgpstrForSceneTextRecognition
>>> import requests
>>> from PIL import Image
>>> processor = MgpstrProcessor.from_pretrained('alibaba-damo/mgp-str-base')
>>> model = MgpstrForSceneTextRecognition.from_pretrained('alibaba-damo/mgp-str-base')
>>> # load image from the IIIT-5k dataset
>>> url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(pixel_values)
>>> generated_text = processor.batch_decode(outputs.logits)['generated_text']
MgpstrConfig
类 transformers.MgpstrConfig
< 源代码 >( image_size = [32, 128] patch_size = 4 num_channels = 3 max_token_length = 27 num_character_labels = 38 num_bpe_labels = 50257 num_wordpiece_labels = 30522 hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 mlp_ratio = 4.0 qkv_bias = True distilled = False layer_norm_eps = 1e-05 drop_rate = 0.0 attn_drop_rate = 0.0 drop_path_rate = 0.0 output_a3_attentions = False initializer_range = 0.02 **kwargs )
参数
- image_size (
List[int]
, 可选, 默认值为[32, 128]
) — 每张图像的大小(分辨率)。 - patch_size (
int
, Optional, 默认为 4) — 每个补丁的大小(分辨率)。 - num_channels (
int
, Optional, 默认为 3) — 输入通道数。 - max_token_length (
int
, Optional, 默认为 27) — 最大输出标记数。 - num_character_labels (
int
, 可选, 默认值为 38) — 字符首的类别数。 - num_bpe_labels (
int
, 可选, 默认值为 50257) — bpe 首的类别数。 - num_wordpiece_labels (
int
, 可选, 默认值为 30522) — 词元首的类别数。 - hidden_size (
int
, 可选, 默认为768) — 嵌入维度。 - num_hidden_layers (
int
, 可选, 默认为12) — Transformer编码器中的隐藏层数量。 - num_attention_heads (
int
, 可选, 默认为12) — Transformer编码器中每个注意层中的注意力头数。 - mlp_ratio (
float
, 可选, 默认为4.0) — mlp隐藏维度与嵌入维度的比例。 - qkv_bias (
bool
, 可选,默认为True
) — 是否向查询、键和值添加偏置。 - distilled (
bool
, 可选,默认为False
) — 模型包含与DeiT模型类似的蒸馏标记和头部。 - layer_norm_eps (
float
, 可选,默认为 1e-05) — 层归一化层使用的epsilon。 - drop_rate (
float
,可选,默认为0.0)— 在嵌入、编码器中所有全连接层的dropout概率。 - attn_drop_rate (
float
,可选,默认为0.0)— 注意力概率的dropout比例。 - drop_path_rate (
float
,可选,默认为0.0)— 随机深度率。 - output_a3_attentions (
bool
,可选,默认为False
)— 模型是否应该返回A^3模块的注意力。 - initializer_range (
float
, optional, defaults to 0.02) — 用于初始化所有权重矩阵的截断正态初始化器的标准差。
这是一个存储MgpstrModel配置的配置类。它根据指定的参数实例化一个MGP-STR模型,定义模型架构。使用默认值实例化配置将产生类似于MGP-STR alibaba-damo/mgp-str-base架构的配置。
配置对象继承自PretrainedConfig,并可用于控制模型输出。更多详细信息,请参阅PretrainedConfig的文档。
示例
>>> from transformers import MgpstrConfig, MgpstrForSceneTextRecognition
>>> # Initializing a Mgpstr mgp-str-base style configuration
>>> configuration = MgpstrConfig()
>>> # Initializing a model (with random weights) from the mgp-str-base style configuration
>>> model = MgpstrForSceneTextRecognition(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
MgpstrTokenizer
类 transformers.MgpstrTokenizer
< 源代码 >( vocab_file unk_token = '[GO]' bos_token = '[GO]' eos_token = '[s]' pad_token = '[GO]' **kwargs )
构建一个MGP-STR字符标记器。
这个标记器继承自PreTrainedTokenizer,其中包含大多数主要方法。用户应参阅这个父类以了解更多关于这些方法的信息。
MgpstrProcessor
类 transformers.MgpstrProcessor
< 源代码 >( image_processor = None tokenizer = None **kwargs )
参数
- image_processor (
ViTImageProcessor
, 可选) — 一个ViTImageProcessor
实例。图像处理器是必要输入。 - 分词器 (MgpstrTokenizer, 可选) — 分词器是必需的输入。
构建一个 MGP-STR 处理器,该处理器将图像处理器和 MGP-STR 分词器包装到一个组件中。
MgpstrProcessor 提供了 ViTImageProcessor
和 MgpstrTokenizer 的所有功能。请参阅 call() 和 batch_decode() 以获取更多信息。
当以正常模式使用时,此方法将所有其参数传递给 ViTImageProcessor 的 call() 并返回其输出。如果 text
不是 None
,此方法还将 text
和 kwargs
参数传递给 MgpstrTokenizer 的 call() 以对文本进行编码。请参阅上述方法的文档以获取更多信息。
batch_decode
< source >( sequences ) → Dict[str, any]
通过调用decode将token id列表转换为字符串列表。
此方法将所有参数传递给PreTrainedTokenizer的batch_decode()。请参阅此方法的docstring以获取更多信息。
MgpstrModel
类 transformers.MgpstrModel
参数
forward
( pixel_values: FloatTensor output_attentions: 可选 = None output_hidden_states: 可选 = None return_dict: 可选 = None )
参数
- pixel_values (
torch.FloatTensor
shape(batch_size, num_channels, height, width)
) — 像素值。可以使用 AutoImageProcessor 获取像素值。详见 ViTImageProcessor.call()。 - output_attentions (
bool
,可选) — 是否返回所有注意力层的注意力张量。更多信息请查看返回的张量下的attentions
。 - output_hidden_states (
bool
,可选) — 是否返回所有层的隐藏状态。更多信息请查看返回的张量下的hidden_states
。 - return_dict (
bool
, optional) — 是否返回一个ModelOutput(而非一个简单的元组)。
MgpstrModel 的前进方法,覆盖了 __call__
特殊方法。
尽管需要在此函数中定义前向传递的配方,但应该在此之后调用 Module
实例,而不是此函数,因为前者会处理前向和后向处理步骤,而后者会默默地忽略它们。
MgpstrForSceneTextRecognition
class transformers.MgpstrForSceneTextRecognition
< source >参数
MGP-STR 模型,顶部有三个分类头(三个 A^3 模块和转换器编码器输出之上的三个线性层),用于场景文本识别 (STR)。
此模型是 PyTorch torch.nn.Module 子类。将其用作常规的 PyTorch 模块,并参考 PyTorch 文档以获取所有与通用用途和行为相关的内容。
forward
参数
- pixel_values (
torch.FloatTensor
形状(batch_size, num_channels, height, width)
) — 像素值。可以通过 AutoImageProcessor 获取像素值。有关详细信息,请参阅 ViTImageProcessor.call()。 - output_attentions (
bool
,可选项) — 是否返回所有注意力层的注意力张量。更详细的说明请参阅返回张量下的attentions
。 - <-- HTML_TAG_START -->output_hidden_states (
bool
,可选项) — 是否返回所有层的隐藏状态。更详细的说明请参阅返回张量下的hidden_states
。 - return_dict(《布尔值》,可选)—— 是否返回ModelOutput而不是普通元组。
- output_a3_attentions(《布尔值》,可选)—— 是否返回a3模块的注意力张量。请参阅返回张量下的
a3_attentions
获取更多细节。
返回值
transformers.models.mgp_str.modeling_mgp_str.MgpstrModelOutput
或tuple(torch.FloatTensor)
根据配置(<class 'transformers.models.mgp_str.configuration_mgp_str.MgpstrConfig'>
)和输入,返回一个transformers.models.mgp_str.modeling_mgp_str.MgpstrModelOutput
或一个包含各种元素的torch.FloatTensor
元组的张量(如果传递了return_dict=False
或者当config.return_dict=False
)。
-
logits(
tuple(torch.FloatTensor)
的形状为(batch_size, config.num_character_labels)
)—— 包含形态特征、bpe特征和wordpiece特征的分类得分的张量。这些特征在SoftMax之前。字符、bpe和wordpiece的分类得分(SoftMax之前)。
-
hidden_states(《torch.FloatTensor》的元组,可选,在传递
output_hidden_states=True
时返回,或者当config.output_hidden_states=True
时)—— 包含每个层的输出张量的元组(如果模型有嵌入层,则为嵌入层的输出,然后为每个层的输出)的形状为(batch_size, sequence_length, hidden_size)
。模型在每一层的输出以及可选的初始嵌入输出的隐藏状态。
-
attentions(《torch.FloatTensor》的元组,可选,在传递
output_attentions=True
时返回,或者当config.output_attentions=True
时)—— 每层的注意力张量的元组,形状为(batch_size, config.max_token_length, sequence_length, sequence_length)
。在注意力softmax之后的注意力权重,用于在自注意力头中计算加权平均值。
-
a3_attentions (
torch.FloatTensor
元组,可选,当传递output_a3_attentions=True
或当config.output_a3_attentions=True
时返回) — 一个形状为(batch_size, config.max_token_length, sequence_length)
的torch.FloatTensor
元组,分别为字符、bpe和wordpiece的注意力。在注意力softmax之后的注意力权重,用于在自注意力头中计算加权平均值。
MgpstrForSceneTextRecognition的前向方法,覆盖了特殊方法__call__
。
尽管需要在此函数中定义前向传递的配方,但应该在此之后调用 Module
实例,而不是此函数,因为前者会处理前向和后向处理步骤,而后者会默默地忽略它们。
示例
>>> from transformers import (
... MgpstrProcessor,
... MgpstrForSceneTextRecognition,
... )
>>> import requests
>>> from PIL import Image
>>> # load image from the IIIT-5k dataset
>>> url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> processor = MgpstrProcessor.from_pretrained("alibaba-damo/mgp-str-base")
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
>>> model = MgpstrForSceneTextRecognition.from_pretrained("alibaba-damo/mgp-str-base")
>>> # inference
>>> outputs = model(pixel_values)
>>> out_strs = processor.batch_decode(outputs.logits)
>>> out_strs["generated_text"]
'["ticket"]'