astronolan Claude commited on
Commit
c89f65f
Β·
1 Parent(s): dda65a0

Add AION-Search Dash app for Hugging Face Spaces

Browse files

- Add complete application code (app.py, src/, clip/)
- Add Dockerfile configured for HF Spaces deployment
- Add requirements.txt with torch-cpu and all dependencies
- Add cleaned model checkpoint (46MB, inference-only)
- Configure for port 7860 with gunicorn

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ env/
7
+ venv/
8
+ .env
9
+ .venv
10
+ *.log
11
+ .DS_Store
12
+ tmp/data/processed/*
13
+ .python-version
14
+ pyproject.toml
15
+ uv.lock
16
+ .claude
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install Python dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application code
15
+ COPY app.py .
16
+ COPY src/ ./src/
17
+ COPY clip/ ./clip/
18
+ COPY aionsearchmodel.pt .
19
+
20
+ # Create necessary directories
21
+ RUN mkdir -p data/processed logs
22
+
23
+ # Expose port for Hugging Face Spaces
24
+ EXPOSE 7860
25
+
26
+ # Run the application with gunicorn for production
27
+ # Increased timeout and workers for HF Spaces
28
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "1", "--threads", "2", "--timeout", "600", "app:server"]
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: AION Search
3
- emoji: πŸ¦€
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AION Search
3
+ emoji: 🌌
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ AION-Search
aionsearchmodel.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e91a0b8e1f632165d62aff10dc598674a35a92e28a4312af220a606bd44664f6
3
+ size 48614488
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """AION Search - Galaxy Semantic Search Application.
3
+
4
+ A Dash web application for semantic search over galaxy images using CLIP embeddings.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import argparse
10
+
11
+ # Fix OpenMP conflict - MUST be set before importing torch/numpy
12
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13
+
14
+ import dash
15
+ import dash_bootstrap_components as dbc
16
+
17
+ import src.config as config
18
+ from src.config import FEATURE_VECTOR_ADDITION
19
+ from src.components import get_app_theme, create_layout
20
+ from src.services import CLIPModelService, EmbeddingService, ZillizService, SearchService, ImageProcessingService
21
+ from src.callbacks import register_callbacks
22
+
23
+ # Set up logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def create_app(checkpoint_path: str) -> dash.Dash:
32
+ """Create and configure the Dash application.
33
+
34
+ Args:
35
+ checkpoint_path: Path to the CLIP model checkpoint
36
+
37
+ Returns:
38
+ Configured Dash app instance
39
+ """
40
+ # Initialize Dash app
41
+ app = dash.Dash(
42
+ __name__,
43
+ external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.FONT_AWESOME],
44
+ suppress_callback_exceptions=True
45
+ )
46
+ server = app.server
47
+
48
+ # Set custom theme
49
+ app.index_string = get_app_theme()
50
+
51
+ # Set app title
52
+ app.title = "AION Galaxy Search"
53
+
54
+ # Initialize services
55
+ logger.info("Initializing services...")
56
+
57
+ # Load CLIP model
58
+ clip_service = CLIPModelService()
59
+ clip_service.load_model(checkpoint_path)
60
+
61
+ # Create service instances
62
+ embedding_service = EmbeddingService(clip_service)
63
+ zilliz_service = ZillizService()
64
+
65
+ # Initialize image processing service for advanced search
66
+ # (now uses pre-existing embeddings from Zilliz, no model loading needed)
67
+ image_service = ImageProcessingService()
68
+ logger.info("Image processing service initialized successfully")
69
+
70
+ search_service = SearchService(embedding_service, zilliz_service, image_service)
71
+
72
+ # Get actual count from Zilliz and update config
73
+ actual_count = zilliz_service.get_collection_count()
74
+ if actual_count > 0:
75
+ config.TOTAL_GALAXIES = actual_count
76
+ logger.info(f"Services initialized. Total galaxies: {config.TOTAL_GALAXIES:,}")
77
+ else:
78
+ logger.warning(f"Failed to get collection count from Zilliz, using default: {config.TOTAL_GALAXIES:,}")
79
+
80
+ # Create app layout
81
+ app.layout = create_layout()
82
+
83
+ # Register callbacks
84
+ register_callbacks(app, search_service)
85
+
86
+ logger.info("App initialization complete!")
87
+
88
+ return app
89
+
90
+
91
+ def main():
92
+ """Main entry point for the application."""
93
+ parser = argparse.ArgumentParser(description='AION Galaxy Search App')
94
+ parser.add_argument(
95
+ '--checkpoint',
96
+ type=str,
97
+ default='aionsearchmodel.pt',
98
+ help='Path to CLIP model checkpoint'
99
+ )
100
+ parser.add_argument(
101
+ '--port',
102
+ type=int,
103
+ default=7860,
104
+ help='Port to run the app on'
105
+ )
106
+ parser.add_argument(
107
+ '--debug',
108
+ action='store_true',
109
+ help='Run in debug mode'
110
+ )
111
+ parser.add_argument(
112
+ '--host',
113
+ type=str,
114
+ default='0.0.0.0',
115
+ help='Host to run the app on'
116
+ )
117
+
118
+ args = parser.parse_args()
119
+
120
+ # Create and run app
121
+ logger.info("Starting AION Galaxy Search...")
122
+ app = create_app(args.checkpoint)
123
+
124
+ logger.info(f"Server starting on {args.host}:{args.port}")
125
+ app.run_server(
126
+ debug=args.debug,
127
+ host=args.host,
128
+ port=args.port
129
+ )
130
+
131
+
132
+ if __name__ == "__main__":
133
+ main()
clip/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLIP alignment for galaxy images and text descriptions.
3
+
4
+ This package provides tools for training and using CLIP-style alignment
5
+ between AION galaxy embeddings and text descriptions.
6
+ """
7
+
8
+ __version__ = "0.1.0"
clip/evaluation/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Evaluation utilities for CLIP model."""
2
+
3
+ from .inference import ClipInferenceModel
4
+
5
+ __all__ = ["ClipInferenceModel"]
clip/evaluation/inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference utilities for trained CLIP model.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from typing import Union, List, Dict, Tuple
10
+ import logging
11
+
12
+ from ..models import GalaxyClipModel
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ClipInferenceModel:
18
+ """Wrapper for using trained CLIP model for inference and search."""
19
+
20
+ def __init__(self, model_path: str, device: str = "cpu"):
21
+ """
22
+ Initialize inference model.
23
+
24
+ Args:
25
+ model_path: Path to saved model (.pt file)
26
+ device: Device to use for inference
27
+ """
28
+ self.device = torch.device(device)
29
+
30
+ # Load model
31
+ checkpoint = torch.load(model_path, map_location=self.device)
32
+ model_config = checkpoint['model_config']
33
+
34
+ # Create model with same config
35
+ self.model = GalaxyClipModel(
36
+ image_input_dim=model_config['image_input_dim'],
37
+ text_input_dim=model_config['text_input_dim'],
38
+ embedding_dim=model_config['embedding_dim']
39
+ )
40
+
41
+ # Load weights
42
+ self.model.load_state_dict(checkpoint['model_state_dict'])
43
+ self.model.to(self.device)
44
+ self.model.eval()
45
+
46
+ self.config = model_config
47
+ logger.info(f"Loaded CLIP model on {device}")
48
+ logger.info(f"Model config: {model_config}")
49
+
50
+ def encode_images(self, image_embeddings):
51
+ """Encode image embeddings to shared space."""
52
+
53
+ tensor = torch.as_tensor(image_embeddings, dtype=torch.float, device=self.device)
54
+
55
+ if tensor.ndim == 1:
56
+ tensor = tensor.unsqueeze(0)
57
+ squeeze = True
58
+ else:
59
+ squeeze = False
60
+
61
+ with torch.no_grad():
62
+ # Use image_projector and normalize
63
+ out = self.model.image_projector(tensor)
64
+
65
+ return out.squeeze(0).cpu() if squeeze else out.cpu()
66
+
67
+ def encode_texts(self, text_embeddings):
68
+ """Encode text embeddings to shared space."""
69
+
70
+ tensor = torch.as_tensor(text_embeddings, dtype=torch.float, device=self.device)
71
+
72
+ if tensor.ndim == 1:
73
+ tensor = tensor.unsqueeze(0)
74
+ squeeze = True
75
+ else:
76
+ squeeze = False
77
+
78
+ with torch.no_grad():
79
+ # Use text_projector and normalize
80
+ out = self.model.text_projector(tensor)
81
+
82
+ return out.squeeze(0).cpu() if squeeze else out.cpu()
clip/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """CLIP model architecture for galaxy embeddings."""
2
+
3
+ from .clip_model import GalaxyClipModel
4
+ from .projections import CrossAttentionImageProjector, TextProjector
5
+
6
+ __all__ = ["GalaxyClipModel", "CrossAttentionImageProjector", "TextProjector"]
clip/models/clip_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict
5
+
6
+ from .projections import TextProjector, CrossAttentionImageProjector, SimpleImageProjector
7
+
8
+ class GalaxyClipModel(nn.Module):
9
+ """CLIP model for aligning galaxy images and text descriptions."""
10
+
11
+ def __init__(
12
+ self,
13
+ image_input_dim: int = 768,
14
+ text_input_dim: int = 3072,
15
+ embedding_dim: int = 1024,
16
+ image_hidden_dim: int = 768,
17
+ text_hidden_dim: int = 1024,
18
+ dropout: float = 0.1,
19
+ use_mean_embeddings: bool = True
20
+ ):
21
+ """
22
+ Initialize CLIP model.
23
+
24
+ Args:
25
+ image_input_dim: AION embedding dimension
26
+ text_input_dim: Text embedding dimension
27
+ embedding_dim: Shared embedding space dimension
28
+ image_hidden_dim: Hidden dimension for image projector
29
+ text_hidden_dim: Hidden dimension for text projector
30
+ dropout: Dropout rate
31
+ use_mean_embeddings: Whether using mean embeddings (True) or full embeddings (False)
32
+ """
33
+ super().__init__()
34
+
35
+ self.embedding_dim = embedding_dim
36
+ self.use_mean_embeddings = use_mean_embeddings
37
+
38
+ # Choose appropriate image projector based on embedding type
39
+ if use_mean_embeddings:
40
+ # Simple projector for mean embeddings (1D vectors)
41
+ self.image_projector = SimpleImageProjector(
42
+ input_dim=image_input_dim,
43
+ output_dim=embedding_dim,
44
+ hidden_dim=image_hidden_dim,
45
+ dropout=dropout
46
+ )
47
+ else:
48
+ # Cross-attention projector for full embeddings (2D sequences)
49
+ self.image_projector = CrossAttentionImageProjector(
50
+ input_dim=image_input_dim,
51
+ output_dim=embedding_dim,
52
+ hidden_dim=image_hidden_dim,
53
+ dropout=dropout
54
+ )
55
+
56
+ self.text_projector = TextProjector(
57
+ input_dim=text_input_dim,
58
+ output_dim=embedding_dim,
59
+ hidden_dim=text_hidden_dim,
60
+ dropout=dropout
61
+ )
62
+
63
+ # Learnable logit scale parameter initialized to standard CLIP temperature 1/0.07
64
+ # Using log parameterization for numerical stability
65
+ self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07, dtype=torch.float32)))
66
+
67
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
68
+ """
69
+ Forward pass for CLIP training.
70
+
71
+ Args:
72
+ batch: Dictionary containing 'image_embedding' and 'text_embedding'
73
+
74
+ Returns:
75
+ Dictionary with projected embeddings and logits
76
+ """
77
+ image_features = batch['image_embedding']
78
+ text_features = batch['text_embedding']
79
+
80
+ # Project to shared space and normalize
81
+ image_features = self.image_projector(image_features)
82
+ text_features = self.text_projector(text_features)
83
+
84
+ # Compute similarity matrix with learnable logit scale
85
+ # Clamp after exp to preserve gradients
86
+ logit_scale = self.logit_scale.exp().clamp(max=100)
87
+ logits_per_image = logit_scale * image_features @ text_features.T
88
+ logits_per_text = logits_per_image.T
89
+
90
+ return {
91
+ 'image_features': image_features,
92
+ 'text_features': text_features,
93
+ 'logits_per_image': logits_per_image,
94
+ 'logits_per_text': logits_per_text,
95
+ 'logit_scale': logit_scale
96
+ }
97
+
98
+ def compute_contrastive_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
99
+ """
100
+ Compute contrastive loss (InfoNCE).
101
+
102
+ Args:
103
+ outputs: Model outputs from forward pass
104
+
105
+ Returns:
106
+ Contrastive loss
107
+ """
108
+ logits_per_image = outputs['logits_per_image']
109
+ logits_per_text = outputs['logits_per_text']
110
+
111
+ batch_size = logits_per_image.shape[0]
112
+ labels = torch.arange(batch_size, device=logits_per_image.device)
113
+
114
+ # Cross-entropy loss for both directions
115
+ loss_i2t = F.cross_entropy(logits_per_image, labels)
116
+ loss_t2i = F.cross_entropy(logits_per_text, labels)
117
+
118
+ return (loss_i2t + loss_t2i) / 2
clip/models/projections.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ import torch.nn.functional as F
5
+
6
+ class TextProjector(nn.Module):
7
+ """Projects text embeddings to shared space."""
8
+
9
+ def __init__(
10
+ self,
11
+ input_dim: int = 3072,
12
+ output_dim: int = 1024,
13
+ hidden_dim: Optional[int] = None,
14
+ dropout: float = 0.1,
15
+ num_layers: int = 4,
16
+ ):
17
+ """
18
+ Initialize text projector.
19
+
20
+ Args:
21
+ input_dim: Dimension of text embeddings (3072)
22
+ output_dim: Dimension of shared embedding space
23
+ hidden_dim: Hidden layer dimension (default: 1024)
24
+ dropout: Dropout rate
25
+ num_layers: Number of residual layers (default: 2)
26
+ """
27
+ super().__init__()
28
+
29
+ if hidden_dim is None:
30
+ hidden_dim = 1024
31
+
32
+ self.fc_in = nn.Linear(input_dim, hidden_dim)
33
+ self.blocks = nn.ModuleList([
34
+ nn.Sequential(
35
+ nn.LayerNorm(hidden_dim),
36
+ nn.GELU(),
37
+ nn.Dropout(dropout),
38
+ nn.Linear(hidden_dim, hidden_dim),
39
+ ) for _ in range(num_layers)
40
+ ])
41
+ self.fc_out = nn.Linear(hidden_dim, output_dim)
42
+
43
+ # Initialize weights
44
+ self._init_weights()
45
+
46
+ def _init_weights(self):
47
+ """Initialize projection weights."""
48
+ for module in self.modules():
49
+ if isinstance(module, nn.Linear):
50
+ nn.init.xavier_uniform_(module.weight)
51
+ if module.bias is not None:
52
+ nn.init.zeros_(module.bias)
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Project text embeddings to shared space.
57
+
58
+ Args:
59
+ x: Text embeddings (batch_size, input_dim)
60
+
61
+ Returns:
62
+ Projected embeddings (batch_size, output_dim)
63
+ """
64
+ h = self.fc_in(x)
65
+ for blk in self.blocks: # residual MLP stack
66
+ h = h + blk(h)
67
+ h = self.fc_out(h)
68
+ return F.normalize(h, dim=-1, eps=1e-3)
69
+
70
+
71
+ class CrossAttentionImageProjector(nn.Module):
72
+ """Simplified projector with self-attention + cross-attention."""
73
+
74
+ def __init__(
75
+ self,
76
+ input_dim: int = 768,
77
+ output_dim: int = 1024,
78
+ hidden_dim: Optional[int] = None,
79
+ dropout: float = 0.1,
80
+ num_layers: int = 2, # Kept for compatibility, not used
81
+ num_heads: int = 4, # Reduced from 8
82
+ ):
83
+ """
84
+ Initialize simplified cross-attention image projector.
85
+
86
+ Args:
87
+ input_dim: Dimension of AION embeddings (768)
88
+ output_dim: Dimension of shared embedding space (default: 1024)
89
+ hidden_dim: Hidden dimension for attention (default: output_dim)
90
+ dropout: Dropout rate
91
+ num_layers: Kept for compatibility but not used
92
+ num_heads: Number of attention heads (reduced to 4)
93
+ """
94
+ super().__init__()
95
+
96
+ if hidden_dim is None:
97
+ hidden_dim = output_dim
98
+
99
+ self.input_dim = input_dim
100
+ self.hidden_dim = hidden_dim
101
+ self.output_dim = output_dim
102
+
103
+ # Project input to hidden dim
104
+ self.input_proj = nn.Linear(input_dim, hidden_dim)
105
+
106
+ # Token pooling to reduce sequence length
107
+ # 576 tokens -> 64 tokens (9x reduction)
108
+ self.token_pool = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, stride=9, padding=0)
109
+
110
+ # Single self-attention layer
111
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
112
+ self.self_attn = nn.MultiheadAttention(
113
+ embed_dim=hidden_dim,
114
+ num_heads=num_heads,
115
+ dropout=dropout,
116
+ batch_first=True
117
+ )
118
+
119
+ # MLP after self-attention
120
+ self.mlp1_norm = nn.LayerNorm(hidden_dim)
121
+ self.mlp1 = nn.Sequential(
122
+ nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
123
+ nn.GELU(),
124
+ nn.Dropout(dropout),
125
+ nn.Linear(hidden_dim * 2, hidden_dim),
126
+ nn.Dropout(dropout)
127
+ )
128
+
129
+ # Learned query vector
130
+ self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
131
+
132
+ # Single cross-attention layer
133
+ self.cross_attn_norm = nn.LayerNorm(hidden_dim)
134
+ self.cross_attn = nn.MultiheadAttention(
135
+ embed_dim=hidden_dim,
136
+ num_heads=num_heads,
137
+ dropout=dropout,
138
+ batch_first=True
139
+ )
140
+
141
+ # Final MLP
142
+ self.final_norm = nn.LayerNorm(hidden_dim)
143
+ self.final_mlp = nn.Sequential(
144
+ nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
145
+ nn.GELU(),
146
+ nn.Dropout(dropout),
147
+ nn.Linear(hidden_dim * 2, output_dim)
148
+ )
149
+
150
+ # Initialize weights
151
+ self._init_weights()
152
+
153
+ def _init_weights(self):
154
+ """Initialize weights."""
155
+ # Initialize query vector
156
+ nn.init.normal_(self.query, std=0.02)
157
+
158
+ # Initialize other weights
159
+ for module in self.modules():
160
+ if isinstance(module, nn.Linear):
161
+ nn.init.xavier_uniform_(module.weight)
162
+ if module.bias is not None:
163
+ nn.init.zeros_(module.bias)
164
+
165
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
166
+ """
167
+ Project image embeddings to shared space using self-attention + cross-attention.
168
+
169
+ Args:
170
+ x: Image embeddings (batch_size, n_tokens, input_dim)
171
+
172
+ Returns:
173
+ Projected embeddings (batch_size, output_dim)
174
+ """
175
+ batch_size = x.shape[0]
176
+ x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input (handles [B, N, D])
177
+
178
+ # Project input
179
+ x = self.input_proj(x) # (B, N, hidden_dim)
180
+
181
+ # Pool tokens to reduce sequence length
182
+ x = x.transpose(1, 2) # (B, hidden_dim, N)
183
+ x = self.token_pool(x) # (B, hidden_dim, N//9)
184
+ x = x.transpose(1, 2) # (B, N//9, hidden_dim)
185
+
186
+ # Self-attention with residual on pooled tokens
187
+ x_norm = self.self_attn_norm(x)
188
+ x_attn, _ = self.self_attn(x_norm, x_norm, x_norm, need_weights=False)
189
+ x = x + x_attn
190
+
191
+ # MLP with residual
192
+ x = x + self.mlp1(self.mlp1_norm(x))
193
+
194
+ # Cross-attention with learned query
195
+ query = self.query.expand(batch_size, -1, -1) # (B, 1, hidden_dim)
196
+ q_norm = self.cross_attn_norm(query)
197
+ attended, _ = self.cross_attn(q_norm, x, x, need_weights=False)
198
+ query = query + attended
199
+
200
+ # Final processing
201
+ output = self.final_norm(query).squeeze(1) # (B, hidden_dim)
202
+ output = self.final_mlp(output) # (B, output_dim)
203
+
204
+ return F.normalize(output, dim=-1, eps=1e-3)
205
+
206
+
207
+ class SimpleImageProjector(nn.Module):
208
+ """Simple projector for mean AION embeddings."""
209
+
210
+ def __init__(
211
+ self,
212
+ input_dim: int = 768,
213
+ output_dim: int = 1024,
214
+ hidden_dim: Optional[int] = None,
215
+ dropout: float = 0.1,
216
+ num_layers: int = 4,
217
+ ):
218
+ """
219
+ Initialize simple image projector.
220
+
221
+ Args:
222
+ input_dim: Dimension of AION embeddings (768)
223
+ output_dim: Dimension of shared embedding space
224
+ hidden_dim: Hidden layer dimension (default: 1024)
225
+ dropout: Dropout rate
226
+ num_layers: Number of residual layers (default: 4)
227
+ """
228
+ super().__init__()
229
+
230
+ if hidden_dim is None:
231
+ hidden_dim = 1024
232
+
233
+ self.fc_in = nn.Linear(input_dim, hidden_dim)
234
+ self.blocks = nn.ModuleList([
235
+ nn.Sequential(
236
+ nn.LayerNorm(hidden_dim),
237
+ nn.GELU(),
238
+ nn.Dropout(dropout),
239
+ nn.Linear(hidden_dim, hidden_dim),
240
+ ) for _ in range(num_layers)
241
+ ])
242
+ self.fc_out = nn.Linear(hidden_dim, output_dim)
243
+
244
+ # Initialize weights
245
+ self._init_weights()
246
+
247
+ def _init_weights(self):
248
+ """Initialize projection weights."""
249
+ for module in self.modules():
250
+ if isinstance(module, nn.Linear):
251
+ nn.init.xavier_uniform_(module.weight)
252
+ if module.bias is not None:
253
+ nn.init.zeros_(module.bias)
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ """
257
+ Project image embeddings to shared space.
258
+
259
+ Args:
260
+ x: Image embeddings (batch_size, input_dim)
261
+
262
+ Returns:
263
+ Projected embeddings (batch_size, output_dim)
264
+ """
265
+ x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input
266
+ h = self.fc_in(x)
267
+ for blk in self.blocks: # residual MLP stack
268
+ h = h + blk(h)
269
+ h = self.fc_out(h)
270
+ return F.normalize(h, dim=-1, eps=1e-3)
clip/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for CLIP training and evaluation."""
2
+
3
+ from .logging_utils import setup_logging
4
+ from .io_utils import save_clip_embeddings_hdf5, inspect_generated_files
5
+
6
+ __all__ = [
7
+ "setup_logging",
8
+ "save_clip_embeddings_hdf5",
9
+ "inspect_generated_files"
10
+ ]
clip/utils/data_loader.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loader for multi-text training using unified parquet file with nested text embeddings.
3
+ This loader handles the new unified format from 05_generate_unified_embeddings.py.
4
+ """
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data import Dataset, DataLoader
10
+ import logging
11
+ from pathlib import Path
12
+ import random
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class UnifiedMultiTextDataset(Dataset):
18
+ """Dataset for unified parquet file with multiple text embeddings per galaxy."""
19
+
20
+ def __init__(self, parquet_path, split="train", train_ratio=0.8,
21
+ text_sampling_strategy="random", epoch=0, max_train_samples=None,
22
+ num_embedding=None):
23
+ self.parquet_path = Path(parquet_path)
24
+ self.split = split
25
+ self.train_ratio = train_ratio
26
+ self.text_sampling_strategy = text_sampling_strategy
27
+ self.epoch = epoch
28
+ self.max_train_samples = max_train_samples
29
+ self.num_embedding = num_embedding
30
+
31
+ # Load the parquet file
32
+ logger.info(f"Loading unified embeddings from {self.parquet_path}")
33
+ self.df = pd.read_parquet(self.parquet_path)
34
+
35
+ # Create train/val split based on galaxy_index
36
+ n_samples = len(self.df)
37
+ indices = np.arange(n_samples)
38
+ self.seed = 42
39
+
40
+ # Deterministic split based on galaxy_index
41
+ split_mask = []
42
+ for idx in range(n_samples):
43
+ galaxy_idx = self.df.iloc[idx]['galaxy_index']
44
+ # Hash the galaxy index for deterministic assignment
45
+ sample_hash = hash((galaxy_idx, self.seed)) % 10000 / 10000.0
46
+ is_train = sample_hash < self.train_ratio
47
+ split_mask.append(is_train)
48
+
49
+ split_mask = np.array(split_mask)
50
+
51
+ if split == "train":
52
+ self.indices = indices[split_mask]
53
+ # Limit training samples if specified
54
+ if self.max_train_samples is not None and len(self.indices) > self.max_train_samples:
55
+ rng = np.random.RandomState(self.seed)
56
+ selected_indices = rng.choice(self.indices, size=self.max_train_samples, replace=False)
57
+ self.indices = np.sort(selected_indices) # Sort for reproducibility
58
+ logger.info(f"Limited training set to {self.max_train_samples} samples")
59
+ else:
60
+ self.indices = indices[~split_mask]
61
+
62
+ logger.info(f"Dataset initialized: {len(self.indices)} samples for {split} split")
63
+ logger.info(f"Text sampling strategy: {text_sampling_strategy}")
64
+
65
+ # Validate num_embedding parameter for specific_summary strategy
66
+ if text_sampling_strategy == "specific_summary" and num_embedding is None:
67
+ raise ValueError("num_embedding parameter is required when using 'specific_summary' strategy")
68
+
69
+ # Check data structure
70
+ sample_row = self.df.iloc[0]
71
+ n_augmented = len(sample_row['augmented_embeddings'])
72
+ logger.info(f"Each galaxy has 1 original + {n_augmented} augmented embeddings = {1 + n_augmented} total")
73
+
74
+ # Validate num_embedding is within valid range
75
+ if text_sampling_strategy == "specific_summary":
76
+ total_embeddings = 1 + n_augmented
77
+ if num_embedding < 0 or num_embedding >= total_embeddings:
78
+ raise ValueError(f"num_embedding must be between 0 and {total_embeddings-1}, got {num_embedding}")
79
+ logger.info(f"Using specific embedding at index {num_embedding}")
80
+
81
+ def __len__(self):
82
+ return len(self.indices)
83
+
84
+ def set_epoch(self, epoch):
85
+ """Set current epoch for round-robin sampling."""
86
+ self.epoch = epoch
87
+
88
+ def _get_all_embeddings_and_sources(self, row):
89
+ """Combine original and augmented embeddings into single lists."""
90
+ # Start with original embedding
91
+ all_embeddings = [np.array(row['text_embedding'], dtype=np.float32)]
92
+ all_sources = [row['description_sources'][0]] # 'original'
93
+
94
+ # Add augmented embeddings
95
+ for aug_emb, aug_source in zip(row['augmented_embeddings'], row['description_sources'][1:]):
96
+ all_embeddings.append(np.array(aug_emb, dtype=np.float32))
97
+ all_sources.append(aug_source)
98
+
99
+ return all_embeddings, all_sources
100
+
101
+ def _sample_text_embedding(self, text_embeddings, text_sources, galaxy_idx):
102
+ """Sample one text embedding from multiple options."""
103
+ n_texts = len(text_embeddings)
104
+
105
+ if self.text_sampling_strategy == "original":
106
+ # Always use original text (index 0)
107
+ idx = 0
108
+ elif self.text_sampling_strategy == "summaries-only":
109
+ # Only use summaries (exclude original at index 0)
110
+ if n_texts > 1:
111
+ rng = random.Random(galaxy_idx + self.epoch * 1000000)
112
+ idx = rng.randint(1, n_texts - 1) # Start from 1 to exclude original
113
+ else:
114
+ # Fallback to original if no summaries available
115
+ idx = 0
116
+ elif self.text_sampling_strategy == "specific_summary":
117
+ # Use the specific embedding index provided
118
+ if self.num_embedding < n_texts:
119
+ idx = self.num_embedding
120
+ else:
121
+ # Fallback to original if index out of range
122
+ logger.warning(f"Requested embedding index {self.num_embedding} out of range for {n_texts} embeddings, using original")
123
+ idx = 0
124
+ elif self.text_sampling_strategy == "random":
125
+ # Random sampling with seed based on galaxy_idx and epoch
126
+ rng = random.Random(galaxy_idx + self.epoch * 1000000)
127
+ idx = rng.randint(0, n_texts - 1)
128
+ elif self.text_sampling_strategy == "round-robin":
129
+ # Cycle through texts based on epoch
130
+ idx = (self.epoch + galaxy_idx) % n_texts
131
+ elif self.text_sampling_strategy == "weighted":
132
+ # Weight towards original (50%) and summaries (50% / n_summaries each)
133
+ rng = random.Random(galaxy_idx + self.epoch * 1000000)
134
+ n_summaries = n_texts - 1
135
+ if n_summaries > 0:
136
+ summary_weight = 0.5 / n_summaries
137
+ weights = [0.5] + [summary_weight] * n_summaries
138
+ else:
139
+ weights = [1.0]
140
+ idx = rng.choices(range(n_texts), weights=weights)[0]
141
+ else:
142
+ idx = 0 # Default to original
143
+
144
+ return text_embeddings[idx], text_sources[idx], idx
145
+
146
+ def __getitem__(self, idx):
147
+ """Get a single sample with randomly selected text embedding."""
148
+ actual_idx = self.indices[idx]
149
+ row = self.df.iloc[actual_idx]
150
+
151
+ # Get AION embedding
152
+ aion_embedding = np.array(row['aion_embedding'], dtype=np.float32)
153
+
154
+ # Get all text embeddings and sources
155
+ text_embeddings, text_sources = self._get_all_embeddings_and_sources(row)
156
+
157
+ # Sample one text embedding
158
+ galaxy_idx = row['galaxy_index']
159
+ selected_text, selected_source, text_idx = self._sample_text_embedding(
160
+ text_embeddings, text_sources, galaxy_idx
161
+ )
162
+
163
+ # Log selection details periodically (every 100th sample)
164
+ if idx % 100 == 0:
165
+ logger.debug(f"Galaxy {galaxy_idx}: Selected {selected_source} (index {text_idx}) from {len(text_sources)} options")
166
+
167
+ return {
168
+ 'aion_embedding': torch.from_numpy(aion_embedding),
169
+ 'text_embedding': torch.from_numpy(selected_text),
170
+ 'galaxy_index': galaxy_idx,
171
+ 'text_source': selected_source,
172
+ 'text_index': text_idx,
173
+ 'object_id': row['object_id']
174
+ }
175
+
176
+
177
+ def create_unified_multi_text_loaders(
178
+ unified_embeddings_path,
179
+ batch_size=64,
180
+ train_ratio=0.8,
181
+ pin_memory=True,
182
+ text_sampling_strategy="random",
183
+ num_workers=4,
184
+ max_train_samples=None,
185
+ num_embedding=None,
186
+ **kwargs
187
+ ):
188
+ """
189
+ Create train and validation data loaders for multi-text training from unified parquet.
190
+
191
+ Args:
192
+ unified_embeddings_path: Path to unified parquet file
193
+ batch_size: Batch size for training
194
+ train_ratio: Fraction of samples for training
195
+ pin_memory: Whether to pin memory for GPU transfer
196
+ text_sampling_strategy: How to sample text embeddings ("original", "summaries-only", "specific_summary", "random", "round-robin", "weighted")
197
+ num_workers: Number of data loading workers
198
+ max_train_samples: Maximum number of training samples (for data scaling experiments)
199
+ num_embedding: When using "specific_summary" strategy, the index of the embedding to use
200
+ **kwargs: Additional arguments
201
+ """
202
+
203
+ # Convert to Path
204
+ parquet_path = Path(unified_embeddings_path)
205
+
206
+ if not parquet_path.exists():
207
+ raise ValueError(f"Unified embeddings file not found: {parquet_path}")
208
+
209
+ logger.info(f"Creating unified multi-text data loaders from {parquet_path}")
210
+ logger.info(f"Batch size: {batch_size}, Workers: {num_workers}")
211
+ logger.info(f"Text sampling strategy: {text_sampling_strategy}")
212
+
213
+ # Create datasets
214
+ train_dataset = UnifiedMultiTextDataset(
215
+ parquet_path=parquet_path,
216
+ split="train",
217
+ train_ratio=train_ratio,
218
+ text_sampling_strategy=text_sampling_strategy,
219
+ max_train_samples=max_train_samples,
220
+ num_embedding=num_embedding
221
+ )
222
+
223
+ val_dataset = UnifiedMultiTextDataset(
224
+ parquet_path=parquet_path,
225
+ split="val",
226
+ train_ratio=train_ratio,
227
+ text_sampling_strategy=text_sampling_strategy,
228
+ num_embedding=num_embedding
229
+ )
230
+
231
+ # Create loaders
232
+ train_loader = DataLoader(
233
+ train_dataset,
234
+ batch_size=batch_size,
235
+ shuffle=True, # Shuffle within the train split
236
+ num_workers=num_workers,
237
+ pin_memory=pin_memory,
238
+ drop_last=True # Drop incomplete batches for stable training
239
+ )
240
+
241
+ val_loader = DataLoader(
242
+ val_dataset,
243
+ batch_size=batch_size,
244
+ shuffle=False, # No shuffle for validation
245
+ num_workers=num_workers,
246
+ pin_memory=pin_memory,
247
+ drop_last=False
248
+ )
249
+
250
+ return train_loader, val_loader
clip/utils/io_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ I/O utilities for saving and loading CLIP embeddings.
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from datetime import datetime
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def save_clip_embeddings_hdf5(
15
+ object_ids,
16
+ galaxy_data,
17
+ text_data,
18
+ aion_clip_embeddings,
19
+ text_clip_embeddings,
20
+ output_dir="data/processed"
21
+ ):
22
+ """Save CLIP embeddings to separate HDF5 files."""
23
+ output_dir = Path(output_dir)
24
+ output_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ # File paths (standardized names)
27
+ aion_clip_path = output_dir / "galaxy_aion_clip_embeddings.hdf5"
28
+ text_clip_path = output_dir / "galaxy_text_clip_embeddings.hdf5"
29
+
30
+ logger.info(f"Saving AION CLIP embeddings to: {aion_clip_path}")
31
+
32
+ # Save AION CLIP embeddings
33
+ with h5py.File(aion_clip_path, 'w') as f:
34
+ # Object IDs
35
+ dt = h5py.special_dtype(vlen=str)
36
+ f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
37
+
38
+ # Coordinates and metadata
39
+ ra_values = np.array([galaxy_data[oid]['ra'] for oid in object_ids])
40
+ dec_values = np.array([galaxy_data[oid]['dec'] for oid in object_ids])
41
+ healpix_values = np.array([galaxy_data[oid]['healpix'] for oid in object_ids])
42
+
43
+ f.create_dataset('ra', data=ra_values, dtype=np.float64)
44
+ f.create_dataset('dec', data=dec_values, dtype=np.float64)
45
+ f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
46
+
47
+ # AION CLIP embeddings
48
+ f.create_dataset('AION_clip_embedding', data=aion_clip_embeddings, dtype=np.float32)
49
+
50
+ # Metadata
51
+ f.attrs['description'] = 'AION embeddings encoded through trained CLIP model'
52
+ f.attrs['embedding_dim'] = aion_clip_embeddings.shape[1]
53
+ f.attrs['num_objects'] = len(object_ids)
54
+ f.attrs['created'] = datetime.now().isoformat()
55
+
56
+ logger.info(f"Saving text CLIP embeddings to: {text_clip_path}")
57
+
58
+ # Save text CLIP embeddings
59
+ with h5py.File(text_clip_path, 'w') as f:
60
+ # Object IDs
61
+ dt = h5py.special_dtype(vlen=str)
62
+ f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
63
+
64
+ # Coordinates and metadata (use text data for consistency)
65
+ ra_values = np.array([text_data[oid]['ra'] for oid in object_ids])
66
+ dec_values = np.array([text_data[oid]['dec'] for oid in object_ids])
67
+ healpix_values = np.array([text_data[oid]['healpix'] for oid in object_ids])
68
+
69
+ f.create_dataset('ra', data=ra_values, dtype=np.float64)
70
+ f.create_dataset('dec', data=dec_values, dtype=np.float64)
71
+ f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
72
+
73
+ # Text CLIP embeddings
74
+ f.create_dataset('text_clip_embedding', data=text_clip_embeddings, dtype=np.float32)
75
+
76
+ # Metadata
77
+ f.attrs['description'] = 'Text embeddings encoded through trained CLIP model'
78
+ f.attrs['embedding_dim'] = text_clip_embeddings.shape[1]
79
+ f.attrs['num_objects'] = len(object_ids)
80
+ f.attrs['created'] = datetime.now().isoformat()
81
+
82
+ return aion_clip_path, text_clip_path
83
+
84
+
85
+ def inspect_generated_files(aion_clip_path, text_clip_path):
86
+ """Inspect the generated HDF5 files."""
87
+ logger.info("Inspecting generated AION CLIP embeddings file...")
88
+
89
+ with h5py.File(aion_clip_path, 'r') as f:
90
+ logger.info(f"AION file datasets: {list(f.keys())}")
91
+ for key in f.keys():
92
+ dataset = f[key]
93
+ logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
94
+ logger.info(f" Attributes: {dict(f.attrs)}")
95
+
96
+ logger.info("Inspecting generated text CLIP embeddings file...")
97
+
98
+ with h5py.File(text_clip_path, 'r') as f:
99
+ logger.info(f"Text file datasets: {list(f.keys())}")
100
+ for key in f.keys():
101
+ dataset = f[key]
102
+ logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
103
+ logger.info(f" Attributes: {dict(f.attrs)}")
clip/utils/logging_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging utilities."""
2
+
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+
7
+
8
+ def setup_logging(log_level: str = "INFO", log_file: str = None):
9
+ """
10
+ Setup logging configuration.
11
+
12
+ Args:
13
+ log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
14
+ log_file: Optional path to log file
15
+ """
16
+ # Clear any existing handlers
17
+ logging.getLogger().handlers.clear()
18
+
19
+ # Create formatter
20
+ formatter = logging.Formatter(
21
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
+
24
+ # Console handler
25
+ console_handler = logging.StreamHandler(sys.stdout)
26
+ console_handler.setFormatter(formatter)
27
+
28
+ # Setup root logger
29
+ logger = logging.getLogger()
30
+ logger.setLevel(getattr(logging, log_level.upper()))
31
+ logger.addHandler(console_handler)
32
+
33
+ # File handler if specified
34
+ if log_file:
35
+ log_path = Path(log_file)
36
+ log_path.parent.mkdir(parents=True, exist_ok=True)
37
+
38
+ file_handler = logging.FileHandler(log_path)
39
+ file_handler.setFormatter(formatter)
40
+ logger.addHandler(file_handler)
41
+
42
+ return logger
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from aion-search!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dash==2.14.1
2
+ dash-bootstrap-components==1.5.0
3
+ h5py==3.10.0
4
+ numpy==1.24.3
5
+ openai==1.10.0
6
+ httpx==0.26.0
7
+ gunicorn==21.2.0
8
+ huggingface-hub==0.20.1
9
+ pandas==2.0.3
10
+ faiss-cpu==1.7.4
11
+ python-dotenv==1.1.1
12
+ torch --index-url https://download.pytorch.org/whl/cpu
13
+ requests
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """AION Search - Galaxy Semantic Search Application."""
2
+
3
+ __version__ = "0.2.0"
src/callbacks.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dash callbacks for AION Search."""
2
+
3
+ import json
4
+ import time
5
+ import logging
6
+ import traceback
7
+ import pandas as pd
8
+ import dash
9
+ from dash import Input, Output, State, callback_context, html
10
+ import dash_bootstrap_components as dbc
11
+
12
+ import src.config as config
13
+ from src.config import (
14
+ DEFAULT_DISPLAY_COUNT,
15
+ LOAD_MORE_COUNT,
16
+ IMAGE_HEIGHT,
17
+ IMAGE_WIDTH,
18
+ ZILLIZ_PRIMARY_KEY,
19
+ )
20
+ from src.components import create_vector_input_row
21
+ from src.services import SearchService
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def register_callbacks(app, search_service: SearchService):
27
+ """Register all Dash callbacks with the app.
28
+
29
+ Args:
30
+ app: Dash app instance
31
+ search_service: SearchService instance for performing searches
32
+ """
33
+
34
+ @app.callback(
35
+ Output("galaxy-count", "children"),
36
+ Input("galaxy-count", "id")
37
+ )
38
+ def update_galaxy_count(_):
39
+ """Update the galaxy count display."""
40
+ if search_service and config.TOTAL_GALAXIES > 0:
41
+ return f"{config.TOTAL_GALAXIES:,} galaxies"
42
+ else:
43
+ return "loading..."
44
+
45
+ @app.callback(
46
+ [Output("vector-collapse", "is_open"),
47
+ Output("vector-arrow", "className")],
48
+ Input("vector-toggle", "n_clicks"),
49
+ State("vector-collapse", "is_open"),
50
+ prevent_initial_call=True
51
+ )
52
+ def toggle_vector_section(n_clicks, is_open):
53
+ """Toggle the vector addition section."""
54
+ new_state = not is_open
55
+ arrow_class = "fas fa-chevron-up" if new_state else "fas fa-chevron-down"
56
+ return new_state, arrow_class
57
+
58
+ @app.callback(
59
+ Output("vector-inputs", "children", allow_duplicate=True),
60
+ Input({"type": "vector-delete", "index": dash.dependencies.ALL}, "n_clicks"),
61
+ State("vector-inputs", "children"),
62
+ prevent_initial_call=True
63
+ )
64
+ def delete_vector_input(n_clicks_list, current_children):
65
+ """Handle deletion of vector input rows."""
66
+ if not n_clicks_list or not any(n_clicks_list):
67
+ return dash.no_update
68
+
69
+ ctx = callback_context
70
+ if not ctx.triggered:
71
+ return dash.no_update
72
+
73
+ if ctx.triggered[0]["value"] is None or ctx.triggered[0]["value"] == 0:
74
+ return dash.no_update
75
+
76
+ button_id = ctx.triggered[0]["prop_id"]
77
+ index_to_delete = json.loads(button_id.split(".")[0])["index"]
78
+
79
+ logger.info(f"Delete button clicked for index: {index_to_delete}")
80
+
81
+ # Filter out the row with the matching index
82
+ new_children = []
83
+ for child in current_children:
84
+ should_keep = True
85
+
86
+ if isinstance(child, dict):
87
+ if 'props' in child and 'id' in child['props']:
88
+ child_id = child['props']['id']
89
+ if isinstance(child_id, dict) and child_id.get("type") == "vector-row" and child_id.get("index") == index_to_delete:
90
+ should_keep = False
91
+ elif hasattr(child, 'id') and isinstance(child.id, dict):
92
+ if child.id.get("type") == "vector-row" and child.id.get("index") == index_to_delete:
93
+ should_keep = False
94
+
95
+ if should_keep:
96
+ new_children.append(child)
97
+
98
+ # Ensure at least one input remains
99
+ if len(new_children) == 0:
100
+ new_children = [create_vector_input_row(0)]
101
+
102
+ return new_children
103
+
104
+ @app.callback(
105
+ [Output("vector-inputs", "children"),
106
+ Output("vector-inputs-count", "data")],
107
+ Input("add-vector-input", "n_clicks"),
108
+ [State("vector-inputs", "children"),
109
+ State("vector-inputs-count", "data")],
110
+ prevent_initial_call=True
111
+ )
112
+ def add_vector_input(n_clicks, current_children, count):
113
+ """Add a new vector input row."""
114
+ if n_clicks:
115
+ new_input = create_vector_input_row(count)
116
+ current_children.append(new_input)
117
+ return current_children, count + 1
118
+
119
+ return dash.no_update, dash.no_update
120
+
121
+ @app.callback(
122
+ [Output({"type": "text-input-container", "index": dash.dependencies.ALL}, "style"),
123
+ Output({"type": "image-input-container", "index": dash.dependencies.ALL}, "style")],
124
+ Input({"type": "vector-query-type", "index": dash.dependencies.ALL}, "value"),
125
+ prevent_initial_call=False
126
+ )
127
+ def toggle_query_type_inputs(query_types):
128
+ """Toggle visibility of text vs image inputs based on query type selection."""
129
+ text_styles = []
130
+ image_styles = []
131
+
132
+ for query_type in query_types:
133
+ if query_type == "text":
134
+ text_styles.append({"display": "block"})
135
+ image_styles.append({"display": "none"})
136
+ else: # image
137
+ text_styles.append({"display": "none"})
138
+ image_styles.append({"display": "block"})
139
+
140
+ return text_styles, image_styles
141
+
142
+ @app.callback(
143
+ [Output("search-button", "n_clicks"),
144
+ Output("search-input", "value")],
145
+ [Input("example-1", "n_clicks"),
146
+ Input("example-2", "n_clicks"),
147
+ Input("example-3", "n_clicks"),
148
+ Input("example-4", "n_clicks"),
149
+ Input("example-5", "n_clicks"),
150
+ Input("example-6", "n_clicks"),
151
+ Input("example-7", "n_clicks")],
152
+ [State("search-button", "n_clicks")],
153
+ prevent_initial_call=True
154
+ )
155
+ def trigger_search_from_examples(click1, click2, click3, click4, click5, click6, click7, current_clicks):
156
+ """Trigger search when example buttons are clicked."""
157
+ ctx = callback_context
158
+ if not ctx.triggered:
159
+ return dash.no_update, dash.no_update
160
+
161
+ button_id = ctx.triggered[0]["prop_id"].split(".")[0]
162
+
163
+ example_queries = {
164
+ "example-1": "Merging edge-on galaxy",
165
+ "example-2": "A peculiar interacting galaxy system featuring plenty of tidal tails and a disturbed morphology",
166
+ "example-3": "a faint tidal stream wrapping around",
167
+ "example-4": "Strong gravitational lens",
168
+ "example-5": "A violent merger in progress with visible tidal features",
169
+ "example-6": "Low surface brightness",
170
+ "example-7": "Ring galaxy"
171
+ }
172
+
173
+ search_query = example_queries.get(button_id, "")
174
+
175
+ if search_query:
176
+ return (current_clicks or 0) + 1, search_query
177
+
178
+ return dash.no_update, dash.no_update
179
+
180
+ @app.callback(
181
+ [Output("search-time", "children"),
182
+ Output("search-results", "children"),
183
+ Output("search-data", "data"),
184
+ Output("download-button", "disabled")],
185
+ [Input("search-button", "n_clicks"),
186
+ Input("search-input", "n_submit")],
187
+ [State("search-input", "value"),
188
+ State("rmag-slider", "value")],
189
+ prevent_initial_call=True
190
+ )
191
+ def perform_search(n_clicks, n_submit, query, rmag_range):
192
+ """Perform text search."""
193
+ if not query or not query.strip():
194
+ return "", dbc.Alert("Please enter a search query", color="warning"), None, True
195
+
196
+ try:
197
+ # Extract min and max from slider range
198
+ rmag_min, rmag_max = rmag_range if rmag_range else (None, None)
199
+
200
+ start_time = time.time()
201
+ df = search_service.search_text(query, rmag_min=rmag_min, rmag_max=rmag_max)
202
+ search_time = time.time() - start_time
203
+
204
+ # Log query to XML/CSV
205
+ from src.utils import build_query_xml, log_query_to_csv
206
+ query_xml = build_query_xml(
207
+ text_queries=[query],
208
+ text_weights=[1.0],
209
+ rmag_min=rmag_min,
210
+ rmag_max=rmag_max
211
+ )
212
+ log_query_to_csv(query_xml)
213
+
214
+ # Build results grid - only load first 60 images
215
+ grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT))
216
+
217
+ # Prepare data for store
218
+ search_data = prepare_search_data(df, query)
219
+
220
+ # Create load more button
221
+ load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None
222
+
223
+ # Build filter description
224
+ filter_desc = ""
225
+ if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0):
226
+ filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]"
227
+
228
+ # Build complete results container
229
+ results_container = html.Div([
230
+ html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})",
231
+ className="results-header mb-2 text-center"),
232
+ html.P(f"'{query}'{filter_desc}",
233
+ className="text-center mb-3",
234
+ style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}),
235
+ dbc.Row(grid_items, justify="center", id="search-results-grid"),
236
+ load_more_button
237
+ ])
238
+
239
+ return "", results_container, search_data, False
240
+
241
+ except Exception as e:
242
+ error_msg = dbc.Alert(f"Search failed: {str(e)}", color="danger")
243
+ logger.error(f"Search error: {e}")
244
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
245
+ return "", error_msg, None, True
246
+
247
+ @app.callback(
248
+ [Output("galaxy-modal", "is_open"),
249
+ Output("modal-title", "children"),
250
+ Output("modal-image", "children"),
251
+ Output("modal-description", "children"),
252
+ Output("current-galaxy-data", "data")],
253
+ [Input({"type": "galaxy-image", "index": dash.dependencies.ALL}, "n_clicks"),
254
+ Input("close-modal", "n_clicks")],
255
+ [State("galaxy-modal", "is_open"),
256
+ State("search-data", "data")],
257
+ prevent_initial_call=True
258
+ )
259
+ def toggle_modal(image_clicks, close_click, is_open, search_data):
260
+ """Toggle galaxy detail modal."""
261
+ ctx = callback_context
262
+
263
+ if not ctx.triggered:
264
+ return False, "", "", "", None
265
+
266
+ if ctx.triggered[0]["prop_id"] == "close-modal.n_clicks":
267
+ return False, "", "", "", None
268
+
269
+ if search_data:
270
+ triggered_prop = ctx.triggered[0]["prop_id"]
271
+ triggered_value = ctx.triggered[0]["value"]
272
+
273
+ if triggered_value is None or triggered_value == 0:
274
+ return False, "", "", "", None
275
+
276
+ if "galaxy-image" in triggered_prop:
277
+ try:
278
+ prop_dict = json.loads(triggered_prop.split(".n_clicks")[0])
279
+ clicked_idx = prop_dict["index"]
280
+
281
+ if clicked_idx < len(search_data["ra"]):
282
+ galaxy_info = extract_galaxy_info(search_data, clicked_idx)
283
+ image_element, description_element = build_modal_content(galaxy_info)
284
+
285
+ galaxy_data = {
286
+ ZILLIZ_PRIMARY_KEY: galaxy_info[ZILLIZ_PRIMARY_KEY],
287
+ "ra": galaxy_info["ra"],
288
+ "dec": galaxy_info["dec"],
289
+ "distance": galaxy_info["distance"],
290
+ "r_mag": galaxy_info["r_mag"]
291
+ }
292
+
293
+ return (
294
+ True,
295
+ f"Galaxy at RA={galaxy_info['ra']:.6f}, Dec={galaxy_info['dec']:.6f}",
296
+ image_element,
297
+ description_element,
298
+ galaxy_data
299
+ )
300
+ except:
301
+ pass
302
+
303
+ return False, "", "", "", None
304
+
305
+ @app.callback(
306
+ Output("info-modal", "is_open"),
307
+ [Input("info-button", "n_clicks"),
308
+ Input("close-info-modal", "n_clicks")],
309
+ State("info-modal", "is_open"),
310
+ prevent_initial_call=True
311
+ )
312
+ def toggle_info_modal(info_click, close_click, is_open):
313
+ """Toggle info modal."""
314
+ ctx = callback_context
315
+ if ctx.triggered:
316
+ button_id = ctx.triggered[0]["prop_id"].split(".")[0]
317
+ if button_id == "info-button":
318
+ return True
319
+ elif button_id == "close-info-modal":
320
+ return False
321
+ return is_open
322
+
323
+ @app.callback(
324
+ [Output("search-results", "children", allow_duplicate=True),
325
+ Output("search-data", "data", allow_duplicate=True)],
326
+ Input("load-more-button", "n_clicks"),
327
+ State("search-data", "data"),
328
+ prevent_initial_call=True
329
+ )
330
+ def load_more_galaxies(n_clicks, search_data):
331
+ """Load more galaxies when the load more button is clicked."""
332
+ if n_clicks and search_data and "loaded_count" in search_data:
333
+ current_count = search_data["loaded_count"]
334
+ total_count = len(search_data["ra"])
335
+ next_count = min(current_count + LOAD_MORE_COUNT, total_count)
336
+
337
+ # Build ALL grid items (existing + new)
338
+ all_grid_items = []
339
+ for i in range(next_count):
340
+ galaxy_info = extract_galaxy_info(search_data, i)
341
+ grid_item = build_galaxy_card(galaxy_info, i)
342
+ all_grid_items.append(grid_item)
343
+
344
+ search_data["loaded_count"] = next_count
345
+
346
+ load_more_button = create_load_more_button(total_count, next_count) if next_count < total_count else None
347
+
348
+ results_container = html.Div([
349
+ html.P(f"Top {total_count} matching galaxies (showing {next_count})",
350
+ className="results-header mb-2 text-center"),
351
+ html.P(f"'{search_data['query']}'",
352
+ className="text-center mb-3",
353
+ style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}),
354
+ dbc.Row(all_grid_items, justify="center", id="search-results-grid"),
355
+ load_more_button
356
+ ])
357
+
358
+ return results_container, search_data
359
+
360
+ return dash.no_update, dash.no_update
361
+
362
+ @app.callback(
363
+ [Output("vector-inputs", "children", allow_duplicate=True),
364
+ Output("vector-inputs-count", "data", allow_duplicate=True),
365
+ Output("vector-collapse", "is_open", allow_duplicate=True),
366
+ Output("galaxy-modal", "is_open", allow_duplicate=True)],
367
+ Input("add-to-advanced-search", "n_clicks"),
368
+ [State("current-galaxy-data", "data"),
369
+ State("vector-inputs", "children"),
370
+ State("vector-inputs-count", "data")],
371
+ prevent_initial_call=True
372
+ )
373
+ def add_galaxy_to_advanced_search(n_clicks, galaxy_data, current_children, count):
374
+ """Add the current galaxy's RA/Dec to advanced search."""
375
+ if not n_clicks or not galaxy_data:
376
+ return dash.no_update, dash.no_update, dash.no_update, dash.no_update
377
+
378
+ # Extract galaxy coordinates
379
+ ra = galaxy_data.get('ra')
380
+ dec = galaxy_data.get('dec')
381
+
382
+ if ra is None or dec is None:
383
+ return dash.no_update, dash.no_update, dash.no_update, dash.no_update
384
+
385
+ # Create a new image input row with the galaxy's RA/Dec pre-filled
386
+ new_row = create_vector_input_row(
387
+ index=count,
388
+ query_type="image",
389
+ ra=ra,
390
+ dec=dec,
391
+ fov=0.025
392
+ )
393
+
394
+ current_children.append(new_row)
395
+
396
+ # Return updated children, incremented count, open vector panel, close modal
397
+ return current_children, count + 1, True, False
398
+
399
+ @app.callback(
400
+ [Output("search-time", "children", allow_duplicate=True),
401
+ Output("search-results", "children", allow_duplicate=True),
402
+ Output("search-data", "data", allow_duplicate=True),
403
+ Output("download-button", "disabled", allow_duplicate=True)],
404
+ Input("vector-search-button", "n_clicks"),
405
+ [State({"type": "vector-query-type", "index": dash.dependencies.ALL}, "value"),
406
+ State({"type": "vector-text", "index": dash.dependencies.ALL}, "value"),
407
+ State({"type": "vector-ra", "index": dash.dependencies.ALL}, "value"),
408
+ State({"type": "vector-dec", "index": dash.dependencies.ALL}, "value"),
409
+ State({"type": "vector-fov", "index": dash.dependencies.ALL}, "value"),
410
+ State({"type": "vector-operation", "index": dash.dependencies.ALL}, "value"),
411
+ State("rmag-slider", "value")],
412
+ prevent_initial_call=True
413
+ )
414
+ def perform_vector_search(n_clicks, query_types, text_values, ra_values, dec_values, fov_values, operations, rmag_range):
415
+ """Perform advanced vector search with multiple text and/or image queries."""
416
+ if not n_clicks:
417
+ return dash.no_update, dash.no_update, dash.no_update, dash.no_update
418
+
419
+ def operation_to_weight(op_str):
420
+ """Convert operation string to float weight."""
421
+ if op_str == "+":
422
+ return 1.0
423
+ elif op_str == "-":
424
+ return -1.0
425
+ else:
426
+ # For magnitude values like "+2", "-5", etc.
427
+ return float(op_str)
428
+
429
+ def weight_to_display(weight):
430
+ """Convert weight back to display string."""
431
+ if weight == 1.0:
432
+ return "+"
433
+ elif weight == -1.0:
434
+ return "-"
435
+ elif weight > 0:
436
+ return f"+{int(weight)}"
437
+ else:
438
+ return str(int(weight))
439
+
440
+ # Parse inputs to separate text and image queries
441
+ text_queries = []
442
+ text_weights = []
443
+ image_queries = []
444
+ image_weights = []
445
+
446
+ for i, query_type in enumerate(query_types):
447
+ operation = operations[i]
448
+ weight = operation_to_weight(operation)
449
+
450
+ if query_type == "text":
451
+ text_value = text_values[i]
452
+ if text_value and text_value.strip():
453
+ text_queries.append(text_value.strip())
454
+ text_weights.append(weight)
455
+ else: # image
456
+ ra = ra_values[i]
457
+ dec = dec_values[i]
458
+ fov = fov_values[i] if fov_values[i] else 0.025
459
+
460
+ if ra is not None and dec is not None:
461
+ image_queries.append({
462
+ 'ra': float(ra),
463
+ 'dec': float(dec),
464
+ 'fov': float(fov)
465
+ })
466
+ image_weights.append(weight)
467
+
468
+ # Validate that we have at least one query
469
+ if not text_queries and not image_queries:
470
+ return "", dbc.Alert("Please enter at least one text or image query", color="warning"), None, True
471
+
472
+ try:
473
+ # Extract min and max from slider range
474
+ rmag_min, rmag_max = rmag_range if rmag_range else (None, None)
475
+
476
+ # Perform advanced search
477
+ start_time = time.time()
478
+ df = search_service.search_advanced(
479
+ text_queries=text_queries if text_queries else None,
480
+ text_weights=text_weights if text_weights else None,
481
+ image_queries=image_queries if image_queries else None,
482
+ image_weights=image_weights if image_weights else None,
483
+ rmag_min=rmag_min,
484
+ rmag_max=rmag_max
485
+ )
486
+ search_time = time.time() - start_time
487
+
488
+ # Log query to XML/CSV
489
+ from src.utils import build_query_xml, log_query_to_csv
490
+ query_xml = build_query_xml(
491
+ text_queries=text_queries if text_queries else None,
492
+ text_weights=text_weights if text_weights else None,
493
+ image_queries=image_queries if image_queries else None,
494
+ image_weights=image_weights if image_weights else None,
495
+ rmag_min=rmag_min,
496
+ rmag_max=rmag_max
497
+ )
498
+ log_query_to_csv(query_xml)
499
+
500
+ # Build results grid
501
+ grid_items = build_galaxy_grid(df.head(DEFAULT_DISPLAY_COUNT))
502
+
503
+ # Build query description for storage (simple text)
504
+ query_desc_parts = []
505
+ for query, weight in zip(text_queries, text_weights):
506
+ op_display = weight_to_display(weight)
507
+ query_desc_parts.append(f"{op_display} text:'{query}'")
508
+ for img_query, weight in zip(image_queries, image_weights):
509
+ op_display = weight_to_display(weight)
510
+ query_desc_parts.append(f"{op_display} image:(RA={img_query['ra']:.2f}, Dec={img_query['dec']:.2f})")
511
+ query_description = " ".join(query_desc_parts)
512
+
513
+ # Build query display with thumbnails for images
514
+ query_display_parts = []
515
+ for query, weight in zip(text_queries, text_weights):
516
+ op_display = weight_to_display(weight)
517
+ query_display_parts.append(html.Span(f"{op_display} text:'{query}' ", style={"margin-right": "8px"}))
518
+
519
+ for img_query, weight in zip(image_queries, image_weights):
520
+ op_display = weight_to_display(weight)
521
+ # Generate thumbnail URL
522
+ from src.utils import cutout_url
523
+ thumbnail_url = cutout_url(
524
+ img_query['ra'],
525
+ img_query['dec'],
526
+ fov=img_query.get('fov', 0.025),
527
+ size=64
528
+ )
529
+ query_display_parts.append(html.Span([
530
+ f"{op_display} ",
531
+ html.Img(
532
+ src=thumbnail_url,
533
+ style={
534
+ "width": "128px",
535
+ "height": "128px",
536
+ "vertical-align": "middle",
537
+ "margin": "0 4px",
538
+ "border-radius": "4px",
539
+ "border": "1px solid rgba(255, 255, 255, 0.2)"
540
+ }
541
+ )
542
+ ], style={"margin-right": "8px", "display": "inline-block"}))
543
+
544
+ # Build filter description
545
+ filter_desc = ""
546
+ if rmag_min is not None and rmag_max is not None and (rmag_min != 13.0 or rmag_max != 20.0):
547
+ filter_desc = f" + r-mag: [{rmag_min:.1f}, {rmag_max:.1f}]"
548
+
549
+ # Prepare data for store
550
+ search_data = prepare_search_data(df, query_description, is_vector_search=True)
551
+ search_data["text_queries"] = text_queries
552
+ search_data["text_weights"] = text_weights
553
+ search_data["image_queries"] = image_queries
554
+ search_data["image_weights"] = image_weights
555
+
556
+ # Create load more button
557
+ load_more_button = create_load_more_button(len(df), DEFAULT_DISPLAY_COUNT) if len(df) > DEFAULT_DISPLAY_COUNT else None
558
+
559
+ # Build results container
560
+ results_container = html.Div([
561
+ html.P(f"Top {len(df)} matching galaxies (showing {min(DEFAULT_DISPLAY_COUNT, len(df))})",
562
+ className="results-header mb-2 text-center"),
563
+ html.P(
564
+ query_display_parts + ([f"{filter_desc}"] if filter_desc else []),
565
+ className="text-center mb-3",
566
+ style={"color": "rgba(245, 245, 247, 0.6)", "font-size": "0.9rem"}
567
+ ),
568
+ dbc.Row(grid_items, justify="center", id="search-results-grid"),
569
+ load_more_button
570
+ ])
571
+
572
+ return "", results_container, search_data, False
573
+
574
+ except Exception as e:
575
+ error_msg = dbc.Alert(f"Advanced search failed: {str(e)}", color="danger")
576
+ logger.error(f"Advanced search error: {e}")
577
+ logger.error(f"Full traceback:\n{traceback.format_exc()}")
578
+ return "", error_msg, None, True
579
+
580
+ @app.callback(
581
+ Output("download-csv", "data"),
582
+ Input("download-button", "n_clicks"),
583
+ State("search-data", "data"),
584
+ prevent_initial_call=True
585
+ )
586
+ def download_csv(n_clicks, search_data):
587
+ """Download search results as CSV."""
588
+ if n_clicks and search_data:
589
+ # Create DataFrame with the search results
590
+ df = pd.DataFrame({
591
+ ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY],
592
+ 'ra': search_data['ra'],
593
+ 'dec': search_data['dec'],
594
+ 'r_mag': search_data['r_mag'],
595
+ 'distance': search_data['distance'],
596
+ 'cutout_url': search_data['cutout_url']
597
+ })
598
+
599
+ # Create CSV string
600
+ csv_string = df.to_csv(index=False)
601
+
602
+ # Return download data
603
+ return dict(content=csv_string, filename="galaxy_search_results.csv")
604
+
605
+ return dash.no_update
606
+
607
+
608
+ # Helper functions for callbacks
609
+
610
+ def build_galaxy_grid(df: pd.DataFrame) -> list:
611
+ """Build galaxy grid items from DataFrame.
612
+
613
+ Args:
614
+ df: DataFrame with galaxy data
615
+
616
+ Returns:
617
+ List of Dash components
618
+ """
619
+ grid_items = []
620
+ for i, row in df.iterrows():
621
+ galaxy_info = {
622
+ ZILLIZ_PRIMARY_KEY: row[ZILLIZ_PRIMARY_KEY],
623
+ "ra": row['ra'],
624
+ "dec": row['dec'],
625
+ "distance": row['distance'],
626
+ "r_mag": row['r_mag'],
627
+ "cutout_url": row['cutout_url']
628
+ }
629
+ grid_item = build_galaxy_card(galaxy_info, i)
630
+ grid_items.append(grid_item)
631
+ return grid_items
632
+
633
+
634
+ def build_galaxy_card(galaxy_info: dict, index: int):
635
+ """Build a single galaxy card component.
636
+
637
+ Args:
638
+ galaxy_info: Dictionary with galaxy information
639
+ index: Index of the galaxy in the results
640
+
641
+ Returns:
642
+ Dash Bootstrap Col component
643
+ """
644
+ return dbc.Col([
645
+ html.Div([
646
+ html.Div([
647
+ html.Img(
648
+ src=galaxy_info["cutout_url"],
649
+ style={
650
+ "width": IMAGE_WIDTH,
651
+ "height": IMAGE_HEIGHT,
652
+ "object-fit": "cover",
653
+ "cursor": "pointer",
654
+ "border-radius": "8px"
655
+ },
656
+ id={"type": "galaxy-image", "index": index},
657
+ className="hover-shadow"
658
+ ),
659
+ html.Div([
660
+ html.Small(f"r = {galaxy_info['r_mag']:.2f} mag", className="score-badge")
661
+ ], style={
662
+ "position": "absolute",
663
+ "bottom": "8px",
664
+ "right": "8px"
665
+ })
666
+ ], style={"position": "relative"})
667
+ ])
668
+ ], width=6, md=4, lg=2, className="mb-2 px-1")
669
+
670
+
671
+ def prepare_search_data(df: pd.DataFrame, query: str, is_vector_search: bool = False) -> dict:
672
+ """Prepare search data for storage.
673
+
674
+ Args:
675
+ df: DataFrame with search results
676
+ query: Search query string
677
+ is_vector_search: Whether this is a vector search
678
+
679
+ Returns:
680
+ Dictionary with search data
681
+ """
682
+ return {
683
+ ZILLIZ_PRIMARY_KEY: df[ZILLIZ_PRIMARY_KEY].tolist(),
684
+ "ra": df['ra'].tolist(),
685
+ "dec": df['dec'].tolist(),
686
+ "distance": df['distance'].tolist(),
687
+ "r_mag": df['r_mag'].tolist(),
688
+ "cutout_url": df['cutout_url'].tolist(),
689
+ "loaded_count": DEFAULT_DISPLAY_COUNT,
690
+ "query": query,
691
+ "is_vector_search": is_vector_search
692
+ }
693
+
694
+
695
+ def extract_galaxy_info(search_data: dict, index: int) -> dict:
696
+ """Extract galaxy info from search data at given index.
697
+
698
+ Args:
699
+ search_data: Dictionary with search data
700
+ index: Index of the galaxy
701
+
702
+ Returns:
703
+ Dictionary with galaxy information
704
+ """
705
+ return {
706
+ ZILLIZ_PRIMARY_KEY: search_data[ZILLIZ_PRIMARY_KEY][index],
707
+ "ra": search_data["ra"][index],
708
+ "dec": search_data["dec"][index],
709
+ "distance": search_data["distance"][index],
710
+ "r_mag": search_data["r_mag"][index],
711
+ "cutout_url": search_data["cutout_url"][index]
712
+ }
713
+
714
+
715
+ def build_modal_content(galaxy_info: dict) -> tuple:
716
+ """Build modal image and description content.
717
+
718
+ Args:
719
+ galaxy_info: Dictionary with galaxy information
720
+
721
+ Returns:
722
+ Tuple of (image_element, description_element)
723
+ """
724
+ image_element = html.Img(
725
+ src=galaxy_info["cutout_url"],
726
+ style={"width": "100%", "max-width": "500px", "height": "auto"}
727
+ )
728
+
729
+ # Format primary key label (convert snake_case to Title Case)
730
+ pk_label = ZILLIZ_PRIMARY_KEY.replace("_", " ").title()
731
+
732
+ description_element = html.Div([
733
+ html.Div([
734
+ html.Span(f"{pk_label}: {galaxy_info[ZILLIZ_PRIMARY_KEY]}", className="d-inline-block mb-0",
735
+ style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
736
+ ], className="mb-2"),
737
+ html.Div([
738
+ html.Span(f"RA: {galaxy_info['ra']:.6f}", className="d-inline-block mb-0",
739
+ style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
740
+ html.Span(" β€’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}),
741
+ html.Span(f"Dec: {galaxy_info['dec']:.6f}", className="d-inline-block mb-0",
742
+ style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
743
+ ], className="mb-2"),
744
+ html.Div([
745
+ html.Span(f"r_mag: {galaxy_info['r_mag']:.2f}", className="d-inline-block mb-0",
746
+ style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
747
+ html.Span(" β€’ ", className="mx-2", style={"color": "rgba(245, 245, 247, 0.5)"}),
748
+ html.Span(f"Distance: {galaxy_info['distance']:.4f}", className="d-inline-block mb-0",
749
+ style={"color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"}),
750
+ ], className="mb-3"),
751
+ ])
752
+
753
+ return image_element, description_element
754
+
755
+
756
+ def create_load_more_button(total_count: int, current_count: int):
757
+ """Create a load more button.
758
+
759
+ Args:
760
+ total_count: Total number of results
761
+ current_count: Number of currently loaded results
762
+
763
+ Returns:
764
+ Dash Bootstrap Button component
765
+ """
766
+ remaining = total_count - current_count
767
+ button_text = f"Load next {min(LOAD_MORE_COUNT, remaining)} galaxies"
768
+
769
+ return dbc.Button(
770
+ button_text,
771
+ id="load-more-button",
772
+ color="secondary",
773
+ className="mt-3",
774
+ style={"width": "100%"}
775
+ )
src/components.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UI components for AION Search."""
2
+
3
+ from dash import dcc, html
4
+ import dash_bootstrap_components as dbc
5
+ from src.config import TOTAL_GALAXIES
6
+
7
+
8
+ def get_app_theme() -> str:
9
+ """Get the custom CSS theme for the app.
10
+
11
+ Returns:
12
+ HTML string with embedded CSS
13
+ """
14
+ return '''
15
+ <!DOCTYPE html>
16
+ <html>
17
+ <head>
18
+ {%metas%}
19
+ <title>galaxy semantic search</title>
20
+ {%favicon%}
21
+ {%css%}
22
+ <style>
23
+ @import url('https://fonts.googleapis.com/css2?family=SF+Pro+Display:wght@200;300;400;500;600&display=swap');
24
+
25
+ * {
26
+ -webkit-font-smoothing: antialiased;
27
+ -moz-osx-font-smoothing: grayscale;
28
+ }
29
+
30
+ body {
31
+ font-family: -apple-system, BlinkMacSystemFont, 'SF Pro Display', 'Inter', sans-serif;
32
+ background: #000000;
33
+ color: #F5F5F7;
34
+ min-height: 100vh;
35
+ margin: 0;
36
+ overflow-x: hidden;
37
+ }
38
+
39
+ body::before {
40
+ content: '';
41
+ position: fixed;
42
+ top: -50%;
43
+ left: -50%;
44
+ width: 200%;
45
+ height: 200%;
46
+ background: radial-gradient(circle at 20% 80%, #1C1C1E 0%, transparent 50%),
47
+ radial-gradient(circle at 80% 20%, #161618 0%, transparent 50%),
48
+ radial-gradient(circle at 40% 40%, #0A0A0B 0%, transparent 50%);
49
+ z-index: -1;
50
+ }
51
+
52
+ .container-fluid {
53
+ background-color: transparent !important;
54
+ padding-top: 2rem !important;
55
+ }
56
+
57
+ .hover-shadow {
58
+ transition: all 0.4s cubic-bezier(0.25, 0.46, 0.45, 0.94);
59
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
60
+ background: #0A0A0B;
61
+ overflow: hidden;
62
+ position: relative;
63
+ }
64
+
65
+ .hover-shadow::before {
66
+ content: '';
67
+ position: absolute;
68
+ top: 0;
69
+ left: 0;
70
+ right: 0;
71
+ bottom: 0;
72
+ background: linear-gradient(135deg, rgba(255,255,255,0.05) 0%, transparent 100%);
73
+ opacity: 0;
74
+ transition: opacity 0.4s ease;
75
+ }
76
+
77
+ .hover-shadow:hover {
78
+ transform: translateY(-4px) scale(1.02);
79
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.8),
80
+ 0 0 60px rgba(255, 255, 255, 0.05) !important;
81
+ border-color: rgba(255, 255, 255, 0.2);
82
+ }
83
+
84
+ .hover-shadow:hover::before {
85
+ opacity: 1;
86
+ }
87
+
88
+ .search-container {
89
+ background: rgba(255, 255, 255, 0.05);
90
+ backdrop-filter: blur(40px) saturate(180%);
91
+ -webkit-backdrop-filter: blur(40px) saturate(180%);
92
+ border-radius: 16px;
93
+ padding: 1.25rem;
94
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
95
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4),
96
+ inset 0 1px 0 rgba(255, 255, 255, 0.1);
97
+ }
98
+
99
+ .example-button {
100
+ background: #E5E5E7 !important;
101
+ background-color: #E5E5E7 !important;
102
+ border: 0.5px solid #D1D1D3 !important;
103
+ color: #1A1A1A !important;
104
+ font-weight: 500;
105
+ font-size: 0.75rem !important;
106
+ padding: 0.4rem 0.9rem !important;
107
+ transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
108
+ letter-spacing: 0.01em;
109
+ }
110
+
111
+ .example-button:hover {
112
+ background: #F0F0F2 !important;
113
+ background-color: #F0F0F2 !important;
114
+ border-color: #C0C0C2 !important;
115
+ color: #000000 !important;
116
+ transform: translateY(-1px);
117
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
118
+ }
119
+
120
+ .example-button i {
121
+ color: #2A2A2A !important;
122
+ }
123
+
124
+ .galaxy-title {
125
+ color: #F5F5F7;
126
+ font-weight: 200;
127
+ font-size: 1.75rem;
128
+ letter-spacing: -0.03em;
129
+ background: linear-gradient(180deg, #F5F5F7 0%, rgba(245, 245, 247, 0.6) 100%);
130
+ -webkit-background-clip: text;
131
+ -webkit-text-fill-color: transparent;
132
+ animation: float 6s ease-in-out infinite;
133
+ }
134
+
135
+ .modal-content {
136
+ background: #1C1C1E;
137
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
138
+ border-radius: 16px;
139
+ backdrop-filter: blur(20px);
140
+ }
141
+
142
+ .modal-header, .modal-footer {
143
+ border-color: rgba(255, 255, 255, 0.05);
144
+ }
145
+
146
+ .form-control:focus, .form-control:active {
147
+ background-color: rgba(255, 255, 255, 0.05) !important;
148
+ border-color: rgba(255, 255, 255, 0.3) !important;
149
+ box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.05) !important;
150
+ color: #F5F5F7 !important;
151
+ }
152
+
153
+ .form-control {
154
+ background-color: rgba(255, 255, 255, 0.03) !important;
155
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
156
+ color: #F5F5F7 !important;
157
+ font-size: 0.95rem !important;
158
+ font-weight: 300;
159
+ letter-spacing: 0.01em;
160
+ }
161
+
162
+ .form-control::placeholder {
163
+ color: rgba(245, 245, 247, 0.4) !important;
164
+ }
165
+
166
+ .btn-primary {
167
+ background: rgba(255, 255, 255, 0.8);
168
+ color: #000;
169
+ border: none;
170
+ font-weight: 600;
171
+ font-size: 0.9rem;
172
+ padding: 0.6rem 1.8rem;
173
+ transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
174
+ letter-spacing: 0.02em;
175
+ }
176
+
177
+ .btn-primary:hover {
178
+ background: rgba(255, 255, 255, 0.95);
179
+ transform: translateY(-1px);
180
+ box-shadow: 0 8px 24px rgba(255, 255, 255, 0.15);
181
+ }
182
+
183
+ .results-header {
184
+ color: rgba(245, 245, 247, 0.6);
185
+ font-weight: 300;
186
+ font-size: 0.85rem !important;
187
+ letter-spacing: 0.05em;
188
+ text-transform: uppercase;
189
+ }
190
+
191
+ .time-breakdown {
192
+ color: rgba(245, 245, 247, 0.4);
193
+ font-size: 0.7rem;
194
+ font-weight: 300;
195
+ letter-spacing: 0.02em;
196
+ }
197
+
198
+ .galaxy-count {
199
+ color: rgba(245, 245, 247, 0.5);
200
+ font-weight: 300;
201
+ font-size: 0.85rem;
202
+ letter-spacing: 0.05em;
203
+ text-transform: uppercase;
204
+ }
205
+
206
+ .score-badge {
207
+ background: rgba(255, 255, 255, 0.1);
208
+ backdrop-filter: blur(10px);
209
+ color: rgba(245, 245, 247, 0.9);
210
+ font-size: 0.65rem !important;
211
+ padding: 3px 8px !important;
212
+ border-radius: 6px;
213
+ font-weight: 500;
214
+ letter-spacing: 0.02em;
215
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
216
+ }
217
+
218
+ .info-button {
219
+ color: rgba(245, 245, 247, 0.5) !important;
220
+ font-size: 0.75rem !important;
221
+ opacity: 0.8;
222
+ transition: all 0.3s ease;
223
+ letter-spacing: 0.02em;
224
+ }
225
+
226
+ .info-button:hover {
227
+ opacity: 1;
228
+ color: #F5F5F7 !important;
229
+ }
230
+
231
+ ::-webkit-scrollbar {
232
+ width: 8px;
233
+ }
234
+
235
+ ::-webkit-scrollbar-track {
236
+ background: rgba(255, 255, 255, 0.02);
237
+ }
238
+
239
+ ::-webkit-scrollbar-thumb {
240
+ background: rgba(255, 255, 255, 0.1);
241
+ border-radius: 4px;
242
+ }
243
+
244
+ ::-webkit-scrollbar-thumb:hover {
245
+ background: rgba(255, 255, 255, 0.2);
246
+ }
247
+
248
+ .btn-link {
249
+ text-decoration: none !important;
250
+ }
251
+
252
+ .input-group-text {
253
+ background: rgba(255, 255, 255, 0.03) !important;
254
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
255
+ color: rgba(245, 245, 247, 0.5) !important;
256
+ }
257
+
258
+ .spinner-border {
259
+ color: rgba(245, 245, 247, 0.5) !important;
260
+ }
261
+
262
+ @supports (backdrop-filter: blur(40px)) {
263
+ .search-container {
264
+ background: rgba(255, 255, 255, 0.03);
265
+ }
266
+ }
267
+
268
+ @keyframes float {
269
+ 0%, 100% { transform: translateY(0px); }
270
+ 50% { transform: translateY(-3px); }
271
+ }
272
+
273
+ .download-button {
274
+ background: rgba(255, 255, 255, 0.05);
275
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
276
+ color: rgba(245, 245, 247, 0.6) !important;
277
+ font-size: 0.75rem !important;
278
+ padding: 0.4rem 0.8rem !important;
279
+ transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
280
+ letter-spacing: 0.01em;
281
+ margin-left: 0.5rem;
282
+ }
283
+
284
+ .download-button:hover {
285
+ background: rgba(255, 255, 255, 0.08) !important;
286
+ border-color: rgba(255, 255, 255, 0.15) !important;
287
+ color: rgba(245, 245, 247, 0.8) !important;
288
+ transform: translateY(-1px);
289
+ }
290
+
291
+ .download-button i {
292
+ color: rgba(245, 245, 247, 0.6) !important;
293
+ }
294
+
295
+ .download-button:hover i {
296
+ color: rgba(245, 245, 247, 0.8) !important;
297
+ }
298
+
299
+ .refinement-toggle {
300
+ background: rgba(255, 255, 255, 0.03);
301
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
302
+ color: rgba(245, 245, 247, 0.5);
303
+ font-size: 0.7rem;
304
+ padding: 0.5rem 0.75rem;
305
+ transition: all 0.3s ease;
306
+ cursor: pointer;
307
+ display: flex;
308
+ align-items: center;
309
+ justify-content: center;
310
+ gap: 0.4rem;
311
+ margin: 0;
312
+ letter-spacing: 0.05em;
313
+ text-transform: uppercase;
314
+ border-radius: 8px;
315
+ height: 100%;
316
+ }
317
+
318
+ .refinement-toggle:hover {
319
+ background: rgba(255, 255, 255, 0.05);
320
+ border-color: rgba(255, 255, 255, 0.15);
321
+ color: rgba(245, 245, 247, 0.8);
322
+ transform: translateY(-1px);
323
+ }
324
+
325
+ .refinement-toggle i {
326
+ transition: transform 0.3s ease;
327
+ font-size: 0.65rem;
328
+ }
329
+
330
+ .refinement-toggle.expanded i {
331
+ transform: rotate(180deg);
332
+ }
333
+
334
+ .refinement-container {
335
+ background: rgba(255, 255, 255, 0.03);
336
+ border: 0.5px solid rgba(255, 255, 255, 0.1);
337
+ border-radius: 12px;
338
+ padding: 1.5rem;
339
+ margin-top: 1rem;
340
+ backdrop-filter: blur(20px);
341
+ }
342
+
343
+ .refinement-label {
344
+ color: rgba(245, 245, 247, 0.6);
345
+ font-size: 0.85rem;
346
+ font-weight: 300;
347
+ margin-bottom: 0.5rem;
348
+ }
349
+
350
+ .vector-operation-select {
351
+ background-color: rgba(255, 255, 255, 0.03) !important;
352
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
353
+ color: #F5F5F7 !important;
354
+ font-size: 0.9rem !important;
355
+ font-weight: 500;
356
+ text-align: center;
357
+ }
358
+
359
+ .vector-operation-select option {
360
+ background-color: #1C1C1E;
361
+ color: #F5F5F7;
362
+ }
363
+
364
+ .vector-query-type-select {
365
+ background-color: rgba(255, 255, 255, 0.03) !important;
366
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
367
+ color: #F5F5F7 !important;
368
+ font-size: 0.9rem !important;
369
+ font-weight: 500;
370
+ }
371
+
372
+ .vector-query-type-select option {
373
+ background-color: #1C1C1E;
374
+ color: #F5F5F7;
375
+ }
376
+
377
+ .btn-add-vector {
378
+ background: rgba(255, 255, 255, 0.05);
379
+ border: 0.5px solid rgba(255, 255, 255, 0.1) !important;
380
+ color: rgba(245, 245, 247, 0.6) !important;
381
+ font-size: 0.8rem !important;
382
+ padding: 0.4rem 0.8rem !important;
383
+ transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94);
384
+ letter-spacing: 0.01em;
385
+ }
386
+
387
+ .btn-add-vector:hover {
388
+ background: rgba(255, 255, 255, 0.08) !important;
389
+ border-color: rgba(255, 255, 255, 0.15) !important;
390
+ color: rgba(245, 245, 247, 0.8) !important;
391
+ transform: translateY(-1px);
392
+ }
393
+
394
+ .btn-add-vector i {
395
+ color: rgba(245, 245, 247, 0.6) !important;
396
+ }
397
+
398
+ .btn-add-vector:hover i {
399
+ color: rgba(245, 245, 247, 0.8) !important;
400
+ }
401
+
402
+ .vector-delete-btn {
403
+ opacity: 0.5;
404
+ transition: opacity 0.2s ease;
405
+ padding: 0.25rem 0.5rem !important;
406
+ border: none !important;
407
+ background: none !important;
408
+ }
409
+
410
+ .vector-delete-btn:hover {
411
+ opacity: 1;
412
+ background: rgba(220, 53, 69, 0.1) !important;
413
+ border-radius: 4px;
414
+ }
415
+
416
+ .vector-delete-btn i {
417
+ font-size: 0.9rem;
418
+ }
419
+
420
+ /* Range slider styling */
421
+ .rmag-slider .rc-slider-rail {
422
+ background-color: rgba(255, 255, 255, 0.1);
423
+ height: 4px;
424
+ }
425
+
426
+ .rmag-slider .rc-slider-track {
427
+ background: linear-gradient(90deg, rgba(255, 255, 255, 0.6) 0%, rgba(255, 255, 255, 0.8) 100%);
428
+ height: 4px;
429
+ }
430
+
431
+ .rmag-slider .rc-slider-handle {
432
+ border: 2px solid rgba(255, 255, 255, 0.8);
433
+ background-color: #F5F5F7;
434
+ opacity: 1;
435
+ width: 16px;
436
+ height: 16px;
437
+ margin-top: -6px;
438
+ }
439
+
440
+ .rmag-slider .rc-slider-handle:hover,
441
+ .rmag-slider .rc-slider-handle:active,
442
+ .rmag-slider .rc-slider-handle:focus {
443
+ border-color: rgba(255, 255, 255, 0.95);
444
+ box-shadow: 0 0 0 5px rgba(255, 255, 255, 0.1);
445
+ }
446
+
447
+ .rmag-slider .rc-slider-mark-text {
448
+ color: rgba(245, 245, 247, 0.5);
449
+ font-size: 0.75rem;
450
+ font-weight: 300;
451
+ }
452
+
453
+ .rmag-slider .rc-slider-tooltip-inner {
454
+ background-color: rgba(255, 255, 255, 0.9);
455
+ color: #000;
456
+ font-size: 0.75rem;
457
+ font-weight: 500;
458
+ padding: 4px 8px;
459
+ border-radius: 4px;
460
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
461
+ }
462
+
463
+ .rmag-slider .rc-slider-tooltip-arrow {
464
+ border-top-color: rgba(255, 255, 255, 0.9);
465
+ }
466
+ </style>
467
+ </head>
468
+ <body>
469
+ {%app_entry%}
470
+ <footer>
471
+ {%config%}
472
+ {%scripts%}
473
+ {%renderer%}
474
+ </footer>
475
+ </body>
476
+ </html>
477
+ '''
478
+
479
+
480
+ def create_header():
481
+ """Create the app header with title and galaxy count."""
482
+ return dbc.Row([
483
+ dbc.Col([
484
+ html.Div([
485
+ html.H1("galaxy semantic search", className="galaxy-title text-center mb-1"),
486
+ html.Div(id="galaxy-count", className="galaxy-count text-center")
487
+ ], className="text-center mb-3")
488
+ ])
489
+ ])
490
+
491
+
492
+ def create_rmag_filter_panel():
493
+ """Create the r_mag filter panel."""
494
+ return dbc.Row([
495
+ dbc.Col([
496
+ html.Div([
497
+ dbc.Row([
498
+ dbc.Col([
499
+ html.Div("r-mag",
500
+ style={"color": "rgba(245, 245, 247, 0.6)",
501
+ "font-size": "0.85rem",
502
+ "font-weight": "300",
503
+ "text-align": "center"})
504
+ ], width=1, className="d-flex align-items-center justify-content-center", style={"padding": "0"}),
505
+ dbc.Col([
506
+ dcc.RangeSlider(
507
+ id="rmag-slider",
508
+ min=13.0,
509
+ max=20.0,
510
+ step=0.1,
511
+ value=[13.0, 20.0],
512
+ marks={13: '13', 15: '15', 17: '17', 19: '19', 20: '20'},
513
+ tooltip={"placement": "bottom", "always_visible": True},
514
+ className="rmag-slider"
515
+ )
516
+ ], width=11)
517
+ ], className="align-items-center", style={"margin": "0"})
518
+ ], style={"padding": "0.5rem 0"})
519
+ ], width=12)
520
+ ], className="mt-2")
521
+
522
+
523
+ def create_search_container():
524
+ """Create the main search input container with examples and input."""
525
+ return dbc.Row([
526
+ dbc.Col([
527
+ html.Div([
528
+ # Info button in top right
529
+ html.Div([
530
+ dbc.Button([
531
+ html.I(className="fas fa-info-circle")
532
+ ], id="info-button", color="link", size="sm",
533
+ className="info-button")
534
+ ], style={"position": "absolute", "top": "8px", "right": "8px", "z-index": "1000"}),
535
+
536
+ # Example search buttons
537
+ html.Div([
538
+ html.P("Try these examples:", className="text-center mb-2",
539
+ style={"color": "rgba(245, 245, 247, 0.5)", "font-weight": "300",
540
+ "font-size": "0.75rem", "letter-spacing": "0.02em"}),
541
+ html.Div([
542
+ dbc.Button([html.I(className="fas fa-satellite me-2"), "Merging edge-on galaxy"],
543
+ id="example-1", className="example-button me-2 mb-2", size="sm", color="light"),
544
+ dbc.Button([html.I(className="fas fa-water me-2"), "Tidal"],
545
+ id="example-2", className="example-button me-2 mb-2", size="sm", color="light"),
546
+ dbc.Button([html.I(className="fas fa-stream me-2"), "Stream"],
547
+ id="example-3", className="example-button me-2 mb-2", size="sm", color="light"),
548
+ dbc.Button([html.I(className="fas fa-glasses me-2"), "Gravitational lens"],
549
+ id="example-4", className="example-button me-2 mb-2", size="sm", color="light"),
550
+ dbc.Button([html.I(className="fas fa-explosion me-2"), "A violent merger"],
551
+ id="example-5", className="example-button me-2 mb-2", size="sm", color="light"),
552
+ dbc.Button([html.I(className="fas fa-moon me-2"), "Low surface brightness"],
553
+ id="example-6", className="example-button me-2 mb-2", size="sm", color="light"),
554
+ dbc.Button([html.I(className="fas fa-ring me-2"), "Ring galaxy"],
555
+ id="example-7", className="example-button mb-2", size="sm", color="light")
556
+ ], className="text-center")
557
+ ], className="mb-3"),
558
+
559
+ # Search input
560
+ dbc.InputGroup([
561
+ dbc.InputGroupText(html.I(className="fas fa-search")),
562
+ dbc.Input(
563
+ id="search-input",
564
+ placeholder="Describe the galaxy you're looking for...",
565
+ type="text",
566
+ n_submit=0
567
+ ),
568
+ dbc.Button("Search",
569
+ id="search-button", color="primary", n_clicks=0),
570
+ dbc.Button([
571
+ html.I(className="fas fa-download")
572
+ ], id="download-button", color="secondary", n_clicks=0,
573
+ className="download-button", size="sm",
574
+ disabled=True)
575
+ ])
576
+ ], className="search-container", style={"position": "relative"}),
577
+
578
+ # r_mag filter
579
+ create_rmag_filter_panel(),
580
+
581
+ # Vector Addition toggle button
582
+ dbc.Row([
583
+ dbc.Col([
584
+ html.Button([
585
+ html.I(className="fas fa-chevron-down", id="vector-arrow"),
586
+ "Advanced Search (Vector Addition / Images)"
587
+ ], id="vector-toggle", className="refinement-toggle w-100")
588
+ ], width=12)
589
+ ], className="mt-3"),
590
+
591
+ # Vector Addition UI - Collapsible section
592
+ create_vector_addition_panel()
593
+ ], width=12, lg=11, className="mx-auto")
594
+ ], className="mb-3")
595
+
596
+
597
+ def create_vector_addition_panel():
598
+ """Create the advanced search (vector addition) collapsible panel."""
599
+ return dbc.Collapse([
600
+ html.Div([
601
+ html.P("Advanced Search: Combine multiple text and/or image queries using vector addition/subtraction:", className="refinement-label"),
602
+ html.Div(id="vector-inputs", children=[
603
+ # Initial input
604
+ create_vector_input_row(0)
605
+ ]),
606
+ dbc.Row([
607
+ dbc.Col([
608
+ dbc.Button(
609
+ [html.I(className="fas fa-plus me-2"), "Add Query"],
610
+ id="add-vector-input",
611
+ color="secondary",
612
+ size="sm",
613
+ className="me-2 btn-add-vector"
614
+ )
615
+ ], width=6),
616
+ dbc.Col([
617
+ dbc.Button(
618
+ "Advanced Search",
619
+ id="vector-search-button",
620
+ className="btn-primary w-100",
621
+ n_clicks=0
622
+ )
623
+ ], width=6)
624
+ ], className="mt-3")
625
+ ], className="refinement-container")
626
+ ], id="vector-collapse", is_open=False)
627
+
628
+
629
+ def create_vector_input_row(index: int, query_type: str = "text", ra: float = None, dec: float = None, fov: float = 0.025):
630
+ """Create a single vector input row with operation selector, query type toggle, and conditional inputs.
631
+
632
+ Args:
633
+ index: Index of the vector input row
634
+ query_type: Type of query - "text" or "image" (default: "text")
635
+ ra: Initial RA value for image queries (default: None)
636
+ dec: Initial Dec value for image queries (default: None)
637
+ fov: Initial FoV value for image queries (default: 0.025)
638
+
639
+ Returns:
640
+ Dash Bootstrap Row component with text/image mode toggle
641
+ """
642
+ # Determine display styles based on query type
643
+ text_display = {"display": "block"} if query_type == "text" else {"display": "none"}
644
+ image_display = {"display": "none"} if query_type == "text" else {"display": "block"}
645
+
646
+ return dbc.Row([
647
+ # Operation column with magnitude support
648
+ dbc.Col([
649
+ dbc.Select(
650
+ id={"type": "vector-operation", "index": index},
651
+ options=[
652
+ {"label": "+10", "value": "+10"},
653
+ {"label": "+5", "value": "+5"},
654
+ {"label": "+2", "value": "+2"},
655
+ {"label": "+", "value": "+"},
656
+ {"label": "-", "value": "-"},
657
+ {"label": "-2", "value": "-2"},
658
+ {"label": "-5", "value": "-5"},
659
+ {"label": "-10", "value": "-10"}
660
+ ],
661
+ value="+",
662
+ style={"width": "70px"},
663
+ className="d-inline-block vector-operation-select"
664
+ )
665
+ ], width=1),
666
+ # Query type toggle (Text/Image)
667
+ dbc.Col([
668
+ dbc.Select(
669
+ id={"type": "vector-query-type", "index": index},
670
+ options=[
671
+ {"label": "Text", "value": "text"},
672
+ {"label": "Image", "value": "image"}
673
+ ],
674
+ value=query_type,
675
+ style={"width": "100px"},
676
+ className="d-inline-block vector-query-type-select"
677
+ )
678
+ ], width=2),
679
+ # Input area (text or image fields)
680
+ dbc.Col([
681
+ # Text input (shown when type is "text")
682
+ html.Div([
683
+ dbc.Input(
684
+ id={"type": "vector-text", "index": index},
685
+ placeholder="Enter text query...",
686
+ type="text"
687
+ )
688
+ ], id={"type": "text-input-container", "index": index}, style=text_display),
689
+ # Image inputs (shown when type is "image")
690
+ html.Div([
691
+ dbc.Row([
692
+ dbc.Col([
693
+ dbc.Input(
694
+ id={"type": "vector-ra", "index": index},
695
+ placeholder="ra:",
696
+ type="number",
697
+ step="any",
698
+ value=ra
699
+ )
700
+ ], width=4),
701
+ dbc.Col([
702
+ dbc.Input(
703
+ id={"type": "vector-dec", "index": index},
704
+ placeholder="dec:",
705
+ type="number",
706
+ step="any",
707
+ value=dec
708
+ )
709
+ ], width=4),
710
+ dbc.Col([
711
+ dbc.Input(
712
+ id={"type": "vector-fov", "index": index},
713
+ placeholder="fov:",
714
+ type="number",
715
+ value=fov,
716
+ step="any"
717
+ )
718
+ ], width=4)
719
+ ])
720
+ ], id={"type": "image-input-container", "index": index}, style=image_display)
721
+ ], width=8),
722
+ # Delete button
723
+ dbc.Col([
724
+ dbc.Button(
725
+ html.I(className="fas fa-times"),
726
+ id={"type": "vector-delete", "index": index},
727
+ color="link",
728
+ size="sm",
729
+ className="text-danger vector-delete-btn",
730
+ style={"padding": "0.25rem 0.5rem"}
731
+ )
732
+ ], width=1, className="d-flex align-items-center justify-content-end")
733
+ ], className="mb-2", id={"type": "vector-row", "index": index})
734
+
735
+
736
+ def create_results_container():
737
+ """Create the search results display container."""
738
+ return dbc.Row([
739
+ dbc.Col([
740
+ html.Div(id="search-time", className="time-breakdown text-center mb-2"),
741
+ html.Div(id="search-results")
742
+ ])
743
+ ])
744
+
745
+
746
+ def create_stores():
747
+ """Create Dash Store components for data persistence."""
748
+ return [
749
+ dcc.Store(id="search-data"),
750
+ dcc.Store(id="current-galaxy-data"),
751
+ dcc.Store(id="vector-inputs-count", data=1),
752
+ dcc.Download(id="download-csv")
753
+ ]
754
+
755
+
756
+ def create_galaxy_modal():
757
+ """Create the modal for displaying galaxy details."""
758
+ return dbc.Modal([
759
+ dbc.ModalHeader(dbc.ModalTitle(id="modal-title")),
760
+ dbc.ModalBody([
761
+ html.Div(id="modal-image", className="text-center mb-3"),
762
+ html.Div(id="modal-description")
763
+ ]),
764
+ dbc.ModalFooter([
765
+ dbc.Button(
766
+ [html.I(className="fas fa-plus-circle me-2"), "Add to Advanced Search"],
767
+ id="add-to-advanced-search",
768
+ color="primary",
769
+ className="me-2"
770
+ ),
771
+ dbc.Button("Close", id="close-modal", className="ms-auto")
772
+ ])
773
+ ], id="galaxy-modal", size="lg", is_open=False)
774
+
775
+
776
+ def create_info_modal():
777
+ """Create the info modal explaining the app."""
778
+ return dbc.Modal([
779
+ dbc.ModalHeader(dbc.ModalTitle([html.I(className="fas fa-info-circle me-2"), "About Galaxy Search"])),
780
+ dbc.ModalBody([
781
+ html.P("This app performs semantic search over galaxy images using CLIP embeddings and BigQuery.",
782
+ style={"color": "rgba(245, 245, 247, 0.8)", "margin-bottom": "1rem", "font-size": "0.9rem"}),
783
+ html.Div([
784
+ html.P("The search uses contrastive language-image pre-training (CLIP) to match text descriptions with galaxy images. "
785
+ "The model was trained on galaxy descriptions and can understand various astronomical features and characteristics.",
786
+ style={"margin-bottom": "1rem", "color": "rgba(245, 245, 247, 0.7)"}),
787
+
788
+ html.H6("Search Tips:", style={"color": "#F5F5F7", "font-weight": "500", "margin-bottom": "0.5rem"}),
789
+ html.Ul([
790
+ html.Li("Describe morphological features (spiral, elliptical, irregular, merging)",
791
+ style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
792
+ html.Li("Mention specific features (tidal tails, dust lanes, star-forming regions)",
793
+ style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
794
+ html.Li("Use color descriptions or brightness characteristics",
795
+ style={"color": "rgba(245, 245, 247, 0.6)", "margin-bottom": "0.3rem"}),
796
+ html.Li("Combine multiple features for more specific results",
797
+ style={"color": "rgba(245, 245, 247, 0.6)"}),
798
+ ], style={"margin-left": "1rem"}),
799
+ ], style={"background": "rgba(255, 255, 255, 0.05)", "padding": "1.5rem", "border-radius": "12px",
800
+ "border": "0.5px solid rgba(255, 255, 255, 0.1)", "color": "rgba(245, 245, 247, 0.7)", "font-size": "0.9rem"})
801
+ ]),
802
+ dbc.ModalFooter(
803
+ dbc.Button("Close", id="close-info-modal", className="ms-auto")
804
+ )
805
+ ], id="info-modal", size="lg", is_open=False)
806
+
807
+
808
+ def create_layout():
809
+ """Create the complete app layout.
810
+
811
+ Returns:
812
+ Dash Container with the full app layout
813
+ """
814
+ return dbc.Container([
815
+ create_header(),
816
+ create_search_container(),
817
+ create_results_container(),
818
+ *create_stores(),
819
+ create_galaxy_modal(),
820
+ create_info_modal()
821
+ ], fluid=True, className="py-2")
src/config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings, environment variables, and constants."""
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ # Environment Variables
9
+ ZILLIZ_BEARER = os.getenv("ZILLIZ_BEARER")
10
+ ZILLIZ_ENDPOINT = os.getenv("ZILLIZ_ENDPOINT")
11
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
12
+
13
+ # App Constants
14
+ # Note: TOTAL_GALAXIES is dynamically updated from Zilliz at startup (see app.py)
15
+ # This is just a fallback default value
16
+ TOTAL_GALAXIES = 0
17
+ DEFAULT_TOP_K = 300
18
+ DEFAULT_DISPLAY_COUNT = 60
19
+ LOAD_MORE_COUNT = 120
20
+
21
+ # Zilliz Configuration
22
+ ZILLIZ_COLLECTION_NAME = "aionsearch"
23
+ # Image search always uses legacy collection which has pre-existing embeddings
24
+ ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME = ZILLIZ_COLLECTION_NAME
25
+
26
+ # Collection-specific configurations
27
+ COLLECTION_CONFIGS = {
28
+ "legacy_5": {
29
+ "anns_field": "aion_search_embedding",
30
+ "primary_key": "object_id",
31
+ "output_fields": ["object_id", "ra", "dec", "r_mag"]
32
+ },
33
+ "aionsearch": {
34
+ "anns_field": "clip_embedding",
35
+ "primary_key": "ra_dec",
36
+ "output_fields": ["ra_dec", "ra", "dec", "r_mag"]
37
+ }
38
+ }
39
+
40
+ # Get configuration for the selected collection
41
+ _collection_config = COLLECTION_CONFIGS.get(ZILLIZ_COLLECTION_NAME, COLLECTION_CONFIGS[ZILLIZ_COLLECTION_NAME])
42
+ ZILLIZ_ANNS_FIELD = _collection_config["anns_field"]
43
+ ZILLIZ_PRIMARY_KEY = _collection_config["primary_key"]
44
+ ZILLIZ_OUTPUT_FIELDS = _collection_config["output_fields"]
45
+
46
+ # OpenAI Configuration
47
+ OPENAI_EMBEDDING_MODEL = "text-embedding-3-large"
48
+
49
+ # CLIP Model Configuration
50
+ CLIP_EMBEDDING_DIM = 1024
51
+ CLIP_NORMALIZE_EPS = 1e-3
52
+
53
+ # UI Configuration
54
+ IMAGE_HEIGHT = "160px"
55
+ IMAGE_WIDTH = "100%"
56
+ CUTOUT_FOV = 0.025
57
+ CUTOUT_SIZE = 256
58
+
59
+ # Logging Configuration
60
+ VCU_COST_PER_MILLION = 4.0 # $4 per 1 million vCU
61
+
62
+ # Feature Flags (for future features)
63
+ FEATURE_IMAGE_SEARCH = False
64
+ FEATURE_AUTH = False
65
+ FEATURE_CACHE = False
66
+ FEATURE_RERANKING = False
67
+ FEATURE_TRACKING = True
68
+ FEATURE_VECTOR_ADDITION = True
src/services.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend services for AION Search."""
2
+
3
+ import time
4
+ import logging
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import pandas as pd
9
+ import requests
10
+ from typing import List
11
+ from openai import OpenAI
12
+
13
+ from src.config import (
14
+ ZILLIZ_BEARER,
15
+ ZILLIZ_ENDPOINT,
16
+ ZILLIZ_COLLECTION_NAME,
17
+ ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
18
+ ZILLIZ_ANNS_FIELD,
19
+ ZILLIZ_PRIMARY_KEY,
20
+ ZILLIZ_OUTPUT_FIELDS,
21
+ COLLECTION_CONFIGS,
22
+ OPENAI_API_KEY,
23
+ OPENAI_EMBEDDING_MODEL,
24
+ CLIP_NORMALIZE_EPS,
25
+ DEFAULT_TOP_K,
26
+ )
27
+ from src.utils import cutout_url, log_zilliz_query
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class CLIPModelService:
33
+ """Service for managing CLIP model loading and inference."""
34
+
35
+ def __init__(self):
36
+ self.model = None
37
+ self.device = None
38
+ self.loaded = False
39
+
40
+ def load_model(self, checkpoint_path: str) -> None:
41
+ """Load the CLIP model from checkpoint.
42
+
43
+ Args:
44
+ checkpoint_path: Path to the CLIP model checkpoint file
45
+ """
46
+ logger.info(f"Loading CLIP model from {checkpoint_path}...")
47
+
48
+ from clip.models.clip_model import GalaxyClipModel
49
+
50
+ # Set device
51
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
+
53
+ # Load checkpoint
54
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
55
+ model_config = checkpoint['model_config']
56
+
57
+ # Initialize model with saved configuration
58
+ self.model = GalaxyClipModel(
59
+ image_input_dim=model_config['image_input_dim'],
60
+ text_input_dim=model_config['text_input_dim'],
61
+ embedding_dim=model_config['embedding_dim'],
62
+ use_mean_embeddings=model_config.get('use_mean_embeddings', True)
63
+ )
64
+
65
+ self.model.load_state_dict(checkpoint['model_state_dict'])
66
+ self.model.to(self.device)
67
+ self.model.eval()
68
+ self.loaded = True
69
+
70
+ logger.info("CLIP model loaded successfully")
71
+
72
+ def encode_text(self, text_embedding: np.ndarray) -> np.ndarray:
73
+ """Project text embedding through CLIP text projector.
74
+
75
+ Args:
76
+ text_embedding: OpenAI text embedding (1536-dim)
77
+
78
+ Returns:
79
+ CLIP-projected embedding (1024-dim)
80
+ """
81
+ if not self.loaded:
82
+ raise RuntimeError("CLIP model not loaded. Call load_model() first.")
83
+
84
+ with torch.no_grad():
85
+ text_tensor = torch.from_numpy(text_embedding).float().unsqueeze(0).to(self.device)
86
+ clip_features = self.model.text_projector(text_tensor)
87
+ # Normalize as per CLIP
88
+ clip_features = F.normalize(clip_features, dim=-1, eps=CLIP_NORMALIZE_EPS)
89
+ query_embedding = clip_features.cpu().numpy().squeeze(0)
90
+
91
+ return query_embedding
92
+
93
+
94
+ class ImageProcessingService:
95
+ """Service for retrieving pre-existing image embeddings from Zilliz."""
96
+
97
+ def __init__(self):
98
+ pass
99
+
100
+ def encode_image(self, ra: float, dec: float, fov: float = 0.025, size: int = 256) -> np.ndarray:
101
+ """Query Zilliz for pre-existing embedding at the given coordinates.
102
+
103
+ Args:
104
+ ra: Right ascension in degrees
105
+ dec: Declination in degrees
106
+ fov: Field of view in degrees (used to define search box)
107
+ size: Image size in pixels (unused, kept for API compatibility)
108
+
109
+ Returns:
110
+ Pre-existing AION-Search embedding vector (1024-dim) from Zilliz
111
+ """
112
+ logger.info(f"Querying Zilliz for pre-existing embedding at RA={ra}, Dec={dec}, FoV={fov}")
113
+
114
+ # Calculate bounding box based on field of view
115
+ ra_min = ra - fov/2
116
+ ra_max = ra + fov/2
117
+ dec_min = dec - fov/2
118
+ dec_max = dec + fov/2
119
+
120
+ # Build filter expression for coordinate range
121
+ filter_expr = f"ra > {ra_min} AND ra < {ra_max} AND dec > {dec_min} AND dec < {dec_max}"
122
+
123
+ # Get the ANNS field for the image search collection
124
+ image_search_config = COLLECTION_CONFIGS.get(ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME)
125
+ image_anns_field = image_search_config["anns_field"]
126
+
127
+ # Prepare query payload - always use the image search collection (legacy)
128
+ payload = {
129
+ "collectionName": ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
130
+ "filter": filter_expr,
131
+ "outputFields": [image_anns_field],
132
+ "limit": 1
133
+ }
134
+
135
+ headers = {
136
+ "Authorization": f"Bearer {ZILLIZ_BEARER}",
137
+ "Accept": "application/json",
138
+ "Content-Type": "application/json"
139
+ }
140
+
141
+ try:
142
+ # Use query endpoint (replace /search with /query)
143
+ query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
144
+ response = requests.post(query_endpoint, json=payload, headers=headers)
145
+ response.raise_for_status()
146
+
147
+ result = response.json()
148
+
149
+ if result.get("code") == 0 and "data" in result:
150
+ data = result["data"]
151
+ if data and len(data) > 0:
152
+ # Extract the embedding from the first result using the image search ANNS field
153
+ embedding = data[0].get(image_anns_field)
154
+ if embedding:
155
+ embedding_array = np.array(embedding, dtype=np.float32)
156
+ logger.info(f"Retrieved pre-existing embedding with shape: {embedding_array.shape}")
157
+ return embedding_array
158
+ else:
159
+ logger.error(f"No embedding field found in result: {data[0].keys()}")
160
+ raise RuntimeError(f"No embedding found at coordinates RA={ra}, Dec={dec}")
161
+ else:
162
+ logger.error(f"No galaxies found at coordinates RA={ra}, Dec={dec} with FoV={fov}")
163
+ raise RuntimeError(f"No galaxies found at coordinates RA={ra}, Dec={dec}")
164
+ else:
165
+ logger.error(f"Zilliz query failed: {result}")
166
+ raise RuntimeError(f"Failed to query Zilliz: {result}")
167
+
168
+ except Exception as e:
169
+ logger.error(f"Error querying Zilliz for embedding: {e}")
170
+ raise
171
+
172
+
173
+ class EmbeddingService:
174
+ """Service for encoding text queries into embeddings."""
175
+
176
+ def __init__(self, clip_service: CLIPModelService):
177
+ self.clip_service = clip_service
178
+ self.openai_client = None
179
+
180
+ def _get_openai_client(self) -> OpenAI:
181
+ """Get or create OpenAI client."""
182
+ if self.openai_client is None:
183
+ if not OPENAI_API_KEY:
184
+ raise ValueError("OPENAI_API_KEY environment variable not set")
185
+ self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
186
+ return self.openai_client
187
+
188
+ def encode_text_query(self, query: str) -> np.ndarray:
189
+ """Encode text query using OpenAI embeddings + CLIP text projector.
190
+
191
+ Args:
192
+ query: Text search query
193
+
194
+ Returns:
195
+ CLIP embedding vector
196
+ """
197
+ client = self._get_openai_client()
198
+
199
+ # Get OpenAI text embedding
200
+ response = client.embeddings.create(
201
+ input=query,
202
+ model=OPENAI_EMBEDDING_MODEL
203
+ )
204
+ text_embedding = np.array(response.data[0].embedding)
205
+
206
+ # Project through CLIP text projector
207
+ return self.clip_service.encode_text(text_embedding)
208
+
209
+ def encode_vector_queries(
210
+ self,
211
+ queries: List[str],
212
+ operations: List[str]
213
+ ) -> np.ndarray:
214
+ """Encode multiple text queries and combine them using vector addition/subtraction.
215
+
216
+ Args:
217
+ queries: List of text queries
218
+ operations: List of operations ('+' or '-') for each query
219
+
220
+ Returns:
221
+ Combined normalized embedding vector
222
+ """
223
+ client = self._get_openai_client()
224
+
225
+ # Get all embeddings at once for efficiency
226
+ response = client.embeddings.create(
227
+ input=queries,
228
+ model=OPENAI_EMBEDDING_MODEL
229
+ )
230
+
231
+ # Initialize combined embedding
232
+ combined_embedding = None
233
+
234
+ # Process each embedding with its operation
235
+ for embedding_data, operation in zip(response.data, operations):
236
+ text_embedding = np.array(embedding_data.embedding)
237
+
238
+ # Project through CLIP text projector
239
+ query_embedding = self.clip_service.encode_text(text_embedding)
240
+
241
+ # Apply operation
242
+ if combined_embedding is None:
243
+ combined_embedding = query_embedding if operation == "+" else -query_embedding
244
+ else:
245
+ if operation == "+":
246
+ combined_embedding += query_embedding
247
+ else:
248
+ combined_embedding -= query_embedding
249
+
250
+ # Normalize the final combined embedding
251
+ norm = np.linalg.norm(combined_embedding)
252
+ if norm > 0:
253
+ combined_embedding = combined_embedding / norm
254
+
255
+ return combined_embedding
256
+
257
+
258
+ class ZillizService:
259
+ """Service for interacting with Zilliz vector database."""
260
+
261
+ def get_collection_count(self) -> int:
262
+ """Get the total number of entities in the collection.
263
+
264
+ Returns:
265
+ Total count of entities in the collection
266
+ """
267
+ logger.info("Getting collection count from Zilliz...")
268
+
269
+ # Use query endpoint with count to get total entities
270
+ payload = {
271
+ "collectionName": ZILLIZ_COLLECTION_NAME,
272
+ "filter": "", # Empty filter to count all entities
273
+ "outputFields": ["count(*)"]
274
+ }
275
+
276
+ headers = {
277
+ "Authorization": f"Bearer {ZILLIZ_BEARER}",
278
+ "Accept": "application/json",
279
+ "Content-Type": "application/json"
280
+ }
281
+
282
+ try:
283
+ # Use the query endpoint (replace /search with /query in the endpoint)
284
+ query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
285
+ response = requests.post(query_endpoint, json=payload, headers=headers)
286
+ response.raise_for_status()
287
+
288
+ result = response.json()
289
+
290
+ if result.get("code") == 0 and "data" in result:
291
+ # The count should be in the response data
292
+ data = result["data"]
293
+ if data and len(data) > 0:
294
+ count = data[0].get("count(*)", 0)
295
+ logger.info(f"Collection count: {count:,}")
296
+ return count
297
+ else:
298
+ logger.error(f"Failed to get collection count: {result}")
299
+ return 0
300
+
301
+ except Exception as e:
302
+ logger.error(f"Error getting collection count: {e}")
303
+ return 0
304
+
305
+ def search(self, query_embedding: np.ndarray, top_k: int = DEFAULT_TOP_K, filter_expr: str = None) -> pd.DataFrame:
306
+ """Search Zilliz for top-k most similar galaxies.
307
+
308
+ Args:
309
+ query_embedding: Query embedding vector
310
+ top_k: Number of results to return
311
+ filter_expr: Optional filter expression for filtering results
312
+
313
+ Returns:
314
+ DataFrame with search results
315
+ """
316
+ logger.info("Querying Zilliz...")
317
+ start_time = time.time()
318
+
319
+ # Prepare the search payload
320
+ payload = {
321
+ "collectionName": ZILLIZ_COLLECTION_NAME,
322
+ "data": [query_embedding.tolist()],
323
+ "annsField": ZILLIZ_ANNS_FIELD,
324
+ "limit": top_k,
325
+ "outputFields": ZILLIZ_OUTPUT_FIELDS
326
+ }
327
+
328
+ # Add filter if provided
329
+ if filter_expr:
330
+ payload["filter"] = filter_expr
331
+ logger.info(f"Applying filter: {filter_expr}")
332
+
333
+ headers = {
334
+ "Authorization": f"Bearer {ZILLIZ_BEARER}",
335
+ "Accept": "application/json",
336
+ "Content-Type": "application/json"
337
+ }
338
+
339
+ try:
340
+ response = requests.post(ZILLIZ_ENDPOINT, json=payload, headers=headers)
341
+ response.raise_for_status()
342
+
343
+ result = response.json()
344
+
345
+ if result.get("code") == 0 and "data" in result:
346
+ # Extract cost from response
347
+ cost_vcu = result.get("cost", 0)
348
+
349
+ # Convert to DataFrame
350
+ data_list = result["data"]
351
+ df = pd.DataFrame(data_list)
352
+
353
+ # Add cutout URLs
354
+ if not df.empty:
355
+ df["cutout_url"] = [cutout_url(ra, dec) for ra, dec in zip(df["ra"], df["dec"])]
356
+
357
+ query_time = time.time() - start_time
358
+
359
+ # Log the query
360
+ log_zilliz_query(
361
+ query_type="vector_search",
362
+ query_info={
363
+ "top_k": top_k,
364
+ "embedding_dim": len(query_embedding)
365
+ },
366
+ result_count=len(df),
367
+ query_time=query_time,
368
+ cost_vcu=cost_vcu
369
+ )
370
+
371
+ return df
372
+ else:
373
+ logger.error(f"Zilliz search failed: {result}")
374
+ return pd.DataFrame()
375
+
376
+ except Exception as e:
377
+ logger.error(f"Zilliz search error: {e}")
378
+ return pd.DataFrame()
379
+
380
+
381
+ class SearchService:
382
+ """High-level search orchestration service."""
383
+
384
+ def __init__(
385
+ self,
386
+ embedding_service: EmbeddingService,
387
+ zilliz_service: ZillizService,
388
+ image_service: 'ImageProcessingService' = None
389
+ ):
390
+ self.embedding_service = embedding_service
391
+ self.zilliz_service = zilliz_service
392
+ self.image_service = image_service
393
+
394
+ def _build_rmag_filter(self, rmag_min=None, rmag_max=None) -> str:
395
+ """Build r_mag filter expression.
396
+
397
+ Args:
398
+ rmag_min: Minimum r_mag value (inclusive)
399
+ rmag_max: Maximum r_mag value (inclusive)
400
+
401
+ Returns:
402
+ Filter expression string, or None if no filter
403
+ """
404
+ filter_parts = []
405
+
406
+ if rmag_min is not None:
407
+ filter_parts.append(f"r_mag >= {rmag_min}")
408
+
409
+ if rmag_max is not None:
410
+ filter_parts.append(f"r_mag <= {rmag_max}")
411
+
412
+ if filter_parts:
413
+ return " AND ".join(filter_parts)
414
+
415
+ return None
416
+
417
+ def search_text(self, query: str, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None) -> pd.DataFrame:
418
+ """Search galaxies using text query.
419
+
420
+ Args:
421
+ query: Text search query
422
+ top_k: Number of results to return
423
+ rmag_min: Minimum r_mag value (inclusive)
424
+ rmag_max: Maximum r_mag value (inclusive)
425
+
426
+ Returns:
427
+ DataFrame with search results
428
+ """
429
+ # Encode query
430
+ query_embedding = self.embedding_service.encode_text_query(query)
431
+
432
+ # Build filter
433
+ filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
434
+
435
+ # Search Zilliz
436
+ return self.zilliz_service.search(query_embedding, top_k, filter_expr)
437
+
438
+ def search_vector(
439
+ self,
440
+ queries: List[str],
441
+ operations: List[str],
442
+ top_k: int = DEFAULT_TOP_K,
443
+ rmag_min=None,
444
+ rmag_max=None
445
+ ) -> pd.DataFrame:
446
+ """Search galaxies using vector addition/subtraction.
447
+
448
+ Args:
449
+ queries: List of text queries
450
+ operations: List of operations ('+' or '-') for each query
451
+ top_k: Number of results to return
452
+ rmag_min: Minimum r_mag value (inclusive)
453
+ rmag_max: Maximum r_mag value (inclusive)
454
+
455
+ Returns:
456
+ DataFrame with search results
457
+ """
458
+ # Encode and combine vectors
459
+ combined_embedding = self.embedding_service.encode_vector_queries(queries, operations)
460
+
461
+ # Build filter
462
+ filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
463
+
464
+ # Search Zilliz
465
+ return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
466
+
467
+ def search_advanced(
468
+ self,
469
+ text_queries: List[str] = None,
470
+ text_weights: List[float] = None,
471
+ image_queries: List[dict] = None,
472
+ image_weights: List[float] = None,
473
+ top_k: int = DEFAULT_TOP_K,
474
+ rmag_min=None,
475
+ rmag_max=None
476
+ ) -> pd.DataFrame:
477
+ """Search galaxies using advanced vector addition/subtraction with text and/or images.
478
+
479
+ Args:
480
+ text_queries: List of text query strings
481
+ text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
482
+ image_queries: List of dicts with 'ra', 'dec', 'fov' keys
483
+ image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
484
+ top_k: Number of results to return
485
+ rmag_min: Minimum r_mag value (inclusive)
486
+ rmag_max: Maximum r_mag value (inclusive)
487
+
488
+ Returns:
489
+ DataFrame with search results
490
+ """
491
+ combined_embedding = None
492
+
493
+ # Process text queries
494
+ if text_queries and len(text_queries) > 0:
495
+ for query, weight in zip(text_queries, text_weights):
496
+ query_embedding = self.embedding_service.encode_text_query(query)
497
+
498
+ # Apply weight
499
+ weighted_embedding = query_embedding * weight
500
+
501
+ if combined_embedding is None:
502
+ combined_embedding = weighted_embedding
503
+ else:
504
+ combined_embedding += weighted_embedding
505
+
506
+ # Process image queries
507
+ if image_queries and len(image_queries) > 0:
508
+ if self.image_service is None:
509
+ raise RuntimeError("Image service not initialized")
510
+
511
+ for img_query, weight in zip(image_queries, image_weights):
512
+ # Encode image
513
+ image_embedding = self.image_service.encode_image(
514
+ ra=img_query['ra'],
515
+ dec=img_query['dec'],
516
+ fov=img_query.get('fov', 0.025),
517
+ size=256
518
+ )
519
+
520
+ # Apply weight
521
+ weighted_embedding = image_embedding * weight
522
+
523
+ if combined_embedding is None:
524
+ combined_embedding = weighted_embedding
525
+ else:
526
+ combined_embedding += weighted_embedding
527
+
528
+ # Normalize the final combined embedding
529
+ if combined_embedding is not None:
530
+ norm = np.linalg.norm(combined_embedding)
531
+ if norm > 0:
532
+ combined_embedding = combined_embedding / norm
533
+
534
+ # Build filter
535
+ filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
536
+
537
+ # Search Zilliz
538
+ return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
src/utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for AION Search."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+ from typing import Dict, Any
8
+ from src.config import CUTOUT_FOV, CUTOUT_SIZE, VCU_COST_PER_MILLION
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def cutout_url(ra: float, dec: float, fov: float = CUTOUT_FOV, size: int = CUTOUT_SIZE) -> str:
14
+ """Generate Legacy Survey cutout URL from RA/Dec coordinates.
15
+
16
+ Args:
17
+ ra: Right Ascension in degrees
18
+ dec: Declination in degrees
19
+ fov: Field of view in degrees
20
+ size: Image size in pixels
21
+
22
+ Returns:
23
+ URL string for the cutout image
24
+ """
25
+ return (
26
+ f"https://alasky.cds.unistra.fr/hips-image-services/hips2fits"
27
+ f"?hips=CDS/P/DESI-Legacy-Surveys/DR10/color"
28
+ f"&ra={ra}&dec={dec}&fov={fov}&width={size}&height={size}&format=jpg"
29
+ )
30
+
31
+
32
+ def log_zilliz_query(
33
+ query_type: str,
34
+ query_info: Dict[str, Any],
35
+ result_count: int,
36
+ query_time: float,
37
+ cost_vcu: int = 0
38
+ ) -> None:
39
+ """Log Zilliz queries to a file in logs/ directory.
40
+
41
+ Args:
42
+ query_type: Type of query (e.g., "vector_search", "text_search")
43
+ query_info: Dictionary containing query details
44
+ result_count: Number of results returned
45
+ query_time: Query execution time in seconds
46
+ cost_vcu: Cost in vCU units
47
+ """
48
+ logs_dir = Path("logs")
49
+ logs_dir.mkdir(parents=True, exist_ok=True)
50
+
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
52
+ log_file = logs_dir / f"zilliz_query_{timestamp}.json"
53
+
54
+ # Convert vCU cost to dollars
55
+ cost_usd = (cost_vcu / 1e6) * VCU_COST_PER_MILLION
56
+
57
+ log_data = {
58
+ "timestamp": datetime.now().isoformat(),
59
+ "query_type": query_type,
60
+ "query_info": query_info,
61
+ "result_count": result_count,
62
+ "query_time_seconds": query_time,
63
+ "cost_vCU": cost_vcu,
64
+ "cost_usd": cost_usd
65
+ }
66
+
67
+ with open(log_file, 'w') as f:
68
+ json.dump(log_data, f, indent=2)
69
+
70
+ logger.info(
71
+ f"Query logged to {log_file} | {result_count} results in {query_time:.3f}s | "
72
+ f"{cost_vcu} vCU (${cost_usd:.6f})"
73
+ )
74
+
75
+
76
+ def format_galaxy_count(count: int) -> str:
77
+ """Format galaxy count with thousands separator.
78
+
79
+ Args:
80
+ count: Number of galaxies
81
+
82
+ Returns:
83
+ Formatted string (e.g., "259,636 galaxies")
84
+ """
85
+ return f"{count:,} galaxies"
86
+
87
+
88
+ def build_query_xml(
89
+ text_queries: list = None,
90
+ text_weights: list = None,
91
+ image_queries: list = None,
92
+ image_weights: list = None,
93
+ rmag_min: float = None,
94
+ rmag_max: float = None
95
+ ) -> str:
96
+ """Build XML representation of a query according to aql.md specification.
97
+
98
+ Args:
99
+ text_queries: List of text query strings
100
+ text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
101
+ image_queries: List of dicts with 'ra', 'dec', 'fov' keys
102
+ image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
103
+ rmag_min: Minimum r_mag filter value
104
+ rmag_max: Maximum r_mag filter value
105
+
106
+ Returns:
107
+ XML string representation of the query
108
+ """
109
+ xml_parts = ['<query>']
110
+
111
+ # Add text queries
112
+ if text_queries and len(text_queries) > 0:
113
+ xml_parts.append(' <text>')
114
+ for query, weight in zip(text_queries, text_weights):
115
+ xml_parts.append(' <term>')
116
+ xml_parts.append(f' <weight>{weight}</weight>')
117
+ xml_parts.append(f' <content>{query}</content>')
118
+ xml_parts.append(' </term>')
119
+ xml_parts.append(' </text>')
120
+
121
+ # Add image queries
122
+ if image_queries and len(image_queries) > 0:
123
+ xml_parts.append(' <image>')
124
+ for img_query, weight in zip(image_queries, image_weights):
125
+ xml_parts.append(' <reference>')
126
+ xml_parts.append(f' <ra>{img_query["ra"]}</ra>')
127
+ xml_parts.append(f' <dec>{img_query["dec"]}</dec>')
128
+ xml_parts.append(f' <fov>{img_query["fov"]}</fov>')
129
+ xml_parts.append(f' <weight>{weight}</weight>')
130
+ xml_parts.append(' </reference>')
131
+ xml_parts.append(' </image>')
132
+
133
+ # Add filters
134
+ if rmag_min is not None or rmag_max is not None:
135
+ xml_parts.append(' <filters>')
136
+ if rmag_min is not None and rmag_max is not None:
137
+ xml_parts.append(' <filter>')
138
+ xml_parts.append(' <column>r_mag</column>')
139
+ xml_parts.append(' <operator>between</operator>')
140
+ xml_parts.append(f' <value_min>{rmag_min}</value_min>')
141
+ xml_parts.append(f' <value_max>{rmag_max}</value_max>')
142
+ xml_parts.append(' </filter>')
143
+ elif rmag_min is not None:
144
+ xml_parts.append(' <filter>')
145
+ xml_parts.append(' <column>r_mag</column>')
146
+ xml_parts.append(' <operator>gte</operator>')
147
+ xml_parts.append(f' <value>{rmag_min}</value>')
148
+ xml_parts.append(' </filter>')
149
+ elif rmag_max is not None:
150
+ xml_parts.append(' <filter>')
151
+ xml_parts.append(' <column>r_mag</column>')
152
+ xml_parts.append(' <operator>lte</operator>')
153
+ xml_parts.append(f' <value>{rmag_max}</value>')
154
+ xml_parts.append(' </filter>')
155
+ xml_parts.append(' </filters>')
156
+
157
+ xml_parts.append('</query>')
158
+ return '\n'.join(xml_parts)
159
+
160
+
161
+ def log_query_to_csv(
162
+ query_xml: str,
163
+ csv_path: str = "logs/query_log.csv"
164
+ ) -> None:
165
+ """Log a query to CSV file with datetime and XML string.
166
+
167
+ Args:
168
+ query_xml: XML string representation of the query
169
+ csv_path: Path to the CSV log file
170
+ """
171
+ import csv
172
+ import os
173
+
174
+ # Create logs directory if it doesn't exist
175
+ log_dir = Path(csv_path).parent
176
+ log_dir.mkdir(parents=True, exist_ok=True)
177
+
178
+ # Prepare log entry
179
+ timestamp = datetime.now().isoformat()
180
+
181
+ # Check if file exists to determine if we need to write header
182
+ file_exists = Path(csv_path).exists()
183
+
184
+ # Append to CSV
185
+ with open(csv_path, 'a', newline='', encoding='utf-8') as f:
186
+ writer = csv.writer(f)
187
+
188
+ # Write header if file is new
189
+ if not file_exists:
190
+ writer.writerow(['datetime', 'query'])
191
+
192
+ # Write the query log
193
+ writer.writerow([timestamp, query_xml])
194
+
195
+ logger.info(f"Query logged to {csv_path}")