Spaces:
Runtime error
Runtime error
File size: 16,046 Bytes
60444f3 55f2687 60444f3 55f2687 60444f3 5146b66 60444f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 |
import os
from pathlib import Path
from typing import List, Optional
import io
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, Request, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from PIL import Image
from image_indexer import ImageIndexer
from image_search import ImageSearch
from image_database import ImageDatabase
# Initialize image indexer, searcher, and database
indexer = ImageIndexer()
searcher = ImageSearch(init_model=False) # Don't init model, will share from indexer
# Share the folder manager instance between indexer and searcher
searcher.folder_manager = indexer.folder_manager
# Wait for indexer model to initialize, then share it with searcher
import time
import threading
def wait_and_share_model():
"""Wait for indexer model to initialize and share with searcher"""
# Wait for indexer model to be ready
if hasattr(indexer, 'model_initialized'):
indexer.model_initialized.wait(timeout=60) # Wait up to 60 seconds
# Share the model if indexer succeeded
if hasattr(indexer, 'model') and indexer.model is not None:
print("Sharing model from indexer to searcher...")
searcher.model = indexer.model
searcher.processor = indexer.processor
searcher.device = indexer.device
searcher.model_initialized = True
print("Model sharing complete")
# Start model sharing in background
threading.Thread(target=wait_and_share_model, daemon=True).start()
image_db = ImageDatabase()
image_extensions = [".jpg", ".jpeg", ".png", ".gif"]
@asynccontextmanager
async def lifespan(_: FastAPI):
"""Initialize the image indexer"""
yield
app = FastAPI(title="Visual Product Search", lifespan=lifespan)
# Setup templates and static files
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Render the home page"""
folders = indexer.folder_manager.get_all_folders()
return templates.TemplateResponse(
"index.html",
{
"request": request,
"initial_status": {
"status": indexer.status.value,
"current_file": indexer.current_file,
"total_files": indexer.total_files,
"processed_files": indexer.processed_files,
"progress_percentage": round((indexer.processed_files / indexer.total_files * 100) if indexer.total_files > 0 else 0, 2)
},
"folders": folders
}
)
@app.get("/health")
async def health_check():
"""Health check endpoint for monitoring"""
return {
"status": "healthy",
"service": "Visual Image Search",
"device": searcher.device if hasattr(searcher, 'device') else "unknown"
}
@app.post("/folders")
async def add_folder(folder_path: str, background_tasks: BackgroundTasks):
"""Add a new folder to index"""
try:
# Add folder to manager first (this creates the collection)
folder_info = indexer.folder_manager.add_folder(folder_path)
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, folder_path)
return folder_info
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.delete("/folders/{folder_path:path}")
async def remove_folder(folder_path: str):
"""Remove a folder from indexing"""
try:
await indexer.remove_folder(folder_path)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.get("/folders")
async def list_folders():
"""List all indexed folders"""
return indexer.folder_manager.get_all_folders()
@app.get("/search/text")
async def search_by_text(query: str, folder: Optional[str] = None) -> List[dict]:
"""Search images by text query, optionally filtered by folder"""
results = await searcher.search_by_text(query, folder)
return results
@app.post("/search/image")
async def search_by_image(
file: UploadFile = File(...),
folder: Optional[str] = None
) -> List[dict]:
"""Search images by uploading a similar image, optionally filtered by folder"""
contents = await file.read()
image = Image.open(io.BytesIO(contents))
results = await searcher.search_by_image(image, folder)
return results
@app.get("/search/url")
async def search_by_url(
url: str,
folder: Optional[str] = None
) -> List[dict]:
"""Search images by providing a URL to a similar image, optionally filtered by folder"""
results = await searcher.search_by_url(url, folder)
return results
@app.get("/images")
async def list_images(folder: Optional[str] = None) -> List[dict]:
"""List all indexed images, optionally filtered by folder"""
return await indexer.get_all_images(folder)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time indexing status updates"""
await indexer.add_websocket_connection(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
await indexer.remove_websocket_connection(websocket)
@app.get("/image/{image_id}")
async def serve_image(image_id: str):
"""Serve an image from the database by ID"""
try:
image_data = image_db.get_image(image_id)
if not image_data:
raise HTTPException(status_code=404, detail="Image not found")
return StreamingResponse(
io.BytesIO(image_data["image_data"]),
media_type=f"image/{image_data['file_extension'].lstrip('.')}",
headers={
"Cache-Control": "max-age=86400", # Cache for 24 hours
"Content-Disposition": f"inline; filename=\"{image_data['filename']}\""
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/thumbnail/{image_id}")
async def serve_thumbnail_by_id(image_id: str):
"""Serve a thumbnail from the database by ID"""
try:
thumbnail_data = image_db.get_thumbnail(image_id)
if not thumbnail_data:
raise HTTPException(status_code=404, detail="Thumbnail not found")
return StreamingResponse(
io.BytesIO(thumbnail_data),
media_type="image/jpeg",
headers={"Cache-Control": "max-age=86400"} # Cache for 24 hours
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def get_database_stats():
"""Get database statistics"""
try:
return image_db.get_database_stats()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/debug/collections")
async def debug_collections():
"""Debug endpoint to check collections and folders"""
try:
# Get Qdrant client and collections
qdrant_client = indexer.qdrant
collections = qdrant_client.get_collections().collections
# Get folder manager status
folders = indexer.folder_manager.get_all_folders()
return {
"qdrant_collections": [col.name for col in collections],
"folder_manager_folders": folders,
"collections_count": len(collections),
"folders_count": len(folders)
}
except Exception as e:
return {"error": str(e)}
@app.get("/debug/folder-managers")
async def debug_folder_managers():
"""Debug endpoint to check if folder managers are the same instance"""
return {
"indexer_folder_manager_id": id(indexer.folder_manager),
"searcher_folder_manager_id": id(searcher.folder_manager),
"are_same_instance": indexer.folder_manager is searcher.folder_manager,
"indexer_folders": indexer.folder_manager.get_all_folders(),
"searcher_folders": searcher.folder_manager.get_all_folders()
}
# Keep the old endpoints for backward compatibility but mark as deprecated
@app.get("/thumbnail/{folder_path:path}/{file_path:path}")
async def serve_thumbnail(folder_path: str, file_path: str):
"""Serve resized image thumbnails (DEPRECATED - use /thumbnail/{image_id} instead)"""
try:
# Get folder info to verify it's an indexed folder
folder_info = indexer.folder_manager.get_folder_info(folder_path)
if not folder_info:
raise HTTPException(status_code=404, detail="Folder not found")
# Construct full file path
full_path = Path(folder_path) / file_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="File not found")
# Only serve image files
if full_path.suffix.lower() not in image_extensions:
raise HTTPException(status_code=400, detail="Invalid file type")
# Open image, resize, and convert to JPEG
img = Image.open(full_path)
img.thumbnail((200, 200)) # Resize maintaining aspect ratio
# Save to a byte stream
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="JPEG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg", headers={"Cache-Control": "max-age=3600"}) # Cache for 1 hour
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/files/{folder_path:path}/{file_path:path}")
async def serve_file(folder_path: str, file_path: str):
"""Serve files from indexed folders (DEPRECATED - use /image/{image_id} instead)"""
try:
# Get folder info to verify it's an indexed folder
folder_info = indexer.folder_manager.get_folder_info(folder_path)
if not folder_info:
raise HTTPException(status_code=404, detail="Folder not found")
# Construct full file path
full_path = Path(folder_path) / file_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="File not found")
# Only serve image files
if full_path.suffix.lower() not in image_extensions:
raise HTTPException(status_code=400, detail="Invalid file type")
return FileResponse(full_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def get_windows_drives():
"""Get available drives on Windows"""
from ctypes import windll
drives = []
bitmask = windll.kernel32.GetLogicalDrives()
for letter in range(65, 91): # A-Z
if bitmask & (1 << (letter - 65)):
drives.append(chr(letter) + ":\\")
return drives
def get_directory_item(item):
"""Get directory item info"""
try:
is_dir = item.is_dir()
if is_dir or item.suffix.lower() in image_extensions:
return {
"name": item.name,
"path": str(item.absolute()),
"type": "directory" if is_dir else "file",
"size": item.stat().st_size if not is_dir else None
}
except Exception:
pass
return None
def get_directory_contents(path: str):
"""Get contents of a directory"""
try:
path_obj = Path(path)
if not path_obj.exists():
return {"error": "Path does not exist"}
parent = str(path_obj.parent) if path_obj.parent != path_obj else None
contents = [
item for item in (get_directory_item(i) for i in path_obj.iterdir())
if item is not None
]
return {
"current_path": str(path_obj.absolute()),
"parent_path": parent,
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower()))
}
except Exception as e:
return {"error": str(e)}
@app.get("/browse")
async def browse_folders():
"""Browse system folders"""
if os.name == "nt": # Windows
return {"drives": get_windows_drives()}
return get_directory_contents("/") # Unix-like
@app.get("/browse/{path:path}")
async def browse_path(path: str):
"""Browse a specific path"""
try:
path_obj = Path(path)
if not path_obj.exists():
raise HTTPException(status_code=404, detail="Path not found")
# Get parent directory for navigation
parent = str(path_obj.parent) if path_obj.parent != path_obj else None
# List directories and files
contents = []
for item in path_obj.iterdir():
try:
is_dir = item.is_dir()
if is_dir or item.suffix.lower() in image_extensions:
contents.append({
"name": item.name,
"path": str(item.absolute()),
"type": "directory" if is_dir else "file",
"size": item.stat().st_size if not is_dir else None
})
except Exception:
continue # Skip items we can't access
return {
"current_path": str(path_obj.absolute()),
"parent_path": parent,
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower()))
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/upload")
async def upload_folder(folder_path: str, background_tasks: BackgroundTasks):
"""Upload/add a folder for indexing (alternative endpoint name)"""
try:
# Add folder to manager first (this creates the collection)
folder_info = indexer.folder_manager.add_folder(folder_path)
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, folder_path)
return folder_info
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.post("/demo/create")
async def create_demo_folder(background_tasks: BackgroundTasks):
"""Create a demo folder with sample images for testing"""
try:
import urllib.request
from pathlib import Path
# Create demo folder
demo_path = Path("/tmp/demo_images")
demo_path.mkdir(exist_ok=True)
# Sample image URLs (small images for demo)
sample_images = [
("https://picsum.photos/300/200?random=1", "demo1.jpg"),
("https://picsum.photos/300/200?random=2", "demo2.jpg"),
("https://picsum.photos/300/200?random=3", "demo3.jpg"),
]
# Download sample images
for url, filename in sample_images:
try:
file_path = demo_path / filename
if not file_path.exists():
urllib.request.urlretrieve(url, file_path)
except Exception as e:
print(f"Could not download {filename}: {e}")
# Add folder for indexing
folder_info = indexer.folder_manager.add_folder(str(demo_path))
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, str(demo_path))
return {
"status": "success",
"message": f"Created demo folder with {len(sample_images)} images",
"folder_info": folder_info
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Use port 7860 for Hugging Face Spaces
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) |