AION-Search / src /utils.py
astronolan's picture
Enhance search functionality and UI components
fa7eb7f
"""Utility functions for AION Search."""
import json
import logging
import uuid
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional
from src.config import CUTOUT_FOV, CUTOUT_SIZE, VCU_COST_PER_MILLION
from src.hf_logging import log_query_event, SESSION_ID
logger = logging.getLogger(__name__)
def cutout_url(ra: float, dec: float, fov: float = CUTOUT_FOV, size: int = CUTOUT_SIZE) -> str:
"""Generate Legacy Survey cutout URL from RA/Dec coordinates.
Args:
ra: Right Ascension in degrees
dec: Declination in degrees
fov: Field of view in degrees
size: Image size in pixels
Returns:
URL string for the cutout image
"""
return (
f"https://alasky.cds.unistra.fr/hips-image-services/hips2fits"
f"?hips=CDS/P/DESI-Legacy-Surveys/DR10/color"
f"&ra={ra}&dec={dec}&fov={fov}&width={size}&height={size}&format=jpg"
)
def log_zilliz_query(
query_type: str,
query_info: Dict[str, Any],
result_count: int,
query_time: float,
cost_vcu: int = 0,
request_id: Optional[str] = None,
error_occurred: bool = False,
error_message: Optional[str] = None,
error_type: Optional[str] = None
) -> None:
"""Print Zilliz query info to terminal and log to HF dataset.
Args:
query_type: Type of query (e.g., "vector_search", "text_search")
query_info: Dictionary containing query details
result_count: Number of results returned
query_time: Query execution time in seconds
cost_vcu: Cost in vCU units
request_id: Unique ID for this request
error_occurred: Whether an error occurred
error_message: Error message if error_occurred is True
error_type: Type of error if error_occurred is True
"""
timestamp = datetime.now().isoformat()
# Convert vCU cost to dollars
cost_usd = (cost_vcu / 1e6) * VCU_COST_PER_MILLION
log_data = {
"timestamp": timestamp,
"query_type": query_type,
"query_info": query_info,
"result_count": result_count,
"query_time_seconds": query_time,
"cost_vCU": cost_vcu,
"cost_usd": cost_usd
}
# Print to terminal
print("\n" + "="*80)
print(f"ZILLIZ QUERY: {query_type}")
print("="*80)
print(json.dumps(log_data, indent=2))
print("="*80 + "\n")
logger.info(
f"{result_count} results in {query_time:.3f}s | "
f"{cost_vcu} vCU (${cost_usd:.6f})"
)
# Log Zilliz stats to HF dataset
try:
payload = {
"log_type": "zilliz_query_stats",
"timestamp": timestamp,
"query_type": query_type,
"query_info": query_info,
"result_count": result_count,
"query_time_seconds": query_time,
"cost_vcu": cost_vcu,
"cost_usd": cost_usd,
"error_occurred": error_occurred,
}
if request_id:
payload["request_id"] = request_id
if error_occurred:
payload["error_message"] = error_message
payload["error_type"] = error_type
log_query_event(payload)
except Exception as e:
logger.error(f"Failed to send Zilliz stats to HF dataset: {e}")
def format_galaxy_count(count: int) -> str:
"""Format galaxy count with thousands separator.
Args:
count: Number of galaxies
Returns:
Formatted string (e.g., "259,636 galaxies")
"""
return f"{count:,} galaxies"
def build_query_xml(
text_queries: list = None,
text_weights: list = None,
image_queries: list = None,
image_weights: list = None,
rmag_min: float = None,
rmag_max: float = None
) -> str:
"""Build XML representation of a query according to aql.md specification.
Args:
text_queries: List of text query strings
text_weights: List of weight magnitudes for text queries (e.g., 1.0, -1.0, 2.0, -5.0)
image_queries: List of dicts with 'ra', 'dec', 'fov' keys
image_weights: List of weight magnitudes for image queries (e.g., 1.0, -1.0, 2.0, -5.0)
rmag_min: Minimum r_mag filter value
rmag_max: Maximum r_mag filter value
Returns:
XML string representation of the query (single line)
"""
xml_parts = ['<query>']
# Add text queries
if text_queries and len(text_queries) > 0:
xml_parts.append('<text>')
for query, weight in zip(text_queries, text_weights):
xml_parts.append('<term>')
xml_parts.append(f'<weight>{weight}</weight>')
xml_parts.append(f'<content>{query}</content>')
xml_parts.append('</term>')
xml_parts.append('</text>')
# Add image queries
if image_queries and len(image_queries) > 0:
xml_parts.append('<image>')
for img_query, weight in zip(image_queries, image_weights):
xml_parts.append('<reference>')
xml_parts.append(f'<ra>{img_query["ra"]}</ra>')
xml_parts.append(f'<dec>{img_query["dec"]}</dec>')
xml_parts.append(f'<fov>{img_query["fov"]}</fov>')
xml_parts.append(f'<weight>{weight}</weight>')
xml_parts.append('</reference>')
xml_parts.append('</image>')
# Add filters
if rmag_min is not None or rmag_max is not None:
xml_parts.append('<filters>')
if rmag_min is not None and rmag_max is not None:
xml_parts.append('<filter>')
xml_parts.append('<column>r_mag</column>')
xml_parts.append('<operator>between</operator>')
xml_parts.append(f'<value_min>{rmag_min}</value_min>')
xml_parts.append(f'<value_max>{rmag_max}</value_max>')
xml_parts.append('</filter>')
elif rmag_min is not None:
xml_parts.append('<filter>')
xml_parts.append('<column>r_mag</column>')
xml_parts.append('<operator>gte</operator>')
xml_parts.append(f'<value>{rmag_min}</value>')
xml_parts.append('</filter>')
elif rmag_max is not None:
xml_parts.append('<filter>')
xml_parts.append('<column>r_mag</column>')
xml_parts.append('<operator>lte</operator>')
xml_parts.append(f'<value>{rmag_max}</value>')
xml_parts.append('</filter>')
xml_parts.append('</filters>')
xml_parts.append('</query>')
return ''.join(xml_parts)
def log_query_to_csv(
query_xml: str,
csv_path: str = "logs/query_log.csv",
request_id: Optional[str] = None,
error_occurred: bool = False,
error_message: Optional[str] = None,
error_type: Optional[str] = None
) -> None:
"""Print query XML to terminal and log to HF dataset.
Args:
query_xml: XML string representation of the query
csv_path: Deprecated parameter (kept for backward compatibility)
request_id: Unique ID for this request
error_occurred: Whether an error occurred during search
error_message: Error message if error_occurred is True
error_type: Type of error if error_occurred is True
"""
timestamp = datetime.now().isoformat()
# Print query to terminal
print("\n" + "="*80)
print(f"QUERY EXECUTED AT: {timestamp}")
print("="*80)
print(query_xml)
print("="*80 + "\n")
logger.info(f"Query printed to terminal")
# Log to HF dataset
try:
payload = {
"log_type": "aql_query",
"timestamp": timestamp,
"query_xml": query_xml,
"error_occurred": error_occurred,
}
if request_id:
payload["request_id"] = request_id
if error_occurred:
payload["error_message"] = error_message
payload["error_type"] = error_type
log_query_event(payload)
except Exception as e:
logger.error(f"Failed to send query log to HF dataset: {e}")
def log_click_event(
request_id: Optional[str],
rank: int,
primary_key: str,
ra: float,
dec: float,
r_mag: float,
distance: float
) -> None:
"""Log a galaxy tile click event to HF dataset.
Args:
request_id: Unique ID for the search request that produced this galaxy
rank: Position in search results (0-indexed)
primary_key: Primary key of the clicked galaxy
ra: Right ascension
dec: Declination
r_mag: r-band magnitude
distance: Cosine similarity score
"""
try:
payload = {
"log_type": "click_event",
"rank": rank,
"primary_key": primary_key,
"ra": ra,
"dec": dec,
"r_mag": r_mag,
"distance": distance,
}
if request_id:
payload["request_id"] = request_id
log_query_event(payload)
logger.info(f"Logged click event: rank={rank}, primary_key={primary_key}")
except Exception as e:
logger.error(f"Failed to log click event: {e}")