jawakja commited on
Commit
03f20a2
·
verified ·
1 Parent(s): fbf5d04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -137
app.py CHANGED
@@ -14,21 +14,20 @@ import faiss
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # Check available resources
18
  logger.info(f"CUDA available: {torch.cuda.is_available()}")
19
  if torch.cuda.is_available():
20
  logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
21
- logger.info(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")
22
 
23
- # Configure quantization for lower memory usage
24
  bnb_config = BitsAndBytesConfig(
25
  load_in_4bit=True,
26
  bnb_4bit_quant_type="nf4",
27
  bnb_4bit_compute_dtype=torch.float16,
28
  )
29
 
 
30
  try:
31
- # Load Qwen-2.5-Omni-3B with memory optimizations
32
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True)
33
  model = AutoModelForCausalLM.from_pretrained(
34
  "Qwen/Qwen2.5-Omni-3B",
@@ -36,211 +35,160 @@ try:
36
  quantization_config=bnb_config,
37
  trust_remote_code=True
38
  ).eval()
39
- logger.info("Model loaded successfully")
40
  except Exception as e:
41
- logger.error(f"Error loading model: {e}")
42
- model = None
43
- tokenizer = None
44
 
45
- # Use a smaller embedding model
46
  try:
47
  embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2')
48
- logger.info("Embedding model loaded successfully")
49
  except Exception as e:
50
- logger.error(f"Error loading embedding model: {e}")
51
  embed_model = None
52
 
53
- # Global state for FAISS
54
  chunks = []
55
  index = None
56
 
57
- # PDF processing
58
  def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
59
  try:
60
  doc = fitz.open(pdf_path)
61
- text = ""
62
- for page in doc:
63
- text += page.get_text()
64
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
65
  except Exception as e:
66
- logger.error(f"PDF extraction error: {e}")
67
- return ["Error extracting PDF content"]
68
 
 
69
  def build_faiss_index(chunks):
70
  try:
71
- if not embed_model:
72
- return None
73
  embeddings = embed_model.encode(chunks, convert_to_numpy=True)
74
- dim = embeddings.shape[1]
75
- idx = faiss.IndexFlatL2(dim)
76
- idx.add(embeddings)
77
- return idx
78
  except Exception as e:
79
  logger.error(f"FAISS index error: {e}")
80
  return None
81
 
 
82
  def rag_query(query, chunks, index, top_k=3):
83
- if not index or not embed_model:
84
- return "Embedding model not available"
85
  try:
86
  q_emb = embed_model.encode([query], convert_to_numpy=True)
87
  D, I = index.search(q_emb, top_k)
88
  return "\n\n".join([chunks[i] for i in I[0]])
89
  except Exception as e:
90
  logger.error(f"RAG query error: {e}")
91
- return "Error retrieving context"
92
 
93
- # Vision/Text chat with Qwen-2.5-Omni
94
- def chat_with_qwen(text=None, image=None):
95
  if not model or not tokenizer:
96
- return "Model failed to load due to resource constraints. Try a smaller model or upgrade your space."
97
-
98
  try:
99
- # For Qwen-2.5-Omni-3B
100
- messages = []
101
-
102
  if image:
103
- # Add the image as a message
104
- messages.append({"role": "user", "content": [
105
- {"image": image},
106
- {"text": text if text else "Please describe this image."}
107
- ]})
108
- else:
109
- # Text-only query
110
- messages.append({"role": "user", "content": text})
111
-
112
- # Generate response
113
- response = model.chat(tokenizer, messages)
114
  return response
115
  except Exception as e:
116
  logger.error(f"Chat error: {e}")
117
- return f"Error generating response: {str(e)}"
118
 
119
- # Video frame extraction - more memory efficient
120
  def extract_video_frames(video_path, max_frames=2):
121
  try:
122
  cap = cv2.VideoCapture(video_path)
123
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
124
  frames = []
125
-
126
- # Take fewer, evenly distributed frames
127
- if total_frames > 0:
128
- frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
129
- for idx in frame_indices:
130
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
131
- success, frame = cap.read()
132
- if success:
133
- frames.append(frame)
134
  cap.release()
135
  return frames
136
  except Exception as e:
137
- logger.error(f"Video frame extraction error: {e}")
138
  return []
139
 
140
- # Main chatbot logic with error handling
141
  def multimodal_chat(message, history, image=None, video=None, pdf=None):
142
  global chunks, index
143
-
144
  if not model:
145
- return "Model not loaded due to memory constraints. Try upgrading your Hugging Face space."
146
 
147
  try:
148
- # PDF-based RAG
149
- if pdf:
150
- chunks = extract_chunks_from_pdf(pdf.name)
 
 
 
151
  index = build_faiss_index(chunks)
152
  if index:
153
  context = rag_query(message, chunks, index)
154
- final_prompt = f"I'll provide some context, then ask a question. Context:\n{context}\n\nQuestion: {message}"
155
- response = chat_with_qwen(final_prompt)
156
  else:
157
- response = "Could not process PDF due to resource constraints"
158
- return response
159
 
160
- # Image
161
- if image:
162
- response = chat_with_qwen(message, image)
163
- return response
164
 
165
- # Video (extract frames and process one by one)
166
- if video:
167
- temp_dir = tempfile.mkdtemp()
168
- try:
169
- video_path = os.path.join(temp_dir, "vid.mp4")
170
- shutil.copy(video, video_path)
171
  frames = extract_video_frames(video_path)
 
 
172
 
173
- # Only process if we got frames
174
- if frames:
175
- # Save frames and process them
176
- frame_descriptions = []
177
- for i, frame in enumerate(frames):
178
- temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg")
179
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
180
- cv2.imwrite(temp_img_path, frame_rgb)
181
-
182
- # Get description for this frame
183
- frame_query = "Describe this video frame in detail."
184
- frame_description = chat_with_qwen(frame_query, temp_img_path)
185
- frame_descriptions.append(f"Frame {i+1}: {frame_description}")
186
-
187
- # Combine frame descriptions and answer the user's question
188
- combined_context = "\n\n".join(frame_descriptions)
189
- final_prompt = f"I analyzed some video frames and here's what I found:\n\n{combined_context}\n\nBased on these video frames, {message if message else 'please describe what\'s happening in this video.'}"
190
- response = chat_with_qwen(final_prompt)
191
- return response
192
- else:
193
- return "Could not extract video frames"
194
- finally:
195
- # Cleanup temp files
196
- shutil.rmtree(temp_dir, ignore_errors=True)
197
 
198
  # Text only
199
  if message:
200
  return chat_with_qwen(message)
201
 
202
- return "Please input a message, image, video, or PDF."
203
  except Exception as e:
204
- logger.error(f"General error in multimodal_chat: {e}")
205
- return f"Error processing your request: {str(e)}. This may be due to memory constraints."
206
 
207
- # ---- Gradio UI ---- #
208
  with gr.Blocks(css="""
209
- body {
210
- background-color: #f3f6fc;
211
- }
212
- .gradio-container {
213
- font-family: 'Segoe UI', sans-serif;
214
- }
215
  h1 {
216
- background: linear-gradient(to right, #667eea, #764ba2);
217
- color: white !important;
218
- padding: 1rem;
219
- border-radius: 12px;
220
- margin-bottom: 0.5rem;
221
- }
222
- p {
223
- font-size: 1rem;
224
- color: white;
225
  }
226
  .gr-box {
227
- background-color: white;
228
- border-radius: 12px;
229
- box-shadow: 0 0 10px rgba(0,0,0,0.05);
230
- padding: 16px;
231
  }
232
- footer {display: none !important;}
233
  """) as demo:
234
- gr.Markdown(
235
- "<h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1>"
236
- "<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>"
237
- )
 
238
 
239
  chatbot = gr.Chatbot(show_label=False, height=450)
240
  state = gr.State([])
241
 
242
  with gr.Row():
243
- txt = gr.Textbox(show_label=False, placeholder="Type a message...", scale=5)
244
  send_btn = gr.Button("🚀 Send", scale=1)
245
 
246
  with gr.Row():
@@ -250,14 +198,13 @@ footer {display: none !important;}
250
 
251
  def user_send(message, history, image, video, pdf):
252
  if not message and not image and not video and not pdf:
253
- return "", history
254
  response = multimodal_chat(message, history, image, video, pdf)
255
  history.append((message, response))
256
- return "", history
257
 
258
- send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
259
- txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
260
 
261
- # Launch the app with memory logging
262
- logger.info("Starting Gradio app")
263
- demo.launch()
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Check CUDA
18
  logger.info(f"CUDA available: {torch.cuda.is_available()}")
19
  if torch.cuda.is_available():
20
  logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
 
21
 
22
+ # BitsAndBytes config for quantized model loading
23
  bnb_config = BitsAndBytesConfig(
24
  load_in_4bit=True,
25
  bnb_4bit_quant_type="nf4",
26
  bnb_4bit_compute_dtype=torch.float16,
27
  )
28
 
29
+ # Load Qwen model
30
  try:
 
31
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True)
32
  model = AutoModelForCausalLM.from_pretrained(
33
  "Qwen/Qwen2.5-Omni-3B",
 
35
  quantization_config=bnb_config,
36
  trust_remote_code=True
37
  ).eval()
38
+ logger.info("Qwen model loaded.")
39
  except Exception as e:
40
+ logger.error(f"Failed to load Qwen: {e}")
41
+ model, tokenizer = None, None
 
42
 
43
+ # Load SentenceTransformer for RAG
44
  try:
45
  embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2')
46
+ logger.info("Embedding model loaded.")
47
  except Exception as e:
48
+ logger.error(f"Failed to load embedding model: {e}")
49
  embed_model = None
50
 
51
+ # Global index state
52
  chunks = []
53
  index = None
54
 
55
+ # PDF text chunking
56
  def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
57
  try:
58
  doc = fitz.open(pdf_path)
59
+ text = "".join([page.get_text() for page in doc])
 
 
60
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
61
  except Exception as e:
62
+ logger.error(f"PDF error: {e}")
63
+ return ["Error extracting content."]
64
 
65
+ # Build FAISS index
66
  def build_faiss_index(chunks):
67
  try:
 
 
68
  embeddings = embed_model.encode(chunks, convert_to_numpy=True)
69
+ index = faiss.IndexFlatL2(embeddings.shape[1])
70
+ index.add(embeddings)
71
+ return index
 
72
  except Exception as e:
73
  logger.error(f"FAISS index error: {e}")
74
  return None
75
 
76
+ # RAG retrieval
77
  def rag_query(query, chunks, index, top_k=3):
 
 
78
  try:
79
  q_emb = embed_model.encode([query], convert_to_numpy=True)
80
  D, I = index.search(q_emb, top_k)
81
  return "\n\n".join([chunks[i] for i in I[0]])
82
  except Exception as e:
83
  logger.error(f"RAG query error: {e}")
84
+ return "Error retrieving context."
85
 
86
+ # Qwen chat
87
+ def chat_with_qwen(text, image=None):
88
  if not model or not tokenizer:
89
+ return "Model not loaded."
 
90
  try:
91
+ messages = [{"role": "user", "content": text}]
 
 
92
  if image:
93
+ messages[0]["content"] = [{"image": image}, {"text": text}]
94
+ response, _ = model.chat(tokenizer, messages, history=None)
 
 
 
 
 
 
 
 
 
95
  return response
96
  except Exception as e:
97
  logger.error(f"Chat error: {e}")
98
+ return f"Chat error: {e}"
99
 
100
+ # Extract representative frames
101
  def extract_video_frames(video_path, max_frames=2):
102
  try:
103
  cap = cv2.VideoCapture(video_path)
104
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
105
+ frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
106
  frames = []
107
+ for idx in frame_indices:
108
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
109
+ success, frame = cap.read()
110
+ if success:
111
+ frames.append(frame)
 
 
 
 
112
  cap.release()
113
  return frames
114
  except Exception as e:
115
+ logger.error(f"Frame extraction error: {e}")
116
  return []
117
 
118
+ # Multimodal chat logic
119
  def multimodal_chat(message, history, image=None, video=None, pdf=None):
120
  global chunks, index
121
+
122
  if not model:
123
+ return "Model not available."
124
 
125
  try:
126
+ # PDF + question
127
+ if pdf and message:
128
+ pdf_path = pdf.name if hasattr(pdf, 'name') else None
129
+ if not pdf_path:
130
+ return "Invalid PDF input."
131
+ chunks = extract_chunks_from_pdf(pdf_path)
132
  index = build_faiss_index(chunks)
133
  if index:
134
  context = rag_query(message, chunks, index)
135
+ user_prompt = f"Context:\n{context}\n\nQuestion: {message}"
136
+ return chat_with_qwen(user_prompt)
137
  else:
138
+ return "Failed to process PDF."
 
139
 
140
+ # Image + question
141
+ if image and message:
142
+ return chat_with_qwen(message, image)
 
143
 
144
+ # Video + question
145
+ if video and message:
146
+ with tempfile.TemporaryDirectory() as temp_dir:
147
+ video_path = os.path.join(temp_dir, "video.mp4")
148
+ shutil.copy(video.name if hasattr(video, 'name') else video, video_path)
 
149
  frames = extract_video_frames(video_path)
150
+ if not frames:
151
+ return "Could not extract video frames."
152
 
153
+ temp_img_path = os.path.join(temp_dir, "frame.jpg")
154
+ cv2.imwrite(temp_img_path, cv2.cvtColor(frames[0], cv2.COLOR_BGR2RGB))
155
+ return chat_with_qwen(message, temp_img_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Text only
158
  if message:
159
  return chat_with_qwen(message)
160
 
161
+ return "Please enter a question and optionally upload a file."
162
  except Exception as e:
163
+ logger.error(f"Chat error: {e}")
164
+ return f"Error: {e}"
165
 
166
+ # Gradio UI
167
  with gr.Blocks(css="""
168
+ body { background-color: #f3f6fc; }
169
+ .gradio-container { font-family: 'Segoe UI', sans-serif; }
 
 
 
 
170
  h1 {
171
+ background: linear-gradient(to right, #667eea, #764ba2);
172
+ color: white !important;
173
+ padding: 1rem; border-radius: 12px; margin-bottom: 0.5rem;
 
 
 
 
 
 
174
  }
175
  .gr-box {
176
+ background-color: white; border-radius: 12px;
177
+ box-shadow: 0 0 10px rgba(0,0,0,0.05); padding: 16px;
 
 
178
  }
179
+ footer { display: none !important; }
180
  """) as demo:
181
+
182
+ gr.Markdown("""
183
+ <h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1>
184
+ <p style='text-align: center;'>Ask your own questions with optional image, video, or PDF context.</p>
185
+ """)
186
 
187
  chatbot = gr.Chatbot(show_label=False, height=450)
188
  state = gr.State([])
189
 
190
  with gr.Row():
191
+ txt = gr.Textbox(show_label=False, placeholder="Type your question...", scale=5)
192
  send_btn = gr.Button("🚀 Send", scale=1)
193
 
194
  with gr.Row():
 
198
 
199
  def user_send(message, history, image, video, pdf):
200
  if not message and not image and not video and not pdf:
201
+ return "", history, history
202
  response = multimodal_chat(message, history, image, video, pdf)
203
  history.append((message, response))
204
+ return "", history, history
205
 
206
+ send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state])
207
+ txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state])
208
 
209
+ logger.info("Launching Gradio app")
210
+ demo.launch()