Spaces:
Running
on
Zero
Running
on
Zero
from langchain_text_splitters import TextSplitter, RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import TextLoader | |
from datetime import datetime | |
import tempfile | |
import os | |
# Local modules | |
from retriever import BuildRetriever, db_dir | |
from mods.bm25s_retriever import BM25SRetriever | |
def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remote"): | |
""" | |
Wrapper function to process file for dense or sparse search | |
Args: | |
file_path: File to process | |
search_type: Type of search to use. Options: "dense", "sparse" | |
compute_mode: Compute mode for embeddings (remote or local) | |
""" | |
# Preprocess: remove quoted lines and handle email boundaries | |
temp_fd, cleaned_temp_file = tempfile.mkstemp(suffix=".txt", prefix="preproc_") | |
with open(file_path, "r", encoding="utf-8", errors="ignore") as infile, open( | |
cleaned_temp_file, "w", encoding="utf-8" | |
) as outfile: | |
for line in infile: | |
# Remove lines that start with '>' or whitespace before '>' | |
if line.lstrip().startswith(">"): | |
continue | |
outfile.write(line) | |
try: | |
os.close(temp_fd) | |
except Exception: | |
pass | |
# Truncate email line number and length to avoid error in openai/_base_client.py: | |
# BadRequestError: Error code: 400 - 'message': 'Requested 312872 tokens, max 300000 tokens per request', 'type': 'max_tokens_per_request' | |
temp_fd2, truncated_temp_file = tempfile.mkstemp(suffix=".txt", prefix="truncated_") | |
with open(cleaned_temp_file, "r", encoding="utf-8") as infile: | |
content = infile.read() | |
# Split into emails using '\n\n\nFrom' as the separator | |
emails = content.split("\n\n\nFrom") | |
processed_emails = [] | |
for i, email in enumerate(emails): | |
lines = email.splitlines() | |
# Truncate each line to 1000 characters and each email to 200 lines | |
# NOTE: 1000 characters is reasonable for a long non-word-wrapped paragraph | |
truncated_lines = [line[:1000] for line in lines[:200]] | |
# Add [Email truncated] line to truncated emails | |
if len(lines) > len(truncated_lines): | |
truncated_lines.append("[Email truncated]") | |
processed_emails.append("\n".join(truncated_lines)) | |
# Join emails back together with '\n\n\nFrom' | |
result = "\n\n\nFrom".join(processed_emails) | |
# Add two blank lines to the first email so all emails have the same formatting | |
# (needed for removing prepended source file names in evals) | |
result = "\n\n" + result | |
with open(truncated_temp_file, "w", encoding="utf-8") as outfile: | |
outfile.write(result) | |
try: | |
os.close(temp_fd2) | |
except Exception: | |
pass | |
try: | |
if search_type == "sparse": | |
# Handle sparse search with BM25 | |
ProcessFileSparse(truncated_temp_file, file_path) | |
elif search_type == "dense": | |
# Handle dense search with ChromaDB | |
ProcessFileDense(truncated_temp_file, file_path, compute_mode) | |
else: | |
raise ValueError(f"Unsupported search type: {search_type}") | |
finally: | |
# Clean up the temporary files | |
try: | |
os.remove(cleaned_temp_file) | |
os.remove(truncated_temp_file) | |
except Exception: | |
pass | |
def ProcessFileDense(cleaned_temp_file, file_path, compute_mode): | |
""" | |
Process file for dense vector search using ChromaDB | |
""" | |
# Get a retriever instance | |
retriever = BuildRetriever(compute_mode, "dense") | |
# Load cleaned text file | |
loader = TextLoader(cleaned_temp_file) | |
documents = loader.load() | |
# Use original file path for "source" key in metadata | |
documents[0].metadata["source"] = file_path | |
# Add file timestamp to metadata | |
mod_time = os.path.getmtime(file_path) | |
timestamp = datetime.fromtimestamp(mod_time).isoformat() | |
documents[0].metadata["timestamp"] = timestamp | |
## Add documents to vectorstore | |
# retriever.add_documents(documents) | |
# Split the document into batches for addition to ChromaDB | |
# https://github.com/chroma-core/chroma/issues/1049 | |
# https://cookbook.chromadb.dev/strategies/batching | |
batch_size = 1000 | |
# Split emails | |
emails = documents[0].page_content.split("\n\n\nFrom") | |
documents_batch = documents | |
for i in range(0, len(emails), batch_size): | |
emails_batch = emails[i : i + batch_size] | |
# Join emails back together | |
page_content = "\n\n\nFrom".join(emails_batch) | |
documents_batch[0].page_content = page_content | |
# Add documents to vectorstore | |
retriever.add_documents(documents_batch) | |
def ProcessFileSparse(cleaned_temp_file, file_path): | |
""" | |
Process file for sparse search using BM25 | |
""" | |
# Load text file to document | |
loader = TextLoader(cleaned_temp_file) | |
documents = loader.load() | |
# Split archive file into emails for BM25 | |
# Using two blank lines followed by "From", and no limits on chunk size | |
splitter = RecursiveCharacterTextSplitter( | |
separators=["\n\n\nFrom"], chunk_size=1, chunk_overlap=0 | |
) | |
## Using 'EmailFrom' as the separator (requires preprocesing) | |
# splitter = RecursiveCharacterTextSplitter(separators=["EmailFrom"]) | |
emails = splitter.split_documents(documents) | |
# Use original file path for "source" key in metadata | |
for email in emails: | |
email.metadata["source"] = file_path | |
# Create or update BM25 index | |
try: | |
# Update BM25 index if it exists | |
bm25_persist_directory = f"{db_dir}/bm25" | |
retriever = BM25SRetriever.from_persisted_directory(bm25_persist_directory) | |
# Get new emails - ones which have not been indexed | |
new_emails = [email for email in emails if email not in retriever.docs] | |
if len(new_emails) > 0: | |
# Create new BM25 index with all emails | |
# NOTE: Adding new documents to an existing index is not possible: | |
# https://github.com/xhluca/bm25s/discussions/20 | |
all_emails = retriever.docs + new_emails | |
BM25SRetriever.from_documents( | |
documents=all_emails, | |
persist_directory=bm25_persist_directory, | |
) | |
print(f"BM25S: added {len(new_emails)} new emails from {file_path}") | |
else: | |
print(f"BM25S: no change for {file_path}") | |
except (FileNotFoundError, OSError): | |
# Create new BM25 index | |
BM25SRetriever.from_documents( | |
documents=emails, | |
persist_directory=bm25_persist_directory, | |
) | |
print(f"BM25S: started with {len(emails)} emails from {file_path}") | |