jawakja commited on
Commit
48fbb50
·
verified ·
1 Parent(s): 4ec7065

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import fitz # PyMuPDF
3
+ import torch
4
+ import cv2
5
+ import os
6
+ import tempfile
7
+ import shutil
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
+
12
+ # Load Qwen-VL-Chat
13
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "Qwen/Qwen-VL-Chat",
16
+ device_map="auto",
17
+ torch_dtype=torch.bfloat16,
18
+ trust_remote_code=True
19
+ ).eval()
20
+
21
+ # Embedding model
22
+ embed_model = SentenceTransformer('all-MiniLM-L6-v2')
23
+
24
+ # Global state for FAISS
25
+ chunks = []
26
+ index = None
27
+
28
+ # PDF processing
29
+ def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
30
+ doc = fitz.open(pdf_path)
31
+ text = ""
32
+ for page in doc:
33
+ text += page.get_text()
34
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
35
+
36
+ def build_faiss_index(chunks):
37
+ embeddings = embed_model.encode(chunks, convert_to_numpy=True)
38
+ dim = embeddings.shape[1]
39
+ idx = faiss.IndexFlatL2(dim)
40
+ idx.add(embeddings)
41
+ return idx
42
+
43
+ def rag_query(query, chunks, index, top_k=3):
44
+ q_emb = embed_model.encode([query], convert_to_numpy=True)
45
+ D, I = index.search(q_emb, top_k)
46
+ return "\n\n".join([chunks[i] for i in I[0]])
47
+
48
+ # Vision/Text chat
49
+ def chat_with_qwen(text=None, image=None):
50
+ elements = []
51
+ if image:
52
+ elements.append({"image": image})
53
+ if text:
54
+ elements.append({"text": text})
55
+ if not elements:
56
+ return "Please upload or type something."
57
+ query = tokenizer.from_list_format(elements)
58
+ response, _ = model.chat(tokenizer, query, history=None)
59
+ return response
60
+
61
+ # Video frame extraction
62
+ def extract_video_frames(video_path, max_frames=3):
63
+ cap = cv2.VideoCapture(video_path)
64
+ frames, count = [], 0
65
+ while len(frames) < max_frames:
66
+ success, frame = cap.read()
67
+ if not success:
68
+ break
69
+ frames.append(frame)
70
+ count += 1
71
+ cap.set(cv2.CAP_PROP_POS_FRAMES, count * 30)
72
+ cap.release()
73
+ return frames
74
+
75
+ # Main chatbot logic
76
+ def multimodal_chat(message, history, image=None, video=None, pdf=None):
77
+ global chunks, index
78
+
79
+ # PDF-based RAG
80
+ if pdf:
81
+ chunks = extract_chunks_from_pdf(pdf.name)
82
+ index = build_faiss_index(chunks)
83
+ context = rag_query(message, chunks, index)
84
+ final_prompt = f"Context:\n{context}\n\nQuestion: {message}"
85
+ response = chat_with_qwen(final_prompt)
86
+ return response
87
+
88
+ # Image
89
+ if image:
90
+ response = chat_with_qwen(message, image)
91
+ return response
92
+
93
+ # Video (extract frames and send all in one call)
94
+ if video:
95
+ temp_dir = tempfile.mkdtemp()
96
+ video_path = os.path.join(temp_dir, "vid.mp4")
97
+ shutil.copy(video, video_path)
98
+ frames = extract_video_frames(video_path)
99
+
100
+ # Save and collect image paths
101
+ images = []
102
+ for i, frame in enumerate(frames):
103
+ temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg")
104
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
+ cv2.imwrite(temp_img_path, frame_rgb)
106
+ images.append(temp_img_path)
107
+
108
+ # Combine all frames and text into one query
109
+ elements = [{"image": img} for img in images]
110
+ if message:
111
+ elements.append({"text": message})
112
+
113
+ query = tokenizer.from_list_format(elements)
114
+ response, _ = model.chat(tokenizer, query, history=None)
115
+ return response
116
+
117
+ # Text only
118
+ if message:
119
+ return chat_with_qwen(message)
120
+
121
+ return "Please input a message, image, video, or PDF."
122
+
123
+ # ---- Gradio UI ---- #
124
+ with gr.Blocks(css="""
125
+ body {
126
+ background-color: #f3f6fc;
127
+ }
128
+ .gradio-container {
129
+ font-family: 'Segoe UI', sans-serif;
130
+ }
131
+ h1 {
132
+ background: linear-gradient(to right, #667eea, #764ba2);
133
+ color: white !important;
134
+ padding: 1rem;
135
+ border-radius: 12px;
136
+ margin-bottom: 0.5rem;
137
+ }
138
+ p {
139
+ font-size: 1rem;
140
+ color: white;
141
+ }
142
+ .gr-box {
143
+ background-color: white;
144
+ border-radius: 12px;
145
+ box-shadow: 0 0 10px rgba(0,0,0,0.05);
146
+ padding: 16px;
147
+ }
148
+ footer {display: none !important;}
149
+ """) as demo:
150
+ gr.Markdown(
151
+ "<h1 style='text-align: center;'>Multimodal Chatbot powered by LLAVACMVRL and QWEN-VL</h1>"
152
+ "<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>"
153
+ )
154
+
155
+ chatbot = gr.Chatbot(show_label=False, height=450)
156
+ state = gr.State([])
157
+
158
+ with gr.Row():
159
+ txt = gr.Textbox(show_label=False, placeholder="Type a message...", scale=5)
160
+ send_btn = gr.Button("🚀 Send", scale=1)
161
+
162
+ with gr.Row():
163
+ image_input = gr.Image(type="filepath", label="Upload Image")
164
+ video_input = gr.Video(label="Upload Video")
165
+ pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")
166
+
167
+ def user_send(message, history, image, video, pdf):
168
+ response = multimodal_chat(message, history, image, video, pdf)
169
+ history.append((message, response))
170
+ return "", history
171
+
172
+ send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
173
+ txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
174
+
175
+ # Launch the app
176
+ demo.launch()