Donut 🍩
Collection
OCR-free Document Understanding Transformer (Donut)
•
4 items
•
Updated
•
1
This model is a fine-tuned version of naver-clova-ix/donut-base on hf-tuner/rvl-cdip-document-classification dataset.
Donut consists of a vision encoder (Swin Transformer) and a text decoder (BART).
Given an image, the encoder first encodes the image into a tensor of embeddings (of shape batch_size, seq_len, hidden_size),
after which the decoder autoregressively generates text, conditioned on the encoding of the encoder.

import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
model_id = "hf-tuner/donut-base-finetuned-rvl-cdip"
config = VisionEncoderDecoderConfig.from_pretrained(model_id)
processor = DonutProcessor.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = VisionEncoderDecoderModel.from_pretrained(ckpt, config=config)
model.to(device)
# inference
task_start_token = "<classification>"
image = Image.open("test-document.png")
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
decoder_input_ids = processor.tokenizer(task_start_token, add_special_tokens=False, return_tensors="pt").input_ids.to(device)
generated_ids = model.generate(pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=8,
bad_words_ids=[[processor.tokenizer.unk_token_id]]
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
processor.token2json(generated_text)
# Output:
# {'class': 'handwritten'}
This model is finetuned to classify document images to one of the 16 classes from aharley/rvl_cdip dataset
{
"0": "letter",
"1": "form",
"2": "email",
"3": "handwritten",
"4": "advertisement",
"5": "scientific report",
"6": "scientific publication",
"7": "specification",
"8": "file folder",
"9": "news article",
"10": "budget",
"11": "invoice",
"12": "presentation",
"13": "questionnaire",
"14": "resume",
"15": "memo"
}
The following hyperparameters were used during training:
Base model
naver-clova-ix/donut-base