File size: 2,488 Bytes
134c981
 
 
 
 
 
4cd2e2b
 
439dcfc
134c981
 
4cd2e2b
134c981
 
 
 
 
 
 
 
 
 
 
4cd2e2b
 
 
 
 
aa882aa
4cd2e2b
134c981
 
3ddd8b0
134c981
 
 
 
439dcfc
 
4cd2e2b
 
 
 
 
 
 
 
 
134c981
3ddd8b0
4cd2e2b
134c981
 
4cd2e2b
134c981
 
 
 
 
4cd2e2b
134c981
 
4cd2e2b
134c981
 
 
4cd2e2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import os
import faiss
from datasets import load_dataset
import requests 
import io

# --- Configuration ---
MODEL_PATH = "clip_finetuned" 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FAISS_INDEX_PATH = "gallery.index"

# --- Load Model, Processor, and FAISS Index ---
print("Loading model and processor...")
model = CLIPModel.from_pretrained(MODEL_PATH).to(DEVICE)
processor = CLIPProcessor.from_pretrained(MODEL_PATH)

print("Loading FAISS index...")
faiss_index = faiss.read_index(FAISS_INDEX_PATH)

# --- Connect to the COCO dataset on the Hub ---
print("Connecting to COCO dataset on the Hub...")
val_dataset = load_dataset("phiyodr/coco2017", split="validation", trust_remote_code=True)
    
print(f"Successfully connected to dataset with {len(val_dataset)} images.")

# --- The Search Function (Corrected) ---
def image_search(query_text: str, top_k: int):
    with torch.no_grad():
        inputs = processor(text=query_text, return_tensors="pt").to(DEVICE)
        text_embedding = model.get_text_features(**inputs)
        text_embedding /= text_embedding.norm(p=2, dim=-1, keepdim=True)

    distances, indices = faiss_index.search(text_embedding.cpu().numpy(), int(top_k))
    results = []
    for i in indices[0]:
        item = val_dataset[int(i)]
        image_url = item['coco_url']
        response = requests.get(image_url)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
        results.append(image)
        
    return results

# --- Gradio Interface (No changes needed here) ---
with gr.Blocks(theme=gr.themes.Soft()) as iface:
    gr.Markdown("# 🖼️ CLIP-Powered Image Search Engine")
    gr.Markdown("Enter a text description to search for matching images.")
    
    with gr.Row():
        query_input = gr.Textbox(label="Search Query", placeholder="e.g., a red car parked near a building", scale=4)
        k_slider = gr.Slider(minimum=1, maximum=12, value=4, step=1, label="Number of Results")
        submit_btn = gr.Button("Search", variant="primary")

    gallery_output = gr.Gallery(label="Search Results", show_label=False, columns=4, height="auto")

    submit_btn.click(fn=image_search, inputs=[query_input, k_slider], outputs=gallery_output)
    
    gr.Examples(
        examples=[["a dog catching a frisbee", 4], ["two people eating pizza", 8]],
        inputs=[query_input, k_slider]
    )

iface.launch()