donut-base-finetuned-rvl-cdip

This model is a fine-tuned version of naver-clova-ix/donut-base on hf-tuner/rvl-cdip-document-classification dataset.

Model description

Donut consists of a vision encoder (Swin Transformer) and a text decoder (BART). Given an image, the encoder first encodes the image into a tensor of embeddings (of shape batch_size, seq_len, hidden_size), after which the decoder autoregressively generates text, conditioned on the encoding of the encoder. Donut_architecture

How to use

import torch
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

model_id = "hf-tuner/donut-base-finetuned-rvl-cdip"

config = VisionEncoderDecoderConfig.from_pretrained(model_id)
processor = DonutProcessor.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = VisionEncoderDecoderModel.from_pretrained(ckpt, config=config)
model.to(device)

# inference
task_start_token = "<classification>"

image = Image.open("test-document.png")

pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
decoder_input_ids = processor.tokenizer(task_start_token, add_special_tokens=False, return_tensors="pt").input_ids.to(device)

generated_ids = model.generate(pixel_values,
                               decoder_input_ids=decoder_input_ids,
                               max_length=8,
                               bad_words_ids=[[processor.tokenizer.unk_token_id]]
                               )

generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
processor.token2json(generated_text)

# Output:
# {'class': 'handwritten'}

Training Data

This model is finetuned to classify document images to one of the 16 classes from aharley/rvl_cdip dataset

{
  "0": "letter",
  "1": "form",
  "2": "email",
  "3": "handwritten",
  "4": "advertisement",
  "5": "scientific report",
  "6": "scientific publication",
  "7": "specification",
  "8": "file folder",
  "9": "news article",
  "10": "budget",
  "11": "invoice",
  "12": "presentation",
  "13": "questionnaire",
  "14": "resume",
  "15": "memo"
}

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 1
  • eval_batch_size: 1
  • seed: 42
  • optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: linear
  • num_epochs: 1
  • mixed_precision_training: Native AMP

Framework versions

  • Transformers 4.56.1
  • Pytorch 2.8.0+cu126
  • Datasets 4.0.0
  • Tokenizers 0.22.0
Downloads last month
7
Safetensors
Model size
0.2B params
Tensor type
I64
·
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hf-tuner/donut-base-finetuned-rvl-cdip

Finetuned
(479)
this model

Dataset used to train hf-tuner/donut-base-finetuned-rvl-cdip

Collection including hf-tuner/donut-base-finetuned-rvl-cdip