"""Database management for fabric-to-espanso.""" from typing import Optional, List, Dict import logging import time from qdrant_client import QdrantClient from qdrant_client.http import models, exceptions from qdrant_client.http.models import Distance, VectorParams, PointStruct from .config import config from .exceptions import DatabaseConnectionError, CollectionError, DatabaseInitializationError, ConfigurationError logger = logging.getLogger('fabric_to_espanso') def get_dense_vector_name(client: QdrantClient, collection_name: str) -> str: """ Get the name of the dense vector from the collection configuration. Args: client: Initialized Qdrant client collection_name: Name of the collection Returns: Name of the dense vector as used in the collection """ try: return list(client.get_collection(collection_name).config.params.vectors.keys())[0] except (IndexError, AttributeError) as e: logger.warning(f"Could not get dense vector name: {e}") # Fallback to a default name return "fast-multilingual-e5-large" def get_sparse_vector_name(client: QdrantClient, collection_name: str) -> str: """ Get the name of the sparse vector from the collection configuration. Args: client: Initialized Qdrant client collection_name: Name of the collection Returns: Name of the sparse vector as used in the collection """ try: return list(client.get_collection(collection_name).config.params.sparse_vectors.keys())[0] except (IndexError, AttributeError) as e: logger.warning(f"Could not get sparse vector name: {e}") # Fallback to a default name return "fast-sparse-splade_pp_en_v1" def create_database_connection(url: Optional[str] = None, api_key: Optional[str] = None) -> QdrantClient: """Create a database connection. Args: url: Optional database URL. If not provided, uses configuration. Returns: QdrantClient: Connected database client Raises: DatabaseConnectionError: If connection fails after retries """ url = url or config.database.url for attempt in range(config.database.max_retries + 1): try: client = QdrantClient( url=url, timeout=config.database.timeout, api_key=api_key ) # Test connection client.get_collections() return client except Exception as e: if attempt == config.database.max_retries: raise DatabaseConnectionError( f"Failed to connect to database at {url} after " f"{config.database.max_retries} attempts: {str(e)}" ) from e logger.warning( f"Connection attempt {attempt + 1} failed, retrying in " f"{config.database.retry_delay} seconds..." ) time.sleep(config.database.retry_delay) def initialize_qdrant_database( url: str = config.database.url, api_key: Optional[str] = "", collection_name: str = config.embedding.collection_name, use_fastembed: bool = config.embedding.use_fastembed, dense_model: str = config.embedding.dense_model_name, sparse_model: str = config.embedding.sparse_model_name ) -> QdrantClient: """Initialize the Qdrant database for storing markdown file information. Args: collection_name: Name of the collection to initialize use_fastembed: Whether to use FastEmbed for embeddings embed_model: Name of the embedding model to use Returns: QdrantClient: Initialized database client Raises: DatabaseInitializationError: If initialization fails CollectionError: If collection creation fails ConfigurationError: If configuration is invalid """ try: # Validate configuration config.validate() # Create database connection client = create_database_connection(url=url, api_key=api_key) client.set_model(dense_model) client.set_sparse_model(sparse_model) # Check if collection exists collections = client.get_collections() collection_names = [c.name for c in collections.collections] if collection_name not in collection_names: logger.info(f"Creating new collection: {collection_name}") # Create collection with appropriate vector configuration if use_fastembed: vectors_config = client.get_fastembed_vector_params() sparse_vectors_config = client.get_fastembed_sparse_vector_params() else: print("Creating database without Fastembed not implemented yet.") raise NotImplementedError() try: client.create_collection( collection_name=collection_name, vectors_config=vectors_config, sparse_vectors_config=sparse_vectors_config, on_disk_payload=True ) except exceptions.UnexpectedResponse as e: raise CollectionError( f"Failed to create collection {collection_name}: {str(e)}" ) from e # Create indexes for efficient searching for field_name, field_type in [ ("filename", models.PayloadSchemaType.KEYWORD), ("date", models.PayloadSchemaType.DATETIME) ]: client.create_payload_index( collection_name=collection_name, field_name=field_name, field_schema=field_type ) logger.info(f"Created indexes for collection {collection_name}") # Log collection status collection_info = client.get_collection(collection_name) logger.info( f"Collection {collection_name} ready with " f"{collection_info.points_count} points" ) return client except Exception as e: logger.error(f"Database initialization failed: {str(e)}", exc_info=True) if isinstance(e, (DatabaseConnectionError, CollectionError)): raise raise DatabaseInitializationError(str(e)) from e def validate_database_payload( client: QdrantClient, collection_name: str, ) -> Dict: """Validate the payload of all points in the Qdrant database. Args: client: Initialized Qdrant client collection_name: Name of the collection to validate """ # First validate existing points in database logger.info("Validating existing database points...") offset = None while True: scroll_result = client.scroll( collection_name=collection_name, limit=5, # Process in batches of 5 offset=offset ) points, offset = scroll_result for point in points: try: fixed_payload = validate_point_payload(point.payload, point.id) if fixed_payload != point.payload: # Update point with fixed payload point_struct = PointStruct( id=point.id, vector=point.vector, payload=fixed_payload ) client.upsert(collection_name=collection_name, points=[point_struct]) logger.info(f"Fixed and updated point {point.id} in database") except ConfigurationError as e: logger.error(str(e)) if not offset: # No more points to process break logger.info("Database validation completed") def validate_point_payload(payload: dict, point_id: Optional[str] = None) -> dict: """Validate and fix point payload fields. Only use if somehow many points have become corrupted. Args: payload (dict): Point payload to validate point_id (str, optional): ID of the point for logging purposes Returns: dict: Validated and potentially fixed payload Raises: ConfigurationError: If required fields are missing and cannot be fixed """ print(f"Validating point {point_id if point_id else ''}") from .exceptions import ConfigurationError # Check for critical fields if 'filename' not in payload or 'content' not in payload: error_msg = f"Point {point_id if point_id else ''} is missing critical fields: " error_msg += "'filename' and/or 'content' are required and cannot be defaulted" raise ConfigurationError(error_msg) # Copy payload to avoid modifying the original fixed_payload = payload.copy() # Apply defaults and fixes for non-critical fields if 'purpose' not in fixed_payload or not fixed_payload['purpose']: fixed_payload['purpose'] = fixed_payload['content'] logger.warning(f"Point {point_id if point_id else ''}: 'purpose' was missing, set to content value") if 'filesize' not in fixed_payload: fixed_payload['filesize'] = self.required_fields_defaults['filesize'] logger.warning(f"Point {point_id if point_id else ''}: 'filesize' was missing, set to {self.required_fields_defaults['filesize']}") if 'trigger' not in fixed_payload: fixed_payload['trigger'] = self.required_fields_defaults['trigger'] logger.warning(f"Point {point_id if point_id else ''}: 'trigger' was missing, set to {self.required_fields_defaults['trigger']}") return fixed_payload