Spaces:
Running
Running
| """Backend services for AION Search.""" | |
| import time | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from typing import List | |
| from openai import OpenAI | |
| from src.config import ( | |
| ZILLIZ_BEARER, | |
| ZILLIZ_ENDPOINT, | |
| ZILLIZ_COLLECTION_NAME, | |
| ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME, | |
| ZILLIZ_ANNS_FIELD, | |
| ZILLIZ_PRIMARY_KEY, | |
| ZILLIZ_OUTPUT_FIELDS, | |
| COLLECTION_CONFIGS, | |
| OPENAI_API_KEY, | |
| OPENAI_EMBEDDING_MODEL, | |
| CLIP_NORMALIZE_EPS, | |
| DEFAULT_TOP_K, | |
| ) | |
| from src.utils import cutout_url, log_zilliz_query | |
| logger = logging.getLogger(__name__) | |
| class CLIPModelService: | |
| """Service for managing CLIP model loading and inference.""" | |
| def __init__(self): | |
| self.model = None | |
| self.device = None | |
| self.loaded = False | |
| def load_model(self, checkpoint_path: str) -> None: | |
| """Load the CLIP model from checkpoint. | |
| Args: | |
| checkpoint_path: Path to the CLIP model checkpoint file | |
| """ | |
| logger.info(f"Loading CLIP model from {checkpoint_path}...") | |
| from clip.models.clip_model import GalaxyClipModel | |
| # Set device | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) | |
| model_config = checkpoint['model_config'] | |
| # Initialize model with saved configuration | |
| self.model = GalaxyClipModel( | |
| image_input_dim=model_config['image_input_dim'], | |
| text_input_dim=model_config['text_input_dim'], | |
| embedding_dim=model_config['embedding_dim'], | |
| use_mean_embeddings=model_config.get('use_mean_embeddings', True) | |
| ) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.loaded = True | |
| logger.info("CLIP model loaded successfully") | |
| def encode_text(self, text_embedding: np.ndarray) -> np.ndarray: | |
| """Project text embedding through CLIP text projector. | |
| Args: | |
| text_embedding: OpenAI text embedding (1536-dim) | |
| Returns: | |
| CLIP-projected embedding (1024-dim) | |
| """ | |
| if not self.loaded: | |
| raise RuntimeError("CLIP model not loaded. Call load_model() first.") | |
| with torch.no_grad(): | |
| text_tensor = torch.from_numpy(text_embedding).float().unsqueeze(0).to(self.device) | |
| clip_features = self.model.text_projector(text_tensor) | |
| # Normalize as per CLIP | |
| clip_features = F.normalize(clip_features, dim=-1, eps=CLIP_NORMALIZE_EPS) | |
| query_embedding = clip_features.cpu().numpy().squeeze(0) | |
| return query_embedding | |
| class ImageProcessingService: | |
| """Service for retrieving pre-existing image embeddings from Zilliz.""" | |
| def __init__(self): | |
| pass | |
| def encode_image(self, ra: float, dec: float, fov: float = 0.025, size: int = 256) -> np.ndarray: | |
| """Query Zilliz for pre-existing embedding at the given coordinates. | |
| Args: | |
| ra: Right ascension in degrees | |
| dec: Declination in degrees | |
| fov: Field of view in degrees (used to define search box) | |
| size: Image size in pixels (unused, kept for API compatibility) | |
| Returns: | |
| Pre-existing AION-Search embedding vector (1024-dim) from Zilliz | |
| """ | |
| logger.info(f"Querying Zilliz for pre-existing embedding at RA={ra}, Dec={dec}, FoV={fov}") | |
| # Calculate bounding box based on field of view | |
| ra_min = ra - fov/2 | |
| ra_max = ra + fov/2 | |
| dec_min = dec - fov/2 | |
| dec_max = dec + fov/2 | |
| # Build filter expression for coordinate range | |
| filter_expr = f"ra > {ra_min} AND ra < {ra_max} AND dec > {dec_min} AND dec < {dec_max}" | |
| # Get the ANNS field for the image search collection | |
| image_search_config = COLLECTION_CONFIGS.get(ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME) | |
| image_anns_field = image_search_config["anns_field"] | |
| # Prepare query payload - always use the image search collection (legacy) | |
| payload = { | |
| "collectionName": ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME, | |
| "filter": filter_expr, | |
| "outputFields": [image_anns_field], | |
| "limit": 1 | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {ZILLIZ_BEARER}", | |
| "Accept": "application/json", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| # Use query endpoint (replace /search with /query) | |
| query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query") | |
| response = requests.post(query_endpoint, json=payload, headers=headers) | |
| response.raise_for_status() | |
| result = response.json() | |
| if result.get("code") == 0 and "data" in result: | |
| data = result["data"] | |
| if data and len(data) > 0: | |
| # Extract the embedding from the first result using the image search ANNS field | |
| embedding = data[0].get(image_anns_field) | |
| if embedding: | |
| embedding_array = np.array(embedding, dtype=np.float32) | |
| logger.info(f"Retrieved pre-existing embedding with shape: {embedding_array.shape}") | |
| return embedding_array | |
| else: | |
| logger.error(f"No embedding field found in result: {data[0].keys()}") | |
| raise RuntimeError(f"No embedding found at coordinates RA={ra}, Dec={dec}") | |
| else: | |
| logger.error(f"No galaxies found at coordinates RA={ra}, Dec={dec} with FoV={fov}") | |
| raise RuntimeError(f"No galaxies found at coordinates RA={ra}, Dec={dec}") | |
| else: | |
| logger.error(f"Zilliz query failed: {result}") | |
| raise RuntimeError(f"Failed to query Zilliz: {result}") | |
| except Exception as e: | |
| logger.error(f"Error querying Zilliz for embedding: {e}") | |
| raise | |
| class EmbeddingService: | |
| """Service for encoding text queries into embeddings.""" | |
| def __init__(self, clip_service: CLIPModelService): | |
| self.clip_service = clip_service | |
| self.openai_client = None | |
| def _get_openai_client(self) -> OpenAI: | |
| """Get or create OpenAI client.""" | |
| if self.openai_client is None: | |
| if not OPENAI_API_KEY: | |
| raise ValueError("OPENAI_API_KEY environment variable not set") | |
| self.openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| return self.openai_client | |
| def _moderate_content(self, text: str) -> bool: | |
| """Check if text content is appropriate using OpenAI Moderation API. | |
| Args: | |
| text: Text to moderate | |
| Returns: | |
| True if content is safe, False if flagged | |
| """ | |
| try: | |
| client = self._get_openai_client() | |
| response = client.moderations.create(input=text) | |
| # If any category is flagged, reject the content | |
| if response.results[0].flagged: | |
| logger.warning(f"Content moderation flagged input") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Moderation API error: {e}") | |
| # On error, allow the content through (fail open) | |
| return True | |
| def encode_text_query(self, query: str) -> np.ndarray: | |
| """Encode text query using OpenAI embeddings + CLIP text projector. | |
| Args: | |
| query: Text search query | |
| Returns: | |
| CLIP embedding vector | |
| """ | |
| # Moderate content first | |
| if not self._moderate_content(query): | |
| raise ValueError("Content moderation filter triggered") | |
| client = self._get_openai_client() | |
| # Get OpenAI text embedding | |
| response = client.embeddings.create( | |
| input=query, | |
| model=OPENAI_EMBEDDING_MODEL | |
| ) | |
| text_embedding = np.array(response.data[0].embedding) | |
| # Project through CLIP text projector | |
| return self.clip_service.encode_text(text_embedding) | |
| def encode_vector_queries( | |
| self, | |
| queries: List[str], | |
| operations: List[str] | |
| ) -> np.ndarray: | |
| """Encode multiple text queries and combine them using vector addition/subtraction. | |
| Args: | |
| queries: List of text queries | |
| operations: List of operations ('+' or '-') for each query | |
| Returns: | |
| Combined normalized embedding vector | |
| """ | |
| # Moderate all queries first | |
| for query in queries: | |
| if not self._moderate_content(query): | |
| raise ValueError("Content moderation filter triggered") | |
| client = self._get_openai_client() | |
| # Get all embeddings at once for efficiency | |
| response = client.embeddings.create( | |
| input=queries, | |
| model=OPENAI_EMBEDDING_MODEL | |
| ) | |
| # Initialize combined embedding | |
| combined_embedding = None | |
| # Process each embedding with its operation | |
| for embedding_data, operation in zip(response.data, operations): | |
| text_embedding = np.array(embedding_data.embedding) | |
| # Project through CLIP text projector | |
| query_embedding = self.clip_service.encode_text(text_embedding) | |
| # Apply operation | |
| if combined_embedding is None: | |
| combined_embedding = query_embedding if operation == "+" else -query_embedding | |
| else: | |
| if operation == "+": | |
| combined_embedding += query_embedding | |
| else: | |
| combined_embedding -= query_embedding | |
| # Normalize the final combined embedding | |
| norm = np.linalg.norm(combined_embedding) | |
| if norm > 0: | |
| combined_embedding = combined_embedding / norm | |
| return combined_embedding | |
| class ZillizService: | |
| """Service for interacting with Zilliz vector database.""" | |
| def get_collection_count(self) -> int: | |
| """Get the total number of entities in the collection. | |
| Returns: | |
| Total count of entities in the collection | |
| """ | |
| logger.info("Getting collection count from Zilliz...") | |
| # Use query endpoint with count to get total entities | |
| payload = { | |
| "collectionName": ZILLIZ_COLLECTION_NAME, | |
| "filter": "", # Empty filter to count all entities | |
| "outputFields": ["count(*)"] | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {ZILLIZ_BEARER}", | |
| "Accept": "application/json", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| # Use the query endpoint (replace /search with /query in the endpoint) | |
| query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query") | |
| response = requests.post(query_endpoint, json=payload, headers=headers) | |
| response.raise_for_status() | |
| result = response.json() | |
| if result.get("code") == 0 and "data" in result: | |
| # The count should be in the response data | |
| data = result["data"] | |
| if data and len(data) > 0: | |
| count = data[0].get("count(*)", 0) | |
| logger.info(f"Collection count: {count:,}") | |
| return count | |
| else: | |
| logger.error(f"Failed to get collection count: {result}") | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Error getting collection count: {e}") | |
| return 0 | |
| def search(self, query_embedding: np.ndarray, top_k: int = DEFAULT_TOP_K, filter_expr: str = None) -> pd.DataFrame: | |
| """Search Zilliz for top-k most similar galaxies. | |
| Args: | |
| query_embedding: Query embedding vector | |
| top_k: Number of results to return | |
| filter_expr: Optional filter expression for filtering results | |
| Returns: | |
| DataFrame with search results | |
| """ | |
| logger.info("Querying Zilliz...") | |
| start_time = time.time() | |
| # Prepare the search payload | |
| payload = { | |
| "collectionName": ZILLIZ_COLLECTION_NAME, | |
| "data": [query_embedding.tolist()], | |
| "annsField": ZILLIZ_ANNS_FIELD, | |
| "limit": top_k, | |
| "outputFields": ZILLIZ_OUTPUT_FIELDS | |
| } | |
| # Add filter if provided | |
| if filter_expr: | |
| payload["filter"] = filter_expr | |
| logger.info(f"Applying filter: {filter_expr}") | |
| headers = { | |
| "Authorization": f"Bearer {ZILLIZ_BEARER}", | |
| "Accept": "application/json", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.post(ZILLIZ_ENDPOINT, json=payload, headers=headers) | |
| response.raise_for_status() | |
| result = response.json() | |
| if result.get("code") == 0 and "data" in result: | |
| # Extract cost from response | |
| cost_vcu = result.get("cost", 0) | |
| # Convert to DataFrame | |
| data_list = result["data"] | |
| df = pd.DataFrame(data_list) | |
| # Add cutout URLs | |
| if not df.empty: | |
| df["cutout_url"] = [cutout_url(ra, dec) for ra, dec in zip(df["ra"], df["dec"])] | |
| query_time = time.time() - start_time | |
| # Log the query | |
| log_zilliz_query( | |
| query_type="vector_search", | |
| query_info={ | |
| "top_k": top_k, | |
| "embedding_dim": len(query_embedding) | |
| }, | |
| result_count=len(df), | |
| query_time=query_time, | |
| cost_vcu=cost_vcu | |
| ) | |
| return df | |
| else: | |
| logger.error(f"Zilliz search failed: {result}") | |
| return pd.DataFrame() | |
| except Exception as e: | |
| logger.error(f"Zilliz search error: {e}") | |
| return pd.DataFrame() | |
| class SearchService: | |
| """High-level search orchestration service.""" | |
| def __init__( | |
| self, | |
| embedding_service: EmbeddingService, | |
| zilliz_service: ZillizService, | |
| image_service: 'ImageProcessingService' = None | |
| ): | |
| self.embedding_service = embedding_service | |
| self.zilliz_service = zilliz_service | |
| self.image_service = image_service | |
| def _build_rmag_filter(self, rmag_min=None, rmag_max=None) -> str: | |
| """Build r_mag filter expression. | |
| Args: | |
| rmag_min: Minimum r_mag value (inclusive) | |
| rmag_max: Maximum r_mag value (inclusive) | |
| Returns: | |
| Filter expression string, or None if no filter | |
| """ | |
| filter_parts = [] | |
| if rmag_min is not None: | |
| filter_parts.append(f"r_mag >= {rmag_min}") | |
| if rmag_max is not None: | |
| filter_parts.append(f"r_mag <= {rmag_max}") | |
| if filter_parts: | |
| return " AND ".join(filter_parts) | |
| return None | |
| def search_text(self, query: str, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None) -> pd.DataFrame: | |
| """Search galaxies using text query. | |
| Args: | |
| query: Text search query | |
| top_k: Number of results to return | |
| rmag_min: Minimum r_mag value (inclusive) | |
| rmag_max: Maximum r_mag value (inclusive) | |
| Returns: | |
| DataFrame with search results | |
| """ | |
| try: | |
| # Encode query | |
| query_embedding = self.embedding_service.encode_text_query(query) | |
| # Build filter | |
| filter_expr = self._build_rmag_filter(rmag_min, rmag_max) | |
| # Search Zilliz | |
| return self.zilliz_service.search(query_embedding, top_k, filter_expr) | |
| except ValueError as e: | |
| # Content moderation triggered - return empty results silently | |
| if "moderation" in str(e).lower(): | |
| logger.info("Search blocked by content moderation") | |
| return pd.DataFrame() | |
| raise | |
| def search_vector( | |
| self, | |
| queries: List[str], | |
| operations: List[str], | |
| top_k: int = DEFAULT_TOP_K, | |
| rmag_min=None, | |
| rmag_max=None | |
| ) -> pd.DataFrame: | |
| """Search galaxies using vector addition/subtraction. | |
| Args: | |
| queries: List of text queries | |
| operations: List of operations ('+' or '-') for each query | |
| top_k: Number of results to return | |
| rmag_min: Minimum r_mag value (inclusive) | |
| rmag_max: Maximum r_mag value (inclusive) | |
| Returns: | |
| DataFrame with search results | |
| """ | |
| try: | |
| # Encode and combine vectors | |
| combined_embedding = self.embedding_service.encode_vector_queries(queries, operations) | |
| # Build filter | |
| filter_expr = self._build_rmag_filter(rmag_min, rmag_max) | |
| # Search Zilliz | |
| return self.zilliz_service.search(combined_embedding, top_k, filter_expr) | |
| except ValueError as e: | |
| # Content moderation triggered - return empty results silently | |
| if "moderation" in str(e).lower(): | |
| logger.info("Search blocked by content moderation") | |
| return pd.DataFrame() | |
| raise | |
| def search_advanced( | |
| self, | |
| text_queries: List[str] = None, | |
| text_weights: List[float] = None, | |
| image_queries: List[dict] = None, | |
| image_weights: List[float] = None, | |
| top_k: int = DEFAULT_TOP_K, | |
| rmag_min=None, | |
| rmag_max=None | |
| ) -> pd.DataFrame: | |
| """Search galaxies using advanced vector addition/subtraction with text and/or images. | |
| Args: | |
| text_queries: List of text query strings | |
| text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0) | |
| image_queries: List of dicts with 'ra', 'dec', 'fov' keys | |
| image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0) | |
| top_k: Number of results to return | |
| rmag_min: Minimum r_mag value (inclusive) | |
| rmag_max: Maximum r_mag value (inclusive) | |
| Returns: | |
| DataFrame with search results | |
| """ | |
| try: | |
| combined_embedding = None | |
| # Process text queries | |
| if text_queries and len(text_queries) > 0: | |
| for query, weight in zip(text_queries, text_weights): | |
| query_embedding = self.embedding_service.encode_text_query(query) | |
| # Apply weight | |
| weighted_embedding = query_embedding * weight | |
| if combined_embedding is None: | |
| combined_embedding = weighted_embedding | |
| else: | |
| combined_embedding += weighted_embedding | |
| # Process image queries | |
| if image_queries and len(image_queries) > 0: | |
| if self.image_service is None: | |
| raise RuntimeError("Image service not initialized") | |
| for img_query, weight in zip(image_queries, image_weights): | |
| # Encode image | |
| image_embedding = self.image_service.encode_image( | |
| ra=img_query['ra'], | |
| dec=img_query['dec'], | |
| fov=img_query.get('fov', 0.025), | |
| size=256 | |
| ) | |
| # Apply weight | |
| weighted_embedding = image_embedding * weight | |
| if combined_embedding is None: | |
| combined_embedding = weighted_embedding | |
| else: | |
| combined_embedding += weighted_embedding | |
| # Normalize the final combined embedding | |
| if combined_embedding is not None: | |
| norm = np.linalg.norm(combined_embedding) | |
| if norm > 0: | |
| combined_embedding = combined_embedding / norm | |
| # Build filter | |
| filter_expr = self._build_rmag_filter(rmag_min, rmag_max) | |
| # Search Zilliz | |
| return self.zilliz_service.search(combined_embedding, top_k, filter_expr) | |
| except ValueError as e: | |
| # Content moderation triggered - return empty results silently | |
| if "moderation" in str(e).lower(): | |
| logger.info("Search blocked by content moderation") | |
| return pd.DataFrame() | |
| raise | |