File size: 5,656 Bytes
ac49be7
87285ff
 
ac49be7
 
 
 
 
87285ff
 
 
 
 
 
 
ac49be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c973dd
ac49be7
 
 
87285ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Azure AI Search: https://python.langchain.com/docs/integrations/vectorstores/azuresearch
import os

# Azure AI Search imports
from langchain_community.vectorstores.azuresearch import AzureSearch


# Load environment variables
try:
    from dotenv import load_dotenv
    load_dotenv()
except:
    pass


class AzureSearchWrapper:
    """
    Wrapper class for Azure AI Search vectorstore to handle filter conversion.
    
    This wrapper automatically converts dictionary-style filters to Azure Search OData filter format,
    ensuring seamless compatibility when switching from other providers.
    """
    
    def __init__(self, azure_search_vectorstore):
        self.vectorstore = azure_search_vectorstore
        
    def __getattr__(self, name):
        """Delegate all other attributes to the wrapped vectorstore."""
        return getattr(self.vectorstore, name)
    
    def _convert_dict_filter_to_odata(self, filter_dict):
        """
        Convert dictionary-style filters to Azure Search OData filter format.
        
        Args:
            filter_dict (dict): Dictionary-style filter
            
        Returns:
            str: OData filter string
        """
        if not filter_dict:
            return None
            
        conditions = []
        
        for key, value in filter_dict.items():
            if key.endswith('_exclude'):
                # Handle exclusion filters (e.g., report_type_exclude)
                base_key = key.replace('_exclude', '')
                if isinstance(value, list):
                    if len(value) == 1:
                        conditions.append(f"{base_key} ne '{value[0]}'")
                    else:
                        exclude_conditions = [f"{base_key} ne '{v}'" for v in value]
                        conditions.append(f"({' and '.join(exclude_conditions)})")
                else:
                    conditions.append(f"{base_key} ne '{value}'")
            elif isinstance(value, list):
                # Handle list values (equivalent to $in operator)
                if len(value) == 1:
                    conditions.append(f"{key} eq '{value[0]}'")
                else:
                    list_conditions = [f"{key} eq '{v}'" for v in value]
                    conditions.append(f"({' or '.join(list_conditions)})")
            else:
                # Handle single values
                conditions.append(f"{key} eq '{value}'")
        
        return " and ".join(conditions) if conditions else None
    
    def similarity_search_with_score(self, query, k=4, filter=None, **kwargs):
        """Override similarity_search_with_score to convert filters."""
        if filter is not None:
            filter = self._convert_dict_filter_to_odata(filter)

        return self.vectorstore.hybrid_search_with_score(
            query=query, k=k, filters=filter, **kwargs
        )

    
    def similarity_search(self, query, k=4, filter=None, **kwargs):
        """Override similarity_search to convert filters."""
        if filter is not None:
            filter = self._convert_dict_filter_to_odata(filter)
        
        return self.vectorstore.similarity_search(
            query=query, k=k, filter=filter, **kwargs
        )
    
    def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs):
        """Override similarity_search_by_vector to convert filters."""
        if filter is not None:
            filter = self._convert_dict_filter_to_odata(filter)
        
        return self.vectorstore.similarity_search_by_vector(
            embedding=embedding, k=k, filter=filter, **kwargs
        )
    
    def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
        """Override as_retriever to handle filter conversion in search_kwargs."""
        if search_kwargs and "filter" in search_kwargs:
            # Convert the filter in search_kwargs
            search_kwargs = search_kwargs.copy()  # Don't modify the original
            if search_kwargs["filter"] is not None:
                search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"])
        
        return self.vectorstore.as_retriever(
            search_type=search_type, search_kwargs=search_kwargs, **kwargs
        )


def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None):
    """
    Create an Azure AI Search vectorstore instance.
    
    Args:
        embeddings: The embeddings function to use
        text_key: The key for text content in the payload (default: "content")
        index_name: The name of the Azure Search index
    
    Returns:
        AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility
    """
    # Get Azure AI Search configuration from environment variables
    azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT")
    azure_search_key = os.getenv("AI_SEARCH_KEY")
    
    if not azure_search_endpoint:
        raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required")
    
    if not azure_search_key:
        raise ValueError("AI_SEARCH_KEY environment variable is required")
    
    if not index_name:
        raise ValueError("index_name must be provided for Azure Search")
    
    # Create Azure Search vectorstore
    vectorstore = AzureSearch(
        azure_search_endpoint=azure_search_endpoint,
        azure_search_key=azure_search_key,
        index_name=index_name,
        embedding_function=embeddings.embed_query,
        content_key=text_key,
    )
    
    # Wrap the vectorstore to handle filter conversion
    return AzureSearchWrapper(vectorstore)