import gradio as gr import torch from transformers import AutoModel, AutoTokenizer import spaces import os import tempfile from PIL import Image # --- 1. Load Model and Tokenizer (Done only once at startup) --- print("Loading model and tokenizer...") model_name = "deepseek-ai/DeepSeek-OCR" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Load the model to CPU first; it will be moved to GPU during processing model = AutoModel.from_pretrained( model_name, _attn_implementation="flash_attention_2", trust_remote_code=True, use_safetensors=True, ) model = model.eval() print("βœ… Model loaded successfully.") # --- 2. Main Processing Function --- @spaces.GPU def process_ocr_task(image, model_size, task_type, ref_text): """ Processes an image with DeepSeek-OCR for all supported tasks. Args: image (PIL.Image): The input image. model_size (str): The model size configuration. task_type (str): The type of OCR task to perform. ref_text (str): The reference text for the 'Locate' task. """ if image is None: return "Please upload an image first.", None # Move the model to GPU and use bfloat16 for better performance print("πŸš€ Moving model to GPU...") model_gpu = model.cuda().to(torch.bfloat16) print("βœ… Model is on GPU.") # Create a temporary directory to store files with tempfile.TemporaryDirectory() as output_path: # --- Build the prompt based on the selected task type --- if task_type == "πŸ“ Free OCR": prompt = "\nFree OCR." elif task_type == "πŸ“„ Convert to Markdown": prompt = "\n<|grounding|>Convert the document to markdown." elif task_type == "πŸ“ˆ Parse Figure": prompt = "\nParse the figure." elif task_type == "πŸ” Locate Object by Reference": if not ref_text or ref_text.strip() == "": raise gr.Error("For the 'Locate' task, you must provide the reference text to find!") # Use an f-string to embed the user's reference text into the prompt prompt = f"\nLocate <|ref|>{ref_text.strip()}<|/ref|> in the image." else: prompt = "\nFree OCR." # Default fallback # Save the uploaded image to the temporary path temp_image_path = os.path.join(output_path, "temp_image.png") image.save(temp_image_path) # Configure model size parameters size_configs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True}, } config = size_configs.get(model_size, size_configs["Gundam (Recommended)"]) print(f"πŸƒ Running inference with prompt: {prompt}") # --- Run the model's inference method --- text_result = model_gpu.infer( tokenizer, prompt=prompt, image_file=temp_image_path, output_path=output_path, base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"], save_results=True, # Important: Must be True to get the output image test_compress=True, eval_mode=True, ) print(f"====\nπŸ“„ Text Result: {text_result}\n====") # --- Handle the output (both text and image) --- image_result_path = None # Tasks that generate a visual output usually create a 'grounding' or 'result' image if task_type in ["πŸ” Locate Object by Reference", "πŸ“„ Convert to Markdown", "πŸ“ˆ Parse Figure"]: # Find the result image in the output directory for filename in os.listdir(output_path): if "grounding" in filename or "result" in filename: image_result_path = os.path.join(output_path, filename) break # If an image was found, open it with PIL; otherwise, return None result_image_pil = Image.open(image_result_path) if image_result_path else None return text_result, result_image_pil # --- 3. Build the Gradio Interface --- with gr.Blocks(title="🐳DeepSeek-OCR🐳", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🐳 Full Demo of DeepSeek-OCR 🐳 Upload an image to explore the document recognition and understanding capabilities of DeepSeek-OCR. **πŸ’‘ How to use:** 1. **Upload an image** using the upload box. 2. Select a **Model Size**. `Gundam` is recommended for most documents for a good balance of speed and accuracy. 3. Choose a **Task Type**: - **πŸ“ Free OCR**: Extracts raw text from the image. Best for simple text extraction. - **πŸ“„ Convert to Markdown**: Converts the entire document into Markdown format, preserving structure like headers, lists, and tables. - **πŸ“ˆ Parse Figure**: Analyzes and extracts structured data from charts, graphs, and geometric figures. - **πŸ” Locate Object by Reference**: Finds a specific object or piece of text in the image. You **must** type what you're looking for into the **"Reference Text"** box that appears. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Image", sources=["upload", "clipboard"]) model_size = gr.Dropdown( choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"], value="Gundam (Recommended)", label="βš™οΈ Model Size", ) task_type = gr.Dropdown( choices=["πŸ“ Free OCR", "πŸ“„ Convert to Markdown", "πŸ“ˆ Parse Figure", "πŸ” Locate Object by Reference"], value="πŸ“„ Convert to Markdown", label="πŸš€ Task Type", ) ref_text_input = gr.Textbox( label="πŸ“ Reference Text (for Locate task)", placeholder="e.g., the teacher, 11-2=, a red car...", visible=False, # Initially hidden ) submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): output_text = gr.Textbox(label="πŸ“„ Text Result", lines=15, show_copy_button=True) output_image = gr.Image(label="πŸ–ΌοΈ Image Result (if any)", type="pil") # --- UI Interaction Logic --- def toggle_ref_text_visibility(task): # If the user selects the 'Locate' task, make the reference textbox visible if task == "πŸ” Locate Object by Reference": return gr.Textbox(visible=True) else: return gr.Textbox(visible=False) # When the 'task_type' dropdown changes, call the function to update the visibility task_type.change( fn=toggle_ref_text_visibility, inputs=task_type, outputs=ref_text_input, ) # Define what happens when the submit button is clicked submit_btn.click( fn=process_ocr_task, inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image], ) # --- Example Images and Tasks --- gr.Examples( examples=[ ["./examples/doc_markdown.png", "Gundam (Recommended)", "πŸ“„ Convert to Markdown", ""], ["./examples/chart.png", "Gundam (Recommended)", "πŸ“ˆ Parse Figure", ""], ["./examples/teacher.png", "Base", "πŸ” Locate Object by Reference", "the teacher"], ["./examples/math_locate.png", "Small", "πŸ” Locate Object by Reference", "11-2="], ["./examples/receipt.jpg", "Base", "πŸ“ Free OCR", ""], ], inputs=[image_input, model_size, task_type, ref_text_input], outputs=[output_text, output_image], fn=process_ocr_task, cache_examples=False, # Disable caching to ensure examples run every time ) # --- 4. Launch the App --- if __name__ == "__main__": # Create an 'examples' directory if it doesn't exist if not os.path.exists("examples"): os.makedirs("examples") # Please manually download the example images into the "examples" folder. # e.g., doc_markdown.png, chart.png, teacher.png, math_locate.png, receipt.jpg demo.queue(max_size=20) demo.launch(share=True) # Set share=True to create a public link