Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from typing import List, Dict, Optional | |
from transformers import CLIPProcessor, CLIPModel | |
from qdrant_singleton import QdrantClientSingleton | |
from folder_manager import FolderManager | |
from image_database import ImageDatabase | |
import httpx | |
import io | |
class ImageSearch: | |
def __init__(self, init_model=True): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"ImageSearch using device: {self.device}") | |
# Initialize Qdrant client, folder manager and image database first | |
self.qdrant = QdrantClientSingleton.get_instance() | |
self.folder_manager = FolderManager() | |
self.image_db = ImageDatabase() | |
# Model initialization | |
self.model_initialized = False | |
self.processor = None | |
self.model = None | |
# Only initialize model if requested (to allow sharing from indexer) | |
if init_model: | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize the CLIP model and processor""" | |
try: | |
print("Loading CLIP model and processor...") | |
# Set environment variables to avoid tqdm threading issues | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "true" | |
# Load model first, then processor with explicit settings | |
self.model = CLIPModel.from_pretrained( | |
"openai/clip-vit-base-patch16", | |
cache_dir="/tmp/transformers_cache", | |
local_files_only=False | |
) | |
self.processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch16", | |
cache_dir="/tmp/transformers_cache", | |
use_fast=False, # Explicitly use slow processor to avoid compatibility issues | |
local_files_only=False | |
) | |
# Move model to device using to_empty() for meta tensors | |
try: | |
self.model = self.model.to(self.device) | |
except NotImplementedError: | |
# Handle meta tensor case | |
self.model = self.model.to_empty(device=self.device) | |
self.model.eval() # Set to evaluation mode | |
print("Model initialization complete") | |
self.model_initialized = True | |
except Exception as e: | |
print(f"Error initializing model: {e}") | |
print("Attempting fallback initialization...") | |
try: | |
# Simplest possible fallback with offline-first approach | |
import torch | |
torch.hub.set_dir('/tmp/torch_cache') | |
# Try loading without any extra parameters first | |
try: | |
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
self.processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch16", | |
use_fast=False | |
) | |
except Exception: | |
# If that fails, try with cache directory | |
self.model = CLIPModel.from_pretrained( | |
"openai/clip-vit-base-patch16", | |
cache_dir="/tmp/transformers_cache" | |
) | |
self.processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch16", | |
cache_dir="/tmp/transformers_cache", | |
use_fast=False | |
) | |
# Handle meta tensor case in fallback too | |
try: | |
self.model = self.model.to(self.device) | |
except NotImplementedError: | |
self.model = self.model.to_empty(device=self.device) | |
self.model.eval() | |
print("Fallback model initialization successful") | |
self.model_initialized = True | |
except Exception as e2: | |
print(f"Fallback also failed: {e2}") | |
print("Model initialization completely failed - search functionality will be disabled") | |
self.model_initialized = False | |
self.processor = None | |
self.model = None | |
def calculate_similarity_percentage(self, score: float) -> float: | |
"""Convert cosine similarity score to percentage""" | |
# Qdrant returns cosine similarity scores between -1 and 1 | |
# We want to convert this to a percentage between 0 and 100 | |
# First normalize to 0-1 range, then convert to percentage | |
normalized = (score + 1) / 2 | |
return normalized * 100 | |
def filter_results(self, search_results: list, threshold: float = 60) -> List[Dict]: | |
"""Filter and format search results""" | |
results = [] | |
for scored_point in search_results: | |
# Convert cosine similarity to percentage | |
similarity = self.calculate_similarity_percentage(scored_point.score) | |
# Only include results above threshold (60% similarity) | |
if similarity >= threshold: | |
# Get image data from SQLite database | |
image_id = scored_point.payload.get("image_id") | |
if image_id: | |
image_data = self.image_db.get_image(image_id) | |
if image_data: | |
results.append({ | |
"id": image_id, | |
"path": scored_point.payload["path"], | |
"filename": image_data["filename"], | |
"root_folder": scored_point.payload["root_folder"], | |
"similarity": round(similarity, 1), | |
"file_size": image_data["file_size"], | |
"width": image_data["width"], | |
"height": image_data["height"] | |
}) | |
return results | |
async def search_by_text(self, query: str, folder_path: Optional[str] = None, k: int = 10) -> List[Dict]: | |
"""Search images by text query""" | |
try: | |
print(f"\nSearching for text: '{query}'") | |
# Check if model is initialized | |
if not self.model_initialized or self.model is None or self.processor is None: | |
print("Model not initialized. Cannot perform text search.") | |
return [] | |
# Get collections to search | |
collections_to_search = [] | |
if folder_path: | |
# Search in specific folder's collection | |
collection_name = self.folder_manager.get_collection_for_path(folder_path) | |
if collection_name: | |
collections_to_search.append(collection_name) | |
print(f"Searching in specific folder collection: {collection_name}") | |
else: | |
# Search in all collections | |
folders = self.folder_manager.get_all_folders() | |
print(f"Found {len(folders)} folders") | |
for folder in folders: | |
print(f"Folder: {folder['path']}, Valid: {folder['is_valid']}, Collection: {folder.get('collection_name', 'None')}") | |
# Include all collections regardless of folder validity since images are in SQLite | |
collections_to_search.extend(folder["collection_name"] for folder in folders if folder.get("collection_name")) | |
print(f"Collections to search: {collections_to_search}") | |
if not collections_to_search: | |
print("No collections available to search") | |
return [] | |
# Generate text embedding | |
inputs = self.processor(text=[query], return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
text_features = self.model.get_text_features(**inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
text_embedding = text_features.cpu().numpy().flatten() | |
# Search in all relevant collections | |
all_results = [] | |
for collection_name in collections_to_search: | |
try: | |
# Get more results from each collection when searching multiple collections | |
collection_limit = k * 3 if len(collections_to_search) > 1 else k | |
search_result = self.qdrant.search( | |
collection_name=collection_name, | |
query_vector=text_embedding.tolist(), | |
limit=collection_limit, # Get more results from each collection | |
offset=0, # Explicitly set offset | |
score_threshold=0.2 # Corresponds to 60% similarity after normalization | |
) | |
# Filter and format results | |
results = self.filter_results(search_result) # Threshold is now default 60 in filter_results | |
all_results.extend(results) | |
print(f"Found {len(results)} matches in collection {collection_name}") | |
except Exception as e: | |
print(f"Error searching collection {collection_name}: {e}") | |
continue | |
# Sort all results by similarity | |
all_results.sort(key=lambda x: x["similarity"], reverse=True) | |
# Take top k results | |
final_results = all_results[:k] | |
print(f"Found {len(final_results)} total relevant matches across {len(collections_to_search)} collections") | |
return final_results | |
except Exception as e: | |
print(f"Error in text search: {e}") | |
import traceback | |
traceback.print_exc() | |
return [] | |
async def search_by_image(self, image: Image.Image, folder_path: Optional[str] = None, k: int = 10) -> List[Dict]: | |
"""Search images by similarity to uploaded image""" | |
try: | |
print(f"\nSearching by image...") | |
# Check if model is initialized | |
if not self.model_initialized or self.model is None or self.processor is None: | |
print("Model not initialized. Cannot perform image search.") | |
return [] | |
# Get collections to search | |
collections_to_search = [] | |
if folder_path: | |
# Search in specific folder's collection | |
collection_name = self.folder_manager.get_collection_for_path(folder_path) | |
if collection_name: | |
collections_to_search.append(collection_name) | |
print(f"Searching in specific folder collection: {collection_name}") | |
else: | |
# Search in all collections | |
folders = self.folder_manager.get_all_folders() | |
print(f"Found {len(folders)} folders") | |
print(f"Raw folder manager data: {self.folder_manager.folders}") | |
for folder in folders: | |
print(f"Folder: {folder['path']}, Valid: {folder['is_valid']}, Collection: {folder.get('collection_name', 'None')}") | |
# Include all collections regardless of folder validity since images are in SQLite | |
collections_to_search.extend(folder["collection_name"] for folder in folders if folder.get("collection_name")) | |
print(f"Collections to search: {collections_to_search}") | |
if not collections_to_search: | |
print("No collections available to search") | |
return [] | |
# Generate image embedding | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
image_features = self.model.get_image_features(**inputs) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
image_embedding = image_features.cpu().numpy().flatten() | |
# Search in all relevant collections | |
all_results = [] | |
for collection_name in collections_to_search: | |
try: | |
# Get more results from each collection when searching multiple collections | |
collection_limit = k * 3 if len(collections_to_search) > 1 else k | |
search_result = self.qdrant.search( | |
collection_name=collection_name, | |
query_vector=image_embedding.tolist(), | |
limit=collection_limit, # Get more results from each collection | |
offset=0, # Explicitly set offset | |
score_threshold=0.2 # Corresponds to 60% similarity after normalization | |
) | |
# Filter and format results | |
results = self.filter_results(search_result) # Threshold is now default 60 in filter_results | |
all_results.extend(results) | |
print(f"Found {len(results)} matches in collection {collection_name}") | |
except Exception as e: | |
print(f"Error searching collection {collection_name}: {e}") | |
continue | |
# Sort all results by similarity | |
all_results.sort(key=lambda x: x["similarity"], reverse=True) | |
# Take top k results | |
final_results = all_results[:k] | |
print(f"Found {len(final_results)} total relevant matches across {len(collections_to_search)} collections") | |
return final_results | |
except Exception as e: | |
print(f"Error in image search: {e}") | |
import traceback | |
traceback.print_exc() | |
return [] | |
async def download_image_from_url(self, url: str) -> Optional[Image.Image]: | |
"""Download and return an image from a URL""" | |
try: | |
print(f"Downloading image from URL: {url}") | |
# Use httpx for async HTTP requests | |
async with httpx.AsyncClient(timeout=30.0) as client: | |
response = await client.get(url) | |
response.raise_for_status() | |
# Check if the response is an image | |
content_type = response.headers.get('content-type', '') | |
if not content_type.startswith('image/'): | |
raise ValueError(f"URL does not point to an image. Content-Type: {content_type}") | |
# Load image from response content | |
image_bytes = io.BytesIO(response.content) | |
image = Image.open(image_bytes) | |
# Convert to RGB if necessary (for consistency with CLIP) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
print(f"Successfully downloaded image: {image.size}") | |
return image | |
except httpx.TimeoutException: | |
print(f"Timeout while downloading image from URL: {url}") | |
return None | |
except httpx.HTTPStatusError as e: | |
print(f"HTTP error {e.response.status_code} while downloading image from URL: {url}") | |
return None | |
except Exception as e: | |
print(f"Error downloading image from URL {url}: {e}") | |
return None | |
async def search_by_url(self, url: str, folder_path: Optional[str] = None, k: int = 10) -> List[Dict]: | |
"""Search images by downloading and comparing an image from a URL""" | |
try: | |
print(f"\nSearching by image URL: {url}") | |
# Download the image from URL | |
image = await self.download_image_from_url(url) | |
if image is None: | |
return [] | |
# Use the existing search_by_image method | |
return await self.search_by_image(image, folder_path, k) | |
except Exception as e: | |
print(f"Error in URL search: {e}") | |
import traceback | |
traceback.print_exc() | |
return [] |