Spaces:
Running
Running
Commit
Β·
c89f65f
1
Parent(s):
dda65a0
Add AION-Search Dash app for Hugging Face Spaces
Browse files- Add complete application code (app.py, src/, clip/)
- Add Dockerfile configured for HF Spaces deployment
- Add requirements.txt with torch-cpu and all dependencies
- Add cleaned model checkpoint (46MB, inference-only)
- Configure for port 7860 with gunicorn
Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- .gitignore +16 -0
- Dockerfile +28 -0
- README.md +4 -4
- aionsearchmodel.pt +3 -0
- app.py +133 -0
- clip/__init__.py +8 -0
- clip/evaluation/__init__.py +5 -0
- clip/evaluation/inference.py +82 -0
- clip/models/__init__.py +6 -0
- clip/models/clip_model.py +118 -0
- clip/models/projections.py +270 -0
- clip/utils/__init__.py +10 -0
- clip/utils/data_loader.py +250 -0
- clip/utils/io_utils.py +103 -0
- clip/utils/logging_utils.py +42 -0
- main.py +6 -0
- requirements.txt +13 -0
- src/__init__.py +3 -0
- src/callbacks.py +775 -0
- src/components.py +821 -0
- src/config.py +68 -0
- src/services.py +538 -0
- src/utils.py +195 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
env/
|
| 7 |
+
venv/
|
| 8 |
+
.env
|
| 9 |
+
.venv
|
| 10 |
+
*.log
|
| 11 |
+
.DS_Store
|
| 12 |
+
tmp/data/processed/*
|
| 13 |
+
.python-version
|
| 14 |
+
pyproject.toml
|
| 15 |
+
uv.lock
|
| 16 |
+
.claude
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy requirements and install Python dependencies
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy application code
|
| 15 |
+
COPY app.py .
|
| 16 |
+
COPY src/ ./src/
|
| 17 |
+
COPY clip/ ./clip/
|
| 18 |
+
COPY aionsearchmodel.pt .
|
| 19 |
+
|
| 20 |
+
# Create necessary directories
|
| 21 |
+
RUN mkdir -p data/processed logs
|
| 22 |
+
|
| 23 |
+
# Expose port for Hugging Face Spaces
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
# Run the application with gunicorn for production
|
| 27 |
+
# Increased timeout and workers for HF Spaces
|
| 28 |
+
CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "1", "--threads", "2", "--timeout", "600", "app:server"]
|
README.md
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
---
|
| 2 |
title: AION Search
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
| 1 |
---
|
| 2 |
title: AION Search
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
AION-Search
|
aionsearchmodel.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e91a0b8e1f632165d62aff10dc598674a35a92e28a4312af220a606bd44664f6
|
| 3 |
+
size 48614488
|
app.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""AION Search - Galaxy Semantic Search Application.
|
| 3 |
+
|
| 4 |
+
A Dash web application for semantic search over galaxy images using CLIP embeddings.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
# Fix OpenMP conflict - MUST be set before importing torch/numpy
|
| 12 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
| 13 |
+
|
| 14 |
+
import dash
|
| 15 |
+
import dash_bootstrap_components as dbc
|
| 16 |
+
|
| 17 |
+
import src.config as config
|
| 18 |
+
from src.config import FEATURE_VECTOR_ADDITION
|
| 19 |
+
from src.components import get_app_theme, create_layout
|
| 20 |
+
from src.services import CLIPModelService, EmbeddingService, ZillizService, SearchService, ImageProcessingService
|
| 21 |
+
from src.callbacks import register_callbacks
|
| 22 |
+
|
| 23 |
+
# Set up logging
|
| 24 |
+
logging.basicConfig(
|
| 25 |
+
level=logging.INFO,
|
| 26 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 27 |
+
)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_app(checkpoint_path: str) -> dash.Dash:
|
| 32 |
+
"""Create and configure the Dash application.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
checkpoint_path: Path to the CLIP model checkpoint
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Configured Dash app instance
|
| 39 |
+
"""
|
| 40 |
+
# Initialize Dash app
|
| 41 |
+
app = dash.Dash(
|
| 42 |
+
__name__,
|
| 43 |
+
external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
|
| 44 |
+
suppress_callback_exceptions=True
|
| 45 |
+
)
|
| 46 |
+
server = app.server
|
| 47 |
+
|
| 48 |
+
# Set custom theme
|
| 49 |
+
app.index_string = get_app_theme()
|
| 50 |
+
|
| 51 |
+
# Set app title
|
| 52 |
+
app.title = "AION Galaxy Search"
|
| 53 |
+
|
| 54 |
+
# Initialize services
|
| 55 |
+
logger.info("Initializing services...")
|
| 56 |
+
|
| 57 |
+
# Load CLIP model
|
| 58 |
+
clip_service = CLIPModelService()
|
| 59 |
+
clip_service.load_model(checkpoint_path)
|
| 60 |
+
|
| 61 |
+
# Create service instances
|
| 62 |
+
embedding_service = EmbeddingService(clip_service)
|
| 63 |
+
zilliz_service = ZillizService()
|
| 64 |
+
|
| 65 |
+
# Initialize image processing service for advanced search
|
| 66 |
+
# (now uses pre-existing embeddings from Zilliz, no model loading needed)
|
| 67 |
+
image_service = ImageProcessingService()
|
| 68 |
+
logger.info("Image processing service initialized successfully")
|
| 69 |
+
|
| 70 |
+
search_service = SearchService(embedding_service, zilliz_service, image_service)
|
| 71 |
+
|
| 72 |
+
# Get actual count from Zilliz and update config
|
| 73 |
+
actual_count = zilliz_service.get_collection_count()
|
| 74 |
+
if actual_count > 0:
|
| 75 |
+
config.TOTAL_GALAXIES = actual_count
|
| 76 |
+
logger.info(f"Services initialized. Total galaxies: {config.TOTAL_GALAXIES:,}")
|
| 77 |
+
else:
|
| 78 |
+
logger.warning(f"Failed to get collection count from Zilliz, using default: {config.TOTAL_GALAXIES:,}")
|
| 79 |
+
|
| 80 |
+
# Create app layout
|
| 81 |
+
app.layout = create_layout()
|
| 82 |
+
|
| 83 |
+
# Register callbacks
|
| 84 |
+
register_callbacks(app, search_service)
|
| 85 |
+
|
| 86 |
+
logger.info("App initialization complete!")
|
| 87 |
+
|
| 88 |
+
return app
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main():
|
| 92 |
+
"""Main entry point for the application."""
|
| 93 |
+
parser = argparse.ArgumentParser(description='AION Galaxy Search App')
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
'--checkpoint',
|
| 96 |
+
type=str,
|
| 97 |
+
default='aionsearchmodel.pt',
|
| 98 |
+
help='Path to CLIP model checkpoint'
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
'--port',
|
| 102 |
+
type=int,
|
| 103 |
+
default=7860,
|
| 104 |
+
help='Port to run the app on'
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
'--debug',
|
| 108 |
+
action='store_true',
|
| 109 |
+
help='Run in debug mode'
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
'--host',
|
| 113 |
+
type=str,
|
| 114 |
+
default='0.0.0.0',
|
| 115 |
+
help='Host to run the app on'
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
# Create and run app
|
| 121 |
+
logger.info("Starting AION Galaxy Search...")
|
| 122 |
+
app = create_app(args.checkpoint)
|
| 123 |
+
|
| 124 |
+
logger.info(f"Server starting on {args.host}:{args.port}")
|
| 125 |
+
app.run_server(
|
| 126 |
+
debug=args.debug,
|
| 127 |
+
host=args.host,
|
| 128 |
+
port=args.port
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
main()
|
clip/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLIP alignment for galaxy images and text descriptions.
|
| 3 |
+
|
| 4 |
+
This package provides tools for training and using CLIP-style alignment
|
| 5 |
+
between AION galaxy embeddings and text descriptions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
__version__ = "0.1.0"
|
clip/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation utilities for CLIP model."""
|
| 2 |
+
|
| 3 |
+
from .inference import ClipInferenceModel
|
| 4 |
+
|
| 5 |
+
__all__ = ["ClipInferenceModel"]
|
clip/evaluation/inference.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference utilities for trained CLIP model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Union, List, Dict, Tuple
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from ..models import GalaxyClipModel
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ClipInferenceModel:
|
| 18 |
+
"""Wrapper for using trained CLIP model for inference and search."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_path: str, device: str = "cpu"):
|
| 21 |
+
"""
|
| 22 |
+
Initialize inference model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path: Path to saved model (.pt file)
|
| 26 |
+
device: Device to use for inference
|
| 27 |
+
"""
|
| 28 |
+
self.device = torch.device(device)
|
| 29 |
+
|
| 30 |
+
# Load model
|
| 31 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
| 32 |
+
model_config = checkpoint['model_config']
|
| 33 |
+
|
| 34 |
+
# Create model with same config
|
| 35 |
+
self.model = GalaxyClipModel(
|
| 36 |
+
image_input_dim=model_config['image_input_dim'],
|
| 37 |
+
text_input_dim=model_config['text_input_dim'],
|
| 38 |
+
embedding_dim=model_config['embedding_dim']
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Load weights
|
| 42 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 43 |
+
self.model.to(self.device)
|
| 44 |
+
self.model.eval()
|
| 45 |
+
|
| 46 |
+
self.config = model_config
|
| 47 |
+
logger.info(f"Loaded CLIP model on {device}")
|
| 48 |
+
logger.info(f"Model config: {model_config}")
|
| 49 |
+
|
| 50 |
+
def encode_images(self, image_embeddings):
|
| 51 |
+
"""Encode image embeddings to shared space."""
|
| 52 |
+
|
| 53 |
+
tensor = torch.as_tensor(image_embeddings, dtype=torch.float, device=self.device)
|
| 54 |
+
|
| 55 |
+
if tensor.ndim == 1:
|
| 56 |
+
tensor = tensor.unsqueeze(0)
|
| 57 |
+
squeeze = True
|
| 58 |
+
else:
|
| 59 |
+
squeeze = False
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Use image_projector and normalize
|
| 63 |
+
out = self.model.image_projector(tensor)
|
| 64 |
+
|
| 65 |
+
return out.squeeze(0).cpu() if squeeze else out.cpu()
|
| 66 |
+
|
| 67 |
+
def encode_texts(self, text_embeddings):
|
| 68 |
+
"""Encode text embeddings to shared space."""
|
| 69 |
+
|
| 70 |
+
tensor = torch.as_tensor(text_embeddings, dtype=torch.float, device=self.device)
|
| 71 |
+
|
| 72 |
+
if tensor.ndim == 1:
|
| 73 |
+
tensor = tensor.unsqueeze(0)
|
| 74 |
+
squeeze = True
|
| 75 |
+
else:
|
| 76 |
+
squeeze = False
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
# Use text_projector and normalize
|
| 80 |
+
out = self.model.text_projector(tensor)
|
| 81 |
+
|
| 82 |
+
return out.squeeze(0).cpu() if squeeze else out.cpu()
|
clip/models/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLIP model architecture for galaxy embeddings."""
|
| 2 |
+
|
| 3 |
+
from .clip_model import GalaxyClipModel
|
| 4 |
+
from .projections import CrossAttentionImageProjector, TextProjector
|
| 5 |
+
|
| 6 |
+
__all__ = ["GalaxyClipModel", "CrossAttentionImageProjector", "TextProjector"]
|
clip/models/clip_model.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
from .projections import TextProjector, CrossAttentionImageProjector, SimpleImageProjector
|
| 7 |
+
|
| 8 |
+
class GalaxyClipModel(nn.Module):
|
| 9 |
+
"""CLIP model for aligning galaxy images and text descriptions."""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
image_input_dim: int = 768,
|
| 14 |
+
text_input_dim: int = 3072,
|
| 15 |
+
embedding_dim: int = 1024,
|
| 16 |
+
image_hidden_dim: int = 768,
|
| 17 |
+
text_hidden_dim: int = 1024,
|
| 18 |
+
dropout: float = 0.1,
|
| 19 |
+
use_mean_embeddings: bool = True
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Initialize CLIP model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image_input_dim: AION embedding dimension
|
| 26 |
+
text_input_dim: Text embedding dimension
|
| 27 |
+
embedding_dim: Shared embedding space dimension
|
| 28 |
+
image_hidden_dim: Hidden dimension for image projector
|
| 29 |
+
text_hidden_dim: Hidden dimension for text projector
|
| 30 |
+
dropout: Dropout rate
|
| 31 |
+
use_mean_embeddings: Whether using mean embeddings (True) or full embeddings (False)
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.embedding_dim = embedding_dim
|
| 36 |
+
self.use_mean_embeddings = use_mean_embeddings
|
| 37 |
+
|
| 38 |
+
# Choose appropriate image projector based on embedding type
|
| 39 |
+
if use_mean_embeddings:
|
| 40 |
+
# Simple projector for mean embeddings (1D vectors)
|
| 41 |
+
self.image_projector = SimpleImageProjector(
|
| 42 |
+
input_dim=image_input_dim,
|
| 43 |
+
output_dim=embedding_dim,
|
| 44 |
+
hidden_dim=image_hidden_dim,
|
| 45 |
+
dropout=dropout
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
# Cross-attention projector for full embeddings (2D sequences)
|
| 49 |
+
self.image_projector = CrossAttentionImageProjector(
|
| 50 |
+
input_dim=image_input_dim,
|
| 51 |
+
output_dim=embedding_dim,
|
| 52 |
+
hidden_dim=image_hidden_dim,
|
| 53 |
+
dropout=dropout
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.text_projector = TextProjector(
|
| 57 |
+
input_dim=text_input_dim,
|
| 58 |
+
output_dim=embedding_dim,
|
| 59 |
+
hidden_dim=text_hidden_dim,
|
| 60 |
+
dropout=dropout
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Learnable logit scale parameter initialized to standard CLIP temperature 1/0.07
|
| 64 |
+
# Using log parameterization for numerical stability
|
| 65 |
+
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07, dtype=torch.float32)))
|
| 66 |
+
|
| 67 |
+
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass for CLIP training.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
batch: Dictionary containing 'image_embedding' and 'text_embedding'
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dictionary with projected embeddings and logits
|
| 76 |
+
"""
|
| 77 |
+
image_features = batch['image_embedding']
|
| 78 |
+
text_features = batch['text_embedding']
|
| 79 |
+
|
| 80 |
+
# Project to shared space and normalize
|
| 81 |
+
image_features = self.image_projector(image_features)
|
| 82 |
+
text_features = self.text_projector(text_features)
|
| 83 |
+
|
| 84 |
+
# Compute similarity matrix with learnable logit scale
|
| 85 |
+
# Clamp after exp to preserve gradients
|
| 86 |
+
logit_scale = self.logit_scale.exp().clamp(max=100)
|
| 87 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 88 |
+
logits_per_text = logits_per_image.T
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
'image_features': image_features,
|
| 92 |
+
'text_features': text_features,
|
| 93 |
+
'logits_per_image': logits_per_image,
|
| 94 |
+
'logits_per_text': logits_per_text,
|
| 95 |
+
'logit_scale': logit_scale
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def compute_contrastive_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Compute contrastive loss (InfoNCE).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
outputs: Model outputs from forward pass
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Contrastive loss
|
| 107 |
+
"""
|
| 108 |
+
logits_per_image = outputs['logits_per_image']
|
| 109 |
+
logits_per_text = outputs['logits_per_text']
|
| 110 |
+
|
| 111 |
+
batch_size = logits_per_image.shape[0]
|
| 112 |
+
labels = torch.arange(batch_size, device=logits_per_image.device)
|
| 113 |
+
|
| 114 |
+
# Cross-entropy loss for both directions
|
| 115 |
+
loss_i2t = F.cross_entropy(logits_per_image, labels)
|
| 116 |
+
loss_t2i = F.cross_entropy(logits_per_text, labels)
|
| 117 |
+
|
| 118 |
+
return (loss_i2t + loss_t2i) / 2
|
clip/models/projections.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class TextProjector(nn.Module):
|
| 7 |
+
"""Projects text embeddings to shared space."""
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
input_dim: int = 3072,
|
| 12 |
+
output_dim: int = 1024,
|
| 13 |
+
hidden_dim: Optional[int] = None,
|
| 14 |
+
dropout: float = 0.1,
|
| 15 |
+
num_layers: int = 4,
|
| 16 |
+
):
|
| 17 |
+
"""
|
| 18 |
+
Initialize text projector.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
input_dim: Dimension of text embeddings (3072)
|
| 22 |
+
output_dim: Dimension of shared embedding space
|
| 23 |
+
hidden_dim: Hidden layer dimension (default: 1024)
|
| 24 |
+
dropout: Dropout rate
|
| 25 |
+
num_layers: Number of residual layers (default: 2)
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
if hidden_dim is None:
|
| 30 |
+
hidden_dim = 1024
|
| 31 |
+
|
| 32 |
+
self.fc_in = nn.Linear(input_dim, hidden_dim)
|
| 33 |
+
self.blocks = nn.ModuleList([
|
| 34 |
+
nn.Sequential(
|
| 35 |
+
nn.LayerNorm(hidden_dim),
|
| 36 |
+
nn.GELU(),
|
| 37 |
+
nn.Dropout(dropout),
|
| 38 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 39 |
+
) for _ in range(num_layers)
|
| 40 |
+
])
|
| 41 |
+
self.fc_out = nn.Linear(hidden_dim, output_dim)
|
| 42 |
+
|
| 43 |
+
# Initialize weights
|
| 44 |
+
self._init_weights()
|
| 45 |
+
|
| 46 |
+
def _init_weights(self):
|
| 47 |
+
"""Initialize projection weights."""
|
| 48 |
+
for module in self.modules():
|
| 49 |
+
if isinstance(module, nn.Linear):
|
| 50 |
+
nn.init.xavier_uniform_(module.weight)
|
| 51 |
+
if module.bias is not None:
|
| 52 |
+
nn.init.zeros_(module.bias)
|
| 53 |
+
|
| 54 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Project text embeddings to shared space.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
x: Text embeddings (batch_size, input_dim)
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Projected embeddings (batch_size, output_dim)
|
| 63 |
+
"""
|
| 64 |
+
h = self.fc_in(x)
|
| 65 |
+
for blk in self.blocks: # residual MLP stack
|
| 66 |
+
h = h + blk(h)
|
| 67 |
+
h = self.fc_out(h)
|
| 68 |
+
return F.normalize(h, dim=-1, eps=1e-3)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CrossAttentionImageProjector(nn.Module):
|
| 72 |
+
"""Simplified projector with self-attention + cross-attention."""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
input_dim: int = 768,
|
| 77 |
+
output_dim: int = 1024,
|
| 78 |
+
hidden_dim: Optional[int] = None,
|
| 79 |
+
dropout: float = 0.1,
|
| 80 |
+
num_layers: int = 2, # Kept for compatibility, not used
|
| 81 |
+
num_heads: int = 4, # Reduced from 8
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Initialize simplified cross-attention image projector.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
input_dim: Dimension of AION embeddings (768)
|
| 88 |
+
output_dim: Dimension of shared embedding space (default: 1024)
|
| 89 |
+
hidden_dim: Hidden dimension for attention (default: output_dim)
|
| 90 |
+
dropout: Dropout rate
|
| 91 |
+
num_layers: Kept for compatibility but not used
|
| 92 |
+
num_heads: Number of attention heads (reduced to 4)
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
if hidden_dim is None:
|
| 97 |
+
hidden_dim = output_dim
|
| 98 |
+
|
| 99 |
+
self.input_dim = input_dim
|
| 100 |
+
self.hidden_dim = hidden_dim
|
| 101 |
+
self.output_dim = output_dim
|
| 102 |
+
|
| 103 |
+
# Project input to hidden dim
|
| 104 |
+
self.input_proj = nn.Linear(input_dim, hidden_dim)
|
| 105 |
+
|
| 106 |
+
# Token pooling to reduce sequence length
|
| 107 |
+
# 576 tokens -> 64 tokens (9x reduction)
|
| 108 |
+
self.token_pool = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, stride=9, padding=0)
|
| 109 |
+
|
| 110 |
+
# Single self-attention layer
|
| 111 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 112 |
+
self.self_attn = nn.MultiheadAttention(
|
| 113 |
+
embed_dim=hidden_dim,
|
| 114 |
+
num_heads=num_heads,
|
| 115 |
+
dropout=dropout,
|
| 116 |
+
batch_first=True
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# MLP after self-attention
|
| 120 |
+
self.mlp1_norm = nn.LayerNorm(hidden_dim)
|
| 121 |
+
self.mlp1 = nn.Sequential(
|
| 122 |
+
nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
|
| 123 |
+
nn.GELU(),
|
| 124 |
+
nn.Dropout(dropout),
|
| 125 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 126 |
+
nn.Dropout(dropout)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Learned query vector
|
| 130 |
+
self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
|
| 131 |
+
|
| 132 |
+
# Single cross-attention layer
|
| 133 |
+
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
|
| 134 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 135 |
+
embed_dim=hidden_dim,
|
| 136 |
+
num_heads=num_heads,
|
| 137 |
+
dropout=dropout,
|
| 138 |
+
batch_first=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Final MLP
|
| 142 |
+
self.final_norm = nn.LayerNorm(hidden_dim)
|
| 143 |
+
self.final_mlp = nn.Sequential(
|
| 144 |
+
nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
|
| 145 |
+
nn.GELU(),
|
| 146 |
+
nn.Dropout(dropout),
|
| 147 |
+
nn.Linear(hidden_dim * 2, output_dim)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Initialize weights
|
| 151 |
+
self._init_weights()
|
| 152 |
+
|
| 153 |
+
def _init_weights(self):
|
| 154 |
+
"""Initialize weights."""
|
| 155 |
+
# Initialize query vector
|
| 156 |
+
nn.init.normal_(self.query, std=0.02)
|
| 157 |
+
|
| 158 |
+
# Initialize other weights
|
| 159 |
+
for module in self.modules():
|
| 160 |
+
if isinstance(module, nn.Linear):
|
| 161 |
+
nn.init.xavier_uniform_(module.weight)
|
| 162 |
+
if module.bias is not None:
|
| 163 |
+
nn.init.zeros_(module.bias)
|
| 164 |
+
|
| 165 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
"""
|
| 167 |
+
Project image embeddings to shared space using self-attention + cross-attention.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
x: Image embeddings (batch_size, n_tokens, input_dim)
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Projected embeddings (batch_size, output_dim)
|
| 174 |
+
"""
|
| 175 |
+
batch_size = x.shape[0]
|
| 176 |
+
x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input (handles [B, N, D])
|
| 177 |
+
|
| 178 |
+
# Project input
|
| 179 |
+
x = self.input_proj(x) # (B, N, hidden_dim)
|
| 180 |
+
|
| 181 |
+
# Pool tokens to reduce sequence length
|
| 182 |
+
x = x.transpose(1, 2) # (B, hidden_dim, N)
|
| 183 |
+
x = self.token_pool(x) # (B, hidden_dim, N//9)
|
| 184 |
+
x = x.transpose(1, 2) # (B, N//9, hidden_dim)
|
| 185 |
+
|
| 186 |
+
# Self-attention with residual on pooled tokens
|
| 187 |
+
x_norm = self.self_attn_norm(x)
|
| 188 |
+
x_attn, _ = self.self_attn(x_norm, x_norm, x_norm, need_weights=False)
|
| 189 |
+
x = x + x_attn
|
| 190 |
+
|
| 191 |
+
# MLP with residual
|
| 192 |
+
x = x + self.mlp1(self.mlp1_norm(x))
|
| 193 |
+
|
| 194 |
+
# Cross-attention with learned query
|
| 195 |
+
query = self.query.expand(batch_size, -1, -1) # (B, 1, hidden_dim)
|
| 196 |
+
q_norm = self.cross_attn_norm(query)
|
| 197 |
+
attended, _ = self.cross_attn(q_norm, x, x, need_weights=False)
|
| 198 |
+
query = query + attended
|
| 199 |
+
|
| 200 |
+
# Final processing
|
| 201 |
+
output = self.final_norm(query).squeeze(1) # (B, hidden_dim)
|
| 202 |
+
output = self.final_mlp(output) # (B, output_dim)
|
| 203 |
+
|
| 204 |
+
return F.normalize(output, dim=-1, eps=1e-3)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class SimpleImageProjector(nn.Module):
|
| 208 |
+
"""Simple projector for mean AION embeddings."""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
input_dim: int = 768,
|
| 213 |
+
output_dim: int = 1024,
|
| 214 |
+
hidden_dim: Optional[int] = None,
|
| 215 |
+
dropout: float = 0.1,
|
| 216 |
+
num_layers: int = 4,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
Initialize simple image projector.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
input_dim: Dimension of AION embeddings (768)
|
| 223 |
+
output_dim: Dimension of shared embedding space
|
| 224 |
+
hidden_dim: Hidden layer dimension (default: 1024)
|
| 225 |
+
dropout: Dropout rate
|
| 226 |
+
num_layers: Number of residual layers (default: 4)
|
| 227 |
+
"""
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
if hidden_dim is None:
|
| 231 |
+
hidden_dim = 1024
|
| 232 |
+
|
| 233 |
+
self.fc_in = nn.Linear(input_dim, hidden_dim)
|
| 234 |
+
self.blocks = nn.ModuleList([
|
| 235 |
+
nn.Sequential(
|
| 236 |
+
nn.LayerNorm(hidden_dim),
|
| 237 |
+
nn.GELU(),
|
| 238 |
+
nn.Dropout(dropout),
|
| 239 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 240 |
+
) for _ in range(num_layers)
|
| 241 |
+
])
|
| 242 |
+
self.fc_out = nn.Linear(hidden_dim, output_dim)
|
| 243 |
+
|
| 244 |
+
# Initialize weights
|
| 245 |
+
self._init_weights()
|
| 246 |
+
|
| 247 |
+
def _init_weights(self):
|
| 248 |
+
"""Initialize projection weights."""
|
| 249 |
+
for module in self.modules():
|
| 250 |
+
if isinstance(module, nn.Linear):
|
| 251 |
+
nn.init.xavier_uniform_(module.weight)
|
| 252 |
+
if module.bias is not None:
|
| 253 |
+
nn.init.zeros_(module.bias)
|
| 254 |
+
|
| 255 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 256 |
+
"""
|
| 257 |
+
Project image embeddings to shared space.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
x: Image embeddings (batch_size, input_dim)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Projected embeddings (batch_size, output_dim)
|
| 264 |
+
"""
|
| 265 |
+
x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input
|
| 266 |
+
h = self.fc_in(x)
|
| 267 |
+
for blk in self.blocks: # residual MLP stack
|
| 268 |
+
h = h + blk(h)
|
| 269 |
+
h = self.fc_out(h)
|
| 270 |
+
return F.normalize(h, dim=-1, eps=1e-3)
|
clip/utils/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for CLIP training and evaluation."""
|
| 2 |
+
|
| 3 |
+
from .logging_utils import setup_logging
|
| 4 |
+
from .io_utils import save_clip_embeddings_hdf5, inspect_generated_files
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"setup_logging",
|
| 8 |
+
"save_clip_embeddings_hdf5",
|
| 9 |
+
"inspect_generated_files"
|
| 10 |
+
]
|
clip/utils/data_loader.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loader for multi-text training using unified parquet file with nested text embeddings.
|
| 3 |
+
This loader handles the new unified format from 05_generate_unified_embeddings.py.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class UnifiedMultiTextDataset(Dataset):
|
| 18 |
+
"""Dataset for unified parquet file with multiple text embeddings per galaxy."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, parquet_path, split="train", train_ratio=0.8,
|
| 21 |
+
text_sampling_strategy="random", epoch=0, max_train_samples=None,
|
| 22 |
+
num_embedding=None):
|
| 23 |
+
self.parquet_path = Path(parquet_path)
|
| 24 |
+
self.split = split
|
| 25 |
+
self.train_ratio = train_ratio
|
| 26 |
+
self.text_sampling_strategy = text_sampling_strategy
|
| 27 |
+
self.epoch = epoch
|
| 28 |
+
self.max_train_samples = max_train_samples
|
| 29 |
+
self.num_embedding = num_embedding
|
| 30 |
+
|
| 31 |
+
# Load the parquet file
|
| 32 |
+
logger.info(f"Loading unified embeddings from {self.parquet_path}")
|
| 33 |
+
self.df = pd.read_parquet(self.parquet_path)
|
| 34 |
+
|
| 35 |
+
# Create train/val split based on galaxy_index
|
| 36 |
+
n_samples = len(self.df)
|
| 37 |
+
indices = np.arange(n_samples)
|
| 38 |
+
self.seed = 42
|
| 39 |
+
|
| 40 |
+
# Deterministic split based on galaxy_index
|
| 41 |
+
split_mask = []
|
| 42 |
+
for idx in range(n_samples):
|
| 43 |
+
galaxy_idx = self.df.iloc[idx]['galaxy_index']
|
| 44 |
+
# Hash the galaxy index for deterministic assignment
|
| 45 |
+
sample_hash = hash((galaxy_idx, self.seed)) % 10000 / 10000.0
|
| 46 |
+
is_train = sample_hash < self.train_ratio
|
| 47 |
+
split_mask.append(is_train)
|
| 48 |
+
|
| 49 |
+
split_mask = np.array(split_mask)
|
| 50 |
+
|
| 51 |
+
if split == "train":
|
| 52 |
+
self.indices = indices[split_mask]
|
| 53 |
+
# Limit training samples if specified
|
| 54 |
+
if self.max_train_samples is not None and len(self.indices) > self.max_train_samples:
|
| 55 |
+
rng = np.random.RandomState(self.seed)
|
| 56 |
+
selected_indices = rng.choice(self.indices, size=self.max_train_samples, replace=False)
|
| 57 |
+
self.indices = np.sort(selected_indices) # Sort for reproducibility
|
| 58 |
+
logger.info(f"Limited training set to {self.max_train_samples} samples")
|
| 59 |
+
else:
|
| 60 |
+
self.indices = indices[~split_mask]
|
| 61 |
+
|
| 62 |
+
logger.info(f"Dataset initialized: {len(self.indices)} samples for {split} split")
|
| 63 |
+
logger.info(f"Text sampling strategy: {text_sampling_strategy}")
|
| 64 |
+
|
| 65 |
+
# Validate num_embedding parameter for specific_summary strategy
|
| 66 |
+
if text_sampling_strategy == "specific_summary" and num_embedding is None:
|
| 67 |
+
raise ValueError("num_embedding parameter is required when using 'specific_summary' strategy")
|
| 68 |
+
|
| 69 |
+
# Check data structure
|
| 70 |
+
sample_row = self.df.iloc[0]
|
| 71 |
+
n_augmented = len(sample_row['augmented_embeddings'])
|
| 72 |
+
logger.info(f"Each galaxy has 1 original + {n_augmented} augmented embeddings = {1 + n_augmented} total")
|
| 73 |
+
|
| 74 |
+
# Validate num_embedding is within valid range
|
| 75 |
+
if text_sampling_strategy == "specific_summary":
|
| 76 |
+
total_embeddings = 1 + n_augmented
|
| 77 |
+
if num_embedding < 0 or num_embedding >= total_embeddings:
|
| 78 |
+
raise ValueError(f"num_embedding must be between 0 and {total_embeddings-1}, got {num_embedding}")
|
| 79 |
+
logger.info(f"Using specific embedding at index {num_embedding}")
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.indices)
|
| 83 |
+
|
| 84 |
+
def set_epoch(self, epoch):
|
| 85 |
+
"""Set current epoch for round-robin sampling."""
|
| 86 |
+
self.epoch = epoch
|
| 87 |
+
|
| 88 |
+
def _get_all_embeddings_and_sources(self, row):
|
| 89 |
+
"""Combine original and augmented embeddings into single lists."""
|
| 90 |
+
# Start with original embedding
|
| 91 |
+
all_embeddings = [np.array(row['text_embedding'], dtype=np.float32)]
|
| 92 |
+
all_sources = [row['description_sources'][0]] # 'original'
|
| 93 |
+
|
| 94 |
+
# Add augmented embeddings
|
| 95 |
+
for aug_emb, aug_source in zip(row['augmented_embeddings'], row['description_sources'][1:]):
|
| 96 |
+
all_embeddings.append(np.array(aug_emb, dtype=np.float32))
|
| 97 |
+
all_sources.append(aug_source)
|
| 98 |
+
|
| 99 |
+
return all_embeddings, all_sources
|
| 100 |
+
|
| 101 |
+
def _sample_text_embedding(self, text_embeddings, text_sources, galaxy_idx):
|
| 102 |
+
"""Sample one text embedding from multiple options."""
|
| 103 |
+
n_texts = len(text_embeddings)
|
| 104 |
+
|
| 105 |
+
if self.text_sampling_strategy == "original":
|
| 106 |
+
# Always use original text (index 0)
|
| 107 |
+
idx = 0
|
| 108 |
+
elif self.text_sampling_strategy == "summaries-only":
|
| 109 |
+
# Only use summaries (exclude original at index 0)
|
| 110 |
+
if n_texts > 1:
|
| 111 |
+
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 112 |
+
idx = rng.randint(1, n_texts - 1) # Start from 1 to exclude original
|
| 113 |
+
else:
|
| 114 |
+
# Fallback to original if no summaries available
|
| 115 |
+
idx = 0
|
| 116 |
+
elif self.text_sampling_strategy == "specific_summary":
|
| 117 |
+
# Use the specific embedding index provided
|
| 118 |
+
if self.num_embedding < n_texts:
|
| 119 |
+
idx = self.num_embedding
|
| 120 |
+
else:
|
| 121 |
+
# Fallback to original if index out of range
|
| 122 |
+
logger.warning(f"Requested embedding index {self.num_embedding} out of range for {n_texts} embeddings, using original")
|
| 123 |
+
idx = 0
|
| 124 |
+
elif self.text_sampling_strategy == "random":
|
| 125 |
+
# Random sampling with seed based on galaxy_idx and epoch
|
| 126 |
+
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 127 |
+
idx = rng.randint(0, n_texts - 1)
|
| 128 |
+
elif self.text_sampling_strategy == "round-robin":
|
| 129 |
+
# Cycle through texts based on epoch
|
| 130 |
+
idx = (self.epoch + galaxy_idx) % n_texts
|
| 131 |
+
elif self.text_sampling_strategy == "weighted":
|
| 132 |
+
# Weight towards original (50%) and summaries (50% / n_summaries each)
|
| 133 |
+
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 134 |
+
n_summaries = n_texts - 1
|
| 135 |
+
if n_summaries > 0:
|
| 136 |
+
summary_weight = 0.5 / n_summaries
|
| 137 |
+
weights = [0.5] + [summary_weight] * n_summaries
|
| 138 |
+
else:
|
| 139 |
+
weights = [1.0]
|
| 140 |
+
idx = rng.choices(range(n_texts), weights=weights)[0]
|
| 141 |
+
else:
|
| 142 |
+
idx = 0 # Default to original
|
| 143 |
+
|
| 144 |
+
return text_embeddings[idx], text_sources[idx], idx
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx):
|
| 147 |
+
"""Get a single sample with randomly selected text embedding."""
|
| 148 |
+
actual_idx = self.indices[idx]
|
| 149 |
+
row = self.df.iloc[actual_idx]
|
| 150 |
+
|
| 151 |
+
# Get AION embedding
|
| 152 |
+
aion_embedding = np.array(row['aion_embedding'], dtype=np.float32)
|
| 153 |
+
|
| 154 |
+
# Get all text embeddings and sources
|
| 155 |
+
text_embeddings, text_sources = self._get_all_embeddings_and_sources(row)
|
| 156 |
+
|
| 157 |
+
# Sample one text embedding
|
| 158 |
+
galaxy_idx = row['galaxy_index']
|
| 159 |
+
selected_text, selected_source, text_idx = self._sample_text_embedding(
|
| 160 |
+
text_embeddings, text_sources, galaxy_idx
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Log selection details periodically (every 100th sample)
|
| 164 |
+
if idx % 100 == 0:
|
| 165 |
+
logger.debug(f"Galaxy {galaxy_idx}: Selected {selected_source} (index {text_idx}) from {len(text_sources)} options")
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
'aion_embedding': torch.from_numpy(aion_embedding),
|
| 169 |
+
'text_embedding': torch.from_numpy(selected_text),
|
| 170 |
+
'galaxy_index': galaxy_idx,
|
| 171 |
+
'text_source': selected_source,
|
| 172 |
+
'text_index': text_idx,
|
| 173 |
+
'object_id': row['object_id']
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def create_unified_multi_text_loaders(
|
| 178 |
+
unified_embeddings_path,
|
| 179 |
+
batch_size=64,
|
| 180 |
+
train_ratio=0.8,
|
| 181 |
+
pin_memory=True,
|
| 182 |
+
text_sampling_strategy="random",
|
| 183 |
+
num_workers=4,
|
| 184 |
+
max_train_samples=None,
|
| 185 |
+
num_embedding=None,
|
| 186 |
+
**kwargs
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Create train and validation data loaders for multi-text training from unified parquet.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
unified_embeddings_path: Path to unified parquet file
|
| 193 |
+
batch_size: Batch size for training
|
| 194 |
+
train_ratio: Fraction of samples for training
|
| 195 |
+
pin_memory: Whether to pin memory for GPU transfer
|
| 196 |
+
text_sampling_strategy: How to sample text embeddings ("original", "summaries-only", "specific_summary", "random", "round-robin", "weighted")
|
| 197 |
+
num_workers: Number of data loading workers
|
| 198 |
+
max_train_samples: Maximum number of training samples (for data scaling experiments)
|
| 199 |
+
num_embedding: When using "specific_summary" strategy, the index of the embedding to use
|
| 200 |
+
**kwargs: Additional arguments
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
# Convert to Path
|
| 204 |
+
parquet_path = Path(unified_embeddings_path)
|
| 205 |
+
|
| 206 |
+
if not parquet_path.exists():
|
| 207 |
+
raise ValueError(f"Unified embeddings file not found: {parquet_path}")
|
| 208 |
+
|
| 209 |
+
logger.info(f"Creating unified multi-text data loaders from {parquet_path}")
|
| 210 |
+
logger.info(f"Batch size: {batch_size}, Workers: {num_workers}")
|
| 211 |
+
logger.info(f"Text sampling strategy: {text_sampling_strategy}")
|
| 212 |
+
|
| 213 |
+
# Create datasets
|
| 214 |
+
train_dataset = UnifiedMultiTextDataset(
|
| 215 |
+
parquet_path=parquet_path,
|
| 216 |
+
split="train",
|
| 217 |
+
train_ratio=train_ratio,
|
| 218 |
+
text_sampling_strategy=text_sampling_strategy,
|
| 219 |
+
max_train_samples=max_train_samples,
|
| 220 |
+
num_embedding=num_embedding
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
val_dataset = UnifiedMultiTextDataset(
|
| 224 |
+
parquet_path=parquet_path,
|
| 225 |
+
split="val",
|
| 226 |
+
train_ratio=train_ratio,
|
| 227 |
+
text_sampling_strategy=text_sampling_strategy,
|
| 228 |
+
num_embedding=num_embedding
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Create loaders
|
| 232 |
+
train_loader = DataLoader(
|
| 233 |
+
train_dataset,
|
| 234 |
+
batch_size=batch_size,
|
| 235 |
+
shuffle=True, # Shuffle within the train split
|
| 236 |
+
num_workers=num_workers,
|
| 237 |
+
pin_memory=pin_memory,
|
| 238 |
+
drop_last=True # Drop incomplete batches for stable training
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
val_loader = DataLoader(
|
| 242 |
+
val_dataset,
|
| 243 |
+
batch_size=batch_size,
|
| 244 |
+
shuffle=False, # No shuffle for validation
|
| 245 |
+
num_workers=num_workers,
|
| 246 |
+
pin_memory=pin_memory,
|
| 247 |
+
drop_last=False
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return train_loader, val_loader
|
clip/utils/io_utils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
I/O utilities for saving and loading CLIP embeddings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import h5py
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def save_clip_embeddings_hdf5(
|
| 15 |
+
object_ids,
|
| 16 |
+
galaxy_data,
|
| 17 |
+
text_data,
|
| 18 |
+
aion_clip_embeddings,
|
| 19 |
+
text_clip_embeddings,
|
| 20 |
+
output_dir="data/processed"
|
| 21 |
+
):
|
| 22 |
+
"""Save CLIP embeddings to separate HDF5 files."""
|
| 23 |
+
output_dir = Path(output_dir)
|
| 24 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# File paths (standardized names)
|
| 27 |
+
aion_clip_path = output_dir / "galaxy_aion_clip_embeddings.hdf5"
|
| 28 |
+
text_clip_path = output_dir / "galaxy_text_clip_embeddings.hdf5"
|
| 29 |
+
|
| 30 |
+
logger.info(f"Saving AION CLIP embeddings to: {aion_clip_path}")
|
| 31 |
+
|
| 32 |
+
# Save AION CLIP embeddings
|
| 33 |
+
with h5py.File(aion_clip_path, 'w') as f:
|
| 34 |
+
# Object IDs
|
| 35 |
+
dt = h5py.special_dtype(vlen=str)
|
| 36 |
+
f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
|
| 37 |
+
|
| 38 |
+
# Coordinates and metadata
|
| 39 |
+
ra_values = np.array([galaxy_data[oid]['ra'] for oid in object_ids])
|
| 40 |
+
dec_values = np.array([galaxy_data[oid]['dec'] for oid in object_ids])
|
| 41 |
+
healpix_values = np.array([galaxy_data[oid]['healpix'] for oid in object_ids])
|
| 42 |
+
|
| 43 |
+
f.create_dataset('ra', data=ra_values, dtype=np.float64)
|
| 44 |
+
f.create_dataset('dec', data=dec_values, dtype=np.float64)
|
| 45 |
+
f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
|
| 46 |
+
|
| 47 |
+
# AION CLIP embeddings
|
| 48 |
+
f.create_dataset('AION_clip_embedding', data=aion_clip_embeddings, dtype=np.float32)
|
| 49 |
+
|
| 50 |
+
# Metadata
|
| 51 |
+
f.attrs['description'] = 'AION embeddings encoded through trained CLIP model'
|
| 52 |
+
f.attrs['embedding_dim'] = aion_clip_embeddings.shape[1]
|
| 53 |
+
f.attrs['num_objects'] = len(object_ids)
|
| 54 |
+
f.attrs['created'] = datetime.now().isoformat()
|
| 55 |
+
|
| 56 |
+
logger.info(f"Saving text CLIP embeddings to: {text_clip_path}")
|
| 57 |
+
|
| 58 |
+
# Save text CLIP embeddings
|
| 59 |
+
with h5py.File(text_clip_path, 'w') as f:
|
| 60 |
+
# Object IDs
|
| 61 |
+
dt = h5py.special_dtype(vlen=str)
|
| 62 |
+
f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
|
| 63 |
+
|
| 64 |
+
# Coordinates and metadata (use text data for consistency)
|
| 65 |
+
ra_values = np.array([text_data[oid]['ra'] for oid in object_ids])
|
| 66 |
+
dec_values = np.array([text_data[oid]['dec'] for oid in object_ids])
|
| 67 |
+
healpix_values = np.array([text_data[oid]['healpix'] for oid in object_ids])
|
| 68 |
+
|
| 69 |
+
f.create_dataset('ra', data=ra_values, dtype=np.float64)
|
| 70 |
+
f.create_dataset('dec', data=dec_values, dtype=np.float64)
|
| 71 |
+
f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
|
| 72 |
+
|
| 73 |
+
# Text CLIP embeddings
|
| 74 |
+
f.create_dataset('text_clip_embedding', data=text_clip_embeddings, dtype=np.float32)
|
| 75 |
+
|
| 76 |
+
# Metadata
|
| 77 |
+
f.attrs['description'] = 'Text embeddings encoded through trained CLIP model'
|
| 78 |
+
f.attrs['embedding_dim'] = text_clip_embeddings.shape[1]
|
| 79 |
+
f.attrs['num_objects'] = len(object_ids)
|
| 80 |
+
f.attrs['created'] = datetime.now().isoformat()
|
| 81 |
+
|
| 82 |
+
return aion_clip_path, text_clip_path
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def inspect_generated_files(aion_clip_path, text_clip_path):
|
| 86 |
+
"""Inspect the generated HDF5 files."""
|
| 87 |
+
logger.info("Inspecting generated AION CLIP embeddings file...")
|
| 88 |
+
|
| 89 |
+
with h5py.File(aion_clip_path, 'r') as f:
|
| 90 |
+
logger.info(f"AION file datasets: {list(f.keys())}")
|
| 91 |
+
for key in f.keys():
|
| 92 |
+
dataset = f[key]
|
| 93 |
+
logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
|
| 94 |
+
logger.info(f" Attributes: {dict(f.attrs)}")
|
| 95 |
+
|
| 96 |
+
logger.info("Inspecting generated text CLIP embeddings file...")
|
| 97 |
+
|
| 98 |
+
with h5py.File(text_clip_path, 'r') as f:
|
| 99 |
+
logger.info(f"Text file datasets: {list(f.keys())}")
|
| 100 |
+
for key in f.keys():
|
| 101 |
+
dataset = f[key]
|
| 102 |
+
logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
|
| 103 |
+
logger.info(f" Attributes: {dict(f.attrs)}")
|
clip/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def setup_logging(log_level: str = "INFO", log_file: str = None):
|
| 9 |
+
"""
|
| 10 |
+
Setup logging configuration.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
| 14 |
+
log_file: Optional path to log file
|
| 15 |
+
"""
|
| 16 |
+
# Clear any existing handlers
|
| 17 |
+
logging.getLogger().handlers.clear()
|
| 18 |
+
|
| 19 |
+
# Create formatter
|
| 20 |
+
formatter = logging.Formatter(
|
| 21 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Console handler
|
| 25 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 26 |
+
console_handler.setFormatter(formatter)
|
| 27 |
+
|
| 28 |
+
# Setup root logger
|
| 29 |
+
logger = logging.getLogger()
|
| 30 |
+
logger.setLevel(getattr(logging, log_level.upper()))
|
| 31 |
+
logger.addHandler(console_handler)
|
| 32 |
+
|
| 33 |
+
# File handler if specified
|
| 34 |
+
if log_file:
|
| 35 |
+
log_path = Path(log_file)
|
| 36 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
file_handler = logging.FileHandler(log_path)
|
| 39 |
+
file_handler.setFormatter(formatter)
|
| 40 |
+
logger.addHandler(file_handler)
|
| 41 |
+
|
| 42 |
+
return logger
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from aion-search!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dash==2.14.1
|
| 2 |
+
dash-bootstrap-components==1.5.0
|
| 3 |
+
h5py==3.10.0
|
| 4 |
+
numpy==1.24.3
|
| 5 |
+
openai==1.10.0
|
| 6 |
+
httpx==0.26.0
|
| 7 |
+
gunicorn==21.2.0
|
| 8 |
+
huggingface-hub==0.20.1
|
| 9 |
+
pandas==2.0.3
|
| 10 |
+
faiss-cpu==1.7.4
|
| 11 |
+
python-dotenv==1.1.1
|
| 12 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 13 |
+
requests
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AION Search - Galaxy Semantic Search Application."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.2.0"
|
src/callbacks.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dash callbacks for AION Search."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import logging
|
| 6 |
+
import traceback
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import dash
|
| 9 |
+
from dash import Input, Output, State, callback_context, html
|
| 10 |
+
import dash_bootstrap_components as dbc
|
| 11 |
+
|
| 12 |
+
import src.config as config
|
| 13 |
+
from src.config import (
|
| 14 |
+
DEFAULT_DISPLAY_COUNT,
|
| 15 |
+
LOAD_MORE_COUNT,
|
| 16 |
+
IMAGE_HEIGHT,
|
| 17 |
+
IMAGE_WIDTH,
|
| 18 |
+
ZILLIZ_PRIMARY_KEY,
|
| 19 |
+
)
|
| 20 |
+
from src.components import create_vector_input_row
|
| 21 |
+
from src.services import SearchService
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def register_callbacks(app, search_service: SearchService):
|
| 27 |
+
"""Register all Dash callbacks with the app.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
app: Dash app instance
|
| 31 |
+
search_service: SearchService instance for performing searches
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@app.callback(
|
| 35 |
+
Output("galaxy-count", "children"),
|
| 36 |
+
Input("galaxy-count", "id")
|
| 37 |
+
)
|
| 38 |
+
def update_galaxy_count(_):
|
| 39 |
+
"""Update the galaxy count display."""
|
| 40 |
+
if search_service and config.TOTAL_GALAXIES > 0:
|
| 41 |
+
return f"{config.TOTAL_GALAXIES:,} galaxies"
|
| 42 |
+
else:
|
| 43 |
+
return "loading..."
|
| 44 |
+
|
| 45 |
+
@app.callback(
|
| 46 |
+
[Output("vector-collapse", "is_open"),
|
| 47 |
+
Output("vector-arrow", "className")],
|
| 48 |
+
Input("vector-toggle", "n_clicks"),
|
| 49 |
+
State("vector-collapse", "is_open"),
|
| 50 |
+
prevent_initial_call=True
|
| 51 |
+
)
|
| 52 |
+
def toggle_vector_section(n_clicks, is_open):
|
| 53 |
+
"""Toggle the vector addition section."""
|
| 54 |
+
new_state = not is_open
|
| 55 |
+
arrow_class = "fas fa-chevron-up" if new_state else "fas fa-chevron-down"
|
| 56 |
+
return new_state, arrow_class
|
| 57 |
+
|
| 58 |
+
@app.callback(
|
| 59 |
+
Output("vector-inputs", "children", allow_duplicate=True),
|
| 60 |
+
Input({"type": "vector-delete", "index": dash.dependencies.ALL}, "n_clicks"),
|
| 61 |
+
State("vector-inputs", "children"),
|
| 62 |
+
prevent_initial_call=True
|
| 63 |
+
)
|
| 64 |
+
def delete_vector_input(n_clicks_list, current_children):
|
| 65 |
+
"""Handle deletion of vector input rows."""
|
| 66 |
+
if not n_clicks_list or not any(n_clicks_list):
|
| 67 |
+
return dash.no_update
|
| 68 |
+
|
| 69 |
+
ctx = callback_context
|
| 70 |
+
if not ctx.triggered:
|
| 71 |
+
return dash.no_update
|
| 72 |
+
|
| 73 |
+
if ctx.triggered[0]["value"] is None or ctx.triggered[0]["value"] == 0:
|
| 74 |
+
return dash.no_update
|
| 75 |
+
|
| 76 |
+
button_id = ctx.triggered[0]["prop_id"]
|
| 77 |
+
index_to_delete = json.loads(button_id.split(".")[0])["index"]
|
| 78 |
+
|
| 79 |
+
logger.info(f"Delete button clicked for index: {index_to_delete}")
|
| 80 |
+
|
| 81 |
+
# Filter out the row with the matching index
|
| 82 |
+
new_children = []
|
| 83 |
+
for child in current_children:
|
| 84 |
+
should_keep = True
|
| 85 |
+
|
| 86 |
+
if isinstance(child, dict):
|
| 87 |
+
if 'props' in child and 'id' in child['props']:
|
| 88 |
+
child_id = child['props']['id']
|
| 89 |
+
if isinstance(child_id, dict) and child_id.get("type") == "vector-row" and child_id.get("index") == index_to_delete:
|
| 90 |
+
should_keep = False
|
| 91 |
+
elif hasattr(child, 'id') and isinstance(child.id, dict):
|
| 92 |
+
if child.id.get("type") == "vector-row" and child.id.get("index") == index_to_delete:
|
| 93 |
+
should_keep = False
|
| 94 |
+
|
| 95 |
+
if should_keep:
|
| 96 |
+
new_children.append(child)
|
| 97 |
+
|
| 98 |
+
# Ensure at least one input remains
|
| 99 |
+
if len(new_children) == 0:
|
| 100 |
+
new_children = [create_vector_input_row(0)]
|
| 101 |
+
|
| 102 |
+
return new_children
|
| 103 |
+
|
| 104 |
+
@app.callback(
|
| 105 |
+
[Output("vector-inputs", "children"),
|
| 106 |
+
Output("vector-inputs-count", "data")],
|
| 107 |
+
Input("add-vector-input", "n_clicks"),
|
| 108 |
+
[State("vector-inputs", "children"),
|
| 109 |
+
State("vector-inputs-count", "data")],
|
| 110 |
+
prevent_initial_call=True
|
| 111 |
+
)
|
| 112 |
+
def add_vector_input(n_clicks, current_children, count):
|
| 113 |
+
"""Add a new vector input row."""
|
| 114 |
+
if n_clicks:
|
| 115 |
+
new_input = create_vector_input_row(count)
|
| 116 |
+
current_children.append(new_input)
|
| 117 |
+
return current_children, count + 1
|
| 118 |
+
|
| 119 |
+
return dash.no_update, dash.no_update
|
| 120 |
+
|
| 121 |
+
@app.callback(
|
| 122 |
+
[Output({"type": "text-input-container", "index": dash.dependencies.ALL}, "style"),
|
| 123 |
+
Output({"type": "image-input-container", "index": dash.dependencies.ALL}, "style")],
|
| 124 |
+
Input({"type": "vector-query-type", "index": dash.dependencies.ALL}, "value"),
|
| 125 |
+
prevent_initial_call=False
|
| 126 |
+
)
|
| 127 |
+
def toggle_query_type_inputs(query_types):
|
| 128 |
+
"""Toggle visibility of text vs image inputs based on query type selection."""
|
| 129 |
+
text_styles = []
|
| 130 |
+
image_styles = []
|
| 131 |
+
|
| 132 |
+
for query_type in query_types:
|
| 133 |
+
if query_type == "text":
|
| 134 |
+
text_styles.append({"display": "block"})
|
| 135 |
+
image_styles.append({"display": "none"})
|
| 136 |
+
else: # image
|
| 137 |
+
text_styles.append({"display": "none"})
|
| 138 |
+
image_styles.append({"display": "block"})
|
| 139 |
+
|
| 140 |
+
return text_styles, image_styles
|
| 141 |
+
|
| 142 |
+
@app.callback(
|
| 143 |
+
[Output("search-button", "n_clicks"),
|
| 144 |
+
Output("search-input", "value")],
|
| 145 |
+
[Input("example-1", "n_clicks"),
|
| 146 |
+
Input("example-2", "n_clicks"),
|
| 147 |
+
Input("example-3", "n_clicks"),
|
| 148 |
+
Input("example-4", "n_clicks"),
|
| 149 |
+
Input("example-5", "n_clicks"),
|
| 150 |
+
Input("example-6", "n_clicks"),
|
| 151 |
+
Input("example-7", "n_clicks")],
|
| 152 |
+
[State("search-button", "n_clicks")],
|
| 153 |
+
prevent_initial_call=True
|
| 154 |
+
)
|
| 155 |
+
def trigger_search_from_examples(click1, click2, click3, click4, click5, click6, click7, current_clicks):
|
| 156 |
+
"""Trigger search when example buttons are clicked."""
|
| 157 |
+
ctx = callback_context
|
| 158 |
+
if not ctx.triggered:
|
| 159 |
+
return dash.no_update, dash.no_update
|
| 160 |
+
|
| 161 |
+
button_id = ctx.triggered[0]["prop_id"].split(".")[0]
|
| 162 |
+
|
| 163 |
+
example_queries = {
|
| 164 |
+
"example-1": "Merging edge-on galaxy",
|
| 165 |
+
"example-2": "A peculiar interacting galaxy system featuring plenty of tidal tails and a disturbed morphology",
|
| 166 |
+
"example-3": "a faint tidal stream wrapping around",
|
| 167 |
+
"example-4": "Strong gravitational lens",
|
| 168 |
+
"example-5": "A violent merger in progress with visible tidal features",
|
| 169 |
+
"example-6": "Low surface brightness",
|
| 170 |
+
"example-7": "Ring galaxy"
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
search_query = example_queries.get(button_id, "")
|
| 174 |
+
|
| 175 |
+
if search_query:
|
| 176 |
+
return (current_clicks or 0) + 1, search_query
|
| 177 |
+
|
| 178 |
+
return dash.no_update, dash.no_update
|
| 179 |
+
|
| 180 |
+
@app.callback(
|
| 181 |
+
[Output("search-time", "children"),
|
| 182 |
+
Output("search-results", "children"),
|
| 183 |
+
Output("search-data", "data"),
|
| 184 |
+
Output("download-button", "disabled")],
|
| 185 |
+
[Input("search-button", "n_clicks"),
|
| 186 |
+
Input("search-input", "n_submit")],
|
| 187 |
+
[State("search-input", "value"),
|
| 188 |
+
State("rmag-slider", "value")],
|
| 189 |
+
prevent_initial_call=True
|
| 190 |
+
)
|
| 191 |
+
def perform_search(n_clicks, n_submit, query, rmag_range):
|
| 192 |
+
"""Perform text search."""
|
| 193 |
+
if not query or not query.strip():
|
| 194 |
+
return "", dbc.Alert("Please enter a search query", color="warning"), None, True
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Extract min and max from slider range
|
| 198 |
+
rmag_min, rmag_max = rmag_range if rmag_range else (None, None)
|
| 199 |
+
|
| 200 |
+
start_time = time.time()
|
| 201 |
+
df = search_service.search_text(query, rmag_min=rmag_min, rmag_max=rmag_max)
|
| 202 |
+
search_time = time.time() - start_time
|
| 203 |
+
|
| 204 |
+
# Log query to XML/CSV
|
| 205 |
+
from src.utils import build_query_xml, log_query_to_csv
|
| 206 |
+
query_xml = build_query_xml(
|
| 207 |
+
text_queries=[query],
|
| 208 |
+
text_weights=[1.0],
|
| 209 |
+
rmag_min=rmag_min,
|
| 210 |
+
rmag_max=rmag_max
|
| 211 |
+
)
|
| 212 |
+
log_query_to_csv(query_xml)
|
| 213 |
+
|
| 214 |
+
# Build results grid - only load first 60 images
|
| 215 |
+
grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT))
|
| 216 |
+
|
| 217 |
+
# Prepare data for store
|
| 218 |
+
search_data = prepare_search_data(df, query)
|
| 219 |
+
|
| 220 |
+
# Create load more button
|
| 221 |
+
load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None
|
| 222 |
+
|
| 223 |
+
# Build filter description
|
| 224 |
+
filter_desc = ""
|
| 225 |
+
if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0):
|
| 226 |
+
filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]"
|
| 227 |
+
|
| 228 |
+
# Build complete results container
|
| 229 |
+
results_container = html.Div([
|
| 230 |
+
html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})",
|
| 231 |
+
className="results-header mb-2 text-center"),
|
| 232 |
+
html.P(f"'{query}'{filter_desc}",
|
| 233 |
+
className="text-center mb-3",
|
| 234 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}),
|
| 235 |
+
dbc.Row(grid_items, justify="center", id="search-results-grid"),
|
| 236 |
+
load_more_button
|
| 237 |
+
])
|
| 238 |
+
|
| 239 |
+
return "", results_container, search_data, False
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
error_msg = dbc.Alert(f"Search failed: {str(e)}", color="danger")
|
| 243 |
+
logger.error(f"Search error: {e}")
|
| 244 |
+
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
| 245 |
+
return "", error_msg, None, True
|
| 246 |
+
|
| 247 |
+
@app.callback(
|
| 248 |
+
[Output("galaxy-modal", "is_open"),
|
| 249 |
+
Output("modal-title", "children"),
|
| 250 |
+
Output("modal-image", "children"),
|
| 251 |
+
Output("modal-description", "children"),
|
| 252 |
+
Output("current-galaxy-data", "data")],
|
| 253 |
+
[Input({"type": "galaxy-image", "index": dash.dependencies.ALL}, "n_clicks"),
|
| 254 |
+
Input("close-modal", "n_clicks")],
|
| 255 |
+
[State("galaxy-modal", "is_open"),
|
| 256 |
+
State("search-data", "data")],
|
| 257 |
+
prevent_initial_call=True
|
| 258 |
+
)
|
| 259 |
+
def toggle_modal(image_clicks, close_click, is_open, search_data):
|
| 260 |
+
"""Toggle galaxy detail modal."""
|
| 261 |
+
ctx = callback_context
|
| 262 |
+
|
| 263 |
+
if not ctx.triggered:
|
| 264 |
+
return False, "", "", "", None
|
| 265 |
+
|
| 266 |
+
if ctx.triggered[0]["prop_id"] == "close-modal.n_clicks":
|
| 267 |
+
return False, "", "", "", None
|
| 268 |
+
|
| 269 |
+
if search_data:
|
| 270 |
+
triggered_prop = ctx.triggered[0]["prop_id"]
|
| 271 |
+
triggered_value = ctx.triggered[0]["value"]
|
| 272 |
+
|
| 273 |
+
if triggered_value is None or triggered_value == 0:
|
| 274 |
+
return False, "", "", "", None
|
| 275 |
+
|
| 276 |
+
if "galaxy-image" in triggered_prop:
|
| 277 |
+
try:
|
| 278 |
+
prop_dict = json.loads(triggered_prop.split(".n_clicks")[0])
|
| 279 |
+
clicked_idx = prop_dict["index"]
|
| 280 |
+
|
| 281 |
+
if clicked_idx < len(search_data["ra"]):
|
| 282 |
+
galaxy_info = extract_galaxy_info(search_data, clicked_idx)
|
| 283 |
+
image_element, description_element = build_modal_content(galaxy_info)
|
| 284 |
+
|
| 285 |
+
galaxy_data = {
|
| 286 |
+
ZILLIZ_PRIMARY_KEY: galaxy_info[ZILLIZ_PRIMARY_KEY],
|
| 287 |
+
"ra": galaxy_info["ra"],
|
| 288 |
+
"dec": galaxy_info["dec"],
|
| 289 |
+
"distance": galaxy_info["distance"],
|
| 290 |
+
"r_mag": galaxy_info["r_mag"]
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return (
|
| 294 |
+
True,
|
| 295 |
+
f"Galaxy at RA={galaxy_info['ra']:.6f}, Dec={galaxy_info['dec']:.6f}",
|
| 296 |
+
image_element,
|
| 297 |
+
description_element,
|
| 298 |
+
galaxy_data
|
| 299 |
+
)
|
| 300 |
+
except:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
return False, "", "", "", None
|
| 304 |
+
|
| 305 |
+
@app.callback(
|
| 306 |
+
Output("info-modal", "is_open"),
|
| 307 |
+
[Input("info-button", "n_clicks"),
|
| 308 |
+
Input("close-info-modal", "n_clicks")],
|
| 309 |
+
State("info-modal", "is_open"),
|
| 310 |
+
prevent_initial_call=True
|
| 311 |
+
)
|
| 312 |
+
def toggle_info_modal(info_click, close_click, is_open):
|
| 313 |
+
"""Toggle info modal."""
|
| 314 |
+
ctx = callback_context
|
| 315 |
+
if ctx.triggered:
|
| 316 |
+
button_id = ctx.triggered[0]["prop_id"].split(".")[0]
|
| 317 |
+
if button_id == "info-button":
|
| 318 |
+
return True
|
| 319 |
+
elif button_id == "close-info-modal":
|
| 320 |
+
return False
|
| 321 |
+
return is_open
|
| 322 |
+
|
| 323 |
+
@app.callback(
|
| 324 |
+
[Output("search-results", "children", allow_duplicate=True),
|
| 325 |
+
Output("search-data", "data", allow_duplicate=True)],
|
| 326 |
+
Input("load-more-button", "n_clicks"),
|
| 327 |
+
State("search-data", "data"),
|
| 328 |
+
prevent_initial_call=True
|
| 329 |
+
)
|
| 330 |
+
def load_more_galaxies(n_clicks, search_data):
|
| 331 |
+
"""Load more galaxies when the load more button is clicked."""
|
| 332 |
+
if n_clicks and search_data and "loaded_count" in search_data:
|
| 333 |
+
current_count = search_data["loaded_count"]
|
| 334 |
+
total_count = len(search_data["ra"])
|
| 335 |
+
next_count = min(current_count + LOAD_MORE_COUNT, total_count)
|
| 336 |
+
|
| 337 |
+
# Build ALL grid items (existing + new)
|
| 338 |
+
all_grid_items = []
|
| 339 |
+
for i in range(next_count):
|
| 340 |
+
galaxy_info = extract_galaxy_info(search_data, i)
|
| 341 |
+
grid_item = build_galaxy_card(galaxy_info, i)
|
| 342 |
+
all_grid_items.append(grid_item)
|
| 343 |
+
|
| 344 |
+
search_data["loaded_count"] = next_count
|
| 345 |
+
|
| 346 |
+
load_more_button = create_load_more_button(total_count, next_count) if next_count < total_count else None
|
| 347 |
+
|
| 348 |
+
results_container = html.Div([
|
| 349 |
+
html.P(f"Top {total_count} matching galaxies (showing {next_count})",
|
| 350 |
+
className="results-header mb-2 text-center"),
|
| 351 |
+
html.P(f"'{search_data['query']}'",
|
| 352 |
+
className="text-center mb-3",
|
| 353 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}),
|
| 354 |
+
dbc.Row(all_grid_items, justify="center", id="search-results-grid"),
|
| 355 |
+
load_more_button
|
| 356 |
+
])
|
| 357 |
+
|
| 358 |
+
return results_container, search_data
|
| 359 |
+
|
| 360 |
+
return dash.no_update, dash.no_update
|
| 361 |
+
|
| 362 |
+
@app.callback(
|
| 363 |
+
[Output("vector-inputs", "children", allow_duplicate=True),
|
| 364 |
+
Output("vector-inputs-count", "data", allow_duplicate=True),
|
| 365 |
+
Output("vector-collapse", "is_open", allow_duplicate=True),
|
| 366 |
+
Output("galaxy-modal", "is_open", allow_duplicate=True)],
|
| 367 |
+
Input("add-to-advanced-search", "n_clicks"),
|
| 368 |
+
[State("current-galaxy-data", "data"),
|
| 369 |
+
State("vector-inputs", "children"),
|
| 370 |
+
State("vector-inputs-count", "data")],
|
| 371 |
+
prevent_initial_call=True
|
| 372 |
+
)
|
| 373 |
+
def add_galaxy_to_advanced_search(n_clicks, galaxy_data, current_children, count):
|
| 374 |
+
"""Add the current galaxy's RA/Dec to advanced search."""
|
| 375 |
+
if not n_clicks or not galaxy_data:
|
| 376 |
+
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
|
| 377 |
+
|
| 378 |
+
# Extract galaxy coordinates
|
| 379 |
+
ra = galaxy_data.get('ra')
|
| 380 |
+
dec = galaxy_data.get('dec')
|
| 381 |
+
|
| 382 |
+
if ra is None or dec is None:
|
| 383 |
+
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
|
| 384 |
+
|
| 385 |
+
# Create a new image input row with the galaxy's RA/Dec pre-filled
|
| 386 |
+
new_row = create_vector_input_row(
|
| 387 |
+
index=count,
|
| 388 |
+
query_type="image",
|
| 389 |
+
ra=ra,
|
| 390 |
+
dec=dec,
|
| 391 |
+
fov=0.025
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
current_children.append(new_row)
|
| 395 |
+
|
| 396 |
+
# Return updated children, incremented count, open vector panel, close modal
|
| 397 |
+
return current_children, count + 1, True, False
|
| 398 |
+
|
| 399 |
+
@app.callback(
|
| 400 |
+
[Output("search-time", "children", allow_duplicate=True),
|
| 401 |
+
Output("search-results", "children", allow_duplicate=True),
|
| 402 |
+
Output("search-data", "data", allow_duplicate=True),
|
| 403 |
+
Output("download-button", "disabled", allow_duplicate=True)],
|
| 404 |
+
Input("vector-search-button", "n_clicks"),
|
| 405 |
+
[State({"type": "vector-query-type", "index": dash.dependencies.ALL}, "value"),
|
| 406 |
+
State({"type": "vector-text", "index": dash.dependencies.ALL}, "value"),
|
| 407 |
+
State({"type": "vector-ra", "index": dash.dependencies.ALL}, "value"),
|
| 408 |
+
State({"type": "vector-dec", "index": dash.dependencies.ALL}, "value"),
|
| 409 |
+
State({"type": "vector-fov", "index": dash.dependencies.ALL}, "value"),
|
| 410 |
+
State({"type": "vector-operation", "index": dash.dependencies.ALL}, "value"),
|
| 411 |
+
State("rmag-slider", "value")],
|
| 412 |
+
prevent_initial_call=True
|
| 413 |
+
)
|
| 414 |
+
def perform_vector_search(n_clicks, query_types, text_values, ra_values, dec_values, fov_values, operations, rmag_range):
|
| 415 |
+
"""Perform advanced vector search with multiple text and/or image queries."""
|
| 416 |
+
if not n_clicks:
|
| 417 |
+
return dash.no_update, dash.no_update, dash.no_update, dash.no_update
|
| 418 |
+
|
| 419 |
+
def operation_to_weight(op_str):
|
| 420 |
+
"""Convert operation string to float weight."""
|
| 421 |
+
if op_str == "+":
|
| 422 |
+
return 1.0
|
| 423 |
+
elif op_str == "-":
|
| 424 |
+
return -1.0
|
| 425 |
+
else:
|
| 426 |
+
# For magnitude values like "+2", "-5", etc.
|
| 427 |
+
return float(op_str)
|
| 428 |
+
|
| 429 |
+
def weight_to_display(weight):
|
| 430 |
+
"""Convert weight back to display string."""
|
| 431 |
+
if weight == 1.0:
|
| 432 |
+
return "+"
|
| 433 |
+
elif weight == -1.0:
|
| 434 |
+
return "-"
|
| 435 |
+
elif weight > 0:
|
| 436 |
+
return f"+{int(weight)}"
|
| 437 |
+
else:
|
| 438 |
+
return str(int(weight))
|
| 439 |
+
|
| 440 |
+
# Parse inputs to separate text and image queries
|
| 441 |
+
text_queries = []
|
| 442 |
+
text_weights = []
|
| 443 |
+
image_queries = []
|
| 444 |
+
image_weights = []
|
| 445 |
+
|
| 446 |
+
for i, query_type in enumerate(query_types):
|
| 447 |
+
operation = operations[i]
|
| 448 |
+
weight = operation_to_weight(operation)
|
| 449 |
+
|
| 450 |
+
if query_type == "text":
|
| 451 |
+
text_value = text_values[i]
|
| 452 |
+
if text_value and text_value.strip():
|
| 453 |
+
text_queries.append(text_value.strip())
|
| 454 |
+
text_weights.append(weight)
|
| 455 |
+
else: # image
|
| 456 |
+
ra = ra_values[i]
|
| 457 |
+
dec = dec_values[i]
|
| 458 |
+
fov = fov_values[i] if fov_values[i] else 0.025
|
| 459 |
+
|
| 460 |
+
if ra is not None and dec is not None:
|
| 461 |
+
image_queries.append({
|
| 462 |
+
'ra': float(ra),
|
| 463 |
+
'dec': float(dec),
|
| 464 |
+
'fov': float(fov)
|
| 465 |
+
})
|
| 466 |
+
image_weights.append(weight)
|
| 467 |
+
|
| 468 |
+
# Validate that we have at least one query
|
| 469 |
+
if not text_queries and not image_queries:
|
| 470 |
+
return "", dbc.Alert("Please enter at least one text or image query", color="warning"), None, True
|
| 471 |
+
|
| 472 |
+
try:
|
| 473 |
+
# Extract min and max from slider range
|
| 474 |
+
rmag_min, rmag_max = rmag_range if rmag_range else (None, None)
|
| 475 |
+
|
| 476 |
+
# Perform advanced search
|
| 477 |
+
start_time = time.time()
|
| 478 |
+
df = search_service.search_advanced(
|
| 479 |
+
text_queries=text_queries if text_queries else None,
|
| 480 |
+
text_weights=text_weights if text_weights else None,
|
| 481 |
+
image_queries=image_queries if image_queries else None,
|
| 482 |
+
image_weights=image_weights if image_weights else None,
|
| 483 |
+
rmag_min=rmag_min,
|
| 484 |
+
rmag_max=rmag_max
|
| 485 |
+
)
|
| 486 |
+
search_time = time.time() - start_time
|
| 487 |
+
|
| 488 |
+
# Log query to XML/CSV
|
| 489 |
+
from src.utils import build_query_xml, log_query_to_csv
|
| 490 |
+
query_xml = build_query_xml(
|
| 491 |
+
text_queries=text_queries if text_queries else None,
|
| 492 |
+
text_weights=text_weights if text_weights else None,
|
| 493 |
+
image_queries=image_queries if image_queries else None,
|
| 494 |
+
image_weights=image_weights if image_weights else None,
|
| 495 |
+
rmag_min=rmag_min,
|
| 496 |
+
rmag_max=rmag_max
|
| 497 |
+
)
|
| 498 |
+
log_query_to_csv(query_xml)
|
| 499 |
+
|
| 500 |
+
# Build results grid
|
| 501 |
+
grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT))
|
| 502 |
+
|
| 503 |
+
# Build query description for storage (simple text)
|
| 504 |
+
query_desc_parts = []
|
| 505 |
+
for query, weight in zip(text_queries, text_weights):
|
| 506 |
+
op_display = weight_to_display(weight)
|
| 507 |
+
query_desc_parts.append(f"{op_display} text:'{query}'")
|
| 508 |
+
for img_query, weight in zip(image_queries, image_weights):
|
| 509 |
+
op_display = weight_to_display(weight)
|
| 510 |
+
query_desc_parts.append(f"{op_display} image:(RA={img_query['ra']:.2f}, Dec={img_query['dec']:.2f})")
|
| 511 |
+
query_description = " ".join(query_desc_parts)
|
| 512 |
+
|
| 513 |
+
# Build query display with thumbnails for images
|
| 514 |
+
query_display_parts = []
|
| 515 |
+
for query, weight in zip(text_queries, text_weights):
|
| 516 |
+
op_display = weight_to_display(weight)
|
| 517 |
+
query_display_parts.append(html.Span(f"{op_display} text:'{query}' ", style={"margin-right": "8px"}))
|
| 518 |
+
|
| 519 |
+
for img_query, weight in zip(image_queries, image_weights):
|
| 520 |
+
op_display = weight_to_display(weight)
|
| 521 |
+
# Generate thumbnail URL
|
| 522 |
+
from src.utils import cutout_url
|
| 523 |
+
thumbnail_url = cutout_url(
|
| 524 |
+
img_query['ra'],
|
| 525 |
+
img_query['dec'],
|
| 526 |
+
fov=img_query.get('fov', 0.025),
|
| 527 |
+
size=64
|
| 528 |
+
)
|
| 529 |
+
query_display_parts.append(html.Span([
|
| 530 |
+
f"{op_display} ",
|
| 531 |
+
html.Img(
|
| 532 |
+
src=thumbnail_url,
|
| 533 |
+
style={
|
| 534 |
+
"width": "128px",
|
| 535 |
+
"height": "128px",
|
| 536 |
+
"vertical-align": "middle",
|
| 537 |
+
"margin": "0 4px",
|
| 538 |
+
"border-radius": "4px",
|
| 539 |
+
"border": "1px solid rgba(255, 255, 255, 0.2)"
|
| 540 |
+
}
|
| 541 |
+
)
|
| 542 |
+
], style={"margin-right": "8px", "display": "inline-block"}))
|
| 543 |
+
|
| 544 |
+
# Build filter description
|
| 545 |
+
filter_desc = ""
|
| 546 |
+
if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0):
|
| 547 |
+
filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]"
|
| 548 |
+
|
| 549 |
+
# Prepare data for store
|
| 550 |
+
search_data = prepare_search_data(df, query_description, is_vector_search=True)
|
| 551 |
+
search_data["text_queries"] = text_queries
|
| 552 |
+
search_data["text_weights"] = text_weights
|
| 553 |
+
search_data["image_queries"] = image_queries
|
| 554 |
+
search_data["image_weights"] = image_weights
|
| 555 |
+
|
| 556 |
+
# Create load more button
|
| 557 |
+
load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None
|
| 558 |
+
|
| 559 |
+
# Build results container
|
| 560 |
+
results_container = html.Div([
|
| 561 |
+
html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})",
|
| 562 |
+
className="results-header mb-2 text-center"),
|
| 563 |
+
html.P(
|
| 564 |
+
query_display_parts + ([f"{filter_desc}"] if filter_desc else []),
|
| 565 |
+
className="text-center mb-3",
|
| 566 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}
|
| 567 |
+
),
|
| 568 |
+
dbc.Row(grid_items, justify="center", id="search-results-grid"),
|
| 569 |
+
load_more_button
|
| 570 |
+
])
|
| 571 |
+
|
| 572 |
+
return "", results_container, search_data, False
|
| 573 |
+
|
| 574 |
+
except Exception as e:
|
| 575 |
+
error_msg = dbc.Alert(f"Advanced search failed: {str(e)}", color="danger")
|
| 576 |
+
logger.error(f"Advanced search error: {e}")
|
| 577 |
+
logger.error(f"Full traceback:\n{traceback.format_exc()}")
|
| 578 |
+
return "", error_msg, None, True
|
| 579 |
+
|
| 580 |
+
@app.callback(
|
| 581 |
+
Output("download-csv", "data"),
|
| 582 |
+
Input("download-button", "n_clicks"),
|
| 583 |
+
State("search-data", "data"),
|
| 584 |
+
prevent_initial_call=True
|
| 585 |
+
)
|
| 586 |
+
def download_csv(n_clicks, search_data):
|
| 587 |
+
"""Download search results as CSV."""
|
| 588 |
+
if n_clicks and search_data:
|
| 589 |
+
# Create DataFrame with the search results
|
| 590 |
+
df = pd.DataFrame({
|
| 591 |
+
ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY],
|
| 592 |
+
'ra': search_data['ra'],
|
| 593 |
+
'dec': search_data['dec'],
|
| 594 |
+
'r_mag': search_data['r_mag'],
|
| 595 |
+
'distance': search_data['distance'],
|
| 596 |
+
'cutout_url': search_data['cutout_url']
|
| 597 |
+
})
|
| 598 |
+
|
| 599 |
+
# Create CSV string
|
| 600 |
+
csv_string = df.to_csv(index=False)
|
| 601 |
+
|
| 602 |
+
# Return download data
|
| 603 |
+
return dict(content=csv_string, filename="galaxy_search_results.csv")
|
| 604 |
+
|
| 605 |
+
return dash.no_update
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
# Helper functions for callbacks
|
| 609 |
+
|
| 610 |
+
def build_galaxy_grid(df: pd.DataFrame) -> list:
|
| 611 |
+
"""Build galaxy grid items from DataFrame.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
df: DataFrame with galaxy data
|
| 615 |
+
|
| 616 |
+
Returns:
|
| 617 |
+
List of Dash components
|
| 618 |
+
"""
|
| 619 |
+
grid_items = []
|
| 620 |
+
for i, row in df.iterrows():
|
| 621 |
+
galaxy_info = {
|
| 622 |
+
ZILLIZ_PRIMARY_KEY: row[ZILLIZ_PRIMARY_KEY],
|
| 623 |
+
"ra": row['ra'],
|
| 624 |
+
"dec": row['dec'],
|
| 625 |
+
"distance": row['distance'],
|
| 626 |
+
"r_mag": row['r_mag'],
|
| 627 |
+
"cutout_url": row['cutout_url']
|
| 628 |
+
}
|
| 629 |
+
grid_item = build_galaxy_card(galaxy_info, i)
|
| 630 |
+
grid_items.append(grid_item)
|
| 631 |
+
return grid_items
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def build_galaxy_card(galaxy_info: dict, index: int):
|
| 635 |
+
"""Build a single galaxy card component.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
galaxy_info: Dictionary with galaxy information
|
| 639 |
+
index: Index of the galaxy in the results
|
| 640 |
+
|
| 641 |
+
Returns:
|
| 642 |
+
Dash Bootstrap Col component
|
| 643 |
+
"""
|
| 644 |
+
return dbc.Col([
|
| 645 |
+
html.Div([
|
| 646 |
+
html.Div([
|
| 647 |
+
html.Img(
|
| 648 |
+
src=galaxy_info["cutout_url"],
|
| 649 |
+
style={
|
| 650 |
+
"width": IMAGE_WIDTH,
|
| 651 |
+
"height": IMAGE_HEIGHT,
|
| 652 |
+
"object-fit": "cover",
|
| 653 |
+
"cursor": "pointer",
|
| 654 |
+
"border-radius": "8px"
|
| 655 |
+
},
|
| 656 |
+
id={"type": "galaxy-image", "index": index},
|
| 657 |
+
className="hover-shadow"
|
| 658 |
+
),
|
| 659 |
+
html.Div([
|
| 660 |
+
html.Small(f"r = {galaxy_info['r_mag']:.2f} mag", className="score-badge")
|
| 661 |
+
], style={
|
| 662 |
+
"position": "absolute",
|
| 663 |
+
"bottom": "8px",
|
| 664 |
+
"right": "8px"
|
| 665 |
+
})
|
| 666 |
+
], style={"position": "relative"})
|
| 667 |
+
])
|
| 668 |
+
], width=6, md=4, lg=2, className="mb-2 px-1")
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def prepare_search_data(df: pd.DataFrame, query: str, is_vector_search: bool = False) -> dict:
|
| 672 |
+
"""Prepare search data for storage.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
df: DataFrame with search results
|
| 676 |
+
query: Search query string
|
| 677 |
+
is_vector_search: Whether this is a vector search
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
Dictionary with search data
|
| 681 |
+
"""
|
| 682 |
+
return {
|
| 683 |
+
ZILLIZ_PRIMARY_KEY: df[ZILLIZ_PRIMARY_KEY].tolist(),
|
| 684 |
+
"ra": df['ra'].tolist(),
|
| 685 |
+
"dec": df['dec'].tolist(),
|
| 686 |
+
"distance": df['distance'].tolist(),
|
| 687 |
+
"r_mag": df['r_mag'].tolist(),
|
| 688 |
+
"cutout_url": df['cutout_url'].tolist(),
|
| 689 |
+
"loaded_count": DEFAULT_DISPLAY_COUNT,
|
| 690 |
+
"query": query,
|
| 691 |
+
"is_vector_search": is_vector_search
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def extract_galaxy_info(search_data: dict, index: int) -> dict:
|
| 696 |
+
"""Extract galaxy info from search data at given index.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
search_data: Dictionary with search data
|
| 700 |
+
index: Index of the galaxy
|
| 701 |
+
|
| 702 |
+
Returns:
|
| 703 |
+
Dictionary with galaxy information
|
| 704 |
+
"""
|
| 705 |
+
return {
|
| 706 |
+
ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY][index],
|
| 707 |
+
"ra": search_data["ra"][index],
|
| 708 |
+
"dec": search_data["dec"][index],
|
| 709 |
+
"distance": search_data["distance"][index],
|
| 710 |
+
"r_mag": search_data["r_mag"][index],
|
| 711 |
+
"cutout_url": search_data["cutout_url"][index]
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def build_modal_content(galaxy_info: dict) -> tuple:
|
| 716 |
+
"""Build modal image and description content.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
galaxy_info: Dictionary with galaxy information
|
| 720 |
+
|
| 721 |
+
Returns:
|
| 722 |
+
Tuple of (image_element, description_element)
|
| 723 |
+
"""
|
| 724 |
+
image_element = html.Img(
|
| 725 |
+
src=galaxy_info["cutout_url"],
|
| 726 |
+
style={"width": "100%", "max-width": "500px", "height": "auto"}
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# Format primary key label (convert snake_case to Title Case)
|
| 730 |
+
pk_label = ZILLIZ_PRIMARY_KEY.replace("_", " ").title()
|
| 731 |
+
|
| 732 |
+
description_element = html.Div([
|
| 733 |
+
html.Div([
|
| 734 |
+
html.Span(f"{pk_label}: {galaxy_info[ZILLIZ_PRIMARY_KEY]}", className="d-inline-block mb-0",
|
| 735 |
+
style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
|
| 736 |
+
], className="mb-2"),
|
| 737 |
+
html.Div([
|
| 738 |
+
html.Span(f"RA: {galaxy_info['ra']:.6f}", className="d-inline-block mb-0",
|
| 739 |
+
style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
|
| 740 |
+
html.Span(" β’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}),
|
| 741 |
+
html.Span(f"Dec: {galaxy_info['dec']:.6f}", className="d-inline-block mb-0",
|
| 742 |
+
style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
|
| 743 |
+
], className="mb-2"),
|
| 744 |
+
html.Div([
|
| 745 |
+
html.Span(f"r_mag: {galaxy_info['r_mag']:.2f}", className="d-inline-block mb-0",
|
| 746 |
+
style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
|
| 747 |
+
html.Span(" β’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}),
|
| 748 |
+
html.Span(f"Distance: {galaxy_info['distance']:.4f}", className="d-inline-block mb-0",
|
| 749 |
+
style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
|
| 750 |
+
], className="mb-3"),
|
| 751 |
+
])
|
| 752 |
+
|
| 753 |
+
return image_element, description_element
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def create_load_more_button(total_count: int, current_count: int):
|
| 757 |
+
"""Create a load more button.
|
| 758 |
+
|
| 759 |
+
Args:
|
| 760 |
+
total_count: Total number of results
|
| 761 |
+
current_count: Number of currently loaded results
|
| 762 |
+
|
| 763 |
+
Returns:
|
| 764 |
+
Dash Bootstrap Button component
|
| 765 |
+
"""
|
| 766 |
+
remaining = total_count - current_count
|
| 767 |
+
button_text = f"Load next {min(LOAD_MORE_COUNT, remaining)} galaxies"
|
| 768 |
+
|
| 769 |
+
return dbc.Button(
|
| 770 |
+
button_text,
|
| 771 |
+
id="load-more-button",
|
| 772 |
+
color="secondary",
|
| 773 |
+
className="mt-3",
|
| 774 |
+
style={"width": "100%"}
|
| 775 |
+
)
|
src/components.py
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI components for AION Search."""
|
| 2 |
+
|
| 3 |
+
from dash import dcc, html
|
| 4 |
+
import dash_bootstrap_components as dbc
|
| 5 |
+
from src.config import TOTAL_GALAXIES
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_app_theme() -> str:
|
| 9 |
+
"""Get the custom CSS theme for the app.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
HTML string with embedded CSS
|
| 13 |
+
"""
|
| 14 |
+
return '''
|
| 15 |
+
<!DOCTYPE html>
|
| 16 |
+
<html>
|
| 17 |
+
<head>
|
| 18 |
+
{%metas%}
|
| 19 |
+
<title>galaxy semantic search</title>
|
| 20 |
+
{%favicon%}
|
| 21 |
+
{%css%}
|
| 22 |
+
<style>
|
| 23 |
+
@import url('https://fonts.googleapis.com/css2?family=SF+Pro+Display:wght@200;300;400;500;600&display=swap');
|
| 24 |
+
|
| 25 |
+
* {
|
| 26 |
+
-webkit-font-smoothing: antialiased;
|
| 27 |
+
-moz-osx-font-smoothing: grayscale;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
body {
|
| 31 |
+
font-family: -apple-system, BlinkMacSystemFont, 'SF Pro Display', 'Inter', sans-serif;
|
| 32 |
+
background: #000000;
|
| 33 |
+
color: #F5F5F7;
|
| 34 |
+
min-height: 100vh;
|
| 35 |
+
margin: 0;
|
| 36 |
+
overflow-x: hidden;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
body::before {
|
| 40 |
+
content: '';
|
| 41 |
+
position: fixed;
|
| 42 |
+
top: -50%;
|
| 43 |
+
left: -50%;
|
| 44 |
+
width: 200%;
|
| 45 |
+
height: 200%;
|
| 46 |
+
background: radial-gradient(circle at 20% 80%, #1C1C1E 0%, transparent 50%),
|
| 47 |
+
radial-gradient(circle at 80% 20%, #161618 0%, transparent 50%),
|
| 48 |
+
radial-gradient(circle at 40% 40%, #0A0A0B 0%, transparent 50%);
|
| 49 |
+
z-index: -1;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
.container-fluid {
|
| 53 |
+
background-color: transparent !important;
|
| 54 |
+
padding-top: 2rem !important;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.hover-shadow {
|
| 58 |
+
transition: all 0.4s cubic-bezier(0.25, 0.46, 0.45, 0.94);
|
| 59 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 60 |
+
background: #0A0A0B;
|
| 61 |
+
overflow: hidden;
|
| 62 |
+
position: relative;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.hover-shadow::before {
|
| 66 |
+
content: '';
|
| 67 |
+
position: absolute;
|
| 68 |
+
top: 0;
|
| 69 |
+
left: 0;
|
| 70 |
+
right: 0;
|
| 71 |
+
bottom: 0;
|
| 72 |
+
background: linear-gradient(135deg, rgba(255,255,255,0.05) 0%, transparent 100%);
|
| 73 |
+
opacity: 0;
|
| 74 |
+
transition: opacity 0.4s ease;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
.hover-shadow:hover {
|
| 78 |
+
transform: translateY(-4px) scale(1.02);
|
| 79 |
+
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.8),
|
| 80 |
+
0 0 60px rgba(255, 255, 255, 0.05) !important;
|
| 81 |
+
border-color: rgba(255, 255, 255, 0.2);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.hover-shadow:hover::before {
|
| 85 |
+
opacity: 1;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
.search-container {
|
| 89 |
+
background: rgba(255, 255, 255, 0.05);
|
| 90 |
+
backdrop-filter: blur(40px) saturate(180%);
|
| 91 |
+
-webkit-backdrop-filter: blur(40px) saturate(180%);
|
| 92 |
+
border-radius: 16px;
|
| 93 |
+
padding: 1.25rem;
|
| 94 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 95 |
+
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4),
|
| 96 |
+
inset 0 1px 0 rgba(255, 255, 255, 0.1);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
.example-button {
|
| 100 |
+
background: #E5E5E7 !important;
|
| 101 |
+
background-color: #E5E5E7 !important;
|
| 102 |
+
border: 0.5px solid #D1D1D3 !important;
|
| 103 |
+
color: #1A1A1A !important;
|
| 104 |
+
font-weight: 500;
|
| 105 |
+
font-size: 0.75rem !important;
|
| 106 |
+
padding: 0.4rem 0.9rem !important;
|
| 107 |
+
transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
|
| 108 |
+
letter-spacing: 0.01em;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
.example-button:hover {
|
| 112 |
+
background: #F0F0F2 !important;
|
| 113 |
+
background-color: #F0F0F2 !important;
|
| 114 |
+
border-color: #C0C0C2 !important;
|
| 115 |
+
color: #000000 !important;
|
| 116 |
+
transform: translateY(-1px);
|
| 117 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
.example-button i {
|
| 121 |
+
color: #2A2A2A !important;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
.galaxy-title {
|
| 125 |
+
color: #F5F5F7;
|
| 126 |
+
font-weight: 200;
|
| 127 |
+
font-size: 1.75rem;
|
| 128 |
+
letter-spacing: -0.03em;
|
| 129 |
+
background: linear-gradient(180deg, #F5F5F7 0%, rgba(245, 245, 247, 0.6) 100%);
|
| 130 |
+
-webkit-background-clip: text;
|
| 131 |
+
-webkit-text-fill-color: transparent;
|
| 132 |
+
animation: float 6s ease-in-out infinite;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.modal-content {
|
| 136 |
+
background: #1C1C1E;
|
| 137 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 138 |
+
border-radius: 16px;
|
| 139 |
+
backdrop-filter: blur(20px);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.modal-header, .modal-footer {
|
| 143 |
+
border-color: rgba(255, 255, 255, 0.05);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
.form-control:focus, .form-control:active {
|
| 147 |
+
background-color: rgba(255, 255, 255, 0.05) !important;
|
| 148 |
+
border-color: rgba(255, 255, 255, 0.3) !important;
|
| 149 |
+
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.05) !important;
|
| 150 |
+
color: #F5F5F7 !important;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.form-control {
|
| 154 |
+
background-color: rgba(255, 255, 255, 0.03) !important;
|
| 155 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 156 |
+
color: #F5F5F7 !important;
|
| 157 |
+
font-size: 0.95rem !important;
|
| 158 |
+
font-weight: 300;
|
| 159 |
+
letter-spacing: 0.01em;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.form-control::placeholder {
|
| 163 |
+
color: rgba(245, 245, 247, 0.4) !important;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
.btn-primary {
|
| 167 |
+
background: rgba(255, 255, 255, 0.8);
|
| 168 |
+
color: #000;
|
| 169 |
+
border: none;
|
| 170 |
+
font-weight: 600;
|
| 171 |
+
font-size: 0.9rem;
|
| 172 |
+
padding: 0.6rem 1.8rem;
|
| 173 |
+
transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
|
| 174 |
+
letter-spacing: 0.02em;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.btn-primary:hover {
|
| 178 |
+
background: rgba(255, 255, 255, 0.95);
|
| 179 |
+
transform: translateY(-1px);
|
| 180 |
+
box-shadow: 0 8px 24px rgba(255, 255, 255, 0.15);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
.results-header {
|
| 184 |
+
color: rgba(245, 245, 247, 0.6);
|
| 185 |
+
font-weight: 300;
|
| 186 |
+
font-size: 0.85rem !important;
|
| 187 |
+
letter-spacing: 0.05em;
|
| 188 |
+
text-transform: uppercase;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.time-breakdown {
|
| 192 |
+
color: rgba(245, 245, 247, 0.4);
|
| 193 |
+
font-size: 0.7rem;
|
| 194 |
+
font-weight: 300;
|
| 195 |
+
letter-spacing: 0.02em;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
.galaxy-count {
|
| 199 |
+
color: rgba(245, 245, 247, 0.5);
|
| 200 |
+
font-weight: 300;
|
| 201 |
+
font-size: 0.85rem;
|
| 202 |
+
letter-spacing: 0.05em;
|
| 203 |
+
text-transform: uppercase;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
.score-badge {
|
| 207 |
+
background: rgba(255, 255, 255, 0.1);
|
| 208 |
+
backdrop-filter: blur(10px);
|
| 209 |
+
color: rgba(245, 245, 247, 0.9);
|
| 210 |
+
font-size: 0.65rem !important;
|
| 211 |
+
padding: 3px 8px !important;
|
| 212 |
+
border-radius: 6px;
|
| 213 |
+
font-weight: 500;
|
| 214 |
+
letter-spacing: 0.02em;
|
| 215 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
.info-button {
|
| 219 |
+
color: rgba(245, 245, 247, 0.5) !important;
|
| 220 |
+
font-size: 0.75rem !important;
|
| 221 |
+
opacity: 0.8;
|
| 222 |
+
transition: all 0.3s ease;
|
| 223 |
+
letter-spacing: 0.02em;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
.info-button:hover {
|
| 227 |
+
opacity: 1;
|
| 228 |
+
color: #F5F5F7 !important;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
::-webkit-scrollbar {
|
| 232 |
+
width: 8px;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
::-webkit-scrollbar-track {
|
| 236 |
+
background: rgba(255, 255, 255, 0.02);
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
::-webkit-scrollbar-thumb {
|
| 240 |
+
background: rgba(255, 255, 255, 0.1);
|
| 241 |
+
border-radius: 4px;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
::-webkit-scrollbar-thumb:hover {
|
| 245 |
+
background: rgba(255, 255, 255, 0.2);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
.btn-link {
|
| 249 |
+
text-decoration: none !important;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.input-group-text {
|
| 253 |
+
background: rgba(255, 255, 255, 0.03) !important;
|
| 254 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 255 |
+
color: rgba(245, 245, 247, 0.5) !important;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.spinner-border {
|
| 259 |
+
color: rgba(245, 245, 247, 0.5) !important;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
@supports (backdrop-filter: blur(40px)) {
|
| 263 |
+
.search-container {
|
| 264 |
+
background: rgba(255, 255, 255, 0.03);
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
@keyframes float {
|
| 269 |
+
0%, 100% { transform: translateY(0px); }
|
| 270 |
+
50% { transform: translateY(-3px); }
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
.download-button {
|
| 274 |
+
background: rgba(255, 255, 255, 0.05);
|
| 275 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 276 |
+
color: rgba(245, 245, 247, 0.6) !important;
|
| 277 |
+
font-size: 0.75rem !important;
|
| 278 |
+
padding: 0.4rem 0.8rem !important;
|
| 279 |
+
transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
|
| 280 |
+
letter-spacing: 0.01em;
|
| 281 |
+
margin-left: 0.5rem;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
.download-button:hover {
|
| 285 |
+
background: rgba(255, 255, 255, 0.08) !important;
|
| 286 |
+
border-color: rgba(255, 255, 255, 0.15) !important;
|
| 287 |
+
color: rgba(245, 245, 247, 0.8) !important;
|
| 288 |
+
transform: translateY(-1px);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
.download-button i {
|
| 292 |
+
color: rgba(245, 245, 247, 0.6) !important;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.download-button:hover i {
|
| 296 |
+
color: rgba(245, 245, 247, 0.8) !important;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
.refinement-toggle {
|
| 300 |
+
background: rgba(255, 255, 255, 0.03);
|
| 301 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 302 |
+
color: rgba(245, 245, 247, 0.5);
|
| 303 |
+
font-size: 0.7rem;
|
| 304 |
+
padding: 0.5rem 0.75rem;
|
| 305 |
+
transition: all 0.3s ease;
|
| 306 |
+
cursor: pointer;
|
| 307 |
+
display: flex;
|
| 308 |
+
align-items: center;
|
| 309 |
+
justify-content: center;
|
| 310 |
+
gap: 0.4rem;
|
| 311 |
+
margin: 0;
|
| 312 |
+
letter-spacing: 0.05em;
|
| 313 |
+
text-transform: uppercase;
|
| 314 |
+
border-radius: 8px;
|
| 315 |
+
height: 100%;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
.refinement-toggle:hover {
|
| 319 |
+
background: rgba(255, 255, 255, 0.05);
|
| 320 |
+
border-color: rgba(255, 255, 255, 0.15);
|
| 321 |
+
color: rgba(245, 245, 247, 0.8);
|
| 322 |
+
transform: translateY(-1px);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
.refinement-toggle i {
|
| 326 |
+
transition: transform 0.3s ease;
|
| 327 |
+
font-size: 0.65rem;
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
.refinement-toggle.expanded i {
|
| 331 |
+
transform: rotate(180deg);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
.refinement-container {
|
| 335 |
+
background: rgba(255, 255, 255, 0.03);
|
| 336 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1);
|
| 337 |
+
border-radius: 12px;
|
| 338 |
+
padding: 1.5rem;
|
| 339 |
+
margin-top: 1rem;
|
| 340 |
+
backdrop-filter: blur(20px);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.refinement-label {
|
| 344 |
+
color: rgba(245, 245, 247, 0.6);
|
| 345 |
+
font-size: 0.85rem;
|
| 346 |
+
font-weight: 300;
|
| 347 |
+
margin-bottom: 0.5rem;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
.vector-operation-select {
|
| 351 |
+
background-color: rgba(255, 255, 255, 0.03) !important;
|
| 352 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 353 |
+
color: #F5F5F7 !important;
|
| 354 |
+
font-size: 0.9rem !important;
|
| 355 |
+
font-weight: 500;
|
| 356 |
+
text-align: center;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
.vector-operation-select option {
|
| 360 |
+
background-color: #1C1C1E;
|
| 361 |
+
color: #F5F5F7;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
.vector-query-type-select {
|
| 365 |
+
background-color: rgba(255, 255, 255, 0.03) !important;
|
| 366 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 367 |
+
color: #F5F5F7 !important;
|
| 368 |
+
font-size: 0.9rem !important;
|
| 369 |
+
font-weight: 500;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
.vector-query-type-select option {
|
| 373 |
+
background-color: #1C1C1E;
|
| 374 |
+
color: #F5F5F7;
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
.btn-add-vector {
|
| 378 |
+
background: rgba(255, 255, 255, 0.05);
|
| 379 |
+
border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
|
| 380 |
+
color: rgba(245, 245, 247, 0.6) !important;
|
| 381 |
+
font-size: 0.8rem !important;
|
| 382 |
+
padding: 0.4rem 0.8rem !important;
|
| 383 |
+
transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
|
| 384 |
+
letter-spacing: 0.01em;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
.btn-add-vector:hover {
|
| 388 |
+
background: rgba(255, 255, 255, 0.08) !important;
|
| 389 |
+
border-color: rgba(255, 255, 255, 0.15) !important;
|
| 390 |
+
color: rgba(245, 245, 247, 0.8) !important;
|
| 391 |
+
transform: translateY(-1px);
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
.btn-add-vector i {
|
| 395 |
+
color: rgba(245, 245, 247, 0.6) !important;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.btn-add-vector:hover i {
|
| 399 |
+
color: rgba(245, 245, 247, 0.8) !important;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
.vector-delete-btn {
|
| 403 |
+
opacity: 0.5;
|
| 404 |
+
transition: opacity 0.2s ease;
|
| 405 |
+
padding: 0.25rem 0.5rem !important;
|
| 406 |
+
border: none !important;
|
| 407 |
+
background: none !important;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
.vector-delete-btn:hover {
|
| 411 |
+
opacity: 1;
|
| 412 |
+
background: rgba(220, 53, 69, 0.1) !important;
|
| 413 |
+
border-radius: 4px;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
.vector-delete-btn i {
|
| 417 |
+
font-size: 0.9rem;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/* Range slider styling */
|
| 421 |
+
.rmag-slider .rc-slider-rail {
|
| 422 |
+
background-color: rgba(255, 255, 255, 0.1);
|
| 423 |
+
height: 4px;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
.rmag-slider .rc-slider-track {
|
| 427 |
+
background: linear-gradient(90deg, rgba(255, 255, 255, 0.6) 0%, rgba(255, 255, 255, 0.8) 100%);
|
| 428 |
+
height: 4px;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
.rmag-slider .rc-slider-handle {
|
| 432 |
+
border: 2px solid rgba(255, 255, 255, 0.8);
|
| 433 |
+
background-color: #F5F5F7;
|
| 434 |
+
opacity: 1;
|
| 435 |
+
width: 16px;
|
| 436 |
+
height: 16px;
|
| 437 |
+
margin-top: -6px;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
.rmag-slider .rc-slider-handle:hover,
|
| 441 |
+
.rmag-slider .rc-slider-handle:active,
|
| 442 |
+
.rmag-slider .rc-slider-handle:focus {
|
| 443 |
+
border-color: rgba(255, 255, 255, 0.95);
|
| 444 |
+
box-shadow: 0 0 0 5px rgba(255, 255, 255, 0.1);
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
.rmag-slider .rc-slider-mark-text {
|
| 448 |
+
color: rgba(245, 245, 247, 0.5);
|
| 449 |
+
font-size: 0.75rem;
|
| 450 |
+
font-weight: 300;
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
.rmag-slider .rc-slider-tooltip-inner {
|
| 454 |
+
background-color: rgba(255, 255, 255, 0.9);
|
| 455 |
+
color: #000;
|
| 456 |
+
font-size: 0.75rem;
|
| 457 |
+
font-weight: 500;
|
| 458 |
+
padding: 4px 8px;
|
| 459 |
+
border-radius: 4px;
|
| 460 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
.rmag-slider .rc-slider-tooltip-arrow {
|
| 464 |
+
border-top-color: rgba(255, 255, 255, 0.9);
|
| 465 |
+
}
|
| 466 |
+
</style>
|
| 467 |
+
</head>
|
| 468 |
+
<body>
|
| 469 |
+
{%app_entry%}
|
| 470 |
+
<footer>
|
| 471 |
+
{%config%}
|
| 472 |
+
{%scripts%}
|
| 473 |
+
{%renderer%}
|
| 474 |
+
</footer>
|
| 475 |
+
</body>
|
| 476 |
+
</html>
|
| 477 |
+
'''
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def create_header():
|
| 481 |
+
"""Create the app header with title and galaxy count."""
|
| 482 |
+
return dbc.Row([
|
| 483 |
+
dbc.Col([
|
| 484 |
+
html.Div([
|
| 485 |
+
html.H1("galaxy semantic search", className="galaxy-title text-center mb-1"),
|
| 486 |
+
html.Div(id="galaxy-count", className="galaxy-count text-center")
|
| 487 |
+
], className="text-center mb-3")
|
| 488 |
+
])
|
| 489 |
+
])
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def create_rmag_filter_panel():
|
| 493 |
+
"""Create the r_mag filter panel."""
|
| 494 |
+
return dbc.Row([
|
| 495 |
+
dbc.Col([
|
| 496 |
+
html.Div([
|
| 497 |
+
dbc.Row([
|
| 498 |
+
dbc.Col([
|
| 499 |
+
html.Div("r-mag",
|
| 500 |
+
style={"color": "rgba(245, 245, 247, 0.6)",
|
| 501 |
+
"font-size": "0.85rem",
|
| 502 |
+
"font-weight": "300",
|
| 503 |
+
"text-align": "center"})
|
| 504 |
+
], width=1, className="d-flex align-items-center justify-content-center", style={"padding": "0"}),
|
| 505 |
+
dbc.Col([
|
| 506 |
+
dcc.RangeSlider(
|
| 507 |
+
id="rmag-slider",
|
| 508 |
+
min=13.0,
|
| 509 |
+
max=20.0,
|
| 510 |
+
step=0.1,
|
| 511 |
+
value=[13.0, 20.0],
|
| 512 |
+
marks={13: '13', 15: '15', 17: '17', 19: '19', 20: '20'},
|
| 513 |
+
tooltip={"placement": "bottom", "always_visible": True},
|
| 514 |
+
className="rmag-slider"
|
| 515 |
+
)
|
| 516 |
+
], width=11)
|
| 517 |
+
], className="align-items-center", style={"margin": "0"})
|
| 518 |
+
], style={"padding": "0.5rem 0"})
|
| 519 |
+
], width=12)
|
| 520 |
+
], className="mt-2")
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def create_search_container():
|
| 524 |
+
"""Create the main search input container with examples and input."""
|
| 525 |
+
return dbc.Row([
|
| 526 |
+
dbc.Col([
|
| 527 |
+
html.Div([
|
| 528 |
+
# Info button in top right
|
| 529 |
+
html.Div([
|
| 530 |
+
dbc.Button([
|
| 531 |
+
html.I(className="fas fa-info-circle")
|
| 532 |
+
], id="info-button", color="link", size="sm",
|
| 533 |
+
className="info-button")
|
| 534 |
+
], style={"position": "absolute", "top": "8px", "right": "8px", "z-index": "1000"}),
|
| 535 |
+
|
| 536 |
+
# Example search buttons
|
| 537 |
+
html.Div([
|
| 538 |
+
html.P("Try these examples:", className="text-center mb-2",
|
| 539 |
+
style={"color": "rgba(245, 245, 247, 0.5)", "font-weight": "300",
|
| 540 |
+
"font-size": "0.75rem", "letter-spacing": "0.02em"}),
|
| 541 |
+
html.Div([
|
| 542 |
+
dbc.Button([html.I(className="fas fa-satellite me-2"), "Merging edge-on galaxy"],
|
| 543 |
+
id="example-1", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 544 |
+
dbc.Button([html.I(className="fas fa-water me-2"), "Tidal"],
|
| 545 |
+
id="example-2", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 546 |
+
dbc.Button([html.I(className="fas fa-stream me-2"), "Stream"],
|
| 547 |
+
id="example-3", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 548 |
+
dbc.Button([html.I(className="fas fa-glasses me-2"), "Gravitational lens"],
|
| 549 |
+
id="example-4", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 550 |
+
dbc.Button([html.I(className="fas fa-explosion me-2"), "A violent merger"],
|
| 551 |
+
id="example-5", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 552 |
+
dbc.Button([html.I(className="fas fa-moon me-2"), "Low surface brightness"],
|
| 553 |
+
id="example-6", className="example-button me-2 mb-2", size="sm", color="light"),
|
| 554 |
+
dbc.Button([html.I(className="fas fa-ring me-2"), "Ring galaxy"],
|
| 555 |
+
id="example-7", className="example-button mb-2", size="sm", color="light")
|
| 556 |
+
], className="text-center")
|
| 557 |
+
], className="mb-3"),
|
| 558 |
+
|
| 559 |
+
# Search input
|
| 560 |
+
dbc.InputGroup([
|
| 561 |
+
dbc.InputGroupText(html.I(className="fas fa-search")),
|
| 562 |
+
dbc.Input(
|
| 563 |
+
id="search-input",
|
| 564 |
+
placeholder="Describe the galaxy you're looking for...",
|
| 565 |
+
type="text",
|
| 566 |
+
n_submit=0
|
| 567 |
+
),
|
| 568 |
+
dbc.Button("Search",
|
| 569 |
+
id="search-button", color="primary", n_clicks=0),
|
| 570 |
+
dbc.Button([
|
| 571 |
+
html.I(className="fas fa-download")
|
| 572 |
+
], id="download-button", color="secondary", n_clicks=0,
|
| 573 |
+
className="download-button", size="sm",
|
| 574 |
+
disabled=True)
|
| 575 |
+
])
|
| 576 |
+
], className="search-container", style={"position": "relative"}),
|
| 577 |
+
|
| 578 |
+
# r_mag filter
|
| 579 |
+
create_rmag_filter_panel(),
|
| 580 |
+
|
| 581 |
+
# Vector Addition toggle button
|
| 582 |
+
dbc.Row([
|
| 583 |
+
dbc.Col([
|
| 584 |
+
html.Button([
|
| 585 |
+
html.I(className="fas fa-chevron-down", id="vector-arrow"),
|
| 586 |
+
"Advanced Search (Vector Addition / Images)"
|
| 587 |
+
], id="vector-toggle", className="refinement-toggle w-100")
|
| 588 |
+
], width=12)
|
| 589 |
+
], className="mt-3"),
|
| 590 |
+
|
| 591 |
+
# Vector Addition UI - Collapsible section
|
| 592 |
+
create_vector_addition_panel()
|
| 593 |
+
], width=12, lg=11, className="mx-auto")
|
| 594 |
+
], className="mb-3")
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def create_vector_addition_panel():
|
| 598 |
+
"""Create the advanced search (vector addition) collapsible panel."""
|
| 599 |
+
return dbc.Collapse([
|
| 600 |
+
html.Div([
|
| 601 |
+
html.P("Advanced Search: Combine multiple text and/or image queries using vector addition/subtraction:", className="refinement-label"),
|
| 602 |
+
html.Div(id="vector-inputs", children=[
|
| 603 |
+
# Initial input
|
| 604 |
+
create_vector_input_row(0)
|
| 605 |
+
]),
|
| 606 |
+
dbc.Row([
|
| 607 |
+
dbc.Col([
|
| 608 |
+
dbc.Button(
|
| 609 |
+
[html.I(className="fas fa-plus me-2"), "Add Query"],
|
| 610 |
+
id="add-vector-input",
|
| 611 |
+
color="secondary",
|
| 612 |
+
size="sm",
|
| 613 |
+
className="me-2 btn-add-vector"
|
| 614 |
+
)
|
| 615 |
+
], width=6),
|
| 616 |
+
dbc.Col([
|
| 617 |
+
dbc.Button(
|
| 618 |
+
"Advanced Search",
|
| 619 |
+
id="vector-search-button",
|
| 620 |
+
className="btn-primary w-100",
|
| 621 |
+
n_clicks=0
|
| 622 |
+
)
|
| 623 |
+
], width=6)
|
| 624 |
+
], className="mt-3")
|
| 625 |
+
], className="refinement-container")
|
| 626 |
+
], id="vector-collapse", is_open=False)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def create_vector_input_row(index: int, query_type: str = "text", ra: float = None, dec: float = None, fov: float = 0.025):
|
| 630 |
+
"""Create a single vector input row with operation selector, query type toggle, and conditional inputs.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
index: Index of the vector input row
|
| 634 |
+
query_type: Type of query - "text" or "image" (default: "text")
|
| 635 |
+
ra: Initial RA value for image queries (default: None)
|
| 636 |
+
dec: Initial Dec value for image queries (default: None)
|
| 637 |
+
fov: Initial FoV value for image queries (default: 0.025)
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
Dash Bootstrap Row component with text/image mode toggle
|
| 641 |
+
"""
|
| 642 |
+
# Determine display styles based on query type
|
| 643 |
+
text_display = {"display": "block"} if query_type == "text" else {"display": "none"}
|
| 644 |
+
image_display = {"display": "none"} if query_type == "text" else {"display": "block"}
|
| 645 |
+
|
| 646 |
+
return dbc.Row([
|
| 647 |
+
# Operation column with magnitude support
|
| 648 |
+
dbc.Col([
|
| 649 |
+
dbc.Select(
|
| 650 |
+
id={"type": "vector-operation", "index": index},
|
| 651 |
+
options=[
|
| 652 |
+
{"label": "+10", "value": "+10"},
|
| 653 |
+
{"label": "+5", "value": "+5"},
|
| 654 |
+
{"label": "+2", "value": "+2"},
|
| 655 |
+
{"label": "+", "value": "+"},
|
| 656 |
+
{"label": "-", "value": "-"},
|
| 657 |
+
{"label": "-2", "value": "-2"},
|
| 658 |
+
{"label": "-5", "value": "-5"},
|
| 659 |
+
{"label": "-10", "value": "-10"}
|
| 660 |
+
],
|
| 661 |
+
value="+",
|
| 662 |
+
style={"width": "70px"},
|
| 663 |
+
className="d-inline-block vector-operation-select"
|
| 664 |
+
)
|
| 665 |
+
], width=1),
|
| 666 |
+
# Query type toggle (Text/Image)
|
| 667 |
+
dbc.Col([
|
| 668 |
+
dbc.Select(
|
| 669 |
+
id={"type": "vector-query-type", "index": index},
|
| 670 |
+
options=[
|
| 671 |
+
{"label": "Text", "value": "text"},
|
| 672 |
+
{"label": "Image", "value": "image"}
|
| 673 |
+
],
|
| 674 |
+
value=query_type,
|
| 675 |
+
style={"width": "100px"},
|
| 676 |
+
className="d-inline-block vector-query-type-select"
|
| 677 |
+
)
|
| 678 |
+
], width=2),
|
| 679 |
+
# Input area (text or image fields)
|
| 680 |
+
dbc.Col([
|
| 681 |
+
# Text input (shown when type is "text")
|
| 682 |
+
html.Div([
|
| 683 |
+
dbc.Input(
|
| 684 |
+
id={"type": "vector-text", "index": index},
|
| 685 |
+
placeholder="Enter text query...",
|
| 686 |
+
type="text"
|
| 687 |
+
)
|
| 688 |
+
], id={"type": "text-input-container", "index": index}, style=text_display),
|
| 689 |
+
# Image inputs (shown when type is "image")
|
| 690 |
+
html.Div([
|
| 691 |
+
dbc.Row([
|
| 692 |
+
dbc.Col([
|
| 693 |
+
dbc.Input(
|
| 694 |
+
id={"type": "vector-ra", "index": index},
|
| 695 |
+
placeholder="ra:",
|
| 696 |
+
type="number",
|
| 697 |
+
step="any",
|
| 698 |
+
value=ra
|
| 699 |
+
)
|
| 700 |
+
], width=4),
|
| 701 |
+
dbc.Col([
|
| 702 |
+
dbc.Input(
|
| 703 |
+
id={"type": "vector-dec", "index": index},
|
| 704 |
+
placeholder="dec:",
|
| 705 |
+
type="number",
|
| 706 |
+
step="any",
|
| 707 |
+
value=dec
|
| 708 |
+
)
|
| 709 |
+
], width=4),
|
| 710 |
+
dbc.Col([
|
| 711 |
+
dbc.Input(
|
| 712 |
+
id={"type": "vector-fov", "index": index},
|
| 713 |
+
placeholder="fov:",
|
| 714 |
+
type="number",
|
| 715 |
+
value=fov,
|
| 716 |
+
step="any"
|
| 717 |
+
)
|
| 718 |
+
], width=4)
|
| 719 |
+
])
|
| 720 |
+
], id={"type": "image-input-container", "index": index}, style=image_display)
|
| 721 |
+
], width=8),
|
| 722 |
+
# Delete button
|
| 723 |
+
dbc.Col([
|
| 724 |
+
dbc.Button(
|
| 725 |
+
html.I(className="fas fa-times"),
|
| 726 |
+
id={"type": "vector-delete", "index": index},
|
| 727 |
+
color="link",
|
| 728 |
+
size="sm",
|
| 729 |
+
className="text-danger vector-delete-btn",
|
| 730 |
+
style={"padding": "0.25rem 0.5rem"}
|
| 731 |
+
)
|
| 732 |
+
], width=1, className="d-flex align-items-center justify-content-end")
|
| 733 |
+
], className="mb-2", id={"type": "vector-row", "index": index})
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
def create_results_container():
|
| 737 |
+
"""Create the search results display container."""
|
| 738 |
+
return dbc.Row([
|
| 739 |
+
dbc.Col([
|
| 740 |
+
html.Div(id="search-time", className="time-breakdown text-center mb-2"),
|
| 741 |
+
html.Div(id="search-results")
|
| 742 |
+
])
|
| 743 |
+
])
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def create_stores():
|
| 747 |
+
"""Create Dash Store components for data persistence."""
|
| 748 |
+
return [
|
| 749 |
+
dcc.Store(id="search-data"),
|
| 750 |
+
dcc.Store(id="current-galaxy-data"),
|
| 751 |
+
dcc.Store(id="vector-inputs-count", data=1),
|
| 752 |
+
dcc.Download(id="download-csv")
|
| 753 |
+
]
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def create_galaxy_modal():
|
| 757 |
+
"""Create the modal for displaying galaxy details."""
|
| 758 |
+
return dbc.Modal([
|
| 759 |
+
dbc.ModalHeader(dbc.ModalTitle(id="modal-title")),
|
| 760 |
+
dbc.ModalBody([
|
| 761 |
+
html.Div(id="modal-image", className="text-center mb-3"),
|
| 762 |
+
html.Div(id="modal-description")
|
| 763 |
+
]),
|
| 764 |
+
dbc.ModalFooter([
|
| 765 |
+
dbc.Button(
|
| 766 |
+
[html.I(className="fas fa-plus-circle me-2"), "Add to Advanced Search"],
|
| 767 |
+
id="add-to-advanced-search",
|
| 768 |
+
color="primary",
|
| 769 |
+
className="me-2"
|
| 770 |
+
),
|
| 771 |
+
dbc.Button("Close", id="close-modal", className="ms-auto")
|
| 772 |
+
])
|
| 773 |
+
], id="galaxy-modal", size="lg", is_open=False)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def create_info_modal():
|
| 777 |
+
"""Create the info modal explaining the app."""
|
| 778 |
+
return dbc.Modal([
|
| 779 |
+
dbc.ModalHeader(dbc.ModalTitle([html.I(className="fas fa-info-circle me-2"), "About Galaxy Search"])),
|
| 780 |
+
dbc.ModalBody([
|
| 781 |
+
html.P("This app performs semantic search over galaxy images using CLIP embeddings and BigQuery.",
|
| 782 |
+
style={"color": "rgba(245, 245, 247, 0.8)", "margin-bottom": "1rem", "font-size": "0.9rem"}),
|
| 783 |
+
html.Div([
|
| 784 |
+
html.P("The search uses contrastive language-image pre-training (CLIP) to match text descriptions with galaxy images. "
|
| 785 |
+
"The model was trained on galaxy descriptions and can understand various astronomical features and characteristics.",
|
| 786 |
+
style={"margin-bottom": "1rem", "color": "rgba(245, 245, 247, 0.7)"}),
|
| 787 |
+
|
| 788 |
+
html.H6("Search Tips:", style={"color": "#F5F5F7", "font-weight": "500", "margin-bottom": "0.5rem"}),
|
| 789 |
+
html.Ul([
|
| 790 |
+
html.Li("Describe morphological features (spiral, elliptical, irregular, merging)",
|
| 791 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
|
| 792 |
+
html.Li("Mention specific features (tidal tails, dust lanes, star-forming regions)",
|
| 793 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
|
| 794 |
+
html.Li("Use color descriptions or brightness characteristics",
|
| 795 |
+
style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
|
| 796 |
+
html.Li("Combine multiple features for more specific results",
|
| 797 |
+
style={"color": "rgba(245, 245, 247, 0.6)"}),
|
| 798 |
+
], style={"margin-left": "1rem"}),
|
| 799 |
+
], style={"background": "rgba(255, 255, 255, 0.05)", "padding": "1.5rem", "border-radius": "12px",
|
| 800 |
+
"border": "0.5px solid rgba(255, 255, 255, 0.1)", "color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"})
|
| 801 |
+
]),
|
| 802 |
+
dbc.ModalFooter(
|
| 803 |
+
dbc.Button("Close", id="close-info-modal", className="ms-auto")
|
| 804 |
+
)
|
| 805 |
+
], id="info-modal", size="lg", is_open=False)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def create_layout():
|
| 809 |
+
"""Create the complete app layout.
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
Dash Container with the full app layout
|
| 813 |
+
"""
|
| 814 |
+
return dbc.Container([
|
| 815 |
+
create_header(),
|
| 816 |
+
create_search_container(),
|
| 817 |
+
create_results_container(),
|
| 818 |
+
*create_stores(),
|
| 819 |
+
create_galaxy_modal(),
|
| 820 |
+
create_info_modal()
|
| 821 |
+
], fluid=True, className="py-2")
|
src/config.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration settings, environment variables, and constants."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
# Environment Variables
|
| 9 |
+
ZILLIZ_BEARER = os.getenv("ZILLIZ_BEARER")
|
| 10 |
+
ZILLIZ_ENDPOINT = os.getenv("ZILLIZ_ENDPOINT")
|
| 11 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 12 |
+
|
| 13 |
+
# App Constants
|
| 14 |
+
# Note: TOTAL_GALAXIES is dynamically updated from Zilliz at startup (see app.py)
|
| 15 |
+
# This is just a fallback default value
|
| 16 |
+
TOTAL_GALAXIES = 0
|
| 17 |
+
DEFAULT_TOP_K = 300
|
| 18 |
+
DEFAULT_DISPLAY_COUNT = 60
|
| 19 |
+
LOAD_MORE_COUNT = 120
|
| 20 |
+
|
| 21 |
+
# Zilliz Configuration
|
| 22 |
+
ZILLIZ_COLLECTION_NAME = "aionsearch"
|
| 23 |
+
# Image search always uses legacy collection which has pre-existing embeddings
|
| 24 |
+
ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME = ZILLIZ_COLLECTION_NAME
|
| 25 |
+
|
| 26 |
+
# Collection-specific configurations
|
| 27 |
+
COLLECTION_CONFIGS = {
|
| 28 |
+
"legacy_5": {
|
| 29 |
+
"anns_field": "aion_search_embedding",
|
| 30 |
+
"primary_key": "object_id",
|
| 31 |
+
"output_fields": ["object_id", "ra", "dec", "r_mag"]
|
| 32 |
+
},
|
| 33 |
+
"aionsearch": {
|
| 34 |
+
"anns_field": "clip_embedding",
|
| 35 |
+
"primary_key": "ra_dec",
|
| 36 |
+
"output_fields": ["ra_dec", "ra", "dec", "r_mag"]
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Get configuration for the selected collection
|
| 41 |
+
_collection_config = COLLECTION_CONFIGS.get(ZILLIZ_COLLECTION_NAME, COLLECTION_CONFIGS[ZILLIZ_COLLECTION_NAME])
|
| 42 |
+
ZILLIZ_ANNS_FIELD = _collection_config["anns_field"]
|
| 43 |
+
ZILLIZ_PRIMARY_KEY = _collection_config["primary_key"]
|
| 44 |
+
ZILLIZ_OUTPUT_FIELDS = _collection_config["output_fields"]
|
| 45 |
+
|
| 46 |
+
# OpenAI Configuration
|
| 47 |
+
OPENAI_EMBEDDING_MODEL = "text-embedding-3-large"
|
| 48 |
+
|
| 49 |
+
# CLIP Model Configuration
|
| 50 |
+
CLIP_EMBEDDING_DIM = 1024
|
| 51 |
+
CLIP_NORMALIZE_EPS = 1e-3
|
| 52 |
+
|
| 53 |
+
# UI Configuration
|
| 54 |
+
IMAGE_HEIGHT = "160px"
|
| 55 |
+
IMAGE_WIDTH = "100%"
|
| 56 |
+
CUTOUT_FOV = 0.025
|
| 57 |
+
CUTOUT_SIZE = 256
|
| 58 |
+
|
| 59 |
+
# Logging Configuration
|
| 60 |
+
VCU_COST_PER_MILLION = 4.0 # $4 per 1 million vCU
|
| 61 |
+
|
| 62 |
+
# Feature Flags (for future features)
|
| 63 |
+
FEATURE_IMAGE_SEARCH = False
|
| 64 |
+
FEATURE_AUTH = False
|
| 65 |
+
FEATURE_CACHE = False
|
| 66 |
+
FEATURE_RERANKING = False
|
| 67 |
+
FEATURE_TRACKING = True
|
| 68 |
+
FEATURE_VECTOR_ADDITION = True
|
src/services.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backend services for AION Search."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import requests
|
| 10 |
+
from typing import List
|
| 11 |
+
from openai import OpenAI
|
| 12 |
+
|
| 13 |
+
from src.config import (
|
| 14 |
+
ZILLIZ_BEARER,
|
| 15 |
+
ZILLIZ_ENDPOINT,
|
| 16 |
+
ZILLIZ_COLLECTION_NAME,
|
| 17 |
+
ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
|
| 18 |
+
ZILLIZ_ANNS_FIELD,
|
| 19 |
+
ZILLIZ_PRIMARY_KEY,
|
| 20 |
+
ZILLIZ_OUTPUT_FIELDS,
|
| 21 |
+
COLLECTION_CONFIGS,
|
| 22 |
+
OPENAI_API_KEY,
|
| 23 |
+
OPENAI_EMBEDDING_MODEL,
|
| 24 |
+
CLIP_NORMALIZE_EPS,
|
| 25 |
+
DEFAULT_TOP_K,
|
| 26 |
+
)
|
| 27 |
+
from src.utils import cutout_url, log_zilliz_query
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CLIPModelService:
|
| 33 |
+
"""Service for managing CLIP model loading and inference."""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.model = None
|
| 37 |
+
self.device = None
|
| 38 |
+
self.loaded = False
|
| 39 |
+
|
| 40 |
+
def load_model(self, checkpoint_path: str) -> None:
|
| 41 |
+
"""Load the CLIP model from checkpoint.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
checkpoint_path: Path to the CLIP model checkpoint file
|
| 45 |
+
"""
|
| 46 |
+
logger.info(f"Loading CLIP model from {checkpoint_path}...")
|
| 47 |
+
|
| 48 |
+
from clip.models.clip_model import GalaxyClipModel
|
| 49 |
+
|
| 50 |
+
# Set device
|
| 51 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 52 |
+
|
| 53 |
+
# Load checkpoint
|
| 54 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 55 |
+
model_config = checkpoint['model_config']
|
| 56 |
+
|
| 57 |
+
# Initialize model with saved configuration
|
| 58 |
+
self.model = GalaxyClipModel(
|
| 59 |
+
image_input_dim=model_config['image_input_dim'],
|
| 60 |
+
text_input_dim=model_config['text_input_dim'],
|
| 61 |
+
embedding_dim=model_config['embedding_dim'],
|
| 62 |
+
use_mean_embeddings=model_config.get('use_mean_embeddings', True)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 66 |
+
self.model.to(self.device)
|
| 67 |
+
self.model.eval()
|
| 68 |
+
self.loaded = True
|
| 69 |
+
|
| 70 |
+
logger.info("CLIP model loaded successfully")
|
| 71 |
+
|
| 72 |
+
def encode_text(self, text_embedding: np.ndarray) -> np.ndarray:
|
| 73 |
+
"""Project text embedding through CLIP text projector.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
text_embedding: OpenAI text embedding (1536-dim)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
CLIP-projected embedding (1024-dim)
|
| 80 |
+
"""
|
| 81 |
+
if not self.loaded:
|
| 82 |
+
raise RuntimeError("CLIP model not loaded. Call load_model() first.")
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
text_tensor = torch.from_numpy(text_embedding).float().unsqueeze(0).to(self.device)
|
| 86 |
+
clip_features = self.model.text_projector(text_tensor)
|
| 87 |
+
# Normalize as per CLIP
|
| 88 |
+
clip_features = F.normalize(clip_features, dim=-1, eps=CLIP_NORMALIZE_EPS)
|
| 89 |
+
query_embedding = clip_features.cpu().numpy().squeeze(0)
|
| 90 |
+
|
| 91 |
+
return query_embedding
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ImageProcessingService:
|
| 95 |
+
"""Service for retrieving pre-existing image embeddings from Zilliz."""
|
| 96 |
+
|
| 97 |
+
def __init__(self):
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
def encode_image(self, ra: float, dec: float, fov: float = 0.025, size: int = 256) -> np.ndarray:
|
| 101 |
+
"""Query Zilliz for pre-existing embedding at the given coordinates.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
ra: Right ascension in degrees
|
| 105 |
+
dec: Declination in degrees
|
| 106 |
+
fov: Field of view in degrees (used to define search box)
|
| 107 |
+
size: Image size in pixels (unused, kept for API compatibility)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Pre-existing AION-Search embedding vector (1024-dim) from Zilliz
|
| 111 |
+
"""
|
| 112 |
+
logger.info(f"Querying Zilliz for pre-existing embedding at RA={ra}, Dec={dec}, FoV={fov}")
|
| 113 |
+
|
| 114 |
+
# Calculate bounding box based on field of view
|
| 115 |
+
ra_min = ra - fov/2
|
| 116 |
+
ra_max = ra + fov/2
|
| 117 |
+
dec_min = dec - fov/2
|
| 118 |
+
dec_max = dec + fov/2
|
| 119 |
+
|
| 120 |
+
# Build filter expression for coordinate range
|
| 121 |
+
filter_expr = f"ra > {ra_min} AND ra < {ra_max} AND dec > {dec_min} AND dec < {dec_max}"
|
| 122 |
+
|
| 123 |
+
# Get the ANNS field for the image search collection
|
| 124 |
+
image_search_config = COLLECTION_CONFIGS.get(ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME)
|
| 125 |
+
image_anns_field = image_search_config["anns_field"]
|
| 126 |
+
|
| 127 |
+
# Prepare query payload - always use the image search collection (legacy)
|
| 128 |
+
payload = {
|
| 129 |
+
"collectionName": ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
|
| 130 |
+
"filter": filter_expr,
|
| 131 |
+
"outputFields": [image_anns_field],
|
| 132 |
+
"limit": 1
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
headers = {
|
| 136 |
+
"Authorization": f"Bearer {ZILLIZ_BEARER}",
|
| 137 |
+
"Accept": "application/json",
|
| 138 |
+
"Content-Type": "application/json"
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Use query endpoint (replace /search with /query)
|
| 143 |
+
query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
|
| 144 |
+
response = requests.post(query_endpoint, json=payload, headers=headers)
|
| 145 |
+
response.raise_for_status()
|
| 146 |
+
|
| 147 |
+
result = response.json()
|
| 148 |
+
|
| 149 |
+
if result.get("code") == 0 and "data" in result:
|
| 150 |
+
data = result["data"]
|
| 151 |
+
if data and len(data) > 0:
|
| 152 |
+
# Extract the embedding from the first result using the image search ANNS field
|
| 153 |
+
embedding = data[0].get(image_anns_field)
|
| 154 |
+
if embedding:
|
| 155 |
+
embedding_array = np.array(embedding, dtype=np.float32)
|
| 156 |
+
logger.info(f"Retrieved pre-existing embedding with shape: {embedding_array.shape}")
|
| 157 |
+
return embedding_array
|
| 158 |
+
else:
|
| 159 |
+
logger.error(f"No embedding field found in result: {data[0].keys()}")
|
| 160 |
+
raise RuntimeError(f"No embedding found at coordinates RA={ra}, Dec={dec}")
|
| 161 |
+
else:
|
| 162 |
+
logger.error(f"No galaxies found at coordinates RA={ra}, Dec={dec} with FoV={fov}")
|
| 163 |
+
raise RuntimeError(f"No galaxies found at coordinates RA={ra}, Dec={dec}")
|
| 164 |
+
else:
|
| 165 |
+
logger.error(f"Zilliz query failed: {result}")
|
| 166 |
+
raise RuntimeError(f"Failed to query Zilliz: {result}")
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"Error querying Zilliz for embedding: {e}")
|
| 170 |
+
raise
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class EmbeddingService:
|
| 174 |
+
"""Service for encoding text queries into embeddings."""
|
| 175 |
+
|
| 176 |
+
def __init__(self, clip_service: CLIPModelService):
|
| 177 |
+
self.clip_service = clip_service
|
| 178 |
+
self.openai_client = None
|
| 179 |
+
|
| 180 |
+
def _get_openai_client(self) -> OpenAI:
|
| 181 |
+
"""Get or create OpenAI client."""
|
| 182 |
+
if self.openai_client is None:
|
| 183 |
+
if not OPENAI_API_KEY:
|
| 184 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
| 185 |
+
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 186 |
+
return self.openai_client
|
| 187 |
+
|
| 188 |
+
def encode_text_query(self, query: str) -> np.ndarray:
|
| 189 |
+
"""Encode text query using OpenAI embeddings + CLIP text projector.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
query: Text search query
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
CLIP embedding vector
|
| 196 |
+
"""
|
| 197 |
+
client = self._get_openai_client()
|
| 198 |
+
|
| 199 |
+
# Get OpenAI text embedding
|
| 200 |
+
response = client.embeddings.create(
|
| 201 |
+
input=query,
|
| 202 |
+
model=OPENAI_EMBEDDING_MODEL
|
| 203 |
+
)
|
| 204 |
+
text_embedding = np.array(response.data[0].embedding)
|
| 205 |
+
|
| 206 |
+
# Project through CLIP text projector
|
| 207 |
+
return self.clip_service.encode_text(text_embedding)
|
| 208 |
+
|
| 209 |
+
def encode_vector_queries(
|
| 210 |
+
self,
|
| 211 |
+
queries: List[str],
|
| 212 |
+
operations: List[str]
|
| 213 |
+
) -> np.ndarray:
|
| 214 |
+
"""Encode multiple text queries and combine them using vector addition/subtraction.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
queries: List of text queries
|
| 218 |
+
operations: List of operations ('+' or '-') for each query
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Combined normalized embedding vector
|
| 222 |
+
"""
|
| 223 |
+
client = self._get_openai_client()
|
| 224 |
+
|
| 225 |
+
# Get all embeddings at once for efficiency
|
| 226 |
+
response = client.embeddings.create(
|
| 227 |
+
input=queries,
|
| 228 |
+
model=OPENAI_EMBEDDING_MODEL
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Initialize combined embedding
|
| 232 |
+
combined_embedding = None
|
| 233 |
+
|
| 234 |
+
# Process each embedding with its operation
|
| 235 |
+
for embedding_data, operation in zip(response.data, operations):
|
| 236 |
+
text_embedding = np.array(embedding_data.embedding)
|
| 237 |
+
|
| 238 |
+
# Project through CLIP text projector
|
| 239 |
+
query_embedding = self.clip_service.encode_text(text_embedding)
|
| 240 |
+
|
| 241 |
+
# Apply operation
|
| 242 |
+
if combined_embedding is None:
|
| 243 |
+
combined_embedding = query_embedding if operation == "+" else -query_embedding
|
| 244 |
+
else:
|
| 245 |
+
if operation == "+":
|
| 246 |
+
combined_embedding += query_embedding
|
| 247 |
+
else:
|
| 248 |
+
combined_embedding -= query_embedding
|
| 249 |
+
|
| 250 |
+
# Normalize the final combined embedding
|
| 251 |
+
norm = np.linalg.norm(combined_embedding)
|
| 252 |
+
if norm > 0:
|
| 253 |
+
combined_embedding = combined_embedding / norm
|
| 254 |
+
|
| 255 |
+
return combined_embedding
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class ZillizService:
|
| 259 |
+
"""Service for interacting with Zilliz vector database."""
|
| 260 |
+
|
| 261 |
+
def get_collection_count(self) -> int:
|
| 262 |
+
"""Get the total number of entities in the collection.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Total count of entities in the collection
|
| 266 |
+
"""
|
| 267 |
+
logger.info("Getting collection count from Zilliz...")
|
| 268 |
+
|
| 269 |
+
# Use query endpoint with count to get total entities
|
| 270 |
+
payload = {
|
| 271 |
+
"collectionName": ZILLIZ_COLLECTION_NAME,
|
| 272 |
+
"filter": "", # Empty filter to count all entities
|
| 273 |
+
"outputFields": ["count(*)"]
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
headers = {
|
| 277 |
+
"Authorization": f"Bearer {ZILLIZ_BEARER}",
|
| 278 |
+
"Accept": "application/json",
|
| 279 |
+
"Content-Type": "application/json"
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
# Use the query endpoint (replace /search with /query in the endpoint)
|
| 284 |
+
query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
|
| 285 |
+
response = requests.post(query_endpoint, json=payload, headers=headers)
|
| 286 |
+
response.raise_for_status()
|
| 287 |
+
|
| 288 |
+
result = response.json()
|
| 289 |
+
|
| 290 |
+
if result.get("code") == 0 and "data" in result:
|
| 291 |
+
# The count should be in the response data
|
| 292 |
+
data = result["data"]
|
| 293 |
+
if data and len(data) > 0:
|
| 294 |
+
count = data[0].get("count(*)", 0)
|
| 295 |
+
logger.info(f"Collection count: {count:,}")
|
| 296 |
+
return count
|
| 297 |
+
else:
|
| 298 |
+
logger.error(f"Failed to get collection count: {result}")
|
| 299 |
+
return 0
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"Error getting collection count: {e}")
|
| 303 |
+
return 0
|
| 304 |
+
|
| 305 |
+
def search(self, query_embedding: np.ndarray, top_k: int = DEFAULT_TOP_K, filter_expr: str = None) -> pd.DataFrame:
|
| 306 |
+
"""Search Zilliz for top-k most similar galaxies.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
query_embedding: Query embedding vector
|
| 310 |
+
top_k: Number of results to return
|
| 311 |
+
filter_expr: Optional filter expression for filtering results
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
DataFrame with search results
|
| 315 |
+
"""
|
| 316 |
+
logger.info("Querying Zilliz...")
|
| 317 |
+
start_time = time.time()
|
| 318 |
+
|
| 319 |
+
# Prepare the search payload
|
| 320 |
+
payload = {
|
| 321 |
+
"collectionName": ZILLIZ_COLLECTION_NAME,
|
| 322 |
+
"data": [query_embedding.tolist()],
|
| 323 |
+
"annsField": ZILLIZ_ANNS_FIELD,
|
| 324 |
+
"limit": top_k,
|
| 325 |
+
"outputFields": ZILLIZ_OUTPUT_FIELDS
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Add filter if provided
|
| 329 |
+
if filter_expr:
|
| 330 |
+
payload["filter"] = filter_expr
|
| 331 |
+
logger.info(f"Applying filter: {filter_expr}")
|
| 332 |
+
|
| 333 |
+
headers = {
|
| 334 |
+
"Authorization": f"Bearer {ZILLIZ_BEARER}",
|
| 335 |
+
"Accept": "application/json",
|
| 336 |
+
"Content-Type": "application/json"
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
try:
|
| 340 |
+
response = requests.post(ZILLIZ_ENDPOINT, json=payload, headers=headers)
|
| 341 |
+
response.raise_for_status()
|
| 342 |
+
|
| 343 |
+
result = response.json()
|
| 344 |
+
|
| 345 |
+
if result.get("code") == 0 and "data" in result:
|
| 346 |
+
# Extract cost from response
|
| 347 |
+
cost_vcu = result.get("cost", 0)
|
| 348 |
+
|
| 349 |
+
# Convert to DataFrame
|
| 350 |
+
data_list = result["data"]
|
| 351 |
+
df = pd.DataFrame(data_list)
|
| 352 |
+
|
| 353 |
+
# Add cutout URLs
|
| 354 |
+
if not df.empty:
|
| 355 |
+
df["cutout_url"] = [cutout_url(ra, dec) for ra, dec in zip(df["ra"], df["dec"])]
|
| 356 |
+
|
| 357 |
+
query_time = time.time() - start_time
|
| 358 |
+
|
| 359 |
+
# Log the query
|
| 360 |
+
log_zilliz_query(
|
| 361 |
+
query_type="vector_search",
|
| 362 |
+
query_info={
|
| 363 |
+
"top_k": top_k,
|
| 364 |
+
"embedding_dim": len(query_embedding)
|
| 365 |
+
},
|
| 366 |
+
result_count=len(df),
|
| 367 |
+
query_time=query_time,
|
| 368 |
+
cost_vcu=cost_vcu
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
return df
|
| 372 |
+
else:
|
| 373 |
+
logger.error(f"Zilliz search failed: {result}")
|
| 374 |
+
return pd.DataFrame()
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"Zilliz search error: {e}")
|
| 378 |
+
return pd.DataFrame()
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class SearchService:
|
| 382 |
+
"""High-level search orchestration service."""
|
| 383 |
+
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
embedding_service: EmbeddingService,
|
| 387 |
+
zilliz_service: ZillizService,
|
| 388 |
+
image_service: 'ImageProcessingService' = None
|
| 389 |
+
):
|
| 390 |
+
self.embedding_service = embedding_service
|
| 391 |
+
self.zilliz_service = zilliz_service
|
| 392 |
+
self.image_service = image_service
|
| 393 |
+
|
| 394 |
+
def _build_rmag_filter(self, rmag_min=None, rmag_max=None) -> str:
|
| 395 |
+
"""Build r_mag filter expression.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
rmag_min: Minimum r_mag value (inclusive)
|
| 399 |
+
rmag_max: Maximum r_mag value (inclusive)
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
Filter expression string, or None if no filter
|
| 403 |
+
"""
|
| 404 |
+
filter_parts = []
|
| 405 |
+
|
| 406 |
+
if rmag_min is not None:
|
| 407 |
+
filter_parts.append(f"r_mag >= {rmag_min}")
|
| 408 |
+
|
| 409 |
+
if rmag_max is not None:
|
| 410 |
+
filter_parts.append(f"r_mag <= {rmag_max}")
|
| 411 |
+
|
| 412 |
+
if filter_parts:
|
| 413 |
+
return " AND ".join(filter_parts)
|
| 414 |
+
|
| 415 |
+
return None
|
| 416 |
+
|
| 417 |
+
def search_text(self, query: str, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None) -> pd.DataFrame:
|
| 418 |
+
"""Search galaxies using text query.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
query: Text search query
|
| 422 |
+
top_k: Number of results to return
|
| 423 |
+
rmag_min: Minimum r_mag value (inclusive)
|
| 424 |
+
rmag_max: Maximum r_mag value (inclusive)
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
DataFrame with search results
|
| 428 |
+
"""
|
| 429 |
+
# Encode query
|
| 430 |
+
query_embedding = self.embedding_service.encode_text_query(query)
|
| 431 |
+
|
| 432 |
+
# Build filter
|
| 433 |
+
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
|
| 434 |
+
|
| 435 |
+
# Search Zilliz
|
| 436 |
+
return self.zilliz_service.search(query_embedding, top_k, filter_expr)
|
| 437 |
+
|
| 438 |
+
def search_vector(
|
| 439 |
+
self,
|
| 440 |
+
queries: List[str],
|
| 441 |
+
operations: List[str],
|
| 442 |
+
top_k: int = DEFAULT_TOP_K,
|
| 443 |
+
rmag_min=None,
|
| 444 |
+
rmag_max=None
|
| 445 |
+
) -> pd.DataFrame:
|
| 446 |
+
"""Search galaxies using vector addition/subtraction.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
queries: List of text queries
|
| 450 |
+
operations: List of operations ('+' or '-') for each query
|
| 451 |
+
top_k: Number of results to return
|
| 452 |
+
rmag_min: Minimum r_mag value (inclusive)
|
| 453 |
+
rmag_max: Maximum r_mag value (inclusive)
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
DataFrame with search results
|
| 457 |
+
"""
|
| 458 |
+
# Encode and combine vectors
|
| 459 |
+
combined_embedding = self.embedding_service.encode_vector_queries(queries, operations)
|
| 460 |
+
|
| 461 |
+
# Build filter
|
| 462 |
+
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
|
| 463 |
+
|
| 464 |
+
# Search Zilliz
|
| 465 |
+
return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
|
| 466 |
+
|
| 467 |
+
def search_advanced(
|
| 468 |
+
self,
|
| 469 |
+
text_queries: List[str] = None,
|
| 470 |
+
text_weights: List[float] = None,
|
| 471 |
+
image_queries: List[dict] = None,
|
| 472 |
+
image_weights: List[float] = None,
|
| 473 |
+
top_k: int = DEFAULT_TOP_K,
|
| 474 |
+
rmag_min=None,
|
| 475 |
+
rmag_max=None
|
| 476 |
+
) -> pd.DataFrame:
|
| 477 |
+
"""Search galaxies using advanced vector addition/subtraction with text and/or images.
|
| 478 |
+
|
| 479 |
+
Args:
|
| 480 |
+
text_queries: List of text query strings
|
| 481 |
+
text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
|
| 482 |
+
image_queries: List of dicts with 'ra', 'dec', 'fov' keys
|
| 483 |
+
image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
|
| 484 |
+
top_k: Number of results to return
|
| 485 |
+
rmag_min: Minimum r_mag value (inclusive)
|
| 486 |
+
rmag_max: Maximum r_mag value (inclusive)
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
DataFrame with search results
|
| 490 |
+
"""
|
| 491 |
+
combined_embedding = None
|
| 492 |
+
|
| 493 |
+
# Process text queries
|
| 494 |
+
if text_queries and len(text_queries) > 0:
|
| 495 |
+
for query, weight in zip(text_queries, text_weights):
|
| 496 |
+
query_embedding = self.embedding_service.encode_text_query(query)
|
| 497 |
+
|
| 498 |
+
# Apply weight
|
| 499 |
+
weighted_embedding = query_embedding * weight
|
| 500 |
+
|
| 501 |
+
if combined_embedding is None:
|
| 502 |
+
combined_embedding = weighted_embedding
|
| 503 |
+
else:
|
| 504 |
+
combined_embedding += weighted_embedding
|
| 505 |
+
|
| 506 |
+
# Process image queries
|
| 507 |
+
if image_queries and len(image_queries) > 0:
|
| 508 |
+
if self.image_service is None:
|
| 509 |
+
raise RuntimeError("Image service not initialized")
|
| 510 |
+
|
| 511 |
+
for img_query, weight in zip(image_queries, image_weights):
|
| 512 |
+
# Encode image
|
| 513 |
+
image_embedding = self.image_service.encode_image(
|
| 514 |
+
ra=img_query['ra'],
|
| 515 |
+
dec=img_query['dec'],
|
| 516 |
+
fov=img_query.get('fov', 0.025),
|
| 517 |
+
size=256
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
# Apply weight
|
| 521 |
+
weighted_embedding = image_embedding * weight
|
| 522 |
+
|
| 523 |
+
if combined_embedding is None:
|
| 524 |
+
combined_embedding = weighted_embedding
|
| 525 |
+
else:
|
| 526 |
+
combined_embedding += weighted_embedding
|
| 527 |
+
|
| 528 |
+
# Normalize the final combined embedding
|
| 529 |
+
if combined_embedding is not None:
|
| 530 |
+
norm = np.linalg.norm(combined_embedding)
|
| 531 |
+
if norm > 0:
|
| 532 |
+
combined_embedding = combined_embedding / norm
|
| 533 |
+
|
| 534 |
+
# Build filter
|
| 535 |
+
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
|
| 536 |
+
|
| 537 |
+
# Search Zilliz
|
| 538 |
+
return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
|
src/utils.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for AION Search."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from src.config import CUTOUT_FOV, CUTOUT_SIZE, VCU_COST_PER_MILLION
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def cutout_url(ra: float, dec: float, fov: float = CUTOUT_FOV, size: int = CUTOUT_SIZE) -> str:
|
| 14 |
+
"""Generate Legacy Survey cutout URL from RA/Dec coordinates.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
ra: Right Ascension in degrees
|
| 18 |
+
dec: Declination in degrees
|
| 19 |
+
fov: Field of view in degrees
|
| 20 |
+
size: Image size in pixels
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
URL string for the cutout image
|
| 24 |
+
"""
|
| 25 |
+
return (
|
| 26 |
+
f"https://alasky.cds.unistra.fr/hips-image-services/hips2fits"
|
| 27 |
+
f"?hips=CDS/P/DESI-Legacy-Surveys/DR10/color"
|
| 28 |
+
f"&ra={ra}&dec={dec}&fov={fov}&width={size}&height={size}&format=jpg"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_zilliz_query(
|
| 33 |
+
query_type: str,
|
| 34 |
+
query_info: Dict[str, Any],
|
| 35 |
+
result_count: int,
|
| 36 |
+
query_time: float,
|
| 37 |
+
cost_vcu: int = 0
|
| 38 |
+
) -> None:
|
| 39 |
+
"""Log Zilliz queries to a file in logs/ directory.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
query_type: Type of query (e.g., "vector_search", "text_search")
|
| 43 |
+
query_info: Dictionary containing query details
|
| 44 |
+
result_count: Number of results returned
|
| 45 |
+
query_time: Query execution time in seconds
|
| 46 |
+
cost_vcu: Cost in vCU units
|
| 47 |
+
"""
|
| 48 |
+
logs_dir = Path("logs")
|
| 49 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 52 |
+
log_file = logs_dir / f"zilliz_query_{timestamp}.json"
|
| 53 |
+
|
| 54 |
+
# Convert vCU cost to dollars
|
| 55 |
+
cost_usd = (cost_vcu / 1e6) * VCU_COST_PER_MILLION
|
| 56 |
+
|
| 57 |
+
log_data = {
|
| 58 |
+
"timestamp": datetime.now().isoformat(),
|
| 59 |
+
"query_type": query_type,
|
| 60 |
+
"query_info": query_info,
|
| 61 |
+
"result_count": result_count,
|
| 62 |
+
"query_time_seconds": query_time,
|
| 63 |
+
"cost_vCU": cost_vcu,
|
| 64 |
+
"cost_usd": cost_usd
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
with open(log_file, 'w') as f:
|
| 68 |
+
json.dump(log_data, f, indent=2)
|
| 69 |
+
|
| 70 |
+
logger.info(
|
| 71 |
+
f"Query logged to {log_file} | {result_count} results in {query_time:.3f}s | "
|
| 72 |
+
f"{cost_vcu} vCU (${cost_usd:.6f})"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def format_galaxy_count(count: int) -> str:
|
| 77 |
+
"""Format galaxy count with thousands separator.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
count: Number of galaxies
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Formatted string (e.g., "259,636 galaxies")
|
| 84 |
+
"""
|
| 85 |
+
return f"{count:,} galaxies"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build_query_xml(
|
| 89 |
+
text_queries: list = None,
|
| 90 |
+
text_weights: list = None,
|
| 91 |
+
image_queries: list = None,
|
| 92 |
+
image_weights: list = None,
|
| 93 |
+
rmag_min: float = None,
|
| 94 |
+
rmag_max: float = None
|
| 95 |
+
) -> str:
|
| 96 |
+
"""Build XML representation of a query according to aql.md specification.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
text_queries: List of text query strings
|
| 100 |
+
text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
|
| 101 |
+
image_queries: List of dicts with 'ra', 'dec', 'fov' keys
|
| 102 |
+
image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
|
| 103 |
+
rmag_min: Minimum r_mag filter value
|
| 104 |
+
rmag_max: Maximum r_mag filter value
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
XML string representation of the query
|
| 108 |
+
"""
|
| 109 |
+
xml_parts = ['<query>']
|
| 110 |
+
|
| 111 |
+
# Add text queries
|
| 112 |
+
if text_queries and len(text_queries) > 0:
|
| 113 |
+
xml_parts.append(' <text>')
|
| 114 |
+
for query, weight in zip(text_queries, text_weights):
|
| 115 |
+
xml_parts.append(' <term>')
|
| 116 |
+
xml_parts.append(f' <weight>{weight}</weight>')
|
| 117 |
+
xml_parts.append(f' <content>{query}</content>')
|
| 118 |
+
xml_parts.append(' </term>')
|
| 119 |
+
xml_parts.append(' </text>')
|
| 120 |
+
|
| 121 |
+
# Add image queries
|
| 122 |
+
if image_queries and len(image_queries) > 0:
|
| 123 |
+
xml_parts.append(' <image>')
|
| 124 |
+
for img_query, weight in zip(image_queries, image_weights):
|
| 125 |
+
xml_parts.append(' <reference>')
|
| 126 |
+
xml_parts.append(f' <ra>{img_query["ra"]}</ra>')
|
| 127 |
+
xml_parts.append(f' <dec>{img_query["dec"]}</dec>')
|
| 128 |
+
xml_parts.append(f' <fov>{img_query["fov"]}</fov>')
|
| 129 |
+
xml_parts.append(f' <weight>{weight}</weight>')
|
| 130 |
+
xml_parts.append(' </reference>')
|
| 131 |
+
xml_parts.append(' </image>')
|
| 132 |
+
|
| 133 |
+
# Add filters
|
| 134 |
+
if rmag_min is not None or rmag_max is not None:
|
| 135 |
+
xml_parts.append(' <filters>')
|
| 136 |
+
if rmag_min is not None and rmag_max is not None:
|
| 137 |
+
xml_parts.append(' <filter>')
|
| 138 |
+
xml_parts.append(' <column>r_mag</column>')
|
| 139 |
+
xml_parts.append(' <operator>between</operator>')
|
| 140 |
+
xml_parts.append(f' <value_min>{rmag_min}</value_min>')
|
| 141 |
+
xml_parts.append(f' <value_max>{rmag_max}</value_max>')
|
| 142 |
+
xml_parts.append(' </filter>')
|
| 143 |
+
elif rmag_min is not None:
|
| 144 |
+
xml_parts.append(' <filter>')
|
| 145 |
+
xml_parts.append(' <column>r_mag</column>')
|
| 146 |
+
xml_parts.append(' <operator>gte</operator>')
|
| 147 |
+
xml_parts.append(f' <value>{rmag_min}</value>')
|
| 148 |
+
xml_parts.append(' </filter>')
|
| 149 |
+
elif rmag_max is not None:
|
| 150 |
+
xml_parts.append(' <filter>')
|
| 151 |
+
xml_parts.append(' <column>r_mag</column>')
|
| 152 |
+
xml_parts.append(' <operator>lte</operator>')
|
| 153 |
+
xml_parts.append(f' <value>{rmag_max}</value>')
|
| 154 |
+
xml_parts.append(' </filter>')
|
| 155 |
+
xml_parts.append(' </filters>')
|
| 156 |
+
|
| 157 |
+
xml_parts.append('</query>')
|
| 158 |
+
return '\n'.join(xml_parts)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def log_query_to_csv(
|
| 162 |
+
query_xml: str,
|
| 163 |
+
csv_path: str = "logs/query_log.csv"
|
| 164 |
+
) -> None:
|
| 165 |
+
"""Log a query to CSV file with datetime and XML string.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
query_xml: XML string representation of the query
|
| 169 |
+
csv_path: Path to the CSV log file
|
| 170 |
+
"""
|
| 171 |
+
import csv
|
| 172 |
+
import os
|
| 173 |
+
|
| 174 |
+
# Create logs directory if it doesn't exist
|
| 175 |
+
log_dir = Path(csv_path).parent
|
| 176 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
# Prepare log entry
|
| 179 |
+
timestamp = datetime.now().isoformat()
|
| 180 |
+
|
| 181 |
+
# Check if file exists to determine if we need to write header
|
| 182 |
+
file_exists = Path(csv_path).exists()
|
| 183 |
+
|
| 184 |
+
# Append to CSV
|
| 185 |
+
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
| 186 |
+
writer = csv.writer(f)
|
| 187 |
+
|
| 188 |
+
# Write header if file is new
|
| 189 |
+
if not file_exists:
|
| 190 |
+
writer.writerow(['datetime', 'query'])
|
| 191 |
+
|
| 192 |
+
# Write the query log
|
| 193 |
+
writer.writerow([timestamp, query_xml])
|
| 194 |
+
|
| 195 |
+
logger.info(f"Query logged to {csv_path}")
|