Spaces:
Runtime error
Runtime error
File size: 5,947 Bytes
60444f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from qdrant_client import QdrantClient
from qdrant_client.http import models
from pathlib import Path
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
CURRENT_SCHEMA_VERSION = "1.2" # Increment this when schema changes
VECTOR_SIZE = 512 # CLIP embedding size
class QdrantClientSingleton:
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
# Check if we have cloud credentials
qdrant_url = os.getenv('QDRANT_URL')
qdrant_api_key = os.getenv('QDRANT_API_KEY')
print(f"QDRANT_URL: {qdrant_url}")
print(f"QDRANT_API_KEY: {'***' + qdrant_api_key[-10:] if qdrant_api_key else 'None'}")
if qdrant_url and qdrant_api_key:
print(f"Initializing Qdrant Cloud client: {qdrant_url}")
try:
cls._instance = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
print("Successfully connected to Qdrant Cloud")
except Exception as e:
print(f"Failed to connect to Qdrant Cloud: {e}")
print("Falling back to local storage")
storage_path = Path("qdrant_data").absolute()
storage_path.mkdir(exist_ok=True)
cls._instance = QdrantClient(path=str(storage_path))
else:
# Fallback to local storage
print("Cloud credentials not found, using local Qdrant storage")
storage_path = Path("qdrant_data").absolute()
storage_path.mkdir(exist_ok=True)
cls._instance = QdrantClient(path=str(storage_path))
# Print collections for debugging
try:
collections = cls._instance.get_collections().collections
print(f"Available collections: {[col.name for col in collections]}")
except Exception as e:
print(f"Error getting collections: {e}")
return cls._instance
@classmethod
def initialize_collection(cls, collection_name: str):
client = cls.get_instance()
# Check if collection exists
collections = client.get_collections().collections
exists = any(collection.name == collection_name for collection in collections)
if not exists:
# Create new collection with current schema version
cls._create_collection(client, collection_name)
else:
# Check schema version and update if necessary
cls._check_and_update_schema(client, collection_name)
@classmethod
def _create_collection(cls, client: QdrantClient, collection_name: str):
"""Create a new collection with the current schema version"""
# First create the collection with basic config
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=VECTOR_SIZE,
distance=models.Distance.COSINE
),
on_disk_payload=True, # Store vectors on disk
optimizers_config=models.OptimizersConfigDiff(
indexing_threshold=0 # Index immediately
)
)
# Then create payload indexes for efficient searching
client.create_payload_index(
collection_name=collection_name,
field_name="image_id",
field_schema=models.PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name=collection_name,
field_name="path",
field_schema=models.PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name=collection_name,
field_name="root_folder",
field_schema=models.PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name=collection_name,
field_name="schema_version",
field_schema=models.PayloadSchemaType.KEYWORD
)
client.create_payload_index(
collection_name=collection_name,
field_name="indexed_at",
field_schema=models.PayloadSchemaType.INTEGER
)
print(f"Created collection {collection_name} with schema version {CURRENT_SCHEMA_VERSION}")
@classmethod
def _check_and_update_schema(cls, client: QdrantClient, collection_name: str):
"""Check collection schema version and update if necessary"""
try:
# Get a sample point to check schema version
sample = client.scroll(
collection_name=collection_name,
limit=1,
with_payload=True
)[0]
if not sample:
print(f"Collection {collection_name} is empty")
return
# Check schema version of existing data
point_version = sample[0].payload.get("schema_version", "0.0")
if point_version != CURRENT_SCHEMA_VERSION:
print(f"Schema version mismatch: {point_version} != {CURRENT_SCHEMA_VERSION}")
print(f"Collection {collection_name} needs to be recreated")
# Recreate collection with new schema
client.delete_collection(collection_name=collection_name)
cls._create_collection(client, collection_name)
else:
print(f"Collection {collection_name} schema is up to date (version {CURRENT_SCHEMA_VERSION})")
except Exception as e:
print(f"Error checking schema: {e}")
cls._create_collection(client, collection_name) |