Visual-Image / image_indexer.py
VesperAI's picture
addede a Production Branch
41ad411
from pathlib import Path
from typing import List, Dict, Set, Optional
import torch
from PIL import Image
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading
from qdrant_client.http.models import PointStruct
import uuid
from qdrant_singleton import QdrantClientSingleton, CURRENT_SCHEMA_VERSION
from fastapi import WebSocket
from enum import Enum
import qdrant_client
import time
from folder_manager import FolderManager
from image_database import ImageDatabase
class IndexingStatus(Enum):
IDLE = "idle"
INDEXING = "indexing"
MONITORING = "monitoring"
class ImageIndexer:
def __init__(self):
# Initialize folder manager and image database
self.folder_manager = FolderManager()
self.image_db = ImageDatabase()
# Initialize status tracking
self.status = IndexingStatus.IDLE
self.current_file: Optional[str] = None
self.total_files = 0
self.processed_files = 0
self.websocket_connections: Set[WebSocket] = set()
# Thread synchronization
self.collection_initialized = threading.Event()
self.model_initialized = threading.Event()
# Initialize Qdrant client
self.qdrant = QdrantClientSingleton.get_instance()
# Thread pool for background processing
self.executor = ThreadPoolExecutor(max_workers=4)
# Cache of indexed paths per collection
self.indexed_paths: Dict[str, Set[str]] = {}
# Model initialization flags
self.model = None
self.processor = None
self.device = None
# Start model initialization in a separate thread
threading.Thread(target=self._initialize_model_thread, daemon=True).start()
def _load_indexed_paths(self, collection_name: str):
"""Load the set of already indexed paths from a collection"""
try:
response = self.qdrant.scroll(
collection_name=collection_name,
limit=10000,
with_payload=True,
with_vectors=False
)
self.indexed_paths[collection_name] = {point.payload["path"] for point in response[0]}
except Exception as e:
print(f"Error loading indexed paths for collection {collection_name}: {e}")
self.indexed_paths[collection_name] = set()
async def broadcast_status(self):
"""Broadcast current status to all connected WebSocket clients"""
status_data = {
"status": self.status.value,
"current_file": self.current_file,
"total_files": self.total_files,
"processed_files": self.processed_files,
"progress_percentage": round((self.processed_files / self.total_files * 100) if self.total_files > 0 else 0, 2)
}
for connection in self.websocket_connections:
try:
await connection.send_json(status_data)
except Exception as e:
print(f"Error broadcasting to WebSocket: {e}")
self.websocket_connections.remove(connection)
async def add_websocket_connection(self, websocket: WebSocket):
"""Add a new WebSocket connection"""
await websocket.accept()
self.websocket_connections.add(websocket)
await self.broadcast_status()
async def remove_websocket_connection(self, websocket: WebSocket):
"""Remove a WebSocket connection"""
self.websocket_connections.remove(websocket)
async def add_folder(self, folder_path: str) -> Dict:
"""Add a new folder to index"""
folder_info = self.folder_manager.add_folder(folder_path)
# Start indexing the new folder
await self.index_folder(folder_path)
return folder_info
async def remove_folder(self, folder_path: str):
"""Remove a folder from indexing"""
# First remove from the folder manager
self.folder_manager.remove_folder(folder_path)
# Clean up SQLite database
folder_abs_path = str(Path(folder_path).absolute())
deleted_count = self.image_db.delete_images_by_folder(folder_abs_path)
print(f"Deleted {deleted_count} images from database for folder: {folder_path}")
async def index_folder(self, folder_path: str):
"""Index all images in a specific folder"""
if not self.model_initialized.is_set() or not self.model or not self.processor:
print("Model not initialized. Skipping indexing.")
self.status = IndexingStatus.IDLE
await self.broadcast_status()
return
folder_path = Path(folder_path)
if not folder_path.exists():
print(f"Folder not found: {folder_path}")
return
collection_name = self.folder_manager.get_collection_for_path(folder_path)
if not collection_name:
print(f"No collection found for folder: {folder_path}")
return
# Wait for model initialization before starting indexing
while not self.model_initialized.is_set():
print("Waiting for model initialization...")
await asyncio.sleep(0.1)
print(f"Starting to index folder: {folder_path}")
self.status = IndexingStatus.INDEXING
self.processed_files = 0
self.current_file = None
await self.broadcast_status() # Broadcast initial status
# Load indexed paths for this collection if not already loaded
if collection_name not in self.indexed_paths:
self._load_indexed_paths(collection_name)
# Use rglob for recursive directory scanning
image_files = [f for f in folder_path.rglob("*") if f.suffix.lower() in {".jpg", ".jpeg", ".png", ".gif"}]
self.total_files = len(image_files)
print(f"Found {self.total_files} images to index")
await self.broadcast_status() # Broadcast after finding total files
try:
for i, image_file in enumerate(image_files, 1):
relative_path = str(image_file.relative_to(folder_path))
self.current_file = str(image_file)
self.processed_files = i - 1 # Update before processing
await self.broadcast_status() # Broadcast before processing each file
if relative_path not in self.indexed_paths[collection_name]:
print(f"Indexing image {i}/{self.total_files}: {image_file.name}")
await self.index_image(image_file, folder_path)
else:
print(f"Skipping already indexed image {i}/{self.total_files}: {image_file.name}")
self.processed_files = i # Update after processing
await self.broadcast_status() # Broadcast after processing each file
# Small delay to allow other tasks to run
await asyncio.sleep(0)
except Exception as e:
print(f"Error during indexing: {e}")
import traceback
traceback.print_exc()
finally:
# Update last indexed timestamp
self.folder_manager.update_last_indexed(str(folder_path))
# Reset status
self.status = IndexingStatus.MONITORING
self.current_file = None
await self.broadcast_status() # Final status broadcast
print("Finished indexing folder")
async def index_image(self, image_path: Path, root_folder: Path):
"""Index a single image"""
if not self.model_initialized.is_set() or not self.model or not self.processor:
print("Model not initialized. Skipping indexing image.")
return
try:
# Wait for model initialization
while not self.model_initialized.is_set():
await asyncio.sleep(0.1)
# Get the collection for this path
collection_name = self.folder_manager.get_collection_for_path(str(root_folder))
if not collection_name:
print(f"No collection found for image: {image_path}")
return
# Convert to relative path from root folder
try:
relative_path = str(image_path.relative_to(root_folder))
except ValueError:
print(f"Image {image_path} is not under root folder {root_folder}")
return
print(f"Indexing image: {relative_path}")
self.current_file = str(image_path)
await self.broadcast_status()
# Check if image already exists in database
existing_image_id = self.image_db.image_exists_by_path(relative_path, str(root_folder.absolute()))
if existing_image_id:
# Check if it exists in Qdrant with current schema version
existing_points = self.qdrant.scroll(
collection_name=collection_name,
scroll_filter=qdrant_client.http.models.Filter(
must=[
qdrant_client.http.models.FieldCondition(
key="image_id",
match={"value": existing_image_id}
),
qdrant_client.http.models.FieldCondition(
key="schema_version",
match={"value": CURRENT_SCHEMA_VERSION}
)
]
),
limit=1
)[0]
if existing_points:
print(f"Skipping {relative_path} - already indexed with current schema version")
return
# Store image in SQLite database first
image_id = self.image_db.store_image(image_path, root_folder)
if not image_id:
print(f"Failed to store image in database: {relative_path}")
return
# Load and preprocess image for embedding
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Generate image embedding
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
# Normalize the features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
embedding = image_features.cpu().numpy().flatten()
# Verify embedding is valid
if np.isnan(embedding).any() or np.isinf(embedding).any():
print(f"Warning: Invalid embedding generated for {relative_path}")
return
# Delete any old versions from Qdrant if they exist
self.qdrant.delete(
collection_name=collection_name,
points_selector=qdrant_client.http.models.FilterSelector(
filter=qdrant_client.http.models.Filter(
must=[
qdrant_client.http.models.FieldCondition(
key="path",
match={"value": relative_path}
)
]
)
)
)
# Store in Qdrant with image ID reference and minimal metadata
point_id = str(uuid.uuid4())
self.qdrant.upsert(
collection_name=collection_name,
points=[
PointStruct(
id=point_id,
vector=embedding.tolist(),
payload={
"image_id": image_id, # Reference to SQLite database
"path": relative_path, # Relative path from root folder
"root_folder": str(root_folder.absolute()), # Store root folder path
"schema_version": CURRENT_SCHEMA_VERSION,
"indexed_at": int(time.time())
}
)
]
)
# Update indexed paths cache
if collection_name not in self.indexed_paths:
self.indexed_paths[collection_name] = set()
self.indexed_paths[collection_name].add(relative_path)
print(f"Stored embedding in Qdrant for {relative_path} (Image ID: {image_id})")
except Exception as e:
print(f"Error indexing image {image_path}: {e}")
import traceback
traceback.print_exc()
finally:
# Don't reset current_file here as it's managed by index_folder
await self.broadcast_status()
def _initialize_model_thread(self):
"""Initialize model in a separate thread"""
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Loading CLIP model and processor...")
# Set environment variables to avoid tqdm threading issues
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "true"
# Load model first, then processor with explicit settings
self.model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch16",
cache_dir="/tmp/transformers_cache",
local_files_only=False
)
self.processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch16",
cache_dir="/tmp/transformers_cache",
use_fast=False, # Explicitly use slow processor to avoid compatibility issues
local_files_only=False
)
# Move model to device using to_empty() for meta tensors
try:
self.model = self.model.to(self.device)
except NotImplementedError:
# Handle meta tensor case
self.model = self.model.to_empty(device=self.device)
self.model.eval() # Set to evaluation mode
self.model_initialized.set()
print("Model initialization complete")
except Exception as e:
print(f"Error initializing model: {e}")
print("Attempting fallback initialization...")
try:
# Simplest possible fallback with offline-first approach
import torch
torch.hub.set_dir('/tmp/torch_cache')
# Try loading without any extra parameters first
try:
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
self.processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch16",
use_fast=False
)
except Exception:
# If that fails, try with cache directory
self.model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch16",
cache_dir="/tmp/transformers_cache"
)
self.processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch16",
cache_dir="/tmp/transformers_cache",
use_fast=False
)
# Handle meta tensor case in fallback too
try:
self.model = self.model.to(self.device)
except NotImplementedError:
self.model = self.model.to_empty(device=self.device)
self.model.eval()
self.model_initialized.set()
print("Fallback model initialization successful")
except Exception as e2:
print(f"Fallback also failed: {e2}")
print("Model initialization completely failed - indexing functionality will be disabled")
import traceback
traceback.print_exc()
self.status = IndexingStatus.IDLE
asyncio.run(self.broadcast_status())
async def get_all_images(self, folder_path: Optional[str] = None) -> List[Dict]:
"""Get all indexed images, optionally filtered by folder"""
try:
if folder_path:
# Get images from specific folder
results = self.image_db.get_images_by_folder(str(Path(folder_path).absolute()))
else:
# Get images from all folders
results = self.image_db.get_all_images()
# Convert to API format
api_results = []
for image_data in results:
api_results.append({
"id": image_data["id"],
"path": image_data["relative_path"],
"filename": image_data["filename"],
"root_folder": image_data["root_folder"],
"file_size": image_data["file_size"],
"width": image_data["width"],
"height": image_data["height"],
"created_at": image_data["created_at"]
})
return api_results
except Exception as e:
print(f"Error getting images: {e}")
import traceback
traceback.print_exc()
return []
class ImageEventHandler(FileSystemEventHandler):
def __init__(self, indexer: ImageIndexer, root_folder: Path):
self.indexer = indexer
self.root_folder = root_folder
def on_created(self, event):
if not event.is_directory:
asyncio.create_task(self.indexer.index_image(Path(event.src_path), self.root_folder))