Spaces:
Running
Running
import gradio as gr | |
import fitz # PyMuPDF | |
import torch | |
import cv2 | |
import os | |
import tempfile | |
import shutil | |
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Check CUDA | |
logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
# BitsAndBytes config for quantized model loading | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
# Load Qwen model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen2.5-Omni-3B", | |
device_map="auto", | |
quantization_config=bnb_config, | |
trust_remote_code=True | |
).eval() | |
logger.info("Qwen model loaded.") | |
except Exception as e: | |
logger.error(f"Failed to load Qwen: {e}") | |
model, tokenizer = None, None | |
# Load SentenceTransformer for RAG | |
try: | |
embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2') | |
logger.info("Embedding model loaded.") | |
except Exception as e: | |
logger.error(f"Failed to load embedding model: {e}") | |
embed_model = None | |
# Global index state | |
chunks = [] | |
index = None | |
# PDF text chunking | |
def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200): | |
try: | |
doc = fitz.open(pdf_path) | |
text = "".join([page.get_text() for page in doc]) | |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)] | |
except Exception as e: | |
logger.error(f"PDF error: {e}") | |
return ["Error extracting content."] | |
# Build FAISS index | |
def build_faiss_index(chunks): | |
try: | |
embeddings = embed_model.encode(chunks, convert_to_numpy=True) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
return index | |
except Exception as e: | |
logger.error(f"FAISS index error: {e}") | |
return None | |
# RAG retrieval | |
def rag_query(query, chunks, index, top_k=3): | |
try: | |
q_emb = embed_model.encode([query], convert_to_numpy=True) | |
D, I = index.search(q_emb, top_k) | |
return "\n\n".join([chunks[i] for i in I[0]]) | |
except Exception as e: | |
logger.error(f"RAG query error: {e}") | |
return "Error retrieving context." | |
# Qwen chat | |
def chat_with_qwen(text, image=None): | |
if not model or not tokenizer: | |
return "Model not loaded." | |
try: | |
messages = [{"role": "user", "content": text}] | |
if image: | |
messages[0]["content"] = [{"image": image}, {"text": text}] | |
response, _ = model.chat(tokenizer, messages, history=None) | |
return response | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return f"Chat error: {e}" | |
# Extract representative frames | |
def extract_video_frames(video_path, max_frames=2): | |
try: | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)] | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
success, frame = cap.read() | |
if success: | |
frames.append(frame) | |
cap.release() | |
return frames | |
except Exception as e: | |
logger.error(f"Frame extraction error: {e}") | |
return [] | |
# Multimodal chat logic | |
def multimodal_chat(message, history, image=None, video=None, pdf=None): | |
global chunks, index | |
if not model: | |
return "Model not available." | |
try: | |
# PDF + question | |
if pdf and message: | |
pdf_path = pdf.name if hasattr(pdf, 'name') else None | |
if not pdf_path: | |
return "Invalid PDF input." | |
chunks = extract_chunks_from_pdf(pdf_path) | |
index = build_faiss_index(chunks) | |
if index: | |
context = rag_query(message, chunks, index) | |
user_prompt = f"Context:\n{context}\n\nQuestion: {message}" | |
return chat_with_qwen(user_prompt) | |
else: | |
return "Failed to process PDF." | |
# Image + question | |
if image and message: | |
return chat_with_qwen(message, image) | |
# Video + question | |
if video and message: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
video_path = os.path.join(temp_dir, "video.mp4") | |
shutil.copy(video.name if hasattr(video, 'name') else video, video_path) | |
frames = extract_video_frames(video_path) | |
if not frames: | |
return "Could not extract video frames." | |
temp_img_path = os.path.join(temp_dir, "frame.jpg") | |
cv2.imwrite(temp_img_path, cv2.cvtColor(frames[0], cv2.COLOR_BGR2RGB)) | |
return chat_with_qwen(message, temp_img_path) | |
# Text only | |
if message: | |
return chat_with_qwen(message) | |
return "Please enter a question and optionally upload a file." | |
except Exception as e: | |
logger.error(f"Chat error: {e}") | |
return f"Error: {e}" | |
# Gradio UI | |
with gr.Blocks(css=""" | |
body { background-color: #f3f6fc; } | |
.gradio-container { font-family: 'Segoe UI', sans-serif; } | |
h1 { | |
background: linear-gradient(to right, #667eea, #764ba2); | |
color: white !important; | |
padding: 1rem; border-radius: 12px; margin-bottom: 0.5rem; | |
} | |
.gr-box { | |
background-color: white; border-radius: 12px; | |
box-shadow: 0 0 10px rgba(0,0,0,0.05); padding: 16px; | |
} | |
footer { display: none !important; } | |
""") as demo: | |
gr.Markdown(""" | |
<h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1> | |
<p style='text-align: center;'>Ask your own questions with optional image, video, or PDF context.</p> | |
""") | |
chatbot = gr.Chatbot(show_label=False, height=450) | |
state = gr.State([]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Type your question...", scale=5) | |
send_btn = gr.Button("🚀 Send", scale=1) | |
with gr.Row(): | |
image_input = gr.Image(type="filepath", label="Upload Image") | |
video_input = gr.Video(label="Upload Video") | |
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF") | |
def user_send(message, history, image, video, pdf): | |
if not message and not image and not video and not pdf: | |
return "", history, history | |
response = multimodal_chat(message, history, image, video, pdf) | |
history.append((message, response)) | |
return "", history, history | |
send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state]) | |
txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot, state]) | |
logger.info("Launching Gradio app") | |
demo.launch() | |