timeki's picture
switch_vectorstore_to_azure_ai_search (#30)
ac49be7 verified
# Azure AI Search: https://python.langchain.com/docs/integrations/vectorstores/azuresearch
import os
# Azure AI Search imports
from langchain_community.vectorstores.azuresearch import AzureSearch
# Load environment variables
try:
from dotenv import load_dotenv
load_dotenv()
except:
pass
class AzureSearchWrapper:
"""
Wrapper class for Azure AI Search vectorstore to handle filter conversion.
This wrapper automatically converts dictionary-style filters to Azure Search OData filter format,
ensuring seamless compatibility when switching from other providers.
"""
def __init__(self, azure_search_vectorstore):
self.vectorstore = azure_search_vectorstore
def __getattr__(self, name):
"""Delegate all other attributes to the wrapped vectorstore."""
return getattr(self.vectorstore, name)
def _convert_dict_filter_to_odata(self, filter_dict):
"""
Convert dictionary-style filters to Azure Search OData filter format.
Args:
filter_dict (dict): Dictionary-style filter
Returns:
str: OData filter string
"""
if not filter_dict:
return None
conditions = []
for key, value in filter_dict.items():
if key.endswith('_exclude'):
# Handle exclusion filters (e.g., report_type_exclude)
base_key = key.replace('_exclude', '')
if isinstance(value, list):
if len(value) == 1:
conditions.append(f"{base_key} ne '{value[0]}'")
else:
exclude_conditions = [f"{base_key} ne '{v}'" for v in value]
conditions.append(f"({' and '.join(exclude_conditions)})")
else:
conditions.append(f"{base_key} ne '{value}'")
elif isinstance(value, list):
# Handle list values (equivalent to $in operator)
if len(value) == 1:
conditions.append(f"{key} eq '{value[0]}'")
else:
list_conditions = [f"{key} eq '{v}'" for v in value]
conditions.append(f"({' or '.join(list_conditions)})")
else:
# Handle single values
conditions.append(f"{key} eq '{value}'")
return " and ".join(conditions) if conditions else None
def similarity_search_with_score(self, query, k=4, filter=None, **kwargs):
"""Override similarity_search_with_score to convert filters."""
if filter is not None:
filter = self._convert_dict_filter_to_odata(filter)
return self.vectorstore.hybrid_search_with_score(
query=query, k=k, filters=filter, **kwargs
)
def similarity_search(self, query, k=4, filter=None, **kwargs):
"""Override similarity_search to convert filters."""
if filter is not None:
filter = self._convert_dict_filter_to_odata(filter)
return self.vectorstore.similarity_search(
query=query, k=k, filter=filter, **kwargs
)
def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs):
"""Override similarity_search_by_vector to convert filters."""
if filter is not None:
filter = self._convert_dict_filter_to_odata(filter)
return self.vectorstore.similarity_search_by_vector(
embedding=embedding, k=k, filter=filter, **kwargs
)
def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
"""Override as_retriever to handle filter conversion in search_kwargs."""
if search_kwargs and "filter" in search_kwargs:
# Convert the filter in search_kwargs
search_kwargs = search_kwargs.copy() # Don't modify the original
if search_kwargs["filter"] is not None:
search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"])
return self.vectorstore.as_retriever(
search_type=search_type, search_kwargs=search_kwargs, **kwargs
)
def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None):
"""
Create an Azure AI Search vectorstore instance.
Args:
embeddings: The embeddings function to use
text_key: The key for text content in the payload (default: "content")
index_name: The name of the Azure Search index
Returns:
AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility
"""
# Get Azure AI Search configuration from environment variables
azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT")
azure_search_key = os.getenv("AI_SEARCH_KEY")
if not azure_search_endpoint:
raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required")
if not azure_search_key:
raise ValueError("AI_SEARCH_KEY environment variable is required")
if not index_name:
raise ValueError("index_name must be provided for Azure Search")
# Create Azure Search vectorstore
vectorstore = AzureSearch(
azure_search_endpoint=azure_search_endpoint,
azure_search_key=azure_search_key,
index_name=index_name,
embedding_function=embeddings.embed_query,
content_key=text_key,
)
# Wrap the vectorstore to handle filter conversion
return AzureSearchWrapper(vectorstore)