anvi27 commited on
Commit
dcc824b
·
verified ·
1 Parent(s): 75f3224
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from pdf2image import convert_from_path
4
+ from byaldi import RAGMultiModalModel
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+ import time
9
+ import json
10
+ import re
11
+
12
+ # Check device availability (GPU/CPU)
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # Function to load models only once
16
+ @st.cache_resource
17
+ def initialize_models():
18
+ # Load models for text extraction
19
+ multimodal_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
20
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
21
+ "Qwen/Qwen2-VL-7B-Instruct",
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.bfloat16
24
+ ).to(device).eval()
25
+
26
+ qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
27
+
28
+ return multimodal_model, qwen_model, qwen_processor
29
+
30
+ multimodal_model, qwen_model, qwen_processor = initialize_models()
31
+
32
+ # Upload section
33
+ st.title("Document Text Extraction")
34
+ doc_file = st.file_uploader("Upload Image File", type=[ "png", "jpg", "jpeg"])
35
+
36
+ # Store extracted text across reruns
37
+ if "document_text" not in st.session_state:
38
+ st.session_state.document_text = None
39
+
40
+ if doc_file is not None:
41
+ # Check file extension
42
+ file_ext = doc_file.name.split('.')[-1].lower()
43
+ document_image = Image.open(doc_file) # Handle image files directly
44
+
45
+ # Display uploaded document image
46
+ st.image(document_image, caption="Document Preview", use_column_width=True)
47
+
48
+ # Create a unique index name for the document
49
+ index_id = f"doc_index_{int(time.time())}" # Timestamp-based unique index
50
+
51
+ # Only process if text hasn't been extracted yet
52
+ if st.session_state.document_text is None:
53
+ st.write(f"Indexing document with unique ID: {index_id}...")
54
+ temp_image_path = "temp_image.png"
55
+ document_image.save(temp_image_path)
56
+
57
+ # Index the image using multimodal model
58
+ multimodal_model.index(
59
+ input_path=temp_image_path,
60
+ index_name=index_id,
61
+ store_collection_with_index=False,
62
+ overwrite=False
63
+ )
64
+
65
+ # Define the extraction query
66
+ extraction_query = "Extract all English and Hindi text from this document"
67
+ st.write("Querying the document with the extraction query...")
68
+
69
+ # Search results from RAG
70
+ search_results = multimodal_model.search(extraction_query, k=1)
71
+
72
+ # Prepare input data for Qwen model
73
+ input_message = [
74
+ {
75
+ "role": "user",
76
+ "content": [
77
+ {"type": "image", "image": document_image},
78
+ {"type": "text", "text": extraction_query},
79
+ ],
80
+ }
81
+ ]
82
+
83
+ # Prepare inputs for Qwen2-VL
84
+ input_text = qwen_processor.apply_chat_template(input_message, tokenize=False, add_generation_prompt=True)
85
+ vision_inputs, _ = process_vision_info(input_message)
86
+
87
+ model_inputs = qwen_processor(
88
+ text=[input_text],
89
+ images=vision_inputs,
90
+ padding=True,
91
+ return_tensors="pt",
92
+ )
93
+
94
+ model_inputs = model_inputs.to(device)
95
+
96
+ # Generate text output from the image using Qwen2-VL model
97
+ st.write("Generating extracted text...")
98
+ output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)
99
+
100
+ generated_ids_trimmed = [
101
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
102
+ ]
103
+
104
+ extracted_output = qwen_processor.batch_decode(
105
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
106
+ )
107
+
108
+ # Store the extracted text in session state
109
+ st.session_state.document_text = extracted_output[0]
110
+
111
+ # Display extracted text in JSON format
112
+ extracted_text_content = st.session_state.document_text
113
+ structured_data = {"extracted_text": extracted_text_content}
114
+
115
+ st.subheader("Extracted Text in JSON:")
116
+ st.json(structured_data)
117
+
118
+ # Implement search functionality in extracted text
119
+ if st.session_state.document_text:
120
+ with st.form(key='text_search_form'):
121
+ search_input = st.text_input("Enter a keyword to search within the extracted text:")
122
+ search_action = st.form_submit_button("Search")
123
+
124
+ if search_action and search_input:
125
+ # Split the extracted text into lines for searching
126
+ full_text = st.session_state.document_text
127
+ lines = full_text.split('\n')
128
+
129
+ results = []
130
+ # Search for keyword in each line and collect lines that contain the keyword
131
+ for line in lines:
132
+ if re.search(re.escape(search_input), line, re.IGNORECASE):
133
+ # Highlight keyword in the line
134
+ highlighted_line = re.sub(f"({re.escape(search_input)})", r"**\1**", line, flags=re.IGNORECASE)
135
+ results.append(highlighted_line)
136
+
137
+ # Display search results
138
+ st.subheader("Search Results:")
139
+ if not results:
140
+ st.write("No matches found.")
141
+ else:
142
+ for result in results:
143
+ st.markdown(result)