AION-Search / src /services.py
astronolan's picture
Added moderation and tutorial
28fcc08
"""Backend services for AION Search."""
import time
import logging
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import requests
from typing import List
from openai import OpenAI
from src.config import (
ZILLIZ_BEARER,
ZILLIZ_ENDPOINT,
ZILLIZ_COLLECTION_NAME,
ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
ZILLIZ_ANNS_FIELD,
ZILLIZ_PRIMARY_KEY,
ZILLIZ_OUTPUT_FIELDS,
COLLECTION_CONFIGS,
OPENAI_API_KEY,
OPENAI_EMBEDDING_MODEL,
CLIP_NORMALIZE_EPS,
DEFAULT_TOP_K,
)
from src.utils import cutout_url, log_zilliz_query
logger = logging.getLogger(__name__)
class CLIPModelService:
"""Service for managing CLIP model loading and inference."""
def __init__(self):
self.model = None
self.device = None
self.loaded = False
def load_model(self, checkpoint_path: str) -> None:
"""Load the CLIP model from checkpoint.
Args:
checkpoint_path: Path to the CLIP model checkpoint file
"""
logger.info(f"Loading CLIP model from {checkpoint_path}...")
from clip.models.clip_model import GalaxyClipModel
# Set device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
model_config = checkpoint['model_config']
# Initialize model with saved configuration
self.model = GalaxyClipModel(
image_input_dim=model_config['image_input_dim'],
text_input_dim=model_config['text_input_dim'],
embedding_dim=model_config['embedding_dim'],
use_mean_embeddings=model_config.get('use_mean_embeddings', True)
)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
self.loaded = True
logger.info("CLIP model loaded successfully")
def encode_text(self, text_embedding: np.ndarray) -> np.ndarray:
"""Project text embedding through CLIP text projector.
Args:
text_embedding: OpenAI text embedding (1536-dim)
Returns:
CLIP-projected embedding (1024-dim)
"""
if not self.loaded:
raise RuntimeError("CLIP model not loaded. Call load_model() first.")
with torch.no_grad():
text_tensor = torch.from_numpy(text_embedding).float().unsqueeze(0).to(self.device)
clip_features = self.model.text_projector(text_tensor)
# Normalize as per CLIP
clip_features = F.normalize(clip_features, dim=-1, eps=CLIP_NORMALIZE_EPS)
query_embedding = clip_features.cpu().numpy().squeeze(0)
return query_embedding
class ImageProcessingService:
"""Service for retrieving pre-existing image embeddings from Zilliz."""
def __init__(self):
pass
def encode_image(self, ra: float, dec: float, fov: float = 0.025, size: int = 256) -> np.ndarray:
"""Query Zilliz for pre-existing embedding at the given coordinates.
Args:
ra: Right ascension in degrees
dec: Declination in degrees
fov: Field of view in degrees (used to define search box)
size: Image size in pixels (unused, kept for API compatibility)
Returns:
Pre-existing AION-Search embedding vector (1024-dim) from Zilliz
"""
logger.info(f"Querying Zilliz for pre-existing embedding at RA={ra}, Dec={dec}, FoV={fov}")
# Calculate bounding box based on field of view
ra_min = ra - fov/2
ra_max = ra + fov/2
dec_min = dec - fov/2
dec_max = dec + fov/2
# Build filter expression for coordinate range
filter_expr = f"ra > {ra_min} AND ra < {ra_max} AND dec > {dec_min} AND dec < {dec_max}"
# Get the ANNS field for the image search collection
image_search_config = COLLECTION_CONFIGS.get(ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME)
image_anns_field = image_search_config["anns_field"]
# Prepare query payload - always use the image search collection (legacy)
payload = {
"collectionName": ZILLIZ_IMAGE_SEARCH_COLLECTION_NAME,
"filter": filter_expr,
"outputFields": [image_anns_field],
"limit": 1
}
headers = {
"Authorization": f"Bearer {ZILLIZ_BEARER}",
"Accept": "application/json",
"Content-Type": "application/json"
}
try:
# Use query endpoint (replace /search with /query)
query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
response = requests.post(query_endpoint, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
if result.get("code") == 0 and "data" in result:
data = result["data"]
if data and len(data) > 0:
# Extract the embedding from the first result using the image search ANNS field
embedding = data[0].get(image_anns_field)
if embedding:
embedding_array = np.array(embedding, dtype=np.float32)
logger.info(f"Retrieved pre-existing embedding with shape: {embedding_array.shape}")
return embedding_array
else:
logger.error(f"No embedding field found in result: {data[0].keys()}")
raise RuntimeError(f"No embedding found at coordinates RA={ra}, Dec={dec}")
else:
logger.error(f"No galaxies found at coordinates RA={ra}, Dec={dec} with FoV={fov}")
raise RuntimeError(f"No galaxies found at coordinates RA={ra}, Dec={dec}")
else:
logger.error(f"Zilliz query failed: {result}")
raise RuntimeError(f"Failed to query Zilliz: {result}")
except Exception as e:
logger.error(f"Error querying Zilliz for embedding: {e}")
raise
class EmbeddingService:
"""Service for encoding text queries into embeddings."""
def __init__(self, clip_service: CLIPModelService):
self.clip_service = clip_service
self.openai_client = None
def _get_openai_client(self) -> OpenAI:
"""Get or create OpenAI client."""
if self.openai_client is None:
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable not set")
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
return self.openai_client
def _moderate_content(self, text: str) -> bool:
"""Check if text content is appropriate using OpenAI Moderation API.
Args:
text: Text to moderate
Returns:
True if content is safe, False if flagged
"""
try:
client = self._get_openai_client()
response = client.moderations.create(input=text)
# If any category is flagged, reject the content
if response.results[0].flagged:
logger.warning(f"Content moderation flagged input")
return False
return True
except Exception as e:
logger.error(f"Moderation API error: {e}")
# On error, allow the content through (fail open)
return True
def encode_text_query(self, query: str) -> np.ndarray:
"""Encode text query using OpenAI embeddings + CLIP text projector.
Args:
query: Text search query
Returns:
CLIP embedding vector
"""
# Moderate content first
if not self._moderate_content(query):
raise ValueError("Content moderation filter triggered")
client = self._get_openai_client()
# Get OpenAI text embedding
response = client.embeddings.create(
input=query,
model=OPENAI_EMBEDDING_MODEL
)
text_embedding = np.array(response.data[0].embedding)
# Project through CLIP text projector
return self.clip_service.encode_text(text_embedding)
def encode_vector_queries(
self,
queries: List[str],
operations: List[str]
) -> np.ndarray:
"""Encode multiple text queries and combine them using vector addition/subtraction.
Args:
queries: List of text queries
operations: List of operations ('+' or '-') for each query
Returns:
Combined normalized embedding vector
"""
# Moderate all queries first
for query in queries:
if not self._moderate_content(query):
raise ValueError("Content moderation filter triggered")
client = self._get_openai_client()
# Get all embeddings at once for efficiency
response = client.embeddings.create(
input=queries,
model=OPENAI_EMBEDDING_MODEL
)
# Initialize combined embedding
combined_embedding = None
# Process each embedding with its operation
for embedding_data, operation in zip(response.data, operations):
text_embedding = np.array(embedding_data.embedding)
# Project through CLIP text projector
query_embedding = self.clip_service.encode_text(text_embedding)
# Apply operation
if combined_embedding is None:
combined_embedding = query_embedding if operation == "+" else -query_embedding
else:
if operation == "+":
combined_embedding += query_embedding
else:
combined_embedding -= query_embedding
# Normalize the final combined embedding
norm = np.linalg.norm(combined_embedding)
if norm > 0:
combined_embedding = combined_embedding / norm
return combined_embedding
class ZillizService:
"""Service for interacting with Zilliz vector database."""
def get_collection_count(self) -> int:
"""Get the total number of entities in the collection.
Returns:
Total count of entities in the collection
"""
logger.info("Getting collection count from Zilliz...")
# Use query endpoint with count to get total entities
payload = {
"collectionName": ZILLIZ_COLLECTION_NAME,
"filter": "", # Empty filter to count all entities
"outputFields": ["count(*)"]
}
headers = {
"Authorization": f"Bearer {ZILLIZ_BEARER}",
"Accept": "application/json",
"Content-Type": "application/json"
}
try:
# Use the query endpoint (replace /search with /query in the endpoint)
query_endpoint = ZILLIZ_ENDPOINT.replace("/search", "/query")
response = requests.post(query_endpoint, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
if result.get("code") == 0 and "data" in result:
# The count should be in the response data
data = result["data"]
if data and len(data) > 0:
count = data[0].get("count(*)", 0)
logger.info(f"Collection count: {count:,}")
return count
else:
logger.error(f"Failed to get collection count: {result}")
return 0
except Exception as e:
logger.error(f"Error getting collection count: {e}")
return 0
def search(self, query_embedding: np.ndarray, top_k: int = DEFAULT_TOP_K, filter_expr: str = None) -> pd.DataFrame:
"""Search Zilliz for top-k most similar galaxies.
Args:
query_embedding: Query embedding vector
top_k: Number of results to return
filter_expr: Optional filter expression for filtering results
Returns:
DataFrame with search results
"""
logger.info("Querying Zilliz...")
start_time = time.time()
# Prepare the search payload
payload = {
"collectionName": ZILLIZ_COLLECTION_NAME,
"data": [query_embedding.tolist()],
"annsField": ZILLIZ_ANNS_FIELD,
"limit": top_k,
"outputFields": ZILLIZ_OUTPUT_FIELDS
}
# Add filter if provided
if filter_expr:
payload["filter"] = filter_expr
logger.info(f"Applying filter: {filter_expr}")
headers = {
"Authorization": f"Bearer {ZILLIZ_BEARER}",
"Accept": "application/json",
"Content-Type": "application/json"
}
try:
response = requests.post(ZILLIZ_ENDPOINT, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
if result.get("code") == 0 and "data" in result:
# Extract cost from response
cost_vcu = result.get("cost", 0)
# Convert to DataFrame
data_list = result["data"]
df = pd.DataFrame(data_list)
# Add cutout URLs
if not df.empty:
df["cutout_url"] = [cutout_url(ra, dec) for ra, dec in zip(df["ra"], df["dec"])]
query_time = time.time() - start_time
# Log the query
log_zilliz_query(
query_type="vector_search",
query_info={
"top_k": top_k,
"embedding_dim": len(query_embedding)
},
result_count=len(df),
query_time=query_time,
cost_vcu=cost_vcu
)
return df
else:
logger.error(f"Zilliz search failed: {result}")
return pd.DataFrame()
except Exception as e:
logger.error(f"Zilliz search error: {e}")
return pd.DataFrame()
class SearchService:
"""High-level search orchestration service."""
def __init__(
self,
embedding_service: EmbeddingService,
zilliz_service: ZillizService,
image_service: 'ImageProcessingService' = None
):
self.embedding_service = embedding_service
self.zilliz_service = zilliz_service
self.image_service = image_service
def _build_rmag_filter(self, rmag_min=None, rmag_max=None) -> str:
"""Build r_mag filter expression.
Args:
rmag_min: Minimum r_mag value (inclusive)
rmag_max: Maximum r_mag value (inclusive)
Returns:
Filter expression string, or None if no filter
"""
filter_parts = []
if rmag_min is not None:
filter_parts.append(f"r_mag >= {rmag_min}")
if rmag_max is not None:
filter_parts.append(f"r_mag <= {rmag_max}")
if filter_parts:
return " AND ".join(filter_parts)
return None
def search_text(self, query: str, top_k: int = DEFAULT_TOP_K, rmag_min=None, rmag_max=None) -> pd.DataFrame:
"""Search galaxies using text query.
Args:
query: Text search query
top_k: Number of results to return
rmag_min: Minimum r_mag value (inclusive)
rmag_max: Maximum r_mag value (inclusive)
Returns:
DataFrame with search results
"""
try:
# Encode query
query_embedding = self.embedding_service.encode_text_query(query)
# Build filter
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
# Search Zilliz
return self.zilliz_service.search(query_embedding, top_k, filter_expr)
except ValueError as e:
# Content moderation triggered - return empty results silently
if "moderation" in str(e).lower():
logger.info("Search blocked by content moderation")
return pd.DataFrame()
raise
def search_vector(
self,
queries: List[str],
operations: List[str],
top_k: int = DEFAULT_TOP_K,
rmag_min=None,
rmag_max=None
) -> pd.DataFrame:
"""Search galaxies using vector addition/subtraction.
Args:
queries: List of text queries
operations: List of operations ('+' or '-') for each query
top_k: Number of results to return
rmag_min: Minimum r_mag value (inclusive)
rmag_max: Maximum r_mag value (inclusive)
Returns:
DataFrame with search results
"""
try:
# Encode and combine vectors
combined_embedding = self.embedding_service.encode_vector_queries(queries, operations)
# Build filter
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
# Search Zilliz
return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
except ValueError as e:
# Content moderation triggered - return empty results silently
if "moderation" in str(e).lower():
logger.info("Search blocked by content moderation")
return pd.DataFrame()
raise
def search_advanced(
self,
text_queries: List[str] = None,
text_weights: List[float] = None,
image_queries: List[dict] = None,
image_weights: List[float] = None,
top_k: int = DEFAULT_TOP_K,
rmag_min=None,
rmag_max=None
) -> pd.DataFrame:
"""Search galaxies using advanced vector addition/subtraction with text and/or images.
Args:
text_queries: List of text query strings
text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
image_queries: List of dicts with 'ra', 'dec', 'fov' keys
image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
top_k: Number of results to return
rmag_min: Minimum r_mag value (inclusive)
rmag_max: Maximum r_mag value (inclusive)
Returns:
DataFrame with search results
"""
try:
combined_embedding = None
# Process text queries
if text_queries and len(text_queries) > 0:
for query, weight in zip(text_queries, text_weights):
query_embedding = self.embedding_service.encode_text_query(query)
# Apply weight
weighted_embedding = query_embedding * weight
if combined_embedding is None:
combined_embedding = weighted_embedding
else:
combined_embedding += weighted_embedding
# Process image queries
if image_queries and len(image_queries) > 0:
if self.image_service is None:
raise RuntimeError("Image service not initialized")
for img_query, weight in zip(image_queries, image_weights):
# Encode image
image_embedding = self.image_service.encode_image(
ra=img_query['ra'],
dec=img_query['dec'],
fov=img_query.get('fov', 0.025),
size=256
)
# Apply weight
weighted_embedding = image_embedding * weight
if combined_embedding is None:
combined_embedding = weighted_embedding
else:
combined_embedding += weighted_embedding
# Normalize the final combined embedding
if combined_embedding is not None:
norm = np.linalg.norm(combined_embedding)
if norm > 0:
combined_embedding = combined_embedding / norm
# Build filter
filter_expr = self._build_rmag_filter(rmag_min, rmag_max)
# Search Zilliz
return self.zilliz_service.search(combined_embedding, top_k, filter_expr)
except ValueError as e:
# Content moderation triggered - return empty results silently
if "moderation" in str(e).lower():
logger.info("Search blocked by content moderation")
return pd.DataFrame()
raise