Spaces:
Runtime error
Runtime error
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from pathlib import Path | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
CURRENT_SCHEMA_VERSION = "1.2" # Increment this when schema changes | |
VECTOR_SIZE = 512 # CLIP embedding size | |
class QdrantClientSingleton: | |
_instance = None | |
def get_instance(cls): | |
if cls._instance is None: | |
# Check if we have cloud credentials | |
qdrant_url = os.getenv('QDRANT_URL') | |
qdrant_api_key = os.getenv('QDRANT_API_KEY') | |
print(f"QDRANT_URL: {qdrant_url}") | |
print(f"QDRANT_API_KEY: {'***' + qdrant_api_key[-10:] if qdrant_api_key else 'None'}") | |
if qdrant_url and qdrant_api_key: | |
print(f"Initializing Qdrant Cloud client: {qdrant_url}") | |
try: | |
cls._instance = QdrantClient( | |
url=qdrant_url, | |
api_key=qdrant_api_key, | |
) | |
print("Successfully connected to Qdrant Cloud") | |
except Exception as e: | |
print(f"Failed to connect to Qdrant Cloud: {e}") | |
print("Falling back to local storage") | |
storage_path = Path("qdrant_data").absolute() | |
storage_path.mkdir(exist_ok=True) | |
cls._instance = QdrantClient(path=str(storage_path)) | |
else: | |
# Fallback to local storage | |
print("Cloud credentials not found, using local Qdrant storage") | |
storage_path = Path("qdrant_data").absolute() | |
storage_path.mkdir(exist_ok=True) | |
cls._instance = QdrantClient(path=str(storage_path)) | |
# Print collections for debugging | |
try: | |
collections = cls._instance.get_collections().collections | |
print(f"Available collections: {[col.name for col in collections]}") | |
except Exception as e: | |
print(f"Error getting collections: {e}") | |
return cls._instance | |
def initialize_collection(cls, collection_name: str): | |
client = cls.get_instance() | |
# Check if collection exists | |
collections = client.get_collections().collections | |
exists = any(collection.name == collection_name for collection in collections) | |
if not exists: | |
# Create new collection with current schema version | |
cls._create_collection(client, collection_name) | |
else: | |
# Check schema version and update if necessary | |
cls._check_and_update_schema(client, collection_name) | |
def _create_collection(cls, client: QdrantClient, collection_name: str): | |
"""Create a new collection with the current schema version""" | |
# First create the collection with basic config | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=models.VectorParams( | |
size=VECTOR_SIZE, | |
distance=models.Distance.COSINE | |
), | |
on_disk_payload=True, # Store vectors on disk | |
optimizers_config=models.OptimizersConfigDiff( | |
indexing_threshold=0 # Index immediately | |
) | |
) | |
# Then create payload indexes for efficient searching | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name="image_id", | |
field_schema=models.PayloadSchemaType.KEYWORD | |
) | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name="path", | |
field_schema=models.PayloadSchemaType.KEYWORD | |
) | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name="root_folder", | |
field_schema=models.PayloadSchemaType.KEYWORD | |
) | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name="schema_version", | |
field_schema=models.PayloadSchemaType.KEYWORD | |
) | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name="indexed_at", | |
field_schema=models.PayloadSchemaType.INTEGER | |
) | |
print(f"Created collection {collection_name} with schema version {CURRENT_SCHEMA_VERSION}") | |
def _check_and_update_schema(cls, client: QdrantClient, collection_name: str): | |
"""Check collection schema version and update if necessary""" | |
try: | |
# Get a sample point to check schema version | |
sample = client.scroll( | |
collection_name=collection_name, | |
limit=1, | |
with_payload=True | |
)[0] | |
if not sample: | |
print(f"Collection {collection_name} is empty") | |
return | |
# Check schema version of existing data | |
point_version = sample[0].payload.get("schema_version", "0.0") | |
if point_version != CURRENT_SCHEMA_VERSION: | |
print(f"Schema version mismatch: {point_version} != {CURRENT_SCHEMA_VERSION}") | |
print(f"Collection {collection_name} needs to be recreated") | |
# Recreate collection with new schema | |
client.delete_collection(collection_name=collection_name) | |
cls._create_collection(client, collection_name) | |
else: | |
print(f"Collection {collection_name} schema is up to date (version {CURRENT_SCHEMA_VERSION})") | |
except Exception as e: | |
print(f"Error checking schema: {e}") | |
cls._create_collection(client, collection_name) |