Spaces:
Runtime error
Runtime error
File size: 7,067 Bytes
63deadc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from typing import Any, Dict, List, Optional, Type
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
from langchain_core.vectorstores import VectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA
def _get_default_text_splitter() -> TextSplitter:
return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
class VectorStoreIndexWrapper(BaseModel):
"""Wrapper around a vectorstore for easy access."""
vectorstore: VectorStore
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def query(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Query the vectorstore."""
if llm is None:
raise NotImplementedError(
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQA.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return chain.invoke({chain.input_key: question})[chain.output_key]
async def aquery(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Query the vectorstore."""
if llm is None:
raise NotImplementedError(
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQA.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return (await chain.ainvoke({chain.input_key: question}))[chain.output_key]
def query_with_sources(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Query the vectorstore and get back sources."""
if llm is None:
raise NotImplementedError(
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return chain.invoke({chain.question_key: question})
async def aquery_with_sources(
self,
question: str,
llm: Optional[BaseLanguageModel] = None,
retriever_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> dict:
"""Query the vectorstore and get back sources."""
if llm is None:
raise NotImplementedError(
"This API has been changed to require an LLM. "
"Please provide an llm to use for querying the vectorstore.\n"
"For example,\n"
"from langchain_openai import OpenAI\n"
"llm = OpenAI(temperature=0)"
)
retriever_kwargs = retriever_kwargs or {}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm, retriever=self.vectorstore.as_retriever(**retriever_kwargs), **kwargs
)
return await chain.ainvoke({chain.question_key: question})
def _get_in_memory_vectorstore() -> Type[VectorStore]:
"""Get the InMemoryVectorStore."""
import warnings
try:
from langchain_community.vectorstores.inmemory import InMemoryVectorStore
except ImportError:
raise ImportError(
"Please install langchain-community to use the InMemoryVectorStore."
)
warnings.warn(
"Using InMemoryVectorStore as the default vectorstore."
"This memory store won't persist data. You should explicitly"
"specify a vectorstore when using VectorstoreIndexCreator"
)
return InMemoryVectorStore
class VectorstoreIndexCreator(BaseModel):
"""Logic for creating indexes."""
vectorstore_cls: Type[VectorStore] = Field(
default_factory=_get_in_memory_vectorstore
)
embedding: Embeddings
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
vectorstore_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def from_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from loaders."""
docs = []
for loader in loaders:
docs.extend(loader.load())
return self.from_documents(docs)
async def afrom_loaders(self, loaders: List[BaseLoader]) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from loaders."""
docs = []
for loader in loaders:
async for doc in loader.alazy_load():
docs.append(doc)
return await self.afrom_documents(docs)
def from_documents(self, documents: List[Document]) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from documents."""
sub_docs = self.text_splitter.split_documents(documents)
vectorstore = self.vectorstore_cls.from_documents(
sub_docs, self.embedding, **self.vectorstore_kwargs
)
return VectorStoreIndexWrapper(vectorstore=vectorstore)
async def afrom_documents(
self, documents: List[Document]
) -> VectorStoreIndexWrapper:
"""Create a vectorstore index from documents."""
sub_docs = self.text_splitter.split_documents(documents)
vectorstore = await self.vectorstore_cls.afrom_documents(
sub_docs, self.embedding, **self.vectorstore_kwargs
)
return VectorStoreIndexWrapper(vectorstore=vectorstore)
|