import gradio as gr import torch from PIL import Image import logging from typing import Optional, Union import os import spaces from dotenv import load_dotenv load_dotenv() # Disable torch compilation to avoid dynamo issues torch._dynamo.config.disable = True torch.backends.cudnn.allow_tf32 = True # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class AtlasOCR: def __init__(self, model_name: str = "atlasia/AtlasOCR", max_tokens: int = 2000): """Initialize the AtlasOCR model with proper error handling.""" try: from unsloth import FastVisionModel logger.info(f"Loading model: {model_name}") # Disable compilation for the model with torch._dynamo.config.patch(disable=True): self.model, self.processor = FastVisionModel.from_pretrained( model_name, device_map="auto", load_in_4bit=True, use_gradient_checkpointing="unsloth", token=os.environ["HF_API_KEY"] ) # Ensure model is not compiled if hasattr(self.model, '_dynamo_compile'): self.model._dynamo_compile = False self.max_tokens = max_tokens self.prompt = "" self.device = next(self.model.parameters()).device logger.info(f"Model loaded successfully on device: {self.device}") except ImportError: logger.error("unsloth not found. Please install it: pip install unsloth") raise except Exception as e: logger.error(f"Error loading model: {e}") raise def prepare_inputs(self, image: Image.Image) -> dict: """Prepare inputs for the model with proper error handling.""" try: messages = [ { "role": "user", "content": [ { "type": "image", }, {"type": "text", "text": self.prompt}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.processor( image, text, add_special_tokens=False, return_tensors="pt", ) return inputs except Exception as e: logger.error(f"Error preparing inputs: {e}") raise def predict(self, image: Image.Image) -> str: """Predict text from image with comprehensive error handling.""" try: if image is None: return "Please upload an image." # Convert numpy array to PIL Image if needed if hasattr(image, 'shape'): # numpy array image = Image.fromarray(image) inputs = self.prepare_inputs(image) # Move inputs to the same device as model with explicit device handling device = self.device logger.info(f"Moving inputs to device: {device}") # Manually move each tensor to device for key in inputs: if hasattr(inputs[key], 'to'): inputs[key] = inputs[key].to(device) # Ensure attention_mask is float32 and on correct device if 'attention_mask' in inputs: inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.float32, device=device) logger.info(f"Generating text with max_tokens={self.max_tokens}") # Disable compilation during generation with torch.no_grad(), torch._dynamo.config.patch(disable=True): generated_ids = self.model.generate( **inputs, max_new_tokens=self.max_tokens, use_cache=True, do_sample=False, temperature=0.1, pad_token_id=self.processor.tokenizer.eos_token_id ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) result = output_text[0].strip() logger.info(f"Generated text: {result[:100]}...") return result except Exception as e: logger.error(f"Error during prediction: {e}") return f"Error processing image: {str(e)}" def __call__(self, image: Union[Image.Image, str]) -> str: """Callable interface for the model.""" if isinstance(image, str): return "Please upload an image file." return self.predict(image) # Global model instance atlas_ocr = None def load_model(): """Load the model globally to avoid reloading.""" global atlas_ocr if atlas_ocr is None: try: atlas_ocr = AtlasOCR() except Exception as e: logger.error(f"Failed to load model: {e}") return False return True @spaces.GPU def perform_ocr(image): """Main OCR function with proper error handling.""" try: if not load_model(): return "Error: Failed to load model. Please check the logs." if image is None: return "Please upload an image to extract text." result = atlas_ocr(image) return result except Exception as e: logger.error(f"Error in perform_ocr: {e}") return f"An error occurred: {str(e)}" def process_with_status(image): """Process image and return result with status - moved outside to avoid pickling issues.""" if image is None: return "Please upload an image.", "No image provided" try: result = perform_ocr(image) return result, "Processing completed successfully" except Exception as e: return f"Error: {str(e)}", f"Error occurred: {str(e)}" def create_interface(): """Create the Gradio interface with proper configuration.""" with gr.Blocks( title="AtlasOCR - Darija Document OCR", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } """ ) as demo: gr.Markdown(""" # AtlasOCR - Darija Document OCR Upload an image to extract Darija text in real-time. This model is specialized for Darija document OCR. """) with gr.Row(): with gr.Column(scale=1): # Input image image_input = gr.Image( type="pil", label="Upload Image", height=400 ) # Submit button submit_btn = gr.Button( "Extract Text", variant="primary", size="lg" ) # Clear button clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=1): # Output text output = gr.Textbox( label="Extracted Text", lines=20, show_copy_button=True, placeholder="Extracted text will appear here..." ) # Status indicator status = gr.Textbox( label="Status", value="Ready to process images", interactive=False ) # Model details with gr.Accordion("Model Information", open=False): gr.Markdown(""" **Model:** AtlasOCR-v0 **Description:** Specialized Darija OCR model for Arabic dialect text extraction **Size:** 3B parameters **Context window:** Supports up to 2000 output tokens **Optimization:** 4-bit quantization for efficient inference """) gr.Examples( examples=[ ["i3.png"], ["i6.png"] ], inputs=image_input, outputs=[output, status], # <-- required fn=process_with_status, # <-- required label="Example Images", examples_per_page=4, cache_examples=True ) # Set up processing flow submit_btn.click( fn=process_with_status, inputs=image_input, outputs=[output, status] ) image_input.change( fn=process_with_status, inputs=image_input, outputs=[output, status] ) clear_btn.click( fn=lambda: (None, "", "Ready to process images"), outputs=[image_input, output, status] ) return demo # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True )