|
import gradio as gr |
|
from PIL import Image |
|
from byaldi import RAGMultiModalModel |
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
import torch |
|
|
|
|
|
def load_models(): |
|
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali") |
|
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", |
|
trust_remote_code=True, torch_dtype=torch.float32) |
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True) |
|
return RAG, model, processor |
|
|
|
RAG, model, processor = load_models() |
|
|
|
def ocr_and_search(image, keyword): |
|
text_query = "Extract all the text in Hindi and English from the image." |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": image}, |
|
{"type": "text", "text": text_query}, |
|
], |
|
} |
|
] |
|
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
image_inputs, video_inputs = process_vision_info(messages) |
|
inputs = processor( |
|
text=[text], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate(**inputs, max_new_tokens=2000) |
|
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] |
|
|
|
|
|
extracted_text = processor.batch_decode( |
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
)[0] |
|
extracted_text = extracted_text.replace("The text in the image is:", "").strip() |
|
|
|
extracted_text = ' '.join(filter(lambda x: not any(char.isdigit() for char in x), extracted_text.split())) |
|
|
|
|
|
english_text = ' '.join(filter(lambda x: all((char.islower() or char.isupper()) or char == "'" for char in x), extracted_text.split())) |
|
hindi_text = ' '.join(filter(lambda x: any(ord(char) >= 128 for char in x), extracted_text.split())) |
|
|
|
|
|
keyword_lower = keyword.lower().strip() |
|
matched_keywords = [] |
|
if keyword_lower: |
|
if keyword_lower in extracted_text.lower(): |
|
matched_keywords = [keyword] |
|
|
|
|
|
plain_text_output = ( |
|
f"- English: {' '.join(english_text.split()) if english_text else 'No English text found.'}\n\n" |
|
f"- Hindi: {' '.join(hindi_text.split()) if hindi_text else 'No Hindi text found.'}" |
|
) |
|
|
|
return extracted_text, matched_keywords, plain_text_output |
|
|
|
|
|
def app(image, keyword): |
|
|
|
extracted_text, matched_keywords, plain_text_output = ocr_and_search(image, keyword) |
|
|
|
|
|
search_results_str = "\n".join(matched_keywords) if matched_keywords else "No matches found for the keyword." |
|
|
|
return extracted_text, search_results_str, plain_text_output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=app, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload an Image"), |
|
gr.Textbox(label="Enter keyword to search in extracted text", placeholder="Keyword") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Extracted Text"), |
|
gr.Textbox(label="Search Results"), |
|
gr.Textbox(label="Plain Text Output", lines=10) |
|
], |
|
title="Optical Character Recognition with Keyword Search from Images", |
|
) |
|
|
|
|
|
iface.launch() |