kosmos-2.5-demo / app.py
nielsr's picture
nielsr HF Staff
Fix Flash Attention 2 import error with conditional loading
dce53e9
import spaces
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Check if Flash Attention 2 is available
def is_flash_attention_available():
try:
import flash_attn
return True
except ImportError:
return False
# Initialize models and processors lazily
base_model = None
base_processor = None
chat_model = None
chat_processor = None
def load_base_model():
global base_model, base_processor
if base_model is None:
base_repo = "microsoft/kosmos-2.5"
# Use Flash Attention 2 if available, otherwise use default attention
model_kwargs = {
"device_map": "cuda",
"dtype": dtype,
}
if is_flash_attention_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
base_repo,
**model_kwargs
)
base_processor = AutoProcessor.from_pretrained(base_repo)
return base_model, base_processor
def load_chat_model():
global chat_model, chat_processor
if chat_model is None:
chat_repo = "microsoft/kosmos-2.5-chat"
# Use Flash Attention 2 if available, otherwise use default attention
model_kwargs = {
"device_map": "cuda",
"dtype": dtype,
}
if is_flash_attention_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
chat_repo,
**model_kwargs
)
chat_processor = AutoProcessor.from_pretrained(chat_repo)
return chat_model, chat_processor
def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
if i < len(bboxs):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
return info.strip()
@spaces.GPU(duration=120)
def generate_markdown(image):
if image is None:
return "Please upload an image."
model, processor = load_base_model()
prompt = "<md>"
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
result = generated_text[0].replace(prompt, "").strip()
return result
@spaces.GPU(duration=120)
def generate_ocr(image):
if image is None:
return "Please upload an image.", None
model, processor = load_base_model()
prompt = "<ocr>"
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Post-process OCR output
output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
# Create visualization
from PIL import ImageDraw
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)
lines = output_text.split("\n")
for line in lines:
if not line.strip():
continue
parts = line.split(",")
if len(parts) >= 8:
try:
coords = list(map(int, parts[:8]))
draw.polygon(coords, outline="red", width=2)
except:
continue
return output_text, vis_image
@spaces.GPU(duration=120)
def generate_chat_response(image, question):
if image is None:
return "Please upload an image."
if not question.strip():
return "Please ask a question."
model, processor = load_chat_model()
template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
prompt = template.format(question)
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Extract only the assistant's response
result = generated_text[0]
if "ASSISTANT:" in result:
result = result.split("ASSISTANT:")[-1].strip()
return result
# Create Gradio interface
with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# KOSMOS-2.5 Document AI Demo
Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
This demo showcases three capabilities:
1. **Markdown Generation**: Convert document images to markdown format
2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
Upload a document image (receipt, form, article, etc.) and try different tasks!
""")
with gr.Tabs():
# Markdown Generation Tab
with gr.TabItem("πŸ“ Markdown Generation"):
with gr.Row():
with gr.Column():
md_image = gr.Image(type="pil", label="Upload Document Image")
gr.Examples(
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"],
inputs=md_image
)
md_button = gr.Button("Generate Markdown", variant="primary")
with gr.Column():
md_output = gr.Textbox(
label="Generated Markdown",
lines=15,
max_lines=20,
show_copy_button=True
)
# OCR Tab
with gr.TabItem("πŸ” OCR with Bounding Boxes"):
with gr.Row():
with gr.Column():
ocr_image = gr.Image(type="pil", label="Upload Document Image")
gr.Examples(
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"],
inputs=ocr_image
)
ocr_button = gr.Button("Extract Text with Coordinates", variant="primary")
with gr.Column():
with gr.Row():
ocr_text = gr.Textbox(
label="Extracted Text with Coordinates",
lines=10,
show_copy_button=True
)
ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
# Chat Tab
with gr.TabItem("πŸ’¬ Document Q&A (Chat)"):
with gr.Row():
with gr.Column():
chat_image = gr.Image(type="pil", label="Upload Document Image")
gr.Examples(
examples=["https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png"],
inputs=chat_image
)
chat_question = gr.Textbox(
label="Ask a question about the document",
placeholder="e.g., What is the total amount on this receipt?",
lines=2
)
gr.Examples(
examples=["What is the total amount on this receipt?", "What items were purchased?", "When was this receipt issued?", "What is the subtotal?"],
inputs=chat_question
)
chat_button = gr.Button("Get Answer", variant="primary")
with gr.Column():
chat_output = gr.Textbox(
label="Answer",
lines=8,
show_copy_button=True
)
# Event handlers
md_button.click(
fn=generate_markdown,
inputs=[md_image],
outputs=[md_output]
)
ocr_button.click(
fn=generate_ocr,
inputs=[ocr_image],
outputs=[ocr_text, ocr_vis]
)
chat_button.click(
fn=generate_chat_response,
inputs=[chat_image, chat_question],
outputs=[chat_output]
)
# Examples section
gr.Markdown("""
## Example Use Cases:
- **Receipts**: Extract itemized information or ask about totals
- **Forms**: Convert to structured format or answer specific questions
- **Articles**: Get markdown format or ask about content
- **Screenshots**: Extract text or get information about specific elements
## Note:
This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
""")
if __name__ == "__main__":
demo.launch()