Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration | |
import re | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
# Check if Flash Attention 2 is available | |
def is_flash_attention_available(): | |
try: | |
import flash_attn | |
return True | |
except ImportError: | |
return False | |
# Initialize models and processors lazily | |
base_model = None | |
base_processor = None | |
chat_model = None | |
chat_processor = None | |
def load_base_model(): | |
global base_model, base_processor | |
if base_model is None: | |
base_repo = "microsoft/kosmos-2.5" | |
# Use Flash Attention 2 if available, otherwise use default attention | |
model_kwargs = { | |
"device_map": "cuda", | |
"dtype": dtype, | |
} | |
if is_flash_attention_available(): | |
model_kwargs["attn_implementation"] = "flash_attention_2" | |
base_model = Kosmos2_5ForConditionalGeneration.from_pretrained( | |
base_repo, | |
**model_kwargs | |
) | |
base_processor = AutoProcessor.from_pretrained(base_repo) | |
return base_model, base_processor | |
def load_chat_model(): | |
global chat_model, chat_processor | |
if chat_model is None: | |
chat_repo = "microsoft/kosmos-2.5-chat" | |
# Use Flash Attention 2 if available, otherwise use default attention | |
model_kwargs = { | |
"device_map": "cuda", | |
"dtype": dtype, | |
} | |
if is_flash_attention_available(): | |
model_kwargs["attn_implementation"] = "flash_attention_2" | |
chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained( | |
chat_repo, | |
**model_kwargs | |
) | |
chat_processor = AutoProcessor.from_pretrained(chat_repo) | |
return chat_model, chat_processor | |
def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"): | |
y = y.replace(prompt, "") | |
if "<md>" in prompt: | |
return y | |
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>" | |
bboxs_raw = re.findall(pattern, y) | |
lines = re.split(pattern, y)[1:] | |
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw] | |
bboxs = [[int(j) for j in i] for i in bboxs] | |
info = "" | |
for i in range(len(lines)): | |
if i < len(bboxs): | |
box = bboxs[i] | |
x0, y0, x1, y1 = box | |
if not (x0 >= x1 or y0 >= y1): | |
x0 = int(x0 * scale_width) | |
y0 = int(y0 * scale_height) | |
x1 = int(x1 * scale_width) | |
y1 = int(y1 * scale_height) | |
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n" | |
return info.strip() | |
def generate_markdown(image): | |
if image is None: | |
return "Please upload an image." | |
model, processor = load_base_model() | |
prompt = "<md>" | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
height, width = inputs.pop("height"), inputs.pop("width") | |
raw_width, raw_height = image.size | |
scale_height = raw_height / height | |
scale_width = raw_width / width | |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
result = generated_text[0].replace(prompt, "").strip() | |
return result | |
def generate_ocr(image): | |
if image is None: | |
return "Please upload an image.", None | |
model, processor = load_base_model() | |
prompt = "<ocr>" | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
height, width = inputs.pop("height"), inputs.pop("width") | |
raw_width, raw_height = image.size | |
scale_height = raw_height / height | |
scale_width = raw_width / width | |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
# Post-process OCR output | |
output_text = post_process_ocr(generated_text[0], scale_height, scale_width) | |
# Create visualization | |
from PIL import ImageDraw | |
vis_image = image.copy() | |
draw = ImageDraw.Draw(vis_image) | |
lines = output_text.split("\n") | |
for line in lines: | |
if not line.strip(): | |
continue | |
parts = line.split(",") | |
if len(parts) >= 8: | |
try: | |
coords = list(map(int, parts[:8])) | |
draw.polygon(coords, outline="red", width=2) | |
except: | |
continue | |
return output_text, vis_image | |
def generate_chat_response(image, question): | |
if image is None: | |
return "Please upload an image." | |
if not question.strip(): | |
return "Please ask a question." | |
model, processor = load_chat_model() | |
template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" | |
prompt = template.format(question) | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
height, width = inputs.pop("height"), inputs.pop("width") | |
raw_width, raw_height = image.size | |
scale_height = raw_height / height | |
scale_width = raw_width / width | |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()} | |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
# Extract only the assistant's response | |
result = generated_text[0] | |
if "ASSISTANT:" in result: | |
result = result.split("ASSISTANT:")[-1].strip() | |
return result | |
# Create Gradio interface | |
with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# KOSMOS-2.5 Document AI Demo | |
Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images! | |
This demo showcases three capabilities: | |
1. **Markdown Generation**: Convert document images to markdown format | |
2. **OCR with Bounding Boxes**: Extract text with spatial coordinates | |
3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat | |
Upload a document image (receipt, form, article, etc.) and try different tasks! | |
""") | |
with gr.Tabs(): | |
# Markdown Generation Tab | |
with gr.TabItem("π Markdown Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
md_image = gr.Image(type="pil", label="Upload Document Image") | |
gr.Examples( | |
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
inputs=md_image | |
) | |
md_button = gr.Button("Generate Markdown", variant="primary") | |
with gr.Column(): | |
md_output = gr.Textbox( | |
label="Generated Markdown", | |
lines=15, | |
max_lines=20, | |
show_copy_button=True | |
) | |
# OCR Tab | |
with gr.TabItem("π OCR with Bounding Boxes"): | |
with gr.Row(): | |
with gr.Column(): | |
ocr_image = gr.Image(type="pil", label="Upload Document Image") | |
gr.Examples( | |
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
inputs=ocr_image | |
) | |
ocr_button = gr.Button("Extract Text with Coordinates", variant="primary") | |
with gr.Column(): | |
with gr.Row(): | |
ocr_text = gr.Textbox( | |
label="Extracted Text with Coordinates", | |
lines=10, | |
show_copy_button=True | |
) | |
ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)") | |
# Chat Tab | |
with gr.TabItem("π¬ Document Q&A (Chat)"): | |
with gr.Row(): | |
with gr.Column(): | |
chat_image = gr.Image(type="pil", label="Upload Document Image") | |
gr.Examples( | |
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"], | |
inputs=chat_image | |
) | |
chat_question = gr.Textbox( | |
label="Ask a question about the document", | |
placeholder="e.g., What is the total amount on this receipt?", | |
lines=2 | |
) | |
gr.Examples( | |
examples=["What is the total amount on this receipt?", "What items were purchased?", "When was this receipt issued?", "What is the subtotal?"], | |
inputs=chat_question | |
) | |
chat_button = gr.Button("Get Answer", variant="primary") | |
with gr.Column(): | |
chat_output = gr.Textbox( | |
label="Answer", | |
lines=8, | |
show_copy_button=True | |
) | |
# Event handlers | |
md_button.click( | |
fn=generate_markdown, | |
inputs=[md_image], | |
outputs=[md_output] | |
) | |
ocr_button.click( | |
fn=generate_ocr, | |
inputs=[ocr_image], | |
outputs=[ocr_text, ocr_vis] | |
) | |
chat_button.click( | |
fn=generate_chat_response, | |
inputs=[chat_image, chat_question], | |
outputs=[chat_output] | |
) | |
# Examples section | |
gr.Markdown(""" | |
## Example Use Cases: | |
- **Receipts**: Extract itemized information or ask about totals | |
- **Forms**: Convert to structured format or answer specific questions | |
- **Articles**: Get markdown format or ask about content | |
- **Screenshots**: Extract text or get information about specific elements | |
## Note: | |
This is a generative model and may occasionally hallucinate. Results should be verified for accuracy. | |
""") | |
if __name__ == "__main__": | |
demo.launch() |