File size: 3,694 Bytes
c89f65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11f6de9
c89f65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c56e632
 
c89f65f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
"""AION Search - Galaxy Semantic Search Application.

A Dash web application for semantic search over galaxy images using CLIP embeddings.
"""

import os
import logging
import argparse

# Fix OpenMP conflict - MUST be set before importing torch/numpy
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import dash
import dash_bootstrap_components as dbc

import src.config as config
from src.config import FEATURE_VECTOR_ADDITION
from src.components import get_app_theme, create_layout
from src.services import CLIPModelService, EmbeddingService, ZillizService, SearchService, ImageProcessingService
from src.callbacks import register_callbacks

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def create_app(checkpoint_path: str) -> dash.Dash:
    """Create and configure the Dash application.

    Args:
        checkpoint_path: Path to the CLIP model checkpoint

    Returns:
        Configured Dash app instance
    """
    # Initialize Dash app
    app = dash.Dash(
        __name__,
        external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
        suppress_callback_exceptions=True
    )
    server = app.server

    # Set custom theme
    app.index_string = get_app_theme()

    # Set app title
    app.title = "AION Galaxy Search"

    # Initialize services
    logger.info("Initializing services...")

    # Load CLIP model
    clip_service = CLIPModelService()
    clip_service.load_model(checkpoint_path)

    # Create service instances
    embedding_service = EmbeddingService(clip_service)
    zilliz_service = ZillizService()

    # Initialize image processing service for advanced search
    # (now uses pre-existing embeddings from Zilliz, no model loading needed)
    image_service = ImageProcessingService()
    logger.info("Image processing service initialized successfully")

    search_service = SearchService(embedding_service, zilliz_service, image_service)

    # Get actual count from Zilliz and update config
    actual_count = zilliz_service.get_collection_count()
    if actual_count > 0:
        config.TOTAL_GALAXIES = actual_count
        logger.info(f"Services initialized. Total galaxies: {config.TOTAL_GALAXIES:,}")
    else:
        logger.warning(f"Failed to get collection count from Zilliz, using default: {config.TOTAL_GALAXIES:,}")

    # Create app layout
    app.layout = create_layout()

    # Register callbacks
    register_callbacks(app, search_service)

    logger.info("App initialization complete!")

    return app


def main():
    """Main entry point for the application."""
    parser = argparse.ArgumentParser(description='AION Galaxy Search App')
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='aionsearchmodel.pt',
        help='Path to CLIP model checkpoint'
    )
    parser.add_argument(
        '--port',
        type=int,
        default=7860,
        help='Port to run the app on'
    )
    parser.add_argument(
        '--debug',
        action='store_true',
        help='Run in debug mode'
    )
    parser.add_argument(
        '--host',
        type=str,
        default='0.0.0.0',
        help='Host to run the app on'
    )

    args = parser.parse_args()

    # Create and run app
    logger.info("Starting AION Galaxy Search...")
    app = create_app(args.checkpoint)

    logger.info(f"Server starting on {args.host}:{args.port}")
    app.run_server(
        debug=args.debug,
        host=args.host,
        port=args.port
    )
app = create_app('aionsearchmodel.pt')
server = app.server

if __name__ == "__main__":
    main()