donut-base-finetuned-docvqa

This model is a fine-tuned version of naver-clova-ix/donut-base on hf-tuner/docvqa-10k-donut dataset.

Model description

Donut consists of a vision encoder (Swin Transformer) and a text decoder (BART). Given an image, the encoder first encodes the image into a tensor of embeddings (of shape batch_size, seq_len, hidden_size), after which the decoder autoregressively generates text, conditioned on the encoding of the encoder. Donut_architecture

How to use

import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

model_id = "hf-tuner/donut-base-finetuned-docvqa"

config = VisionEncoderDecoderConfig.from_pretrained(model_id)
processor = DonutProcessor.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = VisionEncoderDecoderModel.from_pretrained(ckpt, config=config)
model.to(device)

PROMPT_TEMPLATE = "<doc_vqa><s><s_question>{}</s_question><s_answer>"

def predict(image, question):
  prompt = PROMPT_TEMPLATE.format(question)

  pixel_values = processor(image,
                           return_tensors="pt"
                           ).pixel_values.to(device)
  decoder_input_ids = processor.tokenizer(prompt,
                                          add_special_tokens=False,
                                          return_tensors="pt"
                                          ).input_ids.to(device)

  generated_ids = model.generate(pixel_values,
                                decoder_input_ids=decoder_input_ids,
                                max_length=64,
                                bad_words_ids=[[processor.tokenizer.unk_token_id]]
                                )
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
  output_json = processor.token2json(generated_text)
  return output_json

output = predict(image=Image.open('doc-image.png'), question="Which is the date of the approval form?")
# {"question": "Which is the date of the approval form?", "answer": "april 5, 1995"}

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 1
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Use OptimizerNames.ADAMW_TORCH with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: linear
  • num_epochs: 1
  • mixed_precision_training: Native AMP

Framework versions

  • Transformers 4.56.1
  • Pytorch 2.6.0+cu124
  • Datasets 3.6.0
  • Tokenizers 0.22.1
Downloads last month
5
Safetensors
Model size
0.2B params
Tensor type
I64
·
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hf-tuner/donut-base-finetuned-docvqa

Finetuned
(479)
this model

Dataset used to train hf-tuner/donut-base-finetuned-docvqa

Collection including hf-tuner/donut-base-finetuned-docvqa