import gradio as gr import numpy as np import faiss from sentence_transformers import SentenceTransformer import torch from PIL import Image import os from typing import List, Tuple, Optional import time # ============= DATASET SETUP FUNCTION ============= def setup_dataset(): """Download and prepare dataset if not exists.""" if not os.path.exists("dataset/images"): print("š„ First-time setup: downloading dataset...") # Import required modules for setup from datasets import load_dataset from tqdm import tqdm # Create directories os.makedirs("dataset/images", exist_ok=True) # 1. Download images (from download_images_hf.py) print("š„ Loading Caltech101 dataset...") dataset = load_dataset("flwrlabs/caltech101", split="train") dataset = dataset.shuffle(seed=42).select(range(min(500, len(dataset)))) print(f"š¾ Saving {len(dataset)} images locally...") for i, item in enumerate(tqdm(dataset)): img = item["image"] label = item["label"] label_name = dataset.features["label"].int2str(label) fname = f"{i:05d}_{label_name}.jpg" img.save(os.path.join("dataset/images", fname)) # 2. Generate embeddings (from embed_images_clip.py) print("š Generating image embeddings...") device = "cuda" if torch.cuda.is_available() else "cpu" model = SentenceTransformer("clip-ViT-B-32", device=device) image_files = [f for f in os.listdir("dataset/images") if f.lower().endswith((".jpg", ".png"))] embeddings = [] for fname in tqdm(image_files, desc="Encoding images"): img_path = os.path.join("dataset/images", fname) img = Image.open(img_path).convert("RGB") emb = model.encode(img, convert_to_numpy=True, show_progress_bar=False, normalize_embeddings=True) embeddings.append(emb) embeddings = np.array(embeddings, dtype="float32") np.save("dataset/image_embeddings.npy", embeddings) np.save("dataset/image_filenames.npy", np.array(image_files)) # 3. Build FAISS index (from build_faiss_index.py) print("š¦ Building FAISS index...") dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) faiss.write_index(index, "dataset/faiss_index.bin") print("ā Dataset setup complete!") else: print("ā Dataset found, ready to go!") # Call setup before anything else setup_dataset() # Configuration META_PATH = "dataset/image_filenames.npy" INDEX_PATH = "dataset/faiss_index.bin" IMG_DIR = "dataset/images" class MultimodalSearchEngine: def __init__(self): """Initialize the search engine with pre-built index and model.""" self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"š Using device: {self.device}") # Load pre-built index and metadata self.index = faiss.read_index(INDEX_PATH) self.image_files = np.load(META_PATH) # Load CLIP model self.model = SentenceTransformer("clip-ViT-B-32", device=self.device) print(f"ā Loaded index with {self.index.ntotal} images") def search_by_text(self, query: str, k: int = 5) -> List[Tuple[str, float]]: """Search for images matching a text query.""" if not query.strip(): return [] start_time = time.time() query_emb = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True) scores, idxs = self.index.search(query_emb, k) search_time = time.time() - start_time results = [] for j, i in enumerate(idxs[0]): if i != -1: # Valid index img_path = os.path.join(IMG_DIR, self.image_files[i]) results.append((img_path, float(scores[0][j]), search_time)) return results def search_by_image(self, image: Image.Image, k: int = 5) -> List[Tuple[str, float]]: """Search for images visually similar to the given image.""" if image is None: return [] start_time = time.time() # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') query_emb = self.model.encode(image, convert_to_numpy=True, normalize_embeddings=True) query_emb = np.expand_dims(query_emb, axis=0) scores, idxs = self.index.search(query_emb, k) search_time = time.time() - start_time results = [] for j, i in enumerate(idxs[0]): if i != -1: # Valid index img_path = os.path.join(IMG_DIR, self.image_files[i]) results.append((img_path, float(scores[0][j]), search_time)) return results # Initialize the search engine try: search_engine = MultimodalSearchEngine() ENGINE_LOADED = True except Exception as e: print(f"ā Error loading search engine: {e}") ENGINE_LOADED = False def format_results(results: List[Tuple[str, float, float]]) -> Tuple[List[str], str]: """Format search results for Gradio display.""" if not results: return [], "No results found." image_paths = [result[0] for result in results] search_time = results[0][2] if results else 0 # Create detailed results text results_text = f"š **Search Results** (Search time: {search_time:.3f}s)\n\n" for i, (path, score, _) in enumerate(results, 1): filename = os.path.basename(path) # Extract label from filename (format: 00000_label.jpg) label = filename.split('_', 1)[1].rsplit('.', 1)[0] if '_' in filename else 'unknown' results_text += f"**{i}.** {label} (similarity: {score:.3f})\n" return image_paths, results_text def text_search_interface(query: str, num_results: int) -> Tuple[List[str], str]: """Interface function for text-based search.""" if not ENGINE_LOADED: return [], "ā Search engine not loaded. Please check if all files are available." if not query.strip(): return [], "Please enter a search query." try: results = search_engine.search_by_text(query, k=num_results) return format_results(results) except Exception as e: return [], f"ā Error during search: {str(e)}" def image_search_interface(image: Image.Image, num_results: int) -> Tuple[List[str], str]: """Interface function for image-based search.""" if not ENGINE_LOADED: return [], "ā Search engine not loaded. Please check if all files are available." if image is None: return [], "Please upload an image." try: results = search_engine.search_by_image(image, k=num_results) return format_results(results) except Exception as e: return [], f"ā Error during search: {str(e)}" def get_random_examples() -> List[str]: """Get random example queries.""" examples = [ "a cat sitting on a chair", "airplane in the sky", "red car on the road", "person playing guitar", "dog running in the park", "beautiful sunset landscape", "computer on a desk", "flowers in a garden" ] return examples # Create the Gradio interface with gr.Blocks( title="š Multimodal AI Search Engine", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .gallery img { border-radius: 8px; } """ ) as demo: gr.HTML("""
Search through 500 Caltech101 images using text descriptions or image similarity
Powered by CLIP-ViT-B-32 and FAISS for fast similarity search