"""Backend services for AION Search.""" import time import logging import torch import torch.nn.functional as F import numpy as np import pandas as pd import requests from typing import List from openai import OpenAI from src.config import ( ZILLIZ_BEARER, ZILLIZ_ENDPOINT, ZILLIZ_COLLECTION_NAME, ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME, ZILLIZ_ANNS_FIELD, ZILLIZ_PRIMARY_KEY, ZILLIZ_OUTPUT_FIELDS, COLLECTION_CONFIGS, OPENAI_API_KEY, OPENAI_EMBEDDING_MODEL, CLIP_NORMALIZE_EPS, DEFAULT_TOP_K, ) from src.utils import cutout_url, log_zilliz_query logger = logging.getLogger(__name__) class CLIPModelService: """Service for managing CLIP model loading and inference.""" def __init__(self): self.model = None self.device = None self.loaded = False def load_model(self, checkpoint_path: str) -> None: """Load the CLIP model from checkpoint. Args: checkpoint_path: Path to the CLIP model checkpoint file """ logger.info(f"Loading CLIP model from {checkpoint_path}...") from clip.models.clip_model import GalaxyClipModel # Set device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) model_config = checkpoint['model_config'] # Initialize model with saved configuration self.model = GalaxyClipModel( image_input_dim=model_config['image_input_dim'], text_input_dim=model_config['text_input_dim'], embedding_dim=model_config['embedding_dim'], use_mean_embeddings=model_config.get('use_mean_embeddings', True) ) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() self.loaded = True logger.info("CLIP model loaded successfully") def encode_text(self, text_embedding: np.ndarray) -> np.ndarray: """Project text embedding through CLIP text projector. Args: text_embedding: OpenAI text embedding (1536-dim) Returns: CLIP-projected embedding (1024-dim) """ if not self.loaded: raise RuntimeError("CLIP model not loaded. Call load_model() first.") with torch.no_grad(): text_tensor = torch.from_numpy(text_embedding).float().unsqueeze(0).to(self.device) clip_features = self.model.text_projector(text_tensor) # Normalize as per CLIP clip_features = F.normalize(clip_features, dim=-1, eps=CLIP_NORMALIZE_EPS) query_embedding = clip_features.cpu().numpy().squeeze(0) return query_embedding class ImageProcessingService: """Service for retrieving pre-existing image embeddings from Zilliz.""" def __init__(self): pass def encode_image(self, ra: float, dec: float, fov: float = 0.025, size: int = 256) -> np.ndarray: """Query Zilliz for pre-existing embedding at the given coordinates. Args: ra: Right ascension in degrees dec: Declination in degrees fov: Field of view in degrees (used to define search box) size: Image size in pixels (unused, kept for API compatibility) Returns: Pre-existing AION-Search embedding vector (1024-dim) from Zilliz """ logger.info(f"Querying Zilliz for pre-existing embedding at RA={ra}, Dec={dec}, FoV={fov}") # Calculate bounding box based on field of view ra_min = ra - fov/2 ra_max = ra + fov/2 dec_min = dec - fov/2 dec_max = dec + fov/2 # Build filter expression for coordinate range filter_expr = f"ra > {ra_min} AND ra < {ra_max} AND dec > {dec_min} AND dec < {dec_max}" # Get the ANNS field for the image search collection image_search_config = COLLECTION_CONFIGS.get(ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME) image_anns_field = image_search_config["anns_field"] # Prepare query payload - always use the image search collection (legacy) payload = { "collectionName": ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME, "filter": filter_expr, "outputFields": [image_anns_field], "limit": 1 } headers = { "Authorization": f"Bearer {ZILLIZ_BEARER}", "Accept": "application/json", "Content-Type": "application/json" } try: # Use query endpoint (replace /search with /query) query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query") response = requests.post(query_endpoint, json=payload, headers=headers) response.raise_for_status() result = response.json() if result.get("code") == 0 and "data" in result: data = result["data"] if data and len(data) > 0: # Extract the embedding from the first result using the image search ANNS field embedding = data[0].get(image_anns_field) if embedding: embedding_array = np.array(embedding, dtype=np.float32) logger.info(f"Retrieved pre-existing embedding with shape: {embedding_array.shape}") return embedding_array else: logger.error(f"No embedding field found in result: {data[0].keys()}") raise RuntimeError(f"No embedding found at coordinates RA={ra}, Dec={dec}") else: logger.error(f"No galaxies found at coordinates RA={ra}, Dec={dec} with FoV={fov}") raise RuntimeError(f"No galaxies found at coordinates RA={ra}, Dec={dec}") else: logger.error(f"Zilliz query failed: {result}") raise RuntimeError(f"Failed to query Zilliz: {result}") except Exception as e: logger.error(f"Error querying Zilliz for embedding: {e}") raise class EmbeddingService: """Service for encoding text queries into embeddings.""" def __init__(self, clip_service: CLIPModelService): self.clip_service = clip_service self.openai_client = None def _get_openai_client(self) -> OpenAI: """Get or create OpenAI client.""" if self.openai_client is None: if not OPENAI_API_KEY: raise ValueError("OPENAI_API_KEY environment variable not set") self.openai_client = OpenAI(api_key=OPENAI_API_KEY) return self.openai_client def _moderate_content(self, text: str) -> bool: """Check if text content is appropriate using OpenAI Moderation API. Args: text: Text to moderate Returns: True if content is safe, False if flagged """ try: client = self._get_openai_client() response = client.moderations.create(input=text) # If any category is flagged, reject the content if response.results[0].flagged: logger.warning(f"Content moderation flagged input") return False return True except Exception as e: logger.error(f"Moderation API error: {e}") # On error, allow the content through (fail open) return True def encode_text_query(self, query: str) -> np.ndarray: """Encode text query using OpenAI embeddings + CLIP text projector. Args: query: Text search query Returns: CLIP embedding vector """ # Moderate content first if not self._moderate_content(query): raise ValueError("Content moderation filter triggered") client = self._get_openai_client() # Get OpenAI text embedding response = client.embeddings.create( input=query, model=OPENAI_EMBEDDING_MODEL ) text_embedding = np.array(response.data[0].embedding) # Project through CLIP text projector return self.clip_service.encode_text(text_embedding) def encode_vector_queries( self, queries: List[str], operations: List[str] ) -> np.ndarray: """Encode multiple text queries and combine them using vector addition/subtraction. Args: queries: List of text queries operations: List of operations ('+' or '-') for each query Returns: Combined normalized embedding vector """ # Moderate all queries first for query in queries: if not self._moderate_content(query): raise ValueError("Content moderation filter triggered") client = self._get_openai_client() # Get all embeddings at once for efficiency response = client.embeddings.create( input=queries, model=OPENAI_EMBEDDING_MODEL ) # Initialize combined embedding combined_embedding = None # Process each embedding with its operation for embedding_data, operation in zip(response.data, operations): text_embedding = np.array(embedding_data.embedding) # Project through CLIP text projector query_embedding = self.clip_service.encode_text(text_embedding) # Apply operation if combined_embedding is None: combined_embedding = query_embedding if operation == "+" else -query_embedding else: if operation == "+": combined_embedding += query_embedding else: combined_embedding -= query_embedding # Normalize the final combined embedding norm = np.linalg.norm(combined_embedding) if norm > 0: combined_embedding = combined_embedding / norm return combined_embedding class ZillizService: """Service for interacting with Zilliz vector database.""" def get_collection_count(self) -> int: """Get the total number of entities in the collection. Returns: Total count of entities in the collection """ logger.info("Getting collection count from Zilliz...") # Use query endpoint with count to get total entities payload = { "collectionName": ZILLIZ_COLLECTION_NAME, "filter": "", # Empty filter to count all entities "outputFields": ["count(*)"] } headers = { "Authorization": f"Bearer {ZILLIZ_BEARER}", "Accept": "application/json", "Content-Type": "application/json" } try: # Use the query endpoint (replace /search with /query in the endpoint) query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query") response = requests.post(query_endpoint, json=payload, headers=headers) response.raise_for_status() result = response.json() if result.get("code") == 0 and "data" in result: # The count should be in the response data data = result["data"] if data and len(data) > 0: count = data[0].get("count(*)", 0) logger.info(f"Collection count: {count:,}") return count else: logger.error(f"Failed to get collection count: {result}") return 0 except Exception as e: logger.error(f"Error getting collection count: {e}") return 0 def search(self, query_embedding: np.ndarray, top_k: int = DEFAULT_TOP_K, filter_expr: str = None) -> pd.DataFrame: """Search Zilliz for top-k most similar galaxies. Args: query_embedding: Query embedding vector top_k: Number of results to return filter_expr: Optional filter expression for filtering results Returns: DataFrame with search results """ logger.info("Querying Zilliz...") start_time = time.time() # Prepare the search payload payload = { "collectionName": ZILLIZ_COLLECTION_NAME, "data": [query_embedding.tolist()], "annsField": ZILLIZ_ANNS_FIELD, "limit": top_k, "outputFields": ZILLIZ_OUTPUT_FIELDS } # Add filter if provided if filter_expr: payload["filter"] = filter_expr logger.info(f"Applying filter: {filter_expr}") headers = { "Authorization": f"Bearer {ZILLIZ_BEARER}", "Accept": "application/json", "Content-Type": "application/json" } try: response = requests.post(ZILLIZ_ENDPOINT, json=payload, headers=headers) response.raise_for_status() result = response.json() if result.get("code") == 0 and "data" in result: # Extract cost from response cost_vcu = result.get("cost", 0) # Convert to DataFrame data_list = result["data"] df = pd.DataFrame(data_list) # Add cutout URLs if not df.empty: df["cutout_url"] = [cutout_url(ra, dec) for ra, dec in zip(df["ra"], df["dec"])] query_time = time.time() - start_time # Log the query log_zilliz_query( query_type="vector_search", query_info={ "top_k": top_k, "embedding_dim": len(query_embedding) }, result_count=len(df), query_time=query_time, cost_vcu=cost_vcu ) return df else: logger.error(f"Zilliz search failed: {result}") return pd.DataFrame() except Exception as e: logger.error(f"Zilliz search error: {e}") return pd.DataFrame() class SearchService: """High-level search orchestration service.""" def __init__( self, embedding_service: EmbeddingService, zilliz_service: ZillizService, image_service: 'ImageProcessingService' = None ): self.embedding_service = embedding_service self.zilliz_service = zilliz_service self.image_service = image_service def _build_rmag_filter(self, rmag_min=None, rmag_max=None) -> str: """Build r_mag filter expression. Args: rmag_min: Minimum r_mag value (inclusive) rmag_max: Maximum r_mag value (inclusive) Returns: Filter expression string, or None if no filter """ filter_parts = [] if rmag_min is not None: filter_parts.append(f"r_mag >= {rmag_min}") if rmag_max is not None: filter_parts.append(f"r_mag <= {rmag_max}") if filter_parts: return " AND ".join(filter_parts) return None def search_text(self, query: str, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None) -> pd.DataFrame: """Search galaxies using text query. Args: query: Text search query top_k: Number of results to return rmag_min: Minimum r_mag value (inclusive) rmag_max: Maximum r_mag value (inclusive) Returns: DataFrame with search results """ try: # Encode query query_embedding = self.embedding_service.encode_text_query(query) # Build filter filter_expr = self._build_rmag_filter(rmag_min, rmag_max) # Search Zilliz return self.zilliz_service.search(query_embedding, top_k, filter_expr) except ValueError as e: # Content moderation triggered - return empty results silently if "moderation" in str(e).lower(): logger.info("Search blocked by content moderation") return pd.DataFrame() raise def search_vector( self, queries: List[str], operations: List[str], top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None ) -> pd.DataFrame: """Search galaxies using vector addition/subtraction. Args: queries: List of text queries operations: List of operations ('+' or '-') for each query top_k: Number of results to return rmag_min: Minimum r_mag value (inclusive) rmag_max: Maximum r_mag value (inclusive) Returns: DataFrame with search results """ try: # Encode and combine vectors combined_embedding = self.embedding_service.encode_vector_queries(queries, operations) # Build filter filter_expr = self._build_rmag_filter(rmag_min, rmag_max) # Search Zilliz return self.zilliz_service.search(combined_embedding, top_k, filter_expr) except ValueError as e: # Content moderation triggered - return empty results silently if "moderation" in str(e).lower(): logger.info("Search blocked by content moderation") return pd.DataFrame() raise def search_advanced( self, text_queries: List[str] = None, text_weights: List[float] = None, image_queries: List[dict] = None, image_weights: List[float] = None, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None ) -> pd.DataFrame: """Search galaxies using advanced vector addition/subtraction with text and/or images. Args: text_queries: List of text query strings text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0) image_queries: List of dicts with 'ra', 'dec', 'fov' keys image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0) top_k: Number of results to return rmag_min: Minimum r_mag value (inclusive) rmag_max: Maximum r_mag value (inclusive) Returns: DataFrame with search results """ try: combined_embedding = None # Process text queries if text_queries and len(text_queries) > 0: for query, weight in zip(text_queries, text_weights): query_embedding = self.embedding_service.encode_text_query(query) # Apply weight weighted_embedding = query_embedding * weight if combined_embedding is None: combined_embedding = weighted_embedding else: combined_embedding += weighted_embedding # Process image queries if image_queries and len(image_queries) > 0: if self.image_service is None: raise RuntimeError("Image service not initialized") for img_query, weight in zip(image_queries, image_weights): # Encode image image_embedding = self.image_service.encode_image( ra=img_query['ra'], dec=img_query['dec'], fov=img_query.get('fov', 0.025), size=256 ) # Apply weight weighted_embedding = image_embedding * weight if combined_embedding is None: combined_embedding = weighted_embedding else: combined_embedding += weighted_embedding # Normalize the final combined embedding if combined_embedding is not None: norm = np.linalg.norm(combined_embedding) if norm > 0: combined_embedding = combined_embedding / norm # Build filter filter_expr = self._build_rmag_filter(rmag_min, rmag_max) # Search Zilliz return self.zilliz_service.search(combined_embedding, top_k, filter_expr) except ValueError as e: # Content moderation triggered - return empty results silently if "moderation" in str(e).lower(): logger.info("Search blocked by content moderation") return pd.DataFrame() raise