File size: 5,212 Bytes
dcc824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
from PIL import Image
from pdf2image import convert_from_path
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import time
import json
import re

# Check device availability (GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Function to load models only once
@st.cache_resource
def initialize_models():
    # Load models for text extraction
    multimodal_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
    qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    ).to(device).eval()

    qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

    return multimodal_model, qwen_model, qwen_processor

multimodal_model, qwen_model, qwen_processor = initialize_models()

# Upload section
st.title("Document Text Extraction")
doc_file = st.file_uploader("Upload Image File", type=[ "png", "jpg", "jpeg"])

# Store extracted text across reruns
if "document_text" not in st.session_state:
    st.session_state.document_text = None

if doc_file is not None:
    # Check file extension
    file_ext = doc_file.name.split('.')[-1].lower()
    document_image = Image.open(doc_file)  # Handle image files directly

    # Display uploaded document image
    st.image(document_image, caption="Document Preview", use_column_width=True)

    # Create a unique index name for the document
    index_id = f"doc_index_{int(time.time())}"  # Timestamp-based unique index

    # Only process if text hasn't been extracted yet
    if st.session_state.document_text is None:
        st.write(f"Indexing document with unique ID: {index_id}...")
        temp_image_path = "temp_image.png"
        document_image.save(temp_image_path)
        
        # Index the image using multimodal model
        multimodal_model.index(
            input_path=temp_image_path,
            index_name=index_id,
            store_collection_with_index=False,
            overwrite=False
        )

        # Define the extraction query
        extraction_query = "Extract all English and Hindi text from this document"
        st.write("Querying the document with the extraction query...")

        # Search results from RAG
        search_results = multimodal_model.search(extraction_query, k=1)

        # Prepare input data for Qwen model
        input_message = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": document_image},
                    {"type": "text", "text": extraction_query},
                ],
            }
        ]

        # Prepare inputs for Qwen2-VL
        input_text = qwen_processor.apply_chat_template(input_message, tokenize=False, add_generation_prompt=True)
        vision_inputs, _ = process_vision_info(input_message)

        model_inputs = qwen_processor(
            text=[input_text],
            images=vision_inputs,
            padding=True,
            return_tensors="pt",
        )

        model_inputs = model_inputs.to(device)

        # Generate text output from the image using Qwen2-VL model
        st.write("Generating extracted text...")
        output_ids = qwen_model.generate(**model_inputs, max_new_tokens=100)

        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
        ]

        extracted_output = qwen_processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Store the extracted text in session state
        st.session_state.document_text = extracted_output[0]

    # Display extracted text in JSON format
    extracted_text_content = st.session_state.document_text
    structured_data = {"extracted_text": extracted_text_content}

    st.subheader("Extracted Text in JSON:")
    st.json(structured_data)

# Implement search functionality in extracted text
if st.session_state.document_text:
    with st.form(key='text_search_form'):
        search_input = st.text_input("Enter a keyword to search within the extracted text:")
        search_action = st.form_submit_button("Search")

    if search_action and search_input:
        # Split the extracted text into lines for searching
        full_text = st.session_state.document_text
        lines = full_text.split('\n')

        results = []
        # Search for keyword in each line and collect lines that contain the keyword
        for line in lines:
            if re.search(re.escape(search_input), line, re.IGNORECASE):
                # Highlight keyword in the line
                highlighted_line = re.sub(f"({re.escape(search_input)})", r"**\1**", line, flags=re.IGNORECASE)
                results.append(highlighted_line)

        # Display search results
        st.subheader("Search Results:")
        if not results:
            st.write("No matches found.")
        else:
            for result in results:
                st.markdown(result)