File size: 5,947 Bytes
60444f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from qdrant_client import QdrantClient
from qdrant_client.http import models
from pathlib import Path
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

CURRENT_SCHEMA_VERSION = "1.2"  # Increment this when schema changes
VECTOR_SIZE = 512  # CLIP embedding size

class QdrantClientSingleton:
    _instance = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            # Check if we have cloud credentials
            qdrant_url = os.getenv('QDRANT_URL')
            qdrant_api_key = os.getenv('QDRANT_API_KEY')
            
            print(f"QDRANT_URL: {qdrant_url}")
            print(f"QDRANT_API_KEY: {'***' + qdrant_api_key[-10:] if qdrant_api_key else 'None'}")
            
            if qdrant_url and qdrant_api_key:
                print(f"Initializing Qdrant Cloud client: {qdrant_url}")
                try:
                    cls._instance = QdrantClient(
                        url=qdrant_url,
                        api_key=qdrant_api_key,
                    )
                    print("Successfully connected to Qdrant Cloud")
                except Exception as e:
                    print(f"Failed to connect to Qdrant Cloud: {e}")
                    print("Falling back to local storage")
                    storage_path = Path("qdrant_data").absolute()
                    storage_path.mkdir(exist_ok=True)
                    cls._instance = QdrantClient(path=str(storage_path))
            else:
                # Fallback to local storage
                print("Cloud credentials not found, using local Qdrant storage")
                storage_path = Path("qdrant_data").absolute()
                storage_path.mkdir(exist_ok=True)
                cls._instance = QdrantClient(path=str(storage_path))
            
            # Print collections for debugging
            try:
                collections = cls._instance.get_collections().collections
                print(f"Available collections: {[col.name for col in collections]}")
            except Exception as e:
                print(f"Error getting collections: {e}")
        
        return cls._instance

    @classmethod
    def initialize_collection(cls, collection_name: str):
        client = cls.get_instance()
        
        # Check if collection exists
        collections = client.get_collections().collections
        exists = any(collection.name == collection_name for collection in collections)
        
        if not exists:
            # Create new collection with current schema version
            cls._create_collection(client, collection_name)
        else:
            # Check schema version and update if necessary
            cls._check_and_update_schema(client, collection_name)
    
    @classmethod
    def _create_collection(cls, client: QdrantClient, collection_name: str):
        """Create a new collection with the current schema version"""
        # First create the collection with basic config
        client.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(
                size=VECTOR_SIZE,
                distance=models.Distance.COSINE
            ),
            on_disk_payload=True,  # Store vectors on disk
            optimizers_config=models.OptimizersConfigDiff(
                indexing_threshold=0  # Index immediately
            )
        )
        
        # Then create payload indexes for efficient searching
        client.create_payload_index(
            collection_name=collection_name,
            field_name="image_id",
            field_schema=models.PayloadSchemaType.KEYWORD
        )
        
        client.create_payload_index(
            collection_name=collection_name,
            field_name="path",
            field_schema=models.PayloadSchemaType.KEYWORD
        )
        
        client.create_payload_index(
            collection_name=collection_name,
            field_name="root_folder",
            field_schema=models.PayloadSchemaType.KEYWORD
        )
        
        client.create_payload_index(
            collection_name=collection_name,
            field_name="schema_version",
            field_schema=models.PayloadSchemaType.KEYWORD
        )
        
        client.create_payload_index(
            collection_name=collection_name,
            field_name="indexed_at",
            field_schema=models.PayloadSchemaType.INTEGER
        )
        
        print(f"Created collection {collection_name} with schema version {CURRENT_SCHEMA_VERSION}")
    
    @classmethod
    def _check_and_update_schema(cls, client: QdrantClient, collection_name: str):
        """Check collection schema version and update if necessary"""
        try:
            # Get a sample point to check schema version
            sample = client.scroll(
                collection_name=collection_name,
                limit=1,
                with_payload=True
            )[0]
            
            if not sample:
                print(f"Collection {collection_name} is empty")
                return
            
            # Check schema version of existing data
            point_version = sample[0].payload.get("schema_version", "0.0")
            if point_version != CURRENT_SCHEMA_VERSION:
                print(f"Schema version mismatch: {point_version} != {CURRENT_SCHEMA_VERSION}")
                print(f"Collection {collection_name} needs to be recreated")
                
                # Recreate collection with new schema
                client.delete_collection(collection_name=collection_name)
                cls._create_collection(client, collection_name)
            else:
                print(f"Collection {collection_name} schema is up to date (version {CURRENT_SCHEMA_VERSION})")
        except Exception as e:
            print(f"Error checking schema: {e}")
            cls._create_collection(client, collection_name)