import spaces import torch import gradio as gr from PIL import Image from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration import re # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # Check if Flash Attention 2 is available def is_flash_attention_available(): try: import flash_attn return True except ImportError: return False # Initialize models and processors lazily base_model = None base_processor = None chat_model = None chat_processor = None def load_base_model(): global base_model, base_processor if base_model is None: base_repo = "microsoft/kosmos-2.5" # Use Flash Attention 2 if available, otherwise use default attention model_kwargs = { "device_map": "cuda", "dtype": dtype, } if is_flash_attention_available(): model_kwargs["attn_implementation"] = "flash_attention_2" base_model = Kosmos2_5ForConditionalGeneration.from_pretrained( base_repo, **model_kwargs ) base_processor = AutoProcessor.from_pretrained(base_repo) return base_model, base_processor def load_chat_model(): global chat_model, chat_processor if chat_model is None: chat_repo = "microsoft/kosmos-2.5-chat" # Use Flash Attention 2 if available, otherwise use default attention model_kwargs = { "device_map": "cuda", "dtype": dtype, } if is_flash_attention_available(): model_kwargs["attn_implementation"] = "flash_attention_2" chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained( chat_repo, **model_kwargs ) chat_processor = AutoProcessor.from_pretrained(chat_repo) return chat_model, chat_processor def post_process_ocr(y, scale_height, scale_width, prompt=""): y = y.replace(prompt, "") if "" in prompt: return y pattern = r"" bboxs_raw = re.findall(pattern, y) lines = re.split(pattern, y)[1:] bboxs = [re.findall(r"\d+", i) for i in bboxs_raw] bboxs = [[int(j) for j in i] for i in bboxs] info = "" for i in range(len(lines)): if i < len(bboxs): box = bboxs[i] x0, y0, x1, y1 = box if not (x0 >= x1 or y0 >= y1): x0 = int(x0 * scale_width) y0 = int(y0 * scale_height) x1 = int(x1 * scale_width) y1 = int(y1 * scale_height) info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n" return info.strip() @spaces.GPU(duration=120) def generate_markdown(image): if image is None: return "Please upload an image." model, processor = load_base_model() prompt = "" inputs = processor(text=prompt, images=image, return_tensors="pt") height, width = inputs.pop("height"), inputs.pop("width") raw_width, raw_height = image.size scale_height = raw_height / height scale_width = raw_width / width inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=1024, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) result = generated_text[0].replace(prompt, "").strip() return result @spaces.GPU(duration=120) def generate_ocr(image): if image is None: return "Please upload an image.", None model, processor = load_base_model() prompt = "" inputs = processor(text=prompt, images=image, return_tensors="pt") height, width = inputs.pop("height"), inputs.pop("width") raw_width, raw_height = image.size scale_height = raw_height / height scale_width = raw_width / width inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=1024, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) # Post-process OCR output output_text = post_process_ocr(generated_text[0], scale_height, scale_width) # Create visualization from PIL import ImageDraw vis_image = image.copy() draw = ImageDraw.Draw(vis_image) lines = output_text.split("\n") for line in lines: if not line.strip(): continue parts = line.split(",") if len(parts) >= 8: try: coords = list(map(int, parts[:8])) draw.polygon(coords, outline="red", width=2) except: continue return output_text, vis_image @spaces.GPU(duration=120) def generate_chat_response(image, question): if image is None: return "Please upload an image." if not question.strip(): return "Please ask a question." model, processor = load_chat_model() template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" prompt = template.format(question) inputs = processor(text=prompt, images=image, return_tensors="pt") height, width = inputs.pop("height"), inputs.pop("width") raw_width, raw_height = image.size scale_height = raw_height / height scale_width = raw_width / width inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=1024, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) # Extract only the assistant's response result = generated_text[0] if "ASSISTANT:" in result: result = result.split("ASSISTANT:")[-1].strip() return result # Create Gradio interface with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # KOSMOS-2.5 Document AI Demo Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images! This demo showcases three capabilities: 1. **Markdown Generation**: Convert document images to markdown format 2. **OCR with Bounding Boxes**: Extract text with spatial coordinates 3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat Upload a document image (receipt, form, article, etc.) and try different tasks! """) with gr.Tabs(): # Markdown Generation Tab with gr.TabItem("📝 Markdown Generation"): with gr.Row(): with gr.Column(): md_image = gr.Image(type="pil", label="Upload Document Image") gr.Examples( examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], inputs=md_image ) md_button = gr.Button("Generate Markdown", variant="primary") with gr.Column(): md_output = gr.Textbox( label="Generated Markdown", lines=15, max_lines=20, show_copy_button=True ) # OCR Tab with gr.TabItem("🔍 OCR with Bounding Boxes"): with gr.Row(): with gr.Column(): ocr_image = gr.Image(type="pil", label="Upload Document Image") gr.Examples( examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], inputs=ocr_image ) ocr_button = gr.Button("Extract Text with Coordinates", variant="primary") with gr.Column(): with gr.Row(): ocr_text = gr.Textbox( label="Extracted Text with Coordinates", lines=10, show_copy_button=True ) ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)") # Chat Tab with gr.TabItem("💬 Document Q&A (Chat)"): with gr.Row(): with gr.Column(): chat_image = gr.Image(type="pil", label="Upload Document Image") gr.Examples( examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], inputs=chat_image ) chat_question = gr.Textbox( label="Ask a question about the document", placeholder="e.g., What is the total amount on this receipt?", lines=2 ) gr.Examples( examples=["What is the total amount on this receipt?", "What items were purchased?", "When was this receipt issued?", "What is the subtotal?"], inputs=chat_question ) chat_button = gr.Button("Get Answer", variant="primary") with gr.Column(): chat_output = gr.Textbox( label="Answer", lines=8, show_copy_button=True ) # Event handlers md_button.click( fn=generate_markdown, inputs=[md_image], outputs=[md_output] ) ocr_button.click( fn=generate_ocr, inputs=[ocr_image], outputs=[ocr_text, ocr_vis] ) chat_button.click( fn=generate_chat_response, inputs=[chat_image, chat_question], outputs=[chat_output] ) # Examples section gr.Markdown(""" ## Example Use Cases: - **Receipts**: Extract itemized information or ask about totals - **Forms**: Convert to structured format or answer specific questions - **Articles**: Get markdown format or ask about content - **Screenshots**: Extract text or get information about specific elements ## Note: This is a generative model and may occasionally hallucinate. Results should be verified for accuracy. """) if __name__ == "__main__": demo.launch()