find_it_or_create_it / siglip_search_2.py
manganapallydeepa's picture
Upload 8 files
32199b1 verified
import faiss
import numpy as np
import json
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
import argparse
import requests,os
from dotenv import load_dotenv
import tempfile
# ------------------------------
# 1️⃣ Setup Model and Device
# ------------------------------
print("[INFO] Loading model...")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"[INFO] Using device: {device}")
# ------------------------------
# 2️⃣ Load Manifest
# ------------------------------
load_dotenv() # Load .env file
env_blob = os.environ.get("find_it_or_create", "")
secrets = {}
for line in env_blob.splitlines():
if "=" in line:
k, v = line.split("=", 1)
secrets[k.strip()] = v.strip().strip('"')
json_url = secrets.get("JSON_URL")
faiss_image_url=secrets.get("faiss_image")
faiss_text_url=secrets.get("faiss_text")
print("✅ Extracted JSON_URL:")
response = requests.get(json_url)
response.raise_for_status()
manifest_data = response.json()
def download_faiss_index(url):
response = requests.get(url)
response.raise_for_status()
tmp_path = tempfile.mktemp(suffix=".faiss")
with open(tmp_path, "wb") as f:
f.write(response.content)
return tmp_path
faiss_image = download_faiss_index(faiss_image_url)
faiss_text = download_faiss_index(faiss_text_url)
print("✅ Extracted index:")
print("[DEBUG] manifest_data type:", type(manifest_data))
manifest = manifest_data
if "products" in manifest:
manifest = manifest["products"]
# Build mappings from FAISS position → product_id safely
image_pos_to_id = {}
text_pos_to_id = {}
for pid, v in manifest.items():
if "image_pos" in v:
image_pos_to_id[v["image_pos"]] = pid
if "text_pos" in v:
text_pos_to_id[v["text_pos"]] = pid
# ------------------------------
# 3️⃣ Load FAISS Indexes
# ------------------------------
print("[INFO] Loading FAISS indexes...")
image_index = faiss.read_index(faiss_image)
text_index = faiss.read_index(faiss_text)
# ------------------------------
# 4️⃣ Search Functions
# ------------------------------
def search_image(image_path, topk=5):
"""Search similar images and return full manifest data"""
print(f"[DEBUG] Running image search on: {image_path}")
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
image_embedding = model.get_image_features(**inputs)
vec = image_embedding.cpu().numpy()
vec /= np.linalg.norm(vec)
scores, positions = image_index.search(vec, topk)
# print(f"[DEBUG] FAISS returned positions: {positions[0]}, scores: {scores[0]}")
results = []
for pos, score in zip(positions[0], scores[0]):
pid = image_pos_to_id.get(int(pos))
if pid:
results.append({
"score": float(score),
"data": manifest[pid]
})
for r in results:
print(json.dumps(r["data"], indent=2, ensure_ascii=False))
# print(f"[INFO] Found {len(results)} results")
return results
def search_text(query, topk=5):
"""Search similar texts and return full manifest data"""
print(f"[DEBUG] Running text search for query: '{query}'")
inputs = processor(text=query, return_tensors="pt").to(device)
with torch.no_grad():
text_embedding = model.get_text_features(**inputs)
vec = text_embedding.cpu().numpy()
vec /= np.linalg.norm(vec)
scores, positions = text_index.search(vec, topk)
# print(f"[DEBUG] FAISS returned positions: {positions[0]}, scores: {scores[0]}")
results = []
for pos, score in zip(positions[0], scores[0]):
pid = text_pos_to_id.get(int(pos))
if pid:
results.append({
"score": float(score),
"data": manifest[pid]
})
# print(f"[INFO] Found {len(results)} results")
return results
# ------------------------------
# 5️⃣ Command-Line Interface
# ------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SigLIP Image/Text Search")
parser.add_argument("--image", type=str, help="Path to query image")
parser.add_argument("--text", type=str, help="Query text")
parser.add_argument("--topk", type=int, default=5, help="Number of results")
args = parser.parse_args()
if args.image:
results = search_image(args.image, topk=args.topk)
print(f"\n🔎 Image Search Results ({len(results)}):")
# for r in results:
# print(json.dumps(r["data"], indent=2, ensure_ascii=False))
if args.text:
results = search_text(args.text, topk=args.topk)
# print(f"\n🔎 Text Search Results ({len(results)}):")
# for r in results:
# print(json.dumps(r["data"], indent=2, ensure_ascii=False))
if not args.image and not args.text:
print("Please provide --image or --text for search.")