Spaces:
Running
Running
import os | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import SentenceTransformerEmbeddings | |
from langchain.docstore.document import Document | |
import uuid | |
class VectorStore: | |
def __init__(self): | |
# Initialize embedding function | |
self.embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# Initialize ChromaDB client | |
self.client = chromadb.PersistentClient(path="./chroma_db") | |
# Create or get collection | |
self.collection = self.client.get_or_create_collection( | |
name="research_documents", | |
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
) | |
# Initialize LangChain vector store | |
self.vector_store = Chroma( | |
collection_name="research_documents", | |
embedding_function=self.embedding_function, | |
persist_directory="./chroma_db" | |
) | |
# Initialize text splitter | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len, | |
) | |
def add_documents(self, documents): | |
"""Add documents to the vector store""" | |
try: | |
# Split documents into chunks | |
split_docs = [] | |
for doc in documents: | |
splits = self.text_splitter.split_text(doc.page_content) | |
for i, split in enumerate(splits): | |
split_docs.append(Document( | |
page_content=split, | |
metadata={**doc.metadata, "chunk": i} | |
)) | |
# Add to vector store | |
ids = [str(uuid.uuid4()) for _ in split_docs] | |
self.vector_store.add_documents(split_docs, ids=ids) | |
return {"status": "success", "count": len(split_docs)} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |
def search(self, query, k=5): | |
"""Search for relevant documents""" | |
try: | |
# Perform similarity search | |
docs = self.vector_store.similarity_search(query, k=k) | |
return {"status": "success", "documents": docs} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |
def delete_collection(self): | |
"""Delete the entire collection""" | |
try: | |
self.client.delete_collection("research_documents") | |
self.collection = self.client.get_or_create_collection( | |
name="research_documents", | |
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
) | |
return {"status": "success"} | |
except Exception as e: | |
return {"status": "error", "message": str(e)} | |