Visual-Image / app.py
VesperAI's picture
addede a Production Branch
55f2687
import os
from pathlib import Path
from typing import List, Optional
import io
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, Request, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from PIL import Image
from image_indexer import ImageIndexer
from image_search import ImageSearch
from image_database import ImageDatabase
# Initialize image indexer, searcher, and database
indexer = ImageIndexer()
searcher = ImageSearch(init_model=False) # Don't init model, will share from indexer
# Share the folder manager instance between indexer and searcher
searcher.folder_manager = indexer.folder_manager
# Wait for indexer model to initialize, then share it with searcher
import time
import threading
def wait_and_share_model():
"""Wait for indexer model to initialize and share with searcher"""
# Wait for indexer model to be ready
if hasattr(indexer, 'model_initialized'):
indexer.model_initialized.wait(timeout=60) # Wait up to 60 seconds
# Share the model if indexer succeeded
if hasattr(indexer, 'model') and indexer.model is not None:
print("Sharing model from indexer to searcher...")
searcher.model = indexer.model
searcher.processor = indexer.processor
searcher.device = indexer.device
searcher.model_initialized = True
print("Model sharing complete")
# Start model sharing in background
threading.Thread(target=wait_and_share_model, daemon=True).start()
image_db = ImageDatabase()
image_extensions = [".jpg", ".jpeg", ".png", ".gif"]
@asynccontextmanager
async def lifespan(_: FastAPI):
"""Initialize the image indexer"""
yield
app = FastAPI(title="Visual Product Search", lifespan=lifespan)
# Setup templates and static files
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Render the home page"""
folders = indexer.folder_manager.get_all_folders()
return templates.TemplateResponse(
"index.html",
{
"request": request,
"initial_status": {
"status": indexer.status.value,
"current_file": indexer.current_file,
"total_files": indexer.total_files,
"processed_files": indexer.processed_files,
"progress_percentage": round((indexer.processed_files / indexer.total_files * 100) if indexer.total_files > 0 else 0, 2)
},
"folders": folders
}
)
@app.get("/health")
async def health_check():
"""Health check endpoint for monitoring"""
return {
"status": "healthy",
"service": "Visual Image Search",
"device": searcher.device if hasattr(searcher, 'device') else "unknown"
}
@app.post("/folders")
async def add_folder(folder_path: str, background_tasks: BackgroundTasks):
"""Add a new folder to index"""
try:
# Add folder to manager first (this creates the collection)
folder_info = indexer.folder_manager.add_folder(folder_path)
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, folder_path)
return folder_info
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.delete("/folders/{folder_path:path}")
async def remove_folder(folder_path: str):
"""Remove a folder from indexing"""
try:
await indexer.remove_folder(folder_path)
return {"status": "success"}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.get("/folders")
async def list_folders():
"""List all indexed folders"""
return indexer.folder_manager.get_all_folders()
@app.get("/search/text")
async def search_by_text(query: str, folder: Optional[str] = None) -> List[dict]:
"""Search images by text query, optionally filtered by folder"""
results = await searcher.search_by_text(query, folder)
return results
@app.post("/search/image")
async def search_by_image(
file: UploadFile = File(...),
folder: Optional[str] = None
) -> List[dict]:
"""Search images by uploading a similar image, optionally filtered by folder"""
contents = await file.read()
image = Image.open(io.BytesIO(contents))
results = await searcher.search_by_image(image, folder)
return results
@app.get("/search/url")
async def search_by_url(
url: str,
folder: Optional[str] = None
) -> List[dict]:
"""Search images by providing a URL to a similar image, optionally filtered by folder"""
results = await searcher.search_by_url(url, folder)
return results
@app.get("/images")
async def list_images(folder: Optional[str] = None) -> List[dict]:
"""List all indexed images, optionally filtered by folder"""
return await indexer.get_all_images(folder)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time indexing status updates"""
await indexer.add_websocket_connection(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
await indexer.remove_websocket_connection(websocket)
@app.get("/image/{image_id}")
async def serve_image(image_id: str):
"""Serve an image from the database by ID"""
try:
image_data = image_db.get_image(image_id)
if not image_data:
raise HTTPException(status_code=404, detail="Image not found")
return StreamingResponse(
io.BytesIO(image_data["image_data"]),
media_type=f"image/{image_data['file_extension'].lstrip('.')}",
headers={
"Cache-Control": "max-age=86400", # Cache for 24 hours
"Content-Disposition": f"inline; filename=\"{image_data['filename']}\""
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/thumbnail/{image_id}")
async def serve_thumbnail_by_id(image_id: str):
"""Serve a thumbnail from the database by ID"""
try:
thumbnail_data = image_db.get_thumbnail(image_id)
if not thumbnail_data:
raise HTTPException(status_code=404, detail="Thumbnail not found")
return StreamingResponse(
io.BytesIO(thumbnail_data),
media_type="image/jpeg",
headers={"Cache-Control": "max-age=86400"} # Cache for 24 hours
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def get_database_stats():
"""Get database statistics"""
try:
return image_db.get_database_stats()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/debug/collections")
async def debug_collections():
"""Debug endpoint to check collections and folders"""
try:
# Get Qdrant client and collections
qdrant_client = indexer.qdrant
collections = qdrant_client.get_collections().collections
# Get folder manager status
folders = indexer.folder_manager.get_all_folders()
return {
"qdrant_collections": [col.name for col in collections],
"folder_manager_folders": folders,
"collections_count": len(collections),
"folders_count": len(folders)
}
except Exception as e:
return {"error": str(e)}
@app.get("/debug/folder-managers")
async def debug_folder_managers():
"""Debug endpoint to check if folder managers are the same instance"""
return {
"indexer_folder_manager_id": id(indexer.folder_manager),
"searcher_folder_manager_id": id(searcher.folder_manager),
"are_same_instance": indexer.folder_manager is searcher.folder_manager,
"indexer_folders": indexer.folder_manager.get_all_folders(),
"searcher_folders": searcher.folder_manager.get_all_folders()
}
# Keep the old endpoints for backward compatibility but mark as deprecated
@app.get("/thumbnail/{folder_path:path}/{file_path:path}")
async def serve_thumbnail(folder_path: str, file_path: str):
"""Serve resized image thumbnails (DEPRECATED - use /thumbnail/{image_id} instead)"""
try:
# Get folder info to verify it's an indexed folder
folder_info = indexer.folder_manager.get_folder_info(folder_path)
if not folder_info:
raise HTTPException(status_code=404, detail="Folder not found")
# Construct full file path
full_path = Path(folder_path) / file_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="File not found")
# Only serve image files
if full_path.suffix.lower() not in image_extensions:
raise HTTPException(status_code=400, detail="Invalid file type")
# Open image, resize, and convert to JPEG
img = Image.open(full_path)
img.thumbnail((200, 200)) # Resize maintaining aspect ratio
# Save to a byte stream
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="JPEG")
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg", headers={"Cache-Control": "max-age=3600"}) # Cache for 1 hour
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/files/{folder_path:path}/{file_path:path}")
async def serve_file(folder_path: str, file_path: str):
"""Serve files from indexed folders (DEPRECATED - use /image/{image_id} instead)"""
try:
# Get folder info to verify it's an indexed folder
folder_info = indexer.folder_manager.get_folder_info(folder_path)
if not folder_info:
raise HTTPException(status_code=404, detail="Folder not found")
# Construct full file path
full_path = Path(folder_path) / file_path
if not full_path.exists():
raise HTTPException(status_code=404, detail="File not found")
# Only serve image files
if full_path.suffix.lower() not in image_extensions:
raise HTTPException(status_code=400, detail="Invalid file type")
return FileResponse(full_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def get_windows_drives():
"""Get available drives on Windows"""
from ctypes import windll
drives = []
bitmask = windll.kernel32.GetLogicalDrives()
for letter in range(65, 91): # A-Z
if bitmask & (1 << (letter - 65)):
drives.append(chr(letter) + ":\\")
return drives
def get_directory_item(item):
"""Get directory item info"""
try:
is_dir = item.is_dir()
if is_dir or item.suffix.lower() in image_extensions:
return {
"name": item.name,
"path": str(item.absolute()),
"type": "directory" if is_dir else "file",
"size": item.stat().st_size if not is_dir else None
}
except Exception:
pass
return None
def get_directory_contents(path: str):
"""Get contents of a directory"""
try:
path_obj = Path(path)
if not path_obj.exists():
return {"error": "Path does not exist"}
parent = str(path_obj.parent) if path_obj.parent != path_obj else None
contents = [
item for item in (get_directory_item(i) for i in path_obj.iterdir())
if item is not None
]
return {
"current_path": str(path_obj.absolute()),
"parent_path": parent,
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower()))
}
except Exception as e:
return {"error": str(e)}
@app.get("/browse")
async def browse_folders():
"""Browse system folders"""
if os.name == "nt": # Windows
return {"drives": get_windows_drives()}
return get_directory_contents("/") # Unix-like
@app.get("/browse/{path:path}")
async def browse_path(path: str):
"""Browse a specific path"""
try:
path_obj = Path(path)
if not path_obj.exists():
raise HTTPException(status_code=404, detail="Path not found")
# Get parent directory for navigation
parent = str(path_obj.parent) if path_obj.parent != path_obj else None
# List directories and files
contents = []
for item in path_obj.iterdir():
try:
is_dir = item.is_dir()
if is_dir or item.suffix.lower() in image_extensions:
contents.append({
"name": item.name,
"path": str(item.absolute()),
"type": "directory" if is_dir else "file",
"size": item.stat().st_size if not is_dir else None
})
except Exception:
continue # Skip items we can't access
return {
"current_path": str(path_obj.absolute()),
"parent_path": parent,
"contents": sorted(contents, key=lambda x: (x["type"] != "directory", x["name"].lower()))
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/upload")
async def upload_folder(folder_path: str, background_tasks: BackgroundTasks):
"""Upload/add a folder for indexing (alternative endpoint name)"""
try:
# Add folder to manager first (this creates the collection)
folder_info = indexer.folder_manager.add_folder(folder_path)
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, folder_path)
return folder_info
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@app.post("/demo/create")
async def create_demo_folder(background_tasks: BackgroundTasks):
"""Create a demo folder with sample images for testing"""
try:
import urllib.request
from pathlib import Path
# Create demo folder
demo_path = Path("/tmp/demo_images")
demo_path.mkdir(exist_ok=True)
# Sample image URLs (small images for demo)
sample_images = [
("https://picsum.photos/300/200?random=1", "demo1.jpg"),
("https://picsum.photos/300/200?random=2", "demo2.jpg"),
("https://picsum.photos/300/200?random=3", "demo3.jpg"),
]
# Download sample images
for url, filename in sample_images:
try:
file_path = demo_path / filename
if not file_path.exists():
urllib.request.urlretrieve(url, file_path)
except Exception as e:
print(f"Could not download {filename}: {e}")
# Add folder for indexing
folder_info = indexer.folder_manager.add_folder(str(demo_path))
# Start indexing in the background
background_tasks.add_task(indexer.index_folder, str(demo_path))
return {
"status": "success",
"message": f"Created demo folder with {len(sample_images)} images",
"folder_info": folder_info
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Use port 7860 for Hugging Face Spaces
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)