jawakja's picture
Update app.py
03f20a2 verified
import gradio as gr
import fitz # PyMuPDF
import torch
import cv2
import os
import tempfile
import shutil
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Check CUDA
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
# BitsAndBytes config for quantized model loading
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Load Qwen model
try:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Omni-3B",
device_map="auto",
quantization_config=bnb_config,
trust_remote_code=True
).eval()
logger.info("Qwen model loaded.")
except Exception as e:
logger.error(f"Failed to load Qwen: {e}")
model, tokenizer = None, None
# Load SentenceTransformer for RAG
try:
embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2')
logger.info("Embedding model loaded.")
except Exception as e:
logger.error(f"Failed to load embedding model: {e}")
embed_model = None
# Global index state
chunks = []
index = None
# PDF text chunking
def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
try:
doc = fitz.open(pdf_path)
text = "".join([page.get_text() for page in doc])
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
except Exception as e:
logger.error(f"PDF error: {e}")
return ["Error extracting content."]
# Build FAISS index
def build_faiss_index(chunks):
try:
embeddings = embed_model.encode(chunks, convert_to_numpy=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index
except Exception as e:
logger.error(f"FAISS index error: {e}")
return None
# RAG retrieval
def rag_query(query, chunks, index, top_k=3):
try:
q_emb = embed_model.encode([query], convert_to_numpy=True)
D, I = index.search(q_emb, top_k)
return "\n\n".join([chunks[i] for i in I[0]])
except Exception as e:
logger.error(f"RAG query error: {e}")
return "Error retrieving context."
# Qwen chat
def chat_with_qwen(text, image=None):
if not model or not tokenizer:
return "Model not loaded."
try:
messages = [{"role": "user", "content": text}]
if image:
messages[0]["content"] = [{"image": image}, {"text": text}]
response, _ = model.chat(tokenizer, messages, history=None)
return response
except Exception as e:
logger.error(f"Chat error: {e}")
return f"Chat error: {e}"
# Extract representative frames
def extract_video_frames(video_path, max_frames=2):
try:
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
success, frame = cap.read()
if success:
frames.append(frame)
cap.release()
return frames
except Exception as e:
logger.error(f"Frame extraction error: {e}")
return []
# Multimodal chat logic
def multimodal_chat(message, history, image=None, video=None, pdf=None):
global chunks, index
if not model:
return "Model not available."
try:
# PDF + question
if pdf and message:
pdf_path = pdf.name if hasattr(pdf, 'name') else None
if not pdf_path:
return "Invalid PDF input."
chunks = extract_chunks_from_pdf(pdf_path)
index = build_faiss_index(chunks)
if index:
context = rag_query(message, chunks, index)
user_prompt = f"Context:\n{context}\n\nQuestion: {message}"
return chat_with_qwen(user_prompt)
else:
return "Failed to process PDF."
# Image + question
if image and message:
return chat_with_qwen(message, image)
# Video + question
if video and message:
with tempfile.TemporaryDirectory() as temp_dir:
video_path = os.path.join(temp_dir, "video.mp4")
shutil.copy(video.name if hasattr(video, 'name') else video, video_path)
frames = extract_video_frames(video_path)
if not frames:
return "Could not extract video frames."
temp_img_path = os.path.join(temp_dir, "frame.jpg")
cv2.imwrite(temp_img_path, cv2.cvtColor(frames[0], cv2.COLOR_BGR2RGB))
return chat_with_qwen(message, temp_img_path)
# Text only
if message:
return chat_with_qwen(message)
return "Please enter a question and optionally upload a file."
except Exception as e:
logger.error(f"Chat error: {e}")
return f"Error: {e}"
# Gradio UI
with gr.Blocks(css="""
body { background-color: #f3f6fc; }
.gradio-container { font-family: 'Segoe UI', sans-serif; }
h1 {
background: linear-gradient(to right, #667eea, #764ba2);
color: white !important;
padding: 1rem; border-radius: 12px; margin-bottom: 0.5rem;
}
.gr-box {
background-color: white; border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.05); padding: 16px;
}
footer { display: none !important; }
""") as demo:
gr.Markdown("""
<h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1>
<p style='text-align: center;'>Ask your own questions with optional image, video, or PDF context.</p>
""")
chatbot = gr.Chatbot(show_label=False, height=450)
state = gr.State([])
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Type your question...", scale=5)
send_btn = gr.Button("🚀 Send", scale=1)
with gr.Row():
image_input = gr.Image(type="filepath", label="Upload Image")
video_input = gr.Video(label="Upload Video")
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")
def user_send(message, history, image, video, pdf):
if not message and not image and not video and not pdf:
return "", history, history
response = multimodal_chat(message, history, image, video, pdf)
history.append((message, response))
return "", history, history
send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state])
txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state])
logger.info("Launching Gradio app")
demo.launch()