Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from PIL import Image | |
import os | |
from typing import List, Tuple, Optional | |
import time | |
# ============= DATASET SETUP FUNCTION ============= | |
def setup_dataset(): | |
"""Download and prepare dataset if not exists.""" | |
if not os.path.exists("dataset/images"): | |
print("π₯ First-time setup: downloading dataset...") | |
# Import required modules for setup | |
from datasets import load_dataset | |
from tqdm import tqdm | |
# Create directories | |
os.makedirs("dataset/images", exist_ok=True) | |
# 1. Download images (from download_images_hf.py) | |
print("π₯ Loading Caltech101 dataset...") | |
dataset = load_dataset("flwrlabs/caltech101", split="train") | |
dataset = dataset.shuffle(seed=42).select(range(min(500, len(dataset)))) | |
print(f"πΎ Saving {len(dataset)} images locally...") | |
for i, item in enumerate(tqdm(dataset)): | |
img = item["image"] | |
label = item["label"] | |
label_name = dataset.features["label"].int2str(label) | |
fname = f"{i:05d}_{label_name}.jpg" | |
img.save(os.path.join("dataset/images", fname)) | |
# 2. Generate embeddings (from embed_images_clip.py) | |
print("π Generating image embeddings...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = SentenceTransformer("clip-ViT-B-32", device=device) | |
image_files = [f for f in os.listdir("dataset/images") if f.lower().endswith((".jpg", ".png"))] | |
embeddings = [] | |
for fname in tqdm(image_files, desc="Encoding images"): | |
img_path = os.path.join("dataset/images", fname) | |
img = Image.open(img_path).convert("RGB") | |
emb = model.encode(img, convert_to_numpy=True, show_progress_bar=False, normalize_embeddings=True) | |
embeddings.append(emb) | |
embeddings = np.array(embeddings, dtype="float32") | |
np.save("dataset/image_embeddings.npy", embeddings) | |
np.save("dataset/image_filenames.npy", np.array(image_files)) | |
# 3. Build FAISS index (from build_faiss_index.py) | |
print("π¦ Building FAISS index...") | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dim) | |
index.add(embeddings) | |
faiss.write_index(index, "dataset/faiss_index.bin") | |
print("β Dataset setup complete!") | |
else: | |
print("β Dataset found, ready to go!") | |
# Call setup before anything else | |
setup_dataset() | |
# Configuration | |
META_PATH = "dataset/image_filenames.npy" | |
INDEX_PATH = "dataset/faiss_index.bin" | |
IMG_DIR = "dataset/images" | |
class MultimodalSearchEngine: | |
def __init__(self): | |
"""Initialize the search engine with pre-built index and model.""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"π Using device: {self.device}") | |
# Load pre-built index and metadata | |
self.index = faiss.read_index(INDEX_PATH) | |
self.image_files = np.load(META_PATH) | |
# Load CLIP model | |
self.model = SentenceTransformer("clip-ViT-B-32", device=self.device) | |
print(f"β Loaded index with {self.index.ntotal} images") | |
def search_by_text(self, query: str, k: int = 5) -> List[Tuple[str, float]]: | |
"""Search for images matching a text query.""" | |
if not query.strip(): | |
return [] | |
start_time = time.time() | |
query_emb = self.model.encode([query], convert_to_numpy=True, normalize_embeddings=True) | |
scores, idxs = self.index.search(query_emb, k) | |
search_time = time.time() - start_time | |
results = [] | |
for j, i in enumerate(idxs[0]): | |
if i != -1: # Valid index | |
img_path = os.path.join(IMG_DIR, self.image_files[i]) | |
results.append((img_path, float(scores[0][j]), search_time)) | |
return results | |
def search_by_image(self, image: Image.Image, k: int = 5) -> List[Tuple[str, float]]: | |
"""Search for images visually similar to the given image.""" | |
if image is None: | |
return [] | |
start_time = time.time() | |
# Convert to RGB if necessary | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
query_emb = self.model.encode(image, convert_to_numpy=True, normalize_embeddings=True) | |
query_emb = np.expand_dims(query_emb, axis=0) | |
scores, idxs = self.index.search(query_emb, k) | |
search_time = time.time() - start_time | |
results = [] | |
for j, i in enumerate(idxs[0]): | |
if i != -1: # Valid index | |
img_path = os.path.join(IMG_DIR, self.image_files[i]) | |
results.append((img_path, float(scores[0][j]), search_time)) | |
return results | |
# Initialize the search engine | |
try: | |
search_engine = MultimodalSearchEngine() | |
ENGINE_LOADED = True | |
except Exception as e: | |
print(f"β Error loading search engine: {e}") | |
ENGINE_LOADED = False | |
def format_results(results: List[Tuple[str, float, float]]) -> Tuple[List[str], str]: | |
"""Format search results for Gradio display.""" | |
if not results: | |
return [], "No results found." | |
image_paths = [result[0] for result in results] | |
search_time = results[0][2] if results else 0 | |
# Create detailed results text | |
results_text = f"π **Search Results** (Search time: {search_time:.3f}s)\n\n" | |
for i, (path, score, _) in enumerate(results, 1): | |
filename = os.path.basename(path) | |
# Extract label from filename (format: 00000_label.jpg) | |
label = filename.split('_', 1)[1].rsplit('.', 1)[0] if '_' in filename else 'unknown' | |
results_text += f"**{i}.** {label} (similarity: {score:.3f})\n" | |
return image_paths, results_text | |
def text_search_interface(query: str, num_results: int) -> Tuple[List[str], str]: | |
"""Interface function for text-based search.""" | |
if not ENGINE_LOADED: | |
return [], "β Search engine not loaded. Please check if all files are available." | |
if not query.strip(): | |
return [], "Please enter a search query." | |
try: | |
results = search_engine.search_by_text(query, k=num_results) | |
return format_results(results) | |
except Exception as e: | |
return [], f"β Error during search: {str(e)}" | |
def image_search_interface(image: Image.Image, num_results: int) -> Tuple[List[str], str]: | |
"""Interface function for image-based search.""" | |
if not ENGINE_LOADED: | |
return [], "β Search engine not loaded. Please check if all files are available." | |
if image is None: | |
return [], "Please upload an image." | |
try: | |
results = search_engine.search_by_image(image, k=num_results) | |
return format_results(results) | |
except Exception as e: | |
return [], f"β Error during search: {str(e)}" | |
def get_random_examples() -> List[str]: | |
"""Get random example queries.""" | |
examples = [ | |
"a cat sitting on a chair", | |
"airplane in the sky", | |
"red car on the road", | |
"person playing guitar", | |
"dog running in the park", | |
"beautiful sunset landscape", | |
"computer on a desk", | |
"flowers in a garden" | |
] | |
return examples | |
# Create the Gradio interface | |
with gr.Blocks( | |
title="π Multimodal AI Search Engine", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
.gallery img { | |
border-radius: 8px; | |
} | |
""" | |
) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 30px;"> | |
<h1>π Multimodal AI Search Engine</h1> | |
<p style="font-size: 18px; color: #666;"> | |
Search through 500 Caltech101 images using text descriptions or image similarity | |
</p> | |
<p style="font-size: 14px; color: #888;"> | |
Powered by CLIP-ViT-B-32 and FAISS for fast similarity search | |
</p> | |
</div> | |
""") | |
with gr.Tabs() as tabs: | |
# Text Search Tab | |
with gr.Tab("π Text Search", id="text_search"): | |
gr.Markdown("### Search images using natural language descriptions") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_query = gr.Textbox( | |
label="Search Query", | |
placeholder="Describe what you're looking for (e.g., 'a red car', 'person with guitar')", | |
lines=2 | |
) | |
with gr.Column(scale=1): | |
text_num_results = gr.Slider( | |
minimum=1, maximum=20, value=5, step=1, | |
label="Number of Results" | |
) | |
text_search_btn = gr.Button("π Search", variant="primary", size="lg") | |
# Examples | |
gr.Examples( | |
examples=get_random_examples()[:4], | |
inputs=text_query, | |
label="Example Queries" | |
) | |
with gr.Row(): | |
text_results = gr.Gallery( | |
label="Search Results", | |
show_label=True, | |
elem_id="text_gallery", | |
columns=5, | |
rows=1, | |
height="auto", | |
object_fit="contain" | |
) | |
text_info = gr.Markdown(label="Details") | |
# Image Search Tab | |
with gr.Tab("πΌοΈ Image Search", id="image_search"): | |
gr.Markdown("### Find visually similar images") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
image_query = gr.Image( | |
label="Upload Query Image", | |
type="pil", | |
height=300 | |
) | |
with gr.Column(scale=1): | |
image_num_results = gr.Slider( | |
minimum=1, maximum=20, value=5, step=1, | |
label="Number of Results" | |
) | |
image_search_btn = gr.Button("π Search Similar", variant="primary", size="lg") | |
with gr.Row(): | |
image_results = gr.Gallery( | |
label="Similar Images", | |
show_label=True, | |
elem_id="image_gallery", | |
columns=5, | |
rows=1, | |
height="auto", | |
object_fit="contain" | |
) | |
image_info = gr.Markdown(label="Details") | |
# About Tab | |
with gr.Tab("βΉοΈ About", id="about"): | |
gr.Markdown(""" | |
### π¬ Technical Details | |
This multimodal search engine demonstrates advanced AI techniques for content-based image retrieval: | |
**π§ Model Architecture:** | |
- **CLIP-ViT-B-32**: OpenAI's Contrastive Language-Image Pre-training model | |
- **Vision Transformer**: Processes images using attention mechanisms | |
- **Dual-encoder**: Separate encoders for text and images mapping to shared embedding space | |
**β‘ Search Infrastructure:** | |
- **FAISS**: Facebook AI Similarity Search for efficient vector operations | |
- **Cosine Similarity**: Measures semantic similarity in embedding space | |
- **Inner Product Index**: Optimized for normalized embeddings | |
**π Dataset:** | |
- **Caltech101**: 500 randomly sampled images from 101 object categories | |
- **Preprocessing**: RGB conversion, CLIP-compatible normalization | |
- **Embeddings**: 512-dimensional feature vectors per image | |
**π Performance Features:** | |
- **GPU Acceleration**: CUDA support for faster inference | |
- **Batch Processing**: Efficient embedding computation | |
- **Real-time Search**: Sub-second query response times | |
- **Normalized Embeddings**: L2 normalization for consistent similarity scores | |
**π― Applications:** | |
- Content-based image retrieval | |
- Visual search engines | |
- Cross-modal similarity matching | |
- Dataset exploration and analysis | |
### π οΈ Implementation Highlights | |
- Modular architecture with separate indexing and search components | |
- Error handling and graceful degradation | |
- Configurable result counts and similarity thresholds | |
- Professional UI with responsive design | |
""") | |
# Event handlers | |
text_search_btn.click( | |
fn=text_search_interface, | |
inputs=[text_query, text_num_results], | |
outputs=[text_results, text_info] | |
) | |
image_search_btn.click( | |
fn=image_search_interface, | |
inputs=[image_query, image_num_results], | |
outputs=[image_results, image_info] | |
) | |
# Auto-search on Enter key for text | |
text_query.submit( | |
fn=text_search_interface, | |
inputs=[text_query, text_num_results], | |
outputs=[text_results, text_info] | |
) | |
# Launch configuration | |
if __name__ == "__main__": | |
print("\n" + "="*50) | |
print("π Starting Multimodal AI Search Engine") | |
print("="*50) | |
if ENGINE_LOADED: | |
print(f"β Search engine ready with {search_engine.index.ntotal} images") | |
print(f"β Using device: {search_engine.device}") | |
else: | |
print("β Search engine failed to load") | |
print("\nπ‘ Usage Tips:") | |
print("- Text search: Use natural language descriptions") | |
print("- Image search: Upload any image to find similar ones") | |
print("- Adjust result count using the slider") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) |