Spaces:
Running
on
Zero
Running
on
Zero
# Main retriever modules | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import TextLoader | |
from langchain_chroma import Chroma | |
from langchain.retrievers import ParentDocumentRetriever, EnsembleRetriever | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever, RetrieverLike | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from typing import Any, Optional | |
import chromadb | |
import torch | |
import os | |
import re | |
# To use OpenAI models (remote) | |
from langchain_openai import OpenAIEmbeddings | |
## To use Hugging Face models (local) | |
# from langchain_huggingface import HuggingFaceEmbeddings | |
# For more control over BGE and Nomic embeddings | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
# Local modules | |
from mods.bm25s_retriever import BM25SRetriever | |
from mods.file_system import LocalFileStore | |
# Database directory | |
db_dir = "db" | |
# Embedding model | |
embedding_model_id = "nomic-ai/nomic-embed-text-v1.5" | |
def BuildRetriever( | |
compute_mode, | |
search_type: str = "hybrid", | |
top_k=6, | |
start_year=None, | |
end_year=None, | |
embedding_ckpt_dir=None, | |
): | |
""" | |
Build retriever instance. | |
All retriever types are configured to return up to 6 documents for fair comparison in evals. | |
Args: | |
compute_mode: Compute mode for embeddings (remote or local) | |
search_type: Type of search to use. Options: "dense", "sparse", "hybrid" | |
top_k: Number of documents to retrieve for "dense" and "sparse" | |
start_year: Start year (optional) | |
end_year: End year (optional) | |
embedding_ckpt_dir: Directory for embedding model checkpoint | |
""" | |
if search_type == "dense": | |
if not (start_year or end_year): | |
# No year filtering, so directly use base retriever | |
return BuildRetrieverDense( | |
compute_mode, top_k=top_k, embedding_ckpt_dir=embedding_ckpt_dir | |
) | |
else: | |
# Get 1000 documents then keep top_k filtered by year | |
base_retriever = BuildRetrieverDense( | |
compute_mode, top_k=1000, embedding_ckpt_dir=embedding_ckpt_dir | |
) | |
return TopKRetriever( | |
base_retriever=base_retriever, | |
top_k=top_k, | |
start_year=start_year, | |
end_year=end_year, | |
) | |
if search_type == "sparse": | |
if not (start_year or end_year): | |
return BuildRetrieverSparse(top_k=top_k) | |
else: | |
base_retriever = BuildRetrieverSparse(top_k=1000) | |
return TopKRetriever( | |
base_retriever=base_retriever, | |
top_k=top_k, | |
start_year=start_year, | |
end_year=end_year, | |
) | |
elif search_type == "hybrid": | |
# Hybrid search (dense + sparse) - use ensemble method | |
# https://python.langchain.com/api_reference/langchain/retrievers/langchain.retrievers.ensemble.EnsembleRetriever.html | |
# Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k | |
# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python | |
dense_retriever = BuildRetriever( | |
compute_mode, | |
"dense", | |
(top_k // 2), | |
start_year, | |
end_year, | |
embedding_ckpt_dir, | |
) | |
sparse_retriever = BuildRetriever( | |
compute_mode, | |
"sparse", | |
-(top_k // -2), | |
start_year, | |
end_year, | |
embedding_ckpt_dir, | |
) | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[dense_retriever, sparse_retriever], weights=[1, 1] | |
) | |
return ensemble_retriever | |
else: | |
raise ValueError(f"Unsupported search type: {search_type}") | |
def BuildRetrieverSparse(top_k=6): | |
""" | |
Build sparse retriever instance | |
Args: | |
top_k: Number of documents to retrieve | |
""" | |
# BM25 persistent directory | |
bm25_persist_directory = f"{db_dir}/bm25" | |
if not os.path.exists(bm25_persist_directory): | |
os.makedirs(bm25_persist_directory) | |
# Use BM25 sparse search | |
retriever = BM25SRetriever.from_persisted_directory( | |
path=bm25_persist_directory, | |
k=top_k, | |
) | |
return retriever | |
def BuildRetrieverDense(compute_mode: str, top_k=6, embedding_ckpt_dir=None): | |
""" | |
Build dense retriever instance with ChromaDB vectorstore | |
Args: | |
compute_mode: Compute mode for embeddings (remote or local) | |
top_k: Number of documents to retrieve | |
embedding_ckpt_dir: Directory for embedding model checkpoint | |
""" | |
# Don't try to use local models without a GPU | |
if compute_mode == "local" and not torch.cuda.is_available(): | |
raise Exception("Local embeddings selected without GPU") | |
# Define embedding model | |
if compute_mode == "remote": | |
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small") | |
if compute_mode == "local": | |
# embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5", show_progress=True) | |
# https://python.langchain.com/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceBgeEmbeddings.html | |
model_kwargs = { | |
"device": "cuda", | |
"trust_remote_code": True, | |
} | |
encode_kwargs = {"normalize_embeddings": True} | |
# Use embedding model ID or checkpoint directory if given | |
id_or_dir = embedding_ckpt_dir if embedding_ckpt_dir else embedding_model_id | |
embedding_function = HuggingFaceBgeEmbeddings( | |
model_name=id_or_dir, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
query_instruction="search_query:", | |
embed_instruction="search_document:", | |
) | |
# Create vector store | |
client_settings = chromadb.config.Settings(anonymized_telemetry=False) | |
persist_directory = f"{db_dir}/chroma_{compute_mode}" | |
vectorstore = Chroma( | |
collection_name="R-help", | |
embedding_function=embedding_function, | |
client_settings=client_settings, | |
persist_directory=persist_directory, | |
) | |
# The storage layer for the parent documents | |
file_store = f"{db_dir}/file_store_{compute_mode}" | |
byte_store = LocalFileStore(file_store) | |
# Text splitter for child documents | |
child_splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n", "\n", ".", " ", ""], | |
chunk_size=1000, | |
chunk_overlap=100, | |
) | |
# Text splitter for parent documents | |
parent_splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n\nFrom"], chunk_size=1, chunk_overlap=0 | |
) | |
# Instantiate a retriever | |
retriever = ParentDocumentRetriever( | |
vectorstore=vectorstore, | |
# NOTE: https://github.com/langchain-ai/langchain/issues/9345 | |
# Define byte_store = LocalFileStore(file_store) and use byte_store instead of docstore in ParentDocumentRetriever | |
byte_store=byte_store, | |
child_splitter=child_splitter, | |
parent_splitter=parent_splitter, | |
# Get top k documents | |
search_kwargs={"k": top_k}, | |
) | |
return retriever | |
class TopKRetriever(BaseRetriever): | |
"""Retriever that wraps a base retriever and returns the top k documents, optionally matching given start and/or end years.""" | |
# Code adapted from langchain/retrievers/contextual_compression.py | |
base_retriever: RetrieverLike | |
"""Base Retriever to use for getting relevant documents.""" | |
top_k: int = 6 | |
"""Number of documents to return.""" | |
start_year: Optional[int] = None | |
end_year: Optional[int] = None | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
**kwargs: Any, | |
) -> list[Document]: | |
"""Return the top k documents within start and end years if given. | |
Returns: | |
Sequence of documents | |
""" | |
# Run the search with the base retriever | |
filtered_docs = retrieved_docs = self.base_retriever.invoke( | |
query, config={"callbacks": run_manager.get_child()}, **kwargs | |
) | |
if retrieved_docs: | |
# Get the sources (file names) and years | |
sources = [doc.metadata["source"] for doc in filtered_docs] | |
years = [ | |
re.sub(r"-[A-Za-z]+\.txt", "", source.replace("R-help/", "")) | |
for source in sources | |
] | |
# Convert years to integer | |
years = [int(year) for year in years] | |
# Filtering by year | |
if self.start_year: | |
in_range = after_start = [year >= self.start_year for year in years] | |
if self.end_year: | |
in_range = before_end = [year <= self.end_year for year in years] | |
if self.start_year and self.end_year: | |
in_range = [ | |
after and before for after, before in zip(after_start, before_end) | |
] | |
if self.start_year or self.end_year: | |
# Extract docs where the year is in the start-end range | |
filtered_docs = [ | |
doc for doc, in_range in zip(retrieved_docs, in_range) if in_range | |
] | |
# Return the top k docs | |
return filtered_docs[: self.top_k] | |
else: | |
return [] | |