Transformers 文档

DiT

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

PyTorch Flax

DiT

DiT 是一种图像 Transformer 模型,在大规模未标注文档图像上进行预训练。它学习从损坏的输入图像中预测缺失的视觉标记。预训练的 DiT 模型可用作其他模型的骨干网络,用于文档图像分类和表格检测等视觉文档任务。

你可以在 Microsoft 组织下找到所有原始的 DiT 检查点。

请参阅 BEiT 文档,了解如何将 DiT 应用于不同视觉任务的更多示例。

以下示例展示了如何使用 PipelineAutoModel 类对图像进行分类。

<hfoptions id="usage"> <hfoption id="Pipeline">
import torch
from transformers import pipeline

pipeline = pipeline(
    task="image-classification",
    model="microsoft/dit-base-finetuned-rvlcdip",
    torch_dtype=torch.float16,
    device=0
)
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg")
</hfoption> <hfoption id="AutoModel">
import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(
    "microsoft/dit-base-finetuned-rvlcdip",
    use_fast=True,
)
model = AutoModelForImageClassification.from_pretrained(
    "microsoft/dit-base-finetuned-rvlcdip",
    device_map="auto",
)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(image, return_tensors="pt").to("cuda")

with torch.no_grad():
  logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()

class_labels = model.config.id2label
predicted_class_label = class_labels[predicted_class_id]
print(f"The predicted class label is: {predicted_class_label}")
</hfoption>

注意事项

  • 预训练的 DiT 权重可以加载到带有建模头的 [BEiT] 模型中以预测视觉标记。
    from transformers import BeitForMaskedImageModeling
    
    model = BeitForMaskedImageModeling.from_pretraining("microsoft/dit-base")

资源

  • 有关文档图像分类推理的示例,请参阅此笔记本
< > 在 GitHub 上更新