Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
from typing import List, Optional | |
import io | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, File, UploadFile, Request, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks | |
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from PIL import Image | |
from image_indexer import ImageIndexer | |
from image_search import ImageSearch | |
from image_database import ImageDatabase | |
# Initialize image indexer, searcher, and database | |
indexer = ImageIndexer() | |
searcher = ImageSearch(init_model=False) # Don't init model, will share from indexer | |
# Share the folder manager instance between indexer and searcher | |
searcher.folder_manager = indexer.folder_manager | |
# Wait for indexer model to initialize, then share it with searcher | |
import time | |
import threading | |
def wait_and_share_model(): | |
"""Wait for indexer model to initialize and share with searcher""" | |
# Wait for indexer model to be ready | |
if hasattr(indexer, 'model_initialized'): | |
indexer.model_initialized.wait(timeout=60) # Wait up to 60 seconds | |
# Share the model if indexer succeeded | |
if hasattr(indexer, 'model') and indexer.model is not None: | |
print("Sharing model from indexer to searcher...") | |
searcher.model = indexer.model | |
searcher.processor = indexer.processor | |
searcher.device = indexer.device | |
searcher.model_initialized = True | |
print("Model sharing complete") | |
# Start model sharing in background | |
threading.Thread(target=wait_and_share_model, daemon=True).start() | |
image_db = ImageDatabase() | |
image_extensions = [".jpg", ".jpeg", ".png", ".gif"] | |
async def lifespan(_: FastAPI): | |
"""Initialize the image indexer""" | |
yield | |
app = FastAPI(title="Visual Product Search", lifespan=lifespan) | |
# Setup templates and static files | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def home(request: Request): | |
"""Render the home page""" | |
folders = indexer.folder_manager.get_all_folders() | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"initial_status": { | |
"status": indexer.status.value, | |
"current_file": indexer.current_file, | |
"total_files": indexer.total_files, | |
"processed_files": indexer.processed_files, | |
"progress_percentage": round((indexer.processed_files / indexer.total_files * 100) if indexer.total_files > 0 else 0, 2) | |
}, | |
"folders": folders | |
} | |
) | |
async def health_check(): | |
"""Health check endpoint for monitoring""" | |
return { | |
"status": "healthy", | |
"service": "Visual Image Search", | |
"device": searcher.device if hasattr(searcher, 'device') else "unknown" | |
} | |
async def add_folder(folder_path: str, background_tasks: BackgroundTasks): | |
"""Add a new folder to index""" | |
try: | |
# Add folder to manager first (this creates the collection) | |
folder_info = indexer.folder_manager.add_folder(folder_path) | |
# Start indexing in the background | |
background_tasks.add_task(indexer.index_folder, folder_path) | |
return folder_info | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
async def remove_folder(folder_path: str): | |
"""Remove a folder from indexing""" | |
try: | |
await indexer.remove_folder(folder_path) | |
return {"status": "success"} | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
async def list_folders(): | |
"""List all indexed folders""" | |
return indexer.folder_manager.get_all_folders() | |
async def search_by_text(query: str, folder: Optional[str] = None) -> List[dict]: | |
"""Search images by text query, optionally filtered by folder""" | |
results = await searcher.search_by_text(query, folder) | |
return results | |
async def search_by_image( | |
file: UploadFile = File(...), | |
folder: Optional[str] = None | |
) -> List[dict]: | |
"""Search images by uploading a similar image, optionally filtered by folder""" | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)) | |
results = await searcher.search_by_image(image, folder) | |
return results | |
async def search_by_url( | |
url: str, | |
folder: Optional[str] = None | |
) -> List[dict]: | |
"""Search images by providing a URL to a similar image, optionally filtered by folder""" | |
results = await searcher.search_by_url(url, folder) | |
return results | |
async def list_images(folder: Optional[str] = None) -> List[dict]: | |
"""List all indexed images, optionally filtered by folder""" | |
return await indexer.get_all_images(folder) | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocket endpoint for real-time indexing status updates""" | |
await indexer.add_websocket_connection(websocket) | |
try: | |
while True: | |
await websocket.receive_text() | |
except WebSocketDisconnect: | |
await indexer.remove_websocket_connection(websocket) | |
async def serve_image(image_id: str): | |
"""Serve an image from the database by ID""" | |
try: | |
image_data = image_db.get_image(image_id) | |
if not image_data: | |
raise HTTPException(status_code=404, detail="Image not found") | |
return StreamingResponse( | |
io.BytesIO(image_data["image_data"]), | |
media_type=f"image/{image_data['file_extension'].lstrip('.')}", | |
headers={ | |
"Cache-Control": "max-age=86400", # Cache for 24 hours | |
"Content-Disposition": f"inline; filename=\"{image_data['filename']}\"" | |
} | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def serve_thumbnail_by_id(image_id: str): | |
"""Serve a thumbnail from the database by ID""" | |
try: | |
thumbnail_data = image_db.get_thumbnail(image_id) | |
if not thumbnail_data: | |
raise HTTPException(status_code=404, detail="Thumbnail not found") | |
return StreamingResponse( | |
io.BytesIO(thumbnail_data), | |
media_type="image/jpeg", | |
headers={"Cache-Control": "max-age=86400"} # Cache for 24 hours | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_database_stats(): | |
"""Get database statistics""" | |
try: | |
return image_db.get_database_stats() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def debug_collections(): | |
"""Debug endpoint to check collections and folders""" | |
try: | |
# Get Qdrant client and collections | |
qdrant_client = indexer.qdrant | |
collections = qdrant_client.get_collections().collections | |
# Get folder manager status | |
folders = indexer.folder_manager.get_all_folders() | |
return { | |
"qdrant_collections": [col.name for col in collections], | |
"folder_manager_folders": folders, | |
"collections_count": len(collections), | |
"folders_count": len(folders) | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
async def debug_folder_managers(): | |
"""Debug endpoint to check if folder managers are the same instance""" | |
return { | |
"indexer_folder_manager_id": id(indexer.folder_manager), | |
"searcher_folder_manager_id": id(searcher.folder_manager), | |
"are_same_instance": indexer.folder_manager is searcher.folder_manager, | |
"indexer_folders": indexer.folder_manager.get_all_folders(), | |
"searcher_folders": searcher.folder_manager.get_all_folders() | |
} | |
# Keep the old endpoints for backward compatibility but mark as deprecated | |
async def serve_thumbnail(folder_path: str, file_path: str): | |
"""Serve resized image thumbnails (DEPRECATED - use /thumbnail/{image_id} instead)""" | |
try: | |
# Get folder info to verify it's an indexed folder | |
folder_info = indexer.folder_manager.get_folder_info(folder_path) | |
if not folder_info: | |
raise HTTPException(status_code=404, detail="Folder not found") | |
# Construct full file path | |
full_path = Path(folder_path) / file_path | |
if not full_path.exists(): | |
raise HTTPException(status_code=404, detail="File not found") | |
# Only serve image files | |
if full_path.suffix.lower() not in image_extensions: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
# Open image, resize, and convert to JPEG | |
img = Image.open(full_path) | |
img.thumbnail((200, 200)) # Resize maintaining aspect ratio | |
# Save to a byte stream | |
img_byte_arr = io.BytesIO() | |
img.save(img_byte_arr, format="JPEG") | |
img_byte_arr.seek(0) | |
return StreamingResponse(img_byte_arr, media_type="image/jpeg", headers={"Cache-Control": "max-age=3600"}) # Cache for 1 hour | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def serve_file(folder_path: str, file_path: str): | |
"""Serve files from indexed folders (DEPRECATED - use /image/{image_id} instead)""" | |
try: | |
# Get folder info to verify it's an indexed folder | |
folder_info = indexer.folder_manager.get_folder_info(folder_path) | |
if not folder_info: | |
raise HTTPException(status_code=404, detail="Folder not found") | |
# Construct full file path | |
full_path = Path(folder_path) / file_path | |
if not full_path.exists(): | |
raise HTTPException(status_code=404, detail="File not found") | |
# Only serve image files | |
if full_path.suffix.lower() not in image_extensions: | |
raise HTTPException(status_code=400, detail="Invalid file type") | |
return FileResponse(full_path) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
def get_windows_drives(): | |
"""Get available drives on Windows""" | |
from ctypes import windll | |
drives = [] | |
bitmask = windll.kernel32.GetLogicalDrives() | |
for letter in range(65, 91): # A-Z | |
if bitmask & (1 << (letter - 65)): | |
drives.append(chr(letter) + ":\\") | |
return drives | |
def get_directory_item(item): | |
"""Get directory item info""" | |
try: | |
is_dir = item.is_dir() | |
if is_dir or item.suffix.lower() in image_extensions: | |
return { | |
"name": item.name, | |
"path": str(item.absolute()), | |
"type": "directory" if is_dir else "file", | |
"size": item.stat().st_size if not is_dir else None | |
} | |
except Exception: | |
pass | |
return None | |
def get_directory_contents(path: str): | |
"""Get contents of a directory""" | |
try: | |
path_obj = Path(path) | |
if not path_obj.exists(): | |
return {"error": "Path does not exist"} | |
parent = str(path_obj.parent) if path_obj.parent != path_obj else None | |
contents = [ | |
item for item in (get_directory_item(i) for i in path_obj.iterdir()) | |
if item is not None | |
] | |
return { | |
"current_path": str(path_obj.absolute()), | |
"parent_path": parent, | |
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower())) | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
async def browse_folders(): | |
"""Browse system folders""" | |
if os.name == "nt": # Windows | |
return {"drives": get_windows_drives()} | |
return get_directory_contents("/") # Unix-like | |
async def browse_path(path: str): | |
"""Browse a specific path""" | |
try: | |
path_obj = Path(path) | |
if not path_obj.exists(): | |
raise HTTPException(status_code=404, detail="Path not found") | |
# Get parent directory for navigation | |
parent = str(path_obj.parent) if path_obj.parent != path_obj else None | |
# List directories and files | |
contents = [] | |
for item in path_obj.iterdir(): | |
try: | |
is_dir = item.is_dir() | |
if is_dir or item.suffix.lower() in image_extensions: | |
contents.append({ | |
"name": item.name, | |
"path": str(item.absolute()), | |
"type": "directory" if is_dir else "file", | |
"size": item.stat().st_size if not is_dir else None | |
}) | |
except Exception: | |
continue # Skip items we can't access | |
return { | |
"current_path": str(path_obj.absolute()), | |
"parent_path": parent, | |
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower())) | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) from e | |
async def upload_folder(folder_path: str, background_tasks: BackgroundTasks): | |
"""Upload/add a folder for indexing (alternative endpoint name)""" | |
try: | |
# Add folder to manager first (this creates the collection) | |
folder_info = indexer.folder_manager.add_folder(folder_path) | |
# Start indexing in the background | |
background_tasks.add_task(indexer.index_folder, folder_path) | |
return folder_info | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) from e | |
async def create_demo_folder(background_tasks: BackgroundTasks): | |
"""Create a demo folder with sample images for testing""" | |
try: | |
import urllib.request | |
from pathlib import Path | |
# Create demo folder | |
demo_path = Path("/tmp/demo_images") | |
demo_path.mkdir(exist_ok=True) | |
# Sample image URLs (small images for demo) | |
sample_images = [ | |
("https://picsum.photos/300/200?random=1", "demo1.jpg"), | |
("https://picsum.photos/300/200?random=2", "demo2.jpg"), | |
("https://picsum.photos/300/200?random=3", "demo3.jpg"), | |
] | |
# Download sample images | |
for url, filename in sample_images: | |
try: | |
file_path = demo_path / filename | |
if not file_path.exists(): | |
urllib.request.urlretrieve(url, file_path) | |
except Exception as e: | |
print(f"Could not download {filename}: {e}") | |
# Add folder for indexing | |
folder_info = indexer.folder_manager.add_folder(str(demo_path)) | |
# Start indexing in the background | |
background_tasks.add_task(indexer.index_folder, str(demo_path)) | |
return { | |
"status": "success", | |
"message": f"Created demo folder with {len(sample_images)} images", | |
"folder_info": folder_info | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
# Use port 7860 for Hugging Face Spaces | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) |