Donut 🍩
Collection
OCR-free Document Understanding Transformer (Donut)
•
4 items
•
Updated
•
1
This model is a fine-tuned version of naver-clova-ix/donut-base on hf-tuner/docvqa-10k-donut dataset.
Donut consists of a vision encoder (Swin Transformer) and a text decoder (BART).
Given an image, the encoder first encodes the image into a tensor of embeddings (of shape batch_size, seq_len, hidden_size),
after which the decoder autoregressively generates text, conditioned on the encoding of the encoder.

import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
model_id = "hf-tuner/donut-base-finetuned-docvqa"
config = VisionEncoderDecoderConfig.from_pretrained(model_id)
processor = DonutProcessor.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = VisionEncoderDecoderModel.from_pretrained(ckpt, config=config)
model.to(device)
PROMPT_TEMPLATE = "<doc_vqa><s><s_question>{}</s_question><s_answer>"
def predict(image, question):
prompt = PROMPT_TEMPLATE.format(question)
pixel_values = processor(image,
return_tensors="pt"
).pixel_values.to(device)
decoder_input_ids = processor.tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids.to(device)
generated_ids = model.generate(pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=64,
bad_words_ids=[[processor.tokenizer.unk_token_id]]
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
output_json = processor.token2json(generated_text)
return output_json
output = predict(image=Image.open('doc-image.png'), question="Which is the date of the approval form?")
# {"question": "Which is the date of the approval form?", "answer": "april 5, 1995"}
The following hyperparameters were used during training:
Base model
naver-clova-ix/donut-base