gauri-sharan commited on
Commit
2f3144c
·
verified ·
1 Parent(s): 5af99a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from byaldi import RAGMultiModalModel
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ from PIL import Image
7
+ import os
8
+ import traceback
9
+ import spaces # Ensure import for GPU management
10
+
11
+ # Load the Byaldi and Qwen2-VL models without using .cuda()
12
+ rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
13
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
+ "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
+ )
16
+
17
+ # Processor for Qwen2-VL
18
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
19
+
20
+ @spaces.GPU # Decorate the function for GPU management
21
+ def ocr_and_extract(image, text_query):
22
+ try:
23
+ # Save the uploaded image temporarily
24
+ temp_image_path = "temp_image.jpg"
25
+ image.save(temp_image_path)
26
+
27
+ # Index the image with Byaldi
28
+ rag_model.index(
29
+ input_path=temp_image_path,
30
+ index_name="image_index",
31
+ store_collection_with_index=False,
32
+ overwrite=True
33
+ )
34
+
35
+ # Perform the search query on the indexed image
36
+ results = rag_model.search(text_query, k=1)
37
+
38
+ # Prepare the input for Qwen2-VL
39
+ image_data = Image.open(temp_image_path)
40
+
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"type": "image", "image": image_data},
46
+ {"type": "text", "text": text_query},
47
+ ],
48
+ }
49
+ ]
50
+
51
+ # Process the message and prepare for Qwen2-VL
52
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
+ image_inputs, _ = process_vision_info(messages)
54
+
55
+ inputs = processor(
56
+ text=[text_input],
57
+ images=image_inputs,
58
+ padding=True,
59
+ return_tensors="pt",
60
+ )
61
+
62
+ # Move the Qwen2-VL model and inputs to GPU
63
+ qwen_model.to("cuda")
64
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
65
+
66
+ # Generate the output with Qwen2-VL
67
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
68
+ output_text = processor.batch_decode(
69
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
70
+ )
71
+
72
+ # Clean up the temporary file
73
+ os.remove(temp_image_path)
74
+
75
+ return output_text[0]
76
+
77
+ except Exception as e:
78
+ error_message = str(e)
79
+ traceback.print_exc()
80
+ return f"Error: {error_message}"
81
+
82
+ # Gradio interface for image input
83
+ iface = gr.Interface(
84
+ fn=ocr_and_extract,
85
+ inputs=[
86
+ gr.Image(type="pil"),
87
+ gr.Textbox(label="Enter your query (optional)"),
88
+ ],
89
+ outputs="text",
90
+ title="Image OCR with Byaldi + Qwen2-VL",
91
+ description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR.",
92
+ )
93
+
94
+ # Launch the Gradio app
95
+ iface.launch()