File size: 4,305 Bytes
5f2550e
 
 
 
 
 
9a58a0a
2a2a313
aeab520
04ace6d
9a58a0a
ecc0fb6
7efa1e4
2a2a313
 
 
 
 
 
 
 
5f2550e
881fed8
5f2550e
 
 
9a58a0a
5f2550e
 
9a58a0a
5f2550e
28592c1
5f2550e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecc0fb6
5f2550e
 
 
 
 
 
aeee287
9882d84
 
5f2550e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a7f71c
5f2550e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
import re
from PIL import Image
import spaces  # Add spaces import for Hugging Face Spaces
import os
import sys
import logging
from huggingface_hub import HfFolder

hf_token = os.getenv("API_KEY")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# If the key is found, use it to authenticate
if hf_token:
    HfFolder.save_token(hf_token)  # This authenticates you for this session
else:
    print("No HF_KEY found. Please make sure you've set up your Hugging Face API key as an environment variable.")


# Model information
MODEL_ID = "DeepMount00/Smol-OCR-preview"
OCR_INSTRUCTION = "Sei un assistente esperto di OCR, converti il testo in formato MD."

# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_ID, token=hf_token)
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    token=hf_token,
    torch_dtype=torch.bfloat16,
    # _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
).to("cuda")  # Ensure model loads on CUDA for Spaces

@spaces.GPU  # Add spaces.GPU decorator for GPU acceleration
def process_image(image, progress=gr.Progress()):
    if image is None:
        gr.Error("Please upload an image to process.")
        return "Please upload an image to process."

    progress(0, desc="Starting OCR processing...")

    # Convert from Gradio's image format to PIL
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")

    progress(0.2, desc="Preparing image...")

    # Create input messages - note that the instruction is included as part of the user message
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": OCR_INSTRUCTION}
            ]
        },
    ]

    # Prepare inputs
    progress(0.4, desc="Processing with model...")
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[image], return_tensors="pt")
    inputs = inputs.to('cuda')

    # Generate outputs
    progress(0.6, desc="Generating text...")
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=4096,
            temperature=0.1,
            do_sample=True
        )

    # Decode outputs
    progress(0.8, desc="Finalizing results...")
    generated_text = processor.batch_decode(
        generated_ids,
        skip_special_tokens=True
    )[0]

    # Extract only the assistant's response
    # Remove any "User:" and "Assistant:" prefixes if present
    cleaned_text = generated_text

    # Remove user prompt and "User:" prefix if present
    user_pattern = r"User:.*?(?=Assistant:|$)"
    cleaned_text = re.sub(user_pattern, "", cleaned_text, flags=re.DOTALL)

    # Remove "Assistant:" prefix if present
    assistant_pattern = r"Assistant:\s*"
    cleaned_text = re.sub(assistant_pattern, "", cleaned_text)

    # Clean up any extra whitespace
    cleaned_text = cleaned_text.strip()

    progress(1.0, desc="Done!")
    return cleaned_text  # Return only the cleaned text


# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# OCR to Markdown Converter")
    gr.Markdown(f"Upload Italian text images for instant Markdown conversion.Powered by {MODEL_ID} technology for exceptional accuracy with Italian language documents.")

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload an image containing text")
            submit_btn = gr.Button("Process Image", variant="primary")
        with gr.Column(scale=1):
            output_text = gr.Textbox(label="Raw Text", lines=15)
            copy_btn = gr.Button("Select All Text", variant="secondary")

    submit_btn.click(
        fn=process_image,
        inputs=input_image,
        outputs=output_text,
        show_progress="full",
        queue=True  # Enable queue for Spaces
    )

    def copy_to_clipboard(text):
        return text

    copy_btn.click(
        fn=copy_to_clipboard,
        inputs=output_text,
        outputs=output_text
    )

# Launch the app with default Spaces configuration (no need for local file paths)
demo.launch()