Spaces:
Running
on
Zero
Running
on
Zero
"""Template Demo for IBM Granite Hugging Face spaces.""" | |
import html | |
import os | |
import random | |
import re | |
import time | |
from pathlib import Path | |
from threading import Thread | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from docling_core.types.doc import DoclingDocument | |
from docling_core.types.doc.document import DocTagsDocument | |
from PIL import Image, ImageDraw, ImageOps | |
from transformers import ( | |
AutoProcessor, | |
Idefics3ForConditionalGeneration, | |
TextIteratorStreamer, | |
) | |
from themes.research_monochrome import theme | |
dir_ = Path(__file__).parent.parent | |
TITLE = "Granite-docling-258m demo" | |
DESCRIPTION = """ | |
<p>This experimental demo highlights the capabilities of granite-docling-258M for document conversion, | |
showcasing Granite Docling's various features. Explore the sample document excerpts and try the sample | |
prompts or enter your own. Keep in mind that AI can occasionally make mistakes.</p> | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
SAMPLES_PATH = dir_ / "data" / "images" | |
sample_data = [ | |
{ | |
"preview_image": str(SAMPLES_PATH / "new_arxiv.png"), | |
"prompts": [ | |
"Convert this page to docling.", | |
"Does the document contain tables?", | |
"Can you extract the 2nd section header?", | |
"What element is located at <loc_84><loc_403><loc_238><loc_419>", | |
"How can effective temperature be computed?", | |
"Extract all picture elements on the page.", | |
], | |
"image": str(SAMPLES_PATH / "new_arxiv.png"), | |
"name": "Doc Conversion", | |
"pad": False, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "image-2.jpg"), | |
"prompts": ["Convert this table to OTSL.", "What is the Net income in 2008?"], | |
"image": str(SAMPLES_PATH / "image-2.jpg"), | |
"name": "Table Recognition", | |
"pad": True, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "code.jpg"), | |
"prompts": ["Convert code to text."], | |
"image": str(SAMPLES_PATH / "code.jpg"), | |
"name": "Code Recognition", | |
"pad": True, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "lake-zurich-switzerland-view-nature-landscapes-7bbda4-1024.jpg"), | |
"prompts": ["Describe this image."], | |
"image": str(SAMPLES_PATH / "lake-zurich-switzerland-view-nature-landscapes-7bbda4-1024.jpg"), | |
"name": "Image Captioning", | |
"pad": False, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "87664.png"), | |
"prompts": ["Convert formula to latex."], | |
"image": str(SAMPLES_PATH / "87664.png"), | |
"name": "Formula Recognition", | |
"pad": True, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "06236926002285.png"), | |
"prompts": ["Convert chart to OTSL."], | |
"image": str(SAMPLES_PATH / "06236926002285.png"), | |
"name": "Chart Extraction", | |
"pad": False, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "ar_page_0.png"), | |
"prompts": ["Convert this page to docling."], | |
"image": str(SAMPLES_PATH / "ar_page_0.png"), | |
"name": "Arabic Conversion", | |
"pad": False, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "japanse_4_ibm.png"), | |
"prompts": ["Convert this page to docling."], | |
"image": str(SAMPLES_PATH / "japanse_4_ibm.png"), | |
"name": "Japanese Conversion", | |
"pad": False, | |
}, | |
{ | |
"preview_image": str(SAMPLES_PATH / "zh_page_0.png"), | |
"prompts": ["Convert this page to docling."], | |
"image": str(SAMPLES_PATH / "zh_page_0.png"), | |
"name": "Chinese Conversion", | |
"pad": False, | |
}, | |
] | |
# Initialize the model | |
model_id = "ibm-granite/granite-docling-258M" | |
if gr.NO_RELOAD: | |
processor = AutoProcessor.from_pretrained(model_id, use_auth_token=True) | |
model = Idefics3ForConditionalGeneration.from_pretrained( | |
model_id, device_map=device, torch_dtype=torch.bfloat16, use_auth_token=True | |
) | |
if not torch.cuda.is_available(): | |
model = model.to(device) | |
def lower_md_headers(md: str) -> str: | |
"""Convert markdown headers to lower level headers.""" | |
return re.sub(r"(?:^|\n)##?\s(.+)", lambda m: "\n### " + m.group(1), md) | |
def add_random_padding(image: Image.Image, min_percent: float = 0.1, max_percent: float = 0.10) -> Image.Image: | |
"""Add random padding to an image.""" | |
image = image.convert("RGB") | |
width, height = image.size | |
pad_w_percent = random.uniform(min_percent, max_percent) | |
pad_h_percent = random.uniform(min_percent, max_percent) | |
pad_w = int(width * pad_w_percent) | |
pad_h = int(height * pad_h_percent) | |
corner_pixel = image.getpixel((0, 0)) # Top-left corner | |
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) | |
return padded_image | |
def draw_bounding_boxes(image_path: str, response_text: str, is_doctag_response: bool = False) -> Image.Image: | |
"""Draw bounding boxes on the image based on loc tags and return the annotated image.""" | |
try: | |
# Load the original image | |
image = Image.open(image_path).convert("RGB") | |
draw = ImageDraw.Draw(image) | |
# Get image dimensions | |
width, height = image.size | |
# Color mapping for different classes (RGB values converted to hex) | |
class_colors = { | |
"caption": "#FFCC99", # (255, 204, 153) | |
"footnote": "#C8C8FF", # (200, 200, 255) | |
"formula": "#C0C0C0", # (192, 192, 192) | |
"list_item": "#9999FF", # (153, 153, 255) | |
"page_footer": "#CCFFCC", # (204, 255, 204) | |
"page_header": "#CCFFCC", # (204, 255, 204) | |
"picture": "#FFCCA4", # (255, 204, 164) | |
"chart": "#FFCCA4", # (255, 204, 164) | |
"section_header": "#FF9999", # (255, 153, 153) | |
"table": "#FFCCCC", # (255, 204, 204) | |
"text": "#FFFF99", # (255, 255, 153) | |
"title": "#FF9999", # (255, 153, 153) | |
"document_index": "#DCDCDC", # (220, 220, 220) | |
"code": "#7D7D7D", # (125, 125, 125) | |
"checkbox_selected": "#FFB6C1", # (255, 182, 193) | |
"checkbox_unselected": "#FFB6C1", # (255, 182, 193) | |
"form": "#C8FFFF", # (200, 255, 255) | |
"key_value_region": "#B7410E", # (183, 65, 14) | |
"paragraph": "#FFFF99", # (255, 255, 153) | |
"reference": "#B0E0E6", # (176, 224, 230) | |
"grading_scale": "#FFCCCC", # (255, 204, 204) | |
"handwritten_text": "#CCFFCC", # (204, 255, 204) | |
"empty_value": "#DCDCDC", # (220, 220, 220) | |
} | |
doctag_class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>[^<]*</[^>]+>" | |
doctag_matches = re.findall(doctag_class_pattern, response_text) | |
class_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
class_matches = re.findall(class_pattern, response_text) | |
seen_coords = set() | |
all_class_matches = [] | |
for match in doctag_matches: | |
coords = (match[1], match[2], match[3], match[4]) | |
if coords not in seen_coords: | |
seen_coords.add(coords) | |
all_class_matches.append(match) | |
for match in class_matches: | |
coords = (match[1], match[2], match[3], match[4]) | |
if coords not in seen_coords: | |
seen_coords.add(coords) | |
all_class_matches.append(match) | |
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
loc_only_matches = re.findall(loc_only_pattern, response_text) | |
for class_name, xmin, ymin, xmax, ymax in all_class_matches: | |
if is_doctag_response: | |
color = class_colors.get(class_name.lower(), None) | |
if color is None: | |
for key in class_colors: | |
if class_name.lower() in key or key in class_name.lower(): | |
color = class_colors[key] | |
break | |
if color is None: | |
color = "#808080" | |
else: | |
color = "#E0115F" | |
x1 = int((int(xmin) / 500) * width) | |
y1 = int((int(ymin) / 500) * height) | |
x2 = int((int(xmax) / 500) * width) | |
y2 = int((int(ymax) / 500) * height) | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
for xmin, ymin, xmax, ymax in loc_only_matches: | |
if is_doctag_response: | |
continue | |
else: | |
color = "#808080" | |
x1 = int((int(xmin) / 500) * width) | |
y1 = int((int(ymin) / 500) * height) | |
x2 = int((int(xmax) / 500) * width) | |
y2 = int((int(ymax) / 500) * height) | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
return image | |
except Exception: | |
return Image.open(image_path) | |
def clean_model_response(text: str) -> str: | |
"""Clean up model response by removing special tokens and formatting properly.""" | |
if not text: | |
return "No response generated." | |
special_tokens = [ | |
"<|end_of_text|>", | |
"<|end|>", | |
"<|assistant|>", | |
"<|user|>", | |
"<|system|>", | |
"<pad>", | |
"</s>", | |
"<s>", | |
] | |
cleaned = text | |
for token in special_tokens: | |
cleaned = cleaned.replace(token, "") | |
cleaned = cleaned.strip() | |
if not cleaned or len(cleaned) == 0: | |
return "The model generated a response, but it appears to be empty or contain only special tokens." | |
return cleaned | |
def generate_with_model(question: str, image_path: str, apply_padding: bool = False) -> str: | |
"""Generate answer using the Granite Docling model directly on the image.""" | |
if os.environ.get("NO_LLM"): | |
time.sleep(2) | |
return "This is a simulated response from the Granite Docling model." | |
try: | |
image = Image.open(image_path).convert("RGB") | |
if apply_padding: | |
image = add_random_padding(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": question}, | |
], | |
} | |
] | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
temperature = 0.0 | |
inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=4096, | |
temperature=temperature, | |
do_sample=temperature > 0, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
) | |
generated_texts = processor.batch_decode( | |
generated_ids[:, inputs["input_ids"].shape[1] :], | |
skip_special_tokens=False, | |
)[0] | |
cleaned_response = clean_model_response(generated_texts) | |
return cleaned_response | |
except Exception as e: | |
return f"Error processing image: {e!s}" | |
_streaming_raw_output = "" | |
def generate_with_model_streaming(question: str, image_path: str, apply_padding: bool = False) -> None: | |
"""Generate answer using the Granite Docling model with streaming.""" | |
global _streaming_raw_output | |
_streaming_raw_output = "" | |
try: | |
image = Image.open(image_path).convert("RGB") | |
if apply_padding: | |
image = add_random_padding(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": question}, | |
], | |
} | |
] | |
prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
temperature = 0.0 | |
inputs = processor(text=prompt, images=[image], return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False) | |
generation_args = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=4096, | |
temperature=temperature, | |
do_sample=temperature > 0, | |
pad_token_id=processor.tokenizer.eos_token_id, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_args) | |
thread.start() | |
yield "..." | |
full_output = "" | |
escaped_output = "" | |
for new_text in streamer: | |
full_output += new_text | |
escaped_output += html.escape(new_text) | |
yield escaped_output | |
_streaming_raw_output = full_output | |
except Exception as e: | |
yield f"Error generating response: {e!s}" | |
chatbot = gr.Chatbot( | |
examples=[{"text": x} for x in sample_data[0]["prompts"]], | |
type="messages", | |
label=f"Q&A about {sample_data[0]['name']}", | |
height=685, | |
group_consecutive_messages=True, | |
autoscroll=False, | |
elem_classes=["chatbot_view"], | |
) | |
css_file_path = Path(Path(__file__).parent / "app.css") | |
head_file_path = Path(Path(__file__).parent / "app_head.html") | |
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
is_in_edit_mode = gr.State(True) # in block to be reactive | |
selected_doc = gr.State(0) | |
current_question = gr.State("") | |
uploaded_image_path = gr.State(None) # Store path to uploaded image | |
gr.Markdown(f"# {TITLE}") | |
gr.Markdown(DESCRIPTION) | |
# Create gallery with captions for hover effect | |
gallery_with_captions = [] | |
for sd in sample_data: | |
gallery_with_captions.append((sd["preview_image"], sd["name"])) | |
document_gallery = gr.Gallery( | |
gallery_with_captions, | |
label="Select a document", | |
rows=1, | |
columns=9, | |
height="125px", | |
allow_preview=False, | |
selected_index=0, | |
elem_classes=["preview_im_element"], | |
show_label=True, | |
) | |
with gr.Row(): | |
with gr.Column(), gr.Group(): | |
image_display = gr.Image( | |
sample_data[0]["image"], | |
label=f"Preview for {sample_data[0]['name']}", | |
height=700, | |
interactive=False, | |
elem_classes=["image_viewer"], | |
) | |
# Upload button for custom images | |
upload_button = gr.UploadButton( | |
"📁 Upload Image", file_types=["image"], elem_classes=["upload_button"], scale=1 | |
) | |
with gr.Column(): | |
chatbot.render() | |
with gr.Row(): | |
tbb = gr.Textbox(submit_btn=True, show_label=False, placeholder="Type a message...", scale=4) | |
fb = gr.Button("Ask new question", visible=False, scale=1) | |
fb.click(lambda: [], outputs=[chatbot]) | |
def sample_image_selected(d: gr.SelectData) -> tuple: | |
"""Handle sample image selection.""" | |
dx = sample_data[d.index] | |
return ( | |
gr.update(examples=[{"text": x} for x in dx["prompts"]], label=f"Q&A about {dx['name']}"), | |
gr.update(value=dx["image"], label=f"Preview for {dx['name']}"), | |
d.index, | |
) | |
document_gallery.select(lambda: [], outputs=[chatbot]) | |
document_gallery.select(sample_image_selected, inputs=[], outputs=[chatbot, image_display, selected_doc]) | |
def update_user_chat_x(x: gr.SelectData) -> list: | |
"""Update chat with user selection.""" | |
return [gr.ChatMessage(role="user", content=x.value["text"])] | |
def question_from_selection(x: gr.SelectData) -> str: | |
"""Extract question text from selection.""" | |
return x.value["text"] | |
def handle_image_upload(uploaded_file: str | None) -> tuple: | |
"""Handle uploaded image and update the display.""" | |
if uploaded_file is None: | |
return None, None, None | |
# Update the image display with the uploaded image | |
image_update = gr.update(value=uploaded_file, label="Uploaded Image") | |
# Update chatbot to show it's ready for questions about the uploaded image | |
chatbot_update = gr.update( | |
examples=[{"text": "Convert this page to docling."}], label="Q&A about uploaded image" | |
) | |
# Clear the chat history | |
chat_update = [] | |
return image_update, chatbot_update, chat_update, uploaded_file | |
# Connect upload button to handler | |
upload_button.upload( | |
handle_image_upload, inputs=[upload_button], outputs=[image_display, chatbot, chatbot, uploaded_image_path] | |
) | |
def send_generate(msg: str, cb: list, selected_sample: int, uploaded_img_path: str | None = None) -> None: | |
"""Generate response using the model.""" | |
# Use uploaded image if available, otherwise use selected sample | |
image_path = uploaded_img_path if uploaded_img_path is not None else sample_data[selected_sample]["image"] | |
original_msg = gr.ChatMessage(role="user", content=msg) | |
cb.append(original_msg) | |
processing_msg = gr.ChatMessage( | |
role="assistant", | |
content='<span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> ' | |
'<span class="dot-3">.</span></span>', | |
) | |
cb.append(processing_msg) | |
yield cb, gr.update() | |
# Apply padding only for sample images, not uploaded images | |
apply_padding = False if uploaded_img_path is not None else sample_data[selected_sample].get("pad", False) | |
first_token = True | |
try: | |
stream_gen = generate_with_model_streaming(msg.strip(), image_path, apply_padding) | |
for partial_answer in stream_gen: | |
if first_token: | |
cb[-1] = gr.ChatMessage(role="assistant", content=partial_answer) | |
first_token = False | |
else: | |
cb[-1] = gr.ChatMessage(role="assistant", content=partial_answer) | |
yield cb, gr.update() | |
except Exception: | |
answer = generate_with_model(msg.strip(), image_path, apply_padding) | |
cb[-1] = gr.ChatMessage(role="assistant", content=answer) | |
yield cb, gr.update() | |
global _streaming_raw_output | |
answer = _streaming_raw_output if _streaming_raw_output else partial_answer | |
answer = html.unescape(answer) | |
answer = clean_model_response(answer) | |
class_loc_pattern = r"<([^>]+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
class_loc_matches = re.findall(class_loc_pattern, answer) | |
loc_only_pattern = r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>" | |
loc_only_matches = re.findall(loc_only_pattern, answer) | |
has_doctag = "<doctag>" in answer | |
has_loc_tags = class_loc_matches or loc_only_matches | |
xml_tags = ["<doctag>", "<otsl>", "<chart>", "<code>", "<loc_"] | |
if any(tag in answer for tag in xml_tags): | |
cb[-1] = gr.ChatMessage(role="assistant", content=f"```xml\n{answer}\n```") | |
else: | |
cb[-1] = gr.ChatMessage(role="assistant", content=answer) | |
if "convert this page to docling" in msg.lower() or ("convert" in msg.lower() and "otsl" in msg.lower()): | |
try: | |
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [Image.open(image_path)]) | |
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
markdown_output = doc.export_to_markdown() | |
response = gr.ChatMessage( | |
role="assistant", | |
content=f"\nConverted to Markdown using docling.\n\n**MD Output:**\n\n{markdown_output}", | |
) | |
cb.append(response) | |
except Exception as e: | |
error_response = gr.ChatMessage(role="assistant", content=f"Error creating markdown output: {e!s}") | |
cb.append(error_response) | |
elif "convert formula to latex" in msg.lower(): | |
try: | |
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([answer], [Image.open(image_path)]) | |
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document") | |
markdown_output = doc.export_to_markdown() | |
if markdown_output.count("$$") >= 2: | |
parts = markdown_output.split("$$", 2) | |
formula = parts[1].strip() | |
wrapped = f"$$\n\\begin{{aligned}}\n{formula}\n\\end{{aligned}}\n$$" | |
markdown_output = parts[0] + wrapped + parts[2] | |
md_response = gr.ChatMessage( | |
role="assistant", | |
content=f"\nConverted to Markdown using docling.\n\n**LaTeX Output:**\n\n{markdown_output}", | |
) | |
cb.append(md_response) | |
except Exception as e: | |
error_response = gr.ChatMessage(role="assistant", content=f"Error creating LaTeX output: {e!s}") | |
cb.append(error_response) | |
if has_loc_tags: | |
try: | |
annotated_image = draw_bounding_boxes(image_path, answer, is_doctag_response=has_doctag) | |
annotated_array = np.array(annotated_image) | |
yield cb, gr.update(value=annotated_array, visible=True) | |
except Exception: | |
yield cb, gr.update(value=image_path) | |
else: | |
yield cb, gr.update(value=image_path) | |
chatbot.example_select(lambda: False, outputs=is_in_edit_mode) | |
chatbot.example_select(question_from_selection, inputs=[], outputs=[current_question]).then( | |
send_generate, | |
inputs=[current_question, chatbot, selected_doc, uploaded_image_path], | |
outputs=[chatbot, image_display], | |
) | |
def textbox_switch(e_mode: bool) -> list: | |
"""Switch textbox visibility based on edit mode.""" | |
if not e_mode: | |
return [gr.update(visible=False), gr.update(visible=True)] | |
else: | |
return [gr.update(visible=True), gr.update(visible=False)] | |
tbb.submit(lambda: False, outputs=[is_in_edit_mode]) | |
fb.click(lambda: True, outputs=[is_in_edit_mode]) | |
is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) | |
tbb.submit(lambda x: x, inputs=[tbb], outputs=[current_question]).then( | |
send_generate, | |
inputs=[current_question, chatbot, selected_doc, uploaded_image_path], | |
outputs=[chatbot, image_display], | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20) | |
demo.launch() | |