|
|
import faiss
|
|
|
import numpy as np
|
|
|
import json
|
|
|
from transformers import AutoProcessor, AutoModel
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
import argparse
|
|
|
import requests,os
|
|
|
from dotenv import load_dotenv
|
|
|
import tempfile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("[INFO] Loading model...")
|
|
|
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
|
|
model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
|
|
|
model.eval()
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model.to(device)
|
|
|
print(f"[INFO] Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
env_blob = os.environ.get("find_it_or_create", "")
|
|
|
secrets = {}
|
|
|
|
|
|
for line in env_blob.splitlines():
|
|
|
if "=" in line:
|
|
|
k, v = line.split("=", 1)
|
|
|
secrets[k.strip()] = v.strip().strip('"')
|
|
|
|
|
|
json_url = secrets.get("JSON_URL")
|
|
|
faiss_image_url=secrets.get("faiss_image")
|
|
|
faiss_text_url=secrets.get("faiss_text")
|
|
|
|
|
|
print("✅ Extracted JSON_URL:")
|
|
|
|
|
|
response = requests.get(json_url)
|
|
|
response.raise_for_status()
|
|
|
manifest_data = response.json()
|
|
|
|
|
|
def download_faiss_index(url):
|
|
|
response = requests.get(url)
|
|
|
response.raise_for_status()
|
|
|
tmp_path = tempfile.mktemp(suffix=".faiss")
|
|
|
with open(tmp_path, "wb") as f:
|
|
|
f.write(response.content)
|
|
|
return tmp_path
|
|
|
|
|
|
faiss_image = download_faiss_index(faiss_image_url)
|
|
|
faiss_text = download_faiss_index(faiss_text_url)
|
|
|
|
|
|
print("✅ Extracted index:")
|
|
|
|
|
|
print("[DEBUG] manifest_data type:", type(manifest_data))
|
|
|
|
|
|
manifest = manifest_data
|
|
|
if "products" in manifest:
|
|
|
manifest = manifest["products"]
|
|
|
|
|
|
|
|
|
image_pos_to_id = {}
|
|
|
text_pos_to_id = {}
|
|
|
for pid, v in manifest.items():
|
|
|
if "image_pos" in v:
|
|
|
image_pos_to_id[v["image_pos"]] = pid
|
|
|
if "text_pos" in v:
|
|
|
text_pos_to_id[v["text_pos"]] = pid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("[INFO] Loading FAISS indexes...")
|
|
|
image_index = faiss.read_index(faiss_image)
|
|
|
text_index = faiss.read_index(faiss_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_image(image_path, topk=5):
|
|
|
"""Search similar images and return full manifest data"""
|
|
|
print(f"[DEBUG] Running image search on: {image_path}")
|
|
|
img = Image.open(image_path).convert("RGB")
|
|
|
inputs = processor(images=img, return_tensors="pt").to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
image_embedding = model.get_image_features(**inputs)
|
|
|
|
|
|
vec = image_embedding.cpu().numpy()
|
|
|
vec /= np.linalg.norm(vec)
|
|
|
|
|
|
scores, positions = image_index.search(vec, topk)
|
|
|
|
|
|
|
|
|
results = []
|
|
|
for pos, score in zip(positions[0], scores[0]):
|
|
|
pid = image_pos_to_id.get(int(pos))
|
|
|
if pid:
|
|
|
results.append({
|
|
|
"score": float(score),
|
|
|
"data": manifest[pid]
|
|
|
})
|
|
|
for r in results:
|
|
|
print(json.dumps(r["data"], indent=2, ensure_ascii=False))
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
def search_text(query, topk=5):
|
|
|
"""Search similar texts and return full manifest data"""
|
|
|
print(f"[DEBUG] Running text search for query: '{query}'")
|
|
|
inputs = processor(text=query, return_tensors="pt").to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
text_embedding = model.get_text_features(**inputs)
|
|
|
|
|
|
vec = text_embedding.cpu().numpy()
|
|
|
vec /= np.linalg.norm(vec)
|
|
|
|
|
|
scores, positions = text_index.search(vec, topk)
|
|
|
|
|
|
|
|
|
results = []
|
|
|
for pos, score in zip(positions[0], scores[0]):
|
|
|
pid = text_pos_to_id.get(int(pos))
|
|
|
if pid:
|
|
|
results.append({
|
|
|
"score": float(score),
|
|
|
"data": manifest[pid]
|
|
|
})
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="SigLIP Image/Text Search")
|
|
|
parser.add_argument("--image", type=str, help="Path to query image")
|
|
|
parser.add_argument("--text", type=str, help="Query text")
|
|
|
parser.add_argument("--topk", type=int, default=5, help="Number of results")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.image:
|
|
|
results = search_image(args.image, topk=args.topk)
|
|
|
print(f"\n🔎 Image Search Results ({len(results)}):")
|
|
|
|
|
|
|
|
|
|
|
|
if args.text:
|
|
|
results = search_text(args.text, topk=args.topk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not args.image and not args.text:
|
|
|
print("Please provide --image or --text for search.") |