import gradio as gr import torch from PIL import Image import os from dotenv import load_dotenv import spaces load_dotenv() # Disable torch compilation issues torch._dynamo.config.disable = True torch.backends.cudnn.allow_tf32 = True IS_CUDA = torch.cuda.is_available() IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False if IS_ZEROGPU: torch.compiler.set_stance("force_eager") torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True MODEL_NAME ="atlasia/AtlasOCR" MAX_TOKENS = 4096 @spaces.GPU() @torch.inference_mode() def predict(image: Image.Image) -> str: try: from unsloth import FastVisionModel model, processor = FastVisionModel.from_pretrained( MODEL_NAME, device_map="auto", load_in_4bit=True, use_gradient_checkpointing="unsloth", token=os.environ["HF_API_KEY"], ) except Exception as e: print(f"[Error] Failed to load model: {e}") raise Exception(f"❌ Model failed to load: {e}") if image is None: gr.warning("Please upload an image.") # Build prompt messages = [ { "role": "user", "content": [{"type": "image"}, {"type": "text", "text": ""}], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( image, text, add_special_tokens=False, return_tensors="pt", ).to(model.device) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=MAX_TOKENS, do_sample=False, temperature=0.0001, pad_token_id=processor.tokenizer.eos_token_id, ) # Trim input ids from output generated_ids = [ out[len(inp) :] for inp, out in zip(inputs["input_ids"], generated_ids) ] text_out = processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return text_out[0].strip() demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Image", height=400), outputs=gr.Textbox( label="Extracted Text", lines=20, show_copy_button=True, placeholder="Extracted text will appear here...", ), title="AtlasOCR - Darija Document OCR", description="Upload an image to extract Darija text.", examples=[["i3.png"], ["i6.png"]], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)