File size: 5,212 Bytes
dcc824b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
from PIL import Image
from pdf2image import convert_from_path
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import time
import json
import re
# Check device availability (GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Function to load models only once
@st.cache_resource
def initialize_models():
# Load models for text extraction
multimodal_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
return multimodal_model, qwen_model, qwen_processor
multimodal_model, qwen_model, qwen_processor = initialize_models()
# Upload section
st.title("Document Text Extraction")
doc_file = st.file_uploader("Upload Image File", type=[ "png", "jpg", "jpeg"])
# Store extracted text across reruns
if "document_text" not in st.session_state:
st.session_state.document_text = None
if doc_file is not None:
# Check file extension
file_ext = doc_file.name.split('.')[-1].lower()
document_image = Image.open(doc_file) # Handle image files directly
# Display uploaded document image
st.image(document_image, caption="Document Preview", use_column_width=True)
# Create a unique index name for the document
index_id = f"doc_index_{int(time.time())}" # Timestamp-based unique index
# Only process if text hasn't been extracted yet
if st.session_state.document_text is None:
st.write(f"Indexing document with unique ID: {index_id}...")
temp_image_path = "temp_image.png"
document_image.save(temp_image_path)
# Index the image using multimodal model
multimodal_model.index(
input_path=temp_image_path,
index_name=index_id,
store_collection_with_index=False,
overwrite=False
)
# Define the extraction query
extraction_query = "Extract all English and Hindi text from this document"
st.write("Querying the document with the extraction query...")
# Search results from RAG
search_results = multimodal_model.search(extraction_query, k=1)
# Prepare input data for Qwen model
input_message = [
{
"role": "user",
"content": [
{"type": "image", "image": document_image},
{"type": "text", "text": extraction_query},
],
}
]
# Prepare inputs for Qwen2-VL
input_text = qwen_processor.apply_chat_template(input_message, tokenize=False, add_generation_prompt=True)
vision_inputs, _ = process_vision_info(input_message)
model_inputs = qwen_processor(
text=[input_text],
images=vision_inputs,
padding=True,
return_tensors="pt",
)
model_inputs = model_inputs.to(device)
# Generate text output from the image using Qwen2-VL model
st.write("Generating extracted text...")
output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
]
extracted_output = qwen_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Store the extracted text in session state
st.session_state.document_text = extracted_output[0]
# Display extracted text in JSON format
extracted_text_content = st.session_state.document_text
structured_data = {"extracted_text": extracted_text_content}
st.subheader("Extracted Text in JSON:")
st.json(structured_data)
# Implement search functionality in extracted text
if st.session_state.document_text:
with st.form(key='text_search_form'):
search_input = st.text_input("Enter a keyword to search within the extracted text:")
search_action = st.form_submit_button("Search")
if search_action and search_input:
# Split the extracted text into lines for searching
full_text = st.session_state.document_text
lines = full_text.split('\n')
results = []
# Search for keyword in each line and collect lines that contain the keyword
for line in lines:
if re.search(re.escape(search_input), line, re.IGNORECASE):
# Highlight keyword in the line
highlighted_line = re.sub(f"({re.escape(search_input)})", r"**\1**", line, flags=re.IGNORECASE)
results.append(highlighted_line)
# Display search results
st.subheader("Search Results:")
if not results:
st.write("No matches found.")
else:
for result in results:
st.markdown(result)
|