import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel from PIL import Image import os import faiss from datasets import load_dataset import requests import io # --- Configuration --- MODEL_PATH = "clip_finetuned" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" FAISS_INDEX_PATH = "gallery.index" # --- Load Model, Processor, and FAISS Index --- print("Loading model and processor...") model = CLIPModel.from_pretrained(MODEL_PATH).to(DEVICE) processor = CLIPProcessor.from_pretrained(MODEL_PATH) print("Loading FAISS index...") faiss_index = faiss.read_index(FAISS_INDEX_PATH) # --- Connect to the COCO dataset on the Hub --- print("Connecting to COCO dataset on the Hub...") val_dataset = load_dataset("phiyodr/coco2017", split="validation", trust_remote_code=True) print(f"Successfully connected to dataset with {len(val_dataset)} images.") # --- The Search Function (Corrected) --- def image_search(query_text: str, top_k: int): with torch.no_grad(): inputs = processor(text=query_text, return_tensors="pt").to(DEVICE) text_embedding = model.get_text_features(**inputs) text_embedding /= text_embedding.norm(p=2, dim=-1, keepdim=True) distances, indices = faiss_index.search(text_embedding.cpu().numpy(), int(top_k)) results = [] for i in indices[0]: item = val_dataset[int(i)] image_url = item['coco_url'] response = requests.get(image_url) image = Image.open(io.BytesIO(response.content)).convert("RGB") results.append(image) return results # --- Gradio Interface (No changes needed here) --- with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# 🖼️ CLIP-Powered Image Search Engine") gr.Markdown("Enter a text description to search for matching images.") with gr.Row(): query_input = gr.Textbox(label="Search Query", placeholder="e.g., a red car parked near a building", scale=4) k_slider = gr.Slider(minimum=1, maximum=12, value=4, step=1, label="Number of Results") submit_btn = gr.Button("Search", variant="primary") gallery_output = gr.Gallery(label="Search Results", show_label=False, columns=4, height="auto") submit_btn.click(fn=image_search, inputs=[query_input, k_slider], outputs=gallery_output) gr.Examples( examples=[["a dog catching a frisbee", 4], ["two people eating pizza", 8]], inputs=[query_input, k_slider] ) iface.launch()