AION-Search / app.py
astronolan's picture
Enhance search functionality and UI components
5a63add
#!/usr/bin/env python3
"""AION Search - Galaxy Semantic Search Application.
A Dash web application for semantic search over galaxy images using CLIP embeddings.
"""
import os
import logging
import argparse
# Fix OpenMP conflict - MUST be set before importing torch/numpy
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import dash
import dash_bootstrap_components as dbc
import src.config as config
from src.config import FEATURE_VECTOR_ADDITION
from src.components import get_app_theme, create_layout
from src.services import CLIPModelService, EmbeddingService, ZillizService, SearchService, ImageProcessingService
from src.callbacks import register_callbacks
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def create_app(checkpoint_path: str) -> dash.Dash:
"""Create and configure the Dash application.
Args:
checkpoint_path: Path to the CLIP model checkpoint
Returns:
Configured Dash app instance
"""
# Initialize Dash app
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
suppress_callback_exceptions=True
)
server = app.server
# Set custom theme
app.index_string = get_app_theme()
# Set app title
app.title = "AION Galaxy Search"
# Initialize services
logger.info("Initializing services...")
# Load CLIP model
clip_service = CLIPModelService()
clip_service.load_model(checkpoint_path)
# Create service instances
embedding_service = EmbeddingService(clip_service)
zilliz_service = ZillizService()
# Initialize image processing service for advanced search
# (now uses pre-existing embeddings from Zilliz, no model loading needed)
image_service = ImageProcessingService()
logger.info("Image processing service initialized successfully")
search_service = SearchService(embedding_service, zilliz_service, image_service)
# Get actual count from Zilliz and update config
actual_count = zilliz_service.get_collection_count()
if actual_count > 0:
config.TOTAL_GALAXIES = actual_count
logger.info(f"Services initialized. Total galaxies: {config.TOTAL_GALAXIES:,}")
else:
logger.warning(f"Failed to get collection count from Zilliz, using default: {config.TOTAL_GALAXIES:,}")
# Create app layout
app.layout = create_layout()
# Register callbacks
register_callbacks(app, search_service)
logger.info("App initialization complete!")
return app
def main():
"""Main entry point for the application."""
parser = argparse.ArgumentParser(description='AION Galaxy Search App')
parser.add_argument(
'--checkpoint',
type=str,
default='aionsearchmodel.pt',
help='Path to CLIP model checkpoint'
)
parser.add_argument(
'--port',
type=int,
default=7860,
help='Port to run the app on'
)
parser.add_argument(
'--debug',
action='store_true',
help='Run in debug mode'
)
parser.add_argument(
'--host',
type=str,
default='0.0.0.0',
help='Host to run the app on'
)
args = parser.parse_args()
# Create and run app
logger.info("Starting AION Galaxy Search...")
app = create_app(args.checkpoint)
logger.info(f"Server starting on {args.host}:{args.port}")
app.run_server(
debug=args.debug,
host=args.host,
port=args.port
)
app = create_app('aionsearchmodel.pt')
server = app.server
if __name__ == "__main__":
main()