|
|
|
import os |
|
|
|
|
|
from langchain_community.vectorstores.azuresearch import AzureSearch |
|
|
|
|
|
|
|
try: |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
except: |
|
pass |
|
|
|
|
|
class AzureSearchWrapper: |
|
""" |
|
Wrapper class for Azure AI Search vectorstore to handle filter conversion. |
|
|
|
This wrapper automatically converts dictionary-style filters to Azure Search OData filter format, |
|
ensuring seamless compatibility when switching from other providers. |
|
""" |
|
|
|
def __init__(self, azure_search_vectorstore): |
|
self.vectorstore = azure_search_vectorstore |
|
|
|
def __getattr__(self, name): |
|
"""Delegate all other attributes to the wrapped vectorstore.""" |
|
return getattr(self.vectorstore, name) |
|
|
|
def _convert_dict_filter_to_odata(self, filter_dict): |
|
""" |
|
Convert dictionary-style filters to Azure Search OData filter format. |
|
|
|
Args: |
|
filter_dict (dict): Dictionary-style filter |
|
|
|
Returns: |
|
str: OData filter string |
|
""" |
|
if not filter_dict: |
|
return None |
|
|
|
conditions = [] |
|
|
|
for key, value in filter_dict.items(): |
|
if key.endswith('_exclude'): |
|
|
|
base_key = key.replace('_exclude', '') |
|
if isinstance(value, list): |
|
if len(value) == 1: |
|
conditions.append(f"{base_key} ne '{value[0]}'") |
|
else: |
|
exclude_conditions = [f"{base_key} ne '{v}'" for v in value] |
|
conditions.append(f"({' and '.join(exclude_conditions)})") |
|
else: |
|
conditions.append(f"{base_key} ne '{value}'") |
|
elif isinstance(value, list): |
|
|
|
if len(value) == 1: |
|
conditions.append(f"{key} eq '{value[0]}'") |
|
else: |
|
list_conditions = [f"{key} eq '{v}'" for v in value] |
|
conditions.append(f"({' or '.join(list_conditions)})") |
|
else: |
|
|
|
conditions.append(f"{key} eq '{value}'") |
|
|
|
return " and ".join(conditions) if conditions else None |
|
|
|
def similarity_search_with_score(self, query, k=4, filter=None, **kwargs): |
|
"""Override similarity_search_with_score to convert filters.""" |
|
if filter is not None: |
|
filter = self._convert_dict_filter_to_odata(filter) |
|
|
|
return self.vectorstore.hybrid_search_with_score( |
|
query=query, k=k, filters=filter, **kwargs |
|
) |
|
|
|
|
|
def similarity_search(self, query, k=4, filter=None, **kwargs): |
|
"""Override similarity_search to convert filters.""" |
|
if filter is not None: |
|
filter = self._convert_dict_filter_to_odata(filter) |
|
|
|
return self.vectorstore.similarity_search( |
|
query=query, k=k, filter=filter, **kwargs |
|
) |
|
|
|
def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs): |
|
"""Override similarity_search_by_vector to convert filters.""" |
|
if filter is not None: |
|
filter = self._convert_dict_filter_to_odata(filter) |
|
|
|
return self.vectorstore.similarity_search_by_vector( |
|
embedding=embedding, k=k, filter=filter, **kwargs |
|
) |
|
|
|
def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs): |
|
"""Override as_retriever to handle filter conversion in search_kwargs.""" |
|
if search_kwargs and "filter" in search_kwargs: |
|
|
|
search_kwargs = search_kwargs.copy() |
|
if search_kwargs["filter"] is not None: |
|
search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"]) |
|
|
|
return self.vectorstore.as_retriever( |
|
search_type=search_type, search_kwargs=search_kwargs, **kwargs |
|
) |
|
|
|
|
|
def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None): |
|
""" |
|
Create an Azure AI Search vectorstore instance. |
|
|
|
Args: |
|
embeddings: The embeddings function to use |
|
text_key: The key for text content in the payload (default: "content") |
|
index_name: The name of the Azure Search index |
|
|
|
Returns: |
|
AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility |
|
""" |
|
|
|
azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT") |
|
azure_search_key = os.getenv("AI_SEARCH_KEY") |
|
|
|
if not azure_search_endpoint: |
|
raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required") |
|
|
|
if not azure_search_key: |
|
raise ValueError("AI_SEARCH_KEY environment variable is required") |
|
|
|
if not index_name: |
|
raise ValueError("index_name must be provided for Azure Search") |
|
|
|
|
|
vectorstore = AzureSearch( |
|
azure_search_endpoint=azure_search_endpoint, |
|
azure_search_key=azure_search_key, |
|
index_name=index_name, |
|
embedding_function=embeddings.embed_query, |
|
content_key=text_key, |
|
) |
|
|
|
|
|
return AzureSearchWrapper(vectorstore) |
|
|
|
|