myspace134v / modules /rag /vector_store.py
rdune71's picture
Add RAG capability with document upload and management
bb60cf1
import os
import chromadb
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.docstore.document import Document
import uuid
class VectorStore:
def __init__(self):
# Initialize embedding function
self.embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# Initialize ChromaDB client
self.client = chromadb.PersistentClient(path="./chroma_db")
# Create or get collection
self.collection = self.client.get_or_create_collection(
name="research_documents",
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
)
# Initialize LangChain vector store
self.vector_store = Chroma(
collection_name="research_documents",
embedding_function=self.embedding_function,
persist_directory="./chroma_db"
)
# Initialize text splitter
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
def add_documents(self, documents):
"""Add documents to the vector store"""
try:
# Split documents into chunks
split_docs = []
for doc in documents:
splits = self.text_splitter.split_text(doc.page_content)
for i, split in enumerate(splits):
split_docs.append(Document(
page_content=split,
metadata={**doc.metadata, "chunk": i}
))
# Add to vector store
ids = [str(uuid.uuid4()) for _ in split_docs]
self.vector_store.add_documents(split_docs, ids=ids)
return {"status": "success", "count": len(split_docs)}
except Exception as e:
return {"status": "error", "message": str(e)}
def search(self, query, k=5):
"""Search for relevant documents"""
try:
# Perform similarity search
docs = self.vector_store.similarity_search(query, k=k)
return {"status": "success", "documents": docs}
except Exception as e:
return {"status": "error", "message": str(e)}
def delete_collection(self):
"""Delete the entire collection"""
try:
self.client.delete_collection("research_documents")
self.collection = self.client.get_or_create_collection(
name="research_documents",
embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
)
return {"status": "success"}
except Exception as e:
return {"status": "error", "message": str(e)}